LFO 1D example

For this example, we use the transcriptional regulation ODE system.

[1]:
import numpy as np
import torch

import torch.nn.functional as F
from torch.utils.data import DataLoader

from matplotlib import pyplot as plt

from alfi.datasets import ToyTranscriptomics, ReactionDiffusion, HomogeneousReactionDiffusion, P53Data
from alfi.models import NeuralOperator, NeuralLFM
from alfi.trainers import NeuralOperatorTrainer
from alfi.plot import Plotter1d, plot_spatiotemporal_data, tight_kwargs
from alfi.utilities.data import generate_neural_dataset_1d
[16]:
from torch.nn.functional import softplus
from alfi.utilities.data import context_target_split as cts
def show_result(model, loader, plotter=None):
    model.eval()
    x, y, params = next(iter(loader))
    num_outputs = params.shape[2]
    # x_context, y_context, _, _ = cts(x, y, x.shape[1], 0)
    # p_y_pred, params_out = model(x_context, y_context, x, y)
    p_y_pred, params_out = model(x)

    t = x[0, :, 0]
    fig, axes = plt.subplots(ncols=2, figsize=(8, 3))
    for i in range(1, num_outputs+1):
        axes[0].plot(t, x[0, :, i])#.shape, y.shape)
    axes[1].plot(t, y[0, :, 0], label='Target')

    mean = p_y_pred[0, :, 0].detach()
    std = softplus(p_y_pred[0, :, 1].detach()).sqrt()
    # mean = p_y_pred.mean[0].detach()
    # std = p_y_pred.variance[0].sqrt().detach()
    # axes[1].plot(t, mean, label='Prediction')
    # axes[1].fill_between(t, mean + std, mean - std, label='Target')
    axes[1].errorbar(t, mean, std, label='Prediction')
    axes[1].legend()
    print(params.shape)
    params = params.view(params.shape[0], -1)[0]
    params_out = params_out[0].detach()
    if plotter is None:
        plotter = Plotter1d(model, np.array(['']*num_outputs))
    plotter.plot_double_bar(params_out.view(3, num_outputs),
                            titles=['basal', 'sensitivity', 'decay'], ground_truths=params.view(3, 5))

[3]:
i = 3
width = 20
modes = 4
in_channels = 6
block_dim = 1
learning_rate = 1e-3
batch_size = 20

dataset = ToyTranscriptomics(data_dir='../../../data')
ntrain = len(dataset.train_data)
ntest = len(dataset.test_data)
train_loader = DataLoader(dataset.train_data, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset.test_data, batch_size=ntest, shuffle=True)
num_outputs = 10

print(ntrain, ntest)
print(train_loader.dataset[0][0].shape)
2000 10
torch.Size([12, 6])
[4]:
model = NeuralOperator(block_dim, in_channels, 2, modes, width)
r_dim = z_dim = 9
print(model.count_params())

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
trainer = NeuralOperatorTrainer(model, [optimizer], train_loader, test_loader)
46661
[5]:
trainer.train(50, report_interval=1);
Epoch 001/050 - Loss: 2.79 (2.22 1.19 0.98 0.16 0.27)
Epoch 002/050 - Loss: 1.86 (2.14 0.62 0.94 0.12 0.26)
Epoch 003/050 - Loss: 1.76 (2.02 0.53 0.88 0.12 0.26)
Epoch 004/050 - Loss: 1.73 (2.27 0.50 1.01 0.12 0.25)
Epoch 005/050 - Loss: 1.68 (2.73 0.46 1.25 0.12 0.24)
Epoch 006/050 - Loss: 1.67 (2.39 0.44 1.07 0.12 0.26)
Epoch 007/050 - Loss: 1.63 (2.61 0.40 1.17 0.12 0.26)
Epoch 008/050 - Loss: 1.62 (2.01 0.40 0.89 0.12 0.24)
Epoch 009/050 - Loss: 1.59 (2.34 0.38 1.04 0.12 0.26)
Epoch 010/050 - Loss: 1.56 (2.00 0.35 0.88 0.12 0.24)
Epoch 011/050 - Loss: 1.53 (2.12 0.32 0.94 0.12 0.25)
Epoch 012/050 - Loss: 1.53 (2.30 0.30 1.02 0.12 0.25)
Epoch 013/050 - Loss: 1.52 (1.73 0.31 0.75 0.12 0.24)
Epoch 014/050 - Loss: 1.49 (2.63 0.28 1.20 0.12 0.24)
Epoch 015/050 - Loss: 1.46 (2.41 0.26 1.08 0.12 0.25)
Epoch 016/050 - Loss: 1.46 (1.85 0.26 0.80 0.12 0.25)
Epoch 017/050 - Loss: 1.45 (1.95 0.25 0.85 0.12 0.25)
Epoch 018/050 - Loss: 1.42 (2.25 0.22 1.00 0.12 0.25)
Epoch 019/050 - Loss: 1.40 (1.97 0.20 0.87 0.12 0.23)
Epoch 020/050 - Loss: 1.41 (2.27 0.22 1.02 0.12 0.23)
Epoch 021/050 - Loss: 1.38 (2.10 0.19 0.93 0.12 0.24)
Epoch 022/050 - Loss: 1.38 (2.13 0.19 0.95 0.12 0.24)
Epoch 023/050 - Loss: 1.36 (1.99 0.17 0.87 0.12 0.25)
Epoch 024/050 - Loss: 1.36 (1.84 0.17 0.80 0.12 0.25)
Epoch 025/050 - Loss: 1.35 (2.18 0.16 0.95 0.12 0.27)
Epoch 026/050 - Loss: 1.31 (3.23 0.14 1.49 0.12 0.26)
Epoch 027/050 - Loss: 1.32 (2.31 0.14 1.03 0.12 0.25)
Epoch 028/050 - Loss: 1.32 (2.29 0.14 1.02 0.12 0.25)
Epoch 029/050 - Loss: 1.27 (3.13 0.10 1.44 0.12 0.25)
Epoch 030/050 - Loss: 1.28 (3.01 0.10 1.38 0.12 0.25)
Epoch 031/050 - Loss: 1.28 (2.12 0.10 0.93 0.12 0.26)
Epoch 032/050 - Loss: 1.24 (2.48 0.08 1.12 0.12 0.24)
Epoch 033/050 - Loss: 1.23 (1.87 0.07 0.81 0.12 0.24)
Epoch 034/050 - Loss: 1.24 (2.03 0.08 0.88 0.12 0.26)
Epoch 035/050 - Loss: 1.24 (2.33 0.08 1.04 0.12 0.24)
Epoch 036/050 - Loss: 1.22 (2.03 0.06 0.89 0.12 0.26)
Epoch 037/050 - Loss: 1.21 (2.45 0.05 1.10 0.12 0.26)
Epoch 038/050 - Loss: 1.21 (2.91 0.06 1.33 0.11 0.24)
Epoch 039/050 - Loss: 1.21 (2.68 0.06 1.21 0.12 0.27)
Epoch 040/050 - Loss: 1.21 (3.07 0.06 1.41 0.12 0.24)
Epoch 041/050 - Loss: 1.18 (2.94 0.03 1.34 0.12 0.25)
Epoch 042/050 - Loss: 1.17 (2.79 0.02 1.26 0.11 0.27)
Epoch 043/050 - Loss: 1.17 (2.29 0.02 1.01 0.11 0.26)
Epoch 044/050 - Loss: 1.16 (3.28 0.01 1.51 0.11 0.26)
Epoch 045/050 - Loss: 1.14 (4.70 -0.00 2.23 0.11 0.25)
Epoch 046/050 - Loss: 1.13 (2.86 -0.00 1.30 0.11 0.26)
Epoch 047/050 - Loss: 1.14 (3.31 0.00 1.53 0.11 0.25)
Epoch 048/050 - Loss: 1.14 (2.96 -0.01 1.36 0.11 0.24)
Epoch 049/050 - Loss: 1.11 (2.69 -0.02 1.22 0.11 0.26)
Epoch 050/050 - Loss: 1.13 (2.31 -0.02 1.03 0.11 0.25)
[14]:
show_result(model, test_loader)
torch.Size([10, 3, 5])
../../_images/notebooks_nn_lfo_1d_6_1.png
../../_images/notebooks_nn_lfo_1d_6_2.png
[17]:
dataset = P53Data(data_dir='../../../data', replicate=0)
plotter = Plotter1d(model, dataset.gene_names)
print(dataset.t_observed.shape, dataset.ups_t_observed.shape)
train, test = generate_neural_dataset_1d(dataset.ups_t_observed,
                                         [[dataset.ups_m_observed.unsqueeze(0),
                                           dataset.ups_f_observed.view(1, 1, 30)]], [dataset.params])
p53_loader = DataLoader(train, batch_size=1)
show_result(model, p53_loader, plotter)
torch.Size([7]) torch.Size([30])
torch.Size([1, 3, 5])
../../_images/notebooks_nn_lfo_1d_7_1.png
../../_images/notebooks_nn_lfo_1d_7_2.png
[8]:
show_result(model, subsampled_loader)
show_result(model, high_res_loader)
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-8-8dfca9af2371> in <module>
----> 1 show_result(model, subsampled_loader)
      2 show_result(model, high_res_loader)
      3

NameError: name 'subsampled_loader' is not defined
[ ]:
#torch.save(model.state_dict(), './saved_model3205.pt')
[24]:
from experiments.mae import get_datasets
from torch.nn.functional import l1_loss
from alfi.utilities.torch import inv_softplus, q2
datasets = get_datasets(data_dir='./../../..')

f_maes = list()
param_maes = list()
f_q2s = list()
num_plots = 10
for i in range(10):
    dataset = datasets[i]
    dataset.variance = 1e-5 * torch.ones(dataset.m_observed.shape[-1], dtype=torch.float32)

    num_genes = 5
    num_tfs = 1
    if i < num_plots:
        plt.figure(figsize=(4, 2))
        for j in range(5):
            plt.plot(dataset[j][1])
        plt.plot(dataset.f_observed[0, 0])
    t_end = dataset.t_observed[-1]

    ground_truths = [
                    dataset.lfm.basal_rate.detach().view(-1).numpy(),
                    dataset.lfm.sensitivity.detach().view(-1).numpy(),
                    dataset.lfm.decay_rate.detach().view(-1).numpy()
    ]

    highres = False
    if highres:
        t = dataset.t_observed_highres
        m = dataset.m_observed_highres
        f = dataset.f_observed_highres.view(1, 1, 111)
    else:
        t = dataset.t_observed
        m = dataset.m_observed
        f = dataset.f_observed.view(1, 1, 12)

    train, test = generate_neural_dataset_1d(
        t,
        [[m, f]],
        [torch.tensor(ground_truths)])
    mae_loader = DataLoader(train, batch_size=1)

    x, y, params = next(iter(mae_loader))
    p_y_pred, params_out = model(x)
    f_mean = p_y_pred[0, :, 0].detach()
    f_mae = l1_loss(f_mean, f[0,0])
    f_q2 = q2(f[0,0], f_mean).mean() #y_test, f_mean
    params_out = params_out.view(3, 5).detach()
    params_mae = (
        l1_loss(params_out[0], dataset.lfm.basal_rate) +
        l1_loss(params_out[1], dataset.lfm.sensitivity) +
        l1_loss(params_out[2], dataset.lfm.decay_rate)
    ).mean().item()

    f_maes.append(f_mae)
    f_q2s.append(f_q2)
    param_maes.append(params_mae)

    if i < num_plots:
        show_result(model, mae_loader)

f_maes = np.array(f_maes)
f_q2s = np.array(f_q2s)
param_maes = np.array(param_maes)
print('F mae: ', f_maes.mean(), f_maes.std())
print('F q2: ', f_q2s.mean(), f_q2s.std())
print('param mae:', param_maes.mean(), param_maes.std())
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:53: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:54: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:53: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:54: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:53: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:54: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:53: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:54: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:53: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:54: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:53: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:54: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:53: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:54: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
C:\Users\Jacob\Documents\proj\lafomo\lafomo\plot\base_plotter.py:44: RuntimeWarning: More than 20 figures have been opened. Figures created through the pyplot interface (`matplotlib.pyplot.figure`) are retained until explicitly closed and may consume too much memory. (To control this warning, see the rcParam `figure.max_open_warning`).
  fig, axes = plt.subplots(ncols=num_plots, figsize=figsize)
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:53: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:54: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:53: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:54: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:53: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:54: UserWarning: Using a target size (torch.Size([5, 1])) that is different to the input size (torch.Size([5])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
torch.Size([1, 3, 5])
torch.Size([1, 3, 5])
torch.Size([1, 3, 5])
torch.Size([1, 3, 5])
torch.Size([1, 3, 5])
torch.Size([1, 3, 5])
torch.Size([1, 3, 5])
torch.Size([1, 3, 5])
torch.Size([1, 3, 5])
torch.Size([1, 3, 5])
F mae:  0.1293553 0.054392833
F q2:  0.95558965 0.033103462
param mae: 0.6869515806436539 0.16563994315446798
../../_images/notebooks_nn_lfo_1d_10_2.png
../../_images/notebooks_nn_lfo_1d_10_3.png
../../_images/notebooks_nn_lfo_1d_10_4.png
../../_images/notebooks_nn_lfo_1d_10_5.png
../../_images/notebooks_nn_lfo_1d_10_6.png
../../_images/notebooks_nn_lfo_1d_10_7.png
../../_images/notebooks_nn_lfo_1d_10_8.png
../../_images/notebooks_nn_lfo_1d_10_9.png
../../_images/notebooks_nn_lfo_1d_10_10.png
../../_images/notebooks_nn_lfo_1d_10_11.png
../../_images/notebooks_nn_lfo_1d_10_12.png
../../_images/notebooks_nn_lfo_1d_10_13.png
../../_images/notebooks_nn_lfo_1d_10_14.png
../../_images/notebooks_nn_lfo_1d_10_15.png
../../_images/notebooks_nn_lfo_1d_10_16.png
../../_images/notebooks_nn_lfo_1d_10_17.png
../../_images/notebooks_nn_lfo_1d_10_18.png
../../_images/notebooks_nn_lfo_1d_10_19.png
../../_images/notebooks_nn_lfo_1d_10_20.png
../../_images/notebooks_nn_lfo_1d_10_21.png
../../_images/notebooks_nn_lfo_1d_10_22.png
../../_images/notebooks_nn_lfo_1d_10_23.png
../../_images/notebooks_nn_lfo_1d_10_24.png
../../_images/notebooks_nn_lfo_1d_10_25.png
../../_images/notebooks_nn_lfo_1d_10_26.png
../../_images/notebooks_nn_lfo_1d_10_27.png
../../_images/notebooks_nn_lfo_1d_10_28.png
../../_images/notebooks_nn_lfo_1d_10_29.png
../../_images/notebooks_nn_lfo_1d_10_30.png
../../_images/notebooks_nn_lfo_1d_10_31.png
[23]:
print(f_q2s)
[0.08028523 0.01363233 0.01098458 0.02117325 0.0634205  0.03634365
 0.0254864  0.10608732 0.01780646 0.03187532]
[56]:
data = next(iter(mae_loader))
print(data[0].shape, data[1].shape)
for i in range(5):
    plt.plot(data[0][0, :, i+1])
plt.plot(data[1][0, :, 0])
torch.Size([1, 12, 6]) torch.Size([1, 12, 1])
[56]:
[<matplotlib.lines.Line2D at 0x23308716ec8>]
../../_images/notebooks_nn_lfo_1d_12_2.png
[69]:
data = next(iter(test_loader))
print(data[0].shape, data[1].shape)
for i in range(5):
    plt.plot(data[0][0, :, i+1])
plt.plot(data[1][0, :, 0])


torch.Size([10, 12, 6]) torch.Size([10, 12, 1])
[69]:
[<matplotlib.lines.Line2D at 0x23308aafc88>]
../../_images/notebooks_nn_lfo_1d_13_2.png
[59]:

[ ]: