LFO 2D reaction-diffusion

For this example, we use a dataset generated by the reaction-diffusion equation LFM.

[1]:
import numpy as np
import torch

import torch.nn.functional as F
from torch.utils.data import DataLoader

from matplotlib import pyplot as plt

from alfi.datasets import ReactionDiffusion, HomogeneousReactionDiffusion
from alfi.models import NeuralOperator, NeuralLFM
from alfi.trainers import NeuralOperatorTrainer
from alfi.plot import plot_spatiotemporal_data, tight_kwargs, Plotter2d
[11]:
ntest = 50

dataset = ReactionDiffusion('../../../data',
                            nn_format=True,
                            max_n=4000,
                            ntest=ntest,
                            sub=1)
batch_size = 50
train_loader = DataLoader(dataset.train_data, batch_size=batch_size, shuffle=True)
if ntest > 0:
    test_loader = DataLoader(dataset.test_data, batch_size=ntest, shuffle=True)

i = 3
tx = dataset.data[i][0]
lf = dataset.data[i][1]
ts = tx[:, :, 1].unique().sort()[0].numpy()
xs = tx[:, :, 2].unique().sort()[0].numpy()
extent = [ts[0], ts[-1], xs[0], xs[-1]]

plot_spatiotemporal_data(
    [
        tx[:, :, 0].t(),
        lf[:, :, 0].t(),
    ],
    extent, nrows=1, ncols=2
)

[11]:
<mpl_toolkits.axes_grid1.axes_grid.ImageGrid at 0x1e787cc8048>
../../_images/notebooks_nn_lfo_2d_2_1.png
[12]:
dataset = HomogeneousReactionDiffusion('../../../data',
                                       one_fixed_sample=False,
                                       highres=True,
                                       nn_format=True,
                                       sub=1, ntest=0)
high_res_loader = DataLoader(dataset.train_data, batch_size=batch_size, shuffle=False)

dataset = HomogeneousReactionDiffusion('../../../data',
                                       one_fixed_sample=False,
                                       highres=True,
                                       nn_format=True,
                                       sub=2, ntest=0)
subsampled_loader = DataLoader(dataset.train_data, batch_size=batch_size, shuffle=False)
i = torch.randint(50, torch.Size([1]))[0]
tx = dataset.data[i][0]

plot_spatiotemporal_data(
    [torch.tensor(tx[:, :, 0])],
    extent, nrows=1, ncols=1
)

block_dim = 2
learning_rate = 1e-3

modes = 10
width = 38
in_channels = 3

C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
../../_images/notebooks_nn_lfo_2d_3_1.png
[13]:
model = NeuralOperator(block_dim, in_channels, 2, modes, width, num_layers=4)
# r_dim = z_dim = 16
# model = NeuralLFM(block_dim, in_channels,
#                   modes, width, r_dim, z_dim)

print(model.count_params())

optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
trainer = NeuralOperatorTrainer(model, [optimizer], train_loader, test_loader)
print(train_loader.dataset[0][0].shape)
25487735
torch.Size([41, 41, 3])
[5]:
state_dict = torch.load('../nn/saved_model1505.pt')
model = NeuralOperator(block_dim, in_channels, 2, modes, width, num_layers=4)
print(model.count_params())
model.load_state_dict(state_dict)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
trainer = NeuralOperatorTrainer(model, [optimizer], train_loader, test_loader)
---------------------------------------------------------------------------
FileNotFoundError                         Traceback (most recent call last)
<ipython-input-5-8f76473f04dc> in <module>
----> 1 state_dict = torch.load('../nn/saved_model1505.pt')
      2 model = NeuralOperator(block_dim, in_channels, 2, modes, width, num_layers=4)
      3 print(model.count_params())
      4 model.load_state_dict(state_dict)
      5 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)

~\miniconda3\envs\wishart\lib\site-packages\torch\serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
    577         pickle_load_args['encoding'] = 'utf-8'
    578
--> 579     with _open_file_like(f, 'rb') as opened_file:
    580         if _is_zipfile(opened_file):
    581             # The zipfile reader is going to advance the current file position.

~\miniconda3\envs\wishart\lib\site-packages\torch\serialization.py in _open_file_like(name_or_buffer, mode)
    228 def _open_file_like(name_or_buffer, mode):
    229     if _is_path(name_or_buffer):
--> 230         return _open_file(name_or_buffer, mode)
    231     else:
    232         if 'w' in mode:

~\miniconda3\envs\wishart\lib\site-packages\torch\serialization.py in __init__(self, name, mode)
    209 class _open_file(_opener):
    210     def __init__(self, name, mode):
--> 211         super(_open_file, self).__init__(open(name, mode))
    212
    213     def __exit__(self, *args):

FileNotFoundError: [Errno 2] No such file or directory: '../nn/saved_model1505.pt'
[14]:
trainer.train(10);
Epoch 001/010 - Loss: 3.06 (1.65 1.46 1.44 0.16 0.02)
Epoch 002/010 - Loss: 1.59 (1.62 1.40 1.38 0.02 0.02)
Epoch 003/010 - Loss: 1.49 (1.50 1.32 1.21 0.02 0.03)
Epoch 004/010 - Loss: 1.19 (1.26 1.04 0.91 0.02 0.03)
Epoch 005/010 - Loss: 0.85 (0.99 0.71 0.70 0.01 0.03)
Epoch 006/010 - Loss: 0.67 (0.94 0.53 0.61 0.01 0.03)
Epoch 007/010 - Loss: 0.55 (0.82 0.44 0.51 0.01 0.03)
Epoch 008/010 - Loss: 0.49 (0.82 0.38 0.47 0.01 0.04)
Epoch 009/010 - Loss: 0.46 (0.74 0.36 0.44 0.01 0.03)
Epoch 010/010 - Loss: 0.39 (0.68 0.30 0.40 0.01 0.03)
[45]:
import time
def show_result(model, loader, plot_uncertainty=False, plotter=None, index=0):
    x, y, params = next(iter(loader))
    print(x.shape)
    t0 = time.time()
    p_y_pred, params_out = model.predict_f(x[0:1])
    tend = time.time()
    times.append((tend - t0)/60)
    print('t', (tend - t0)/60)
    mean = p_y_pred.mean[index]
    std  = p_y_pred.variance.sqrt()[index]
    num_t = x.shape[1]
    num_x = x.shape[2]
    x = x[index][...,0]
    y = y[index]
    if not plot_uncertainty:
        plot_spatiotemporal_data(
            [
                mean.detach().view(num_t, num_x).t(),
                y.view(num_t, num_x).t(), #num_t
                x.view(num_t, num_x).t()
            ],
            extent, nrows=1, ncols=3,
            titles=['Latent (Prediction)', 'Latent (Target)', 'Test input'],
            cticks=None,  # [0, 100, 200]
            clim=[(y.min(), y.max())] * 2 + [(x.min(), x.max())],
        )
    else:
        a = np.zeros((41, 41, 4))
        a[:, :, 0] = mean
        a[:, :, 0] = (a[:, :, 0] - a[:, :, 0].min()) / (a[:, :, 0].max() - a[:, :, 0].min())
        a[:, :, 3] = std
        a[:, :, 3] = 1-(a[:, :, 3] - a[:, :, 3].min()) / (a[:, :, 3].max() - a[:, :, 3].min())
        plt.imshow(a, origin='lower', extent=extent)
        plt.colorbar()

    plotter.plot_double_bar(
        params_out[0:1],
        ['l1', 'l2', 'sens', 'decay', 'diff'],
        ground_truths=params[0:1],
        figsize=(5, 3)
    )
    out = mean.squeeze()
    y_target = y.squeeze()

    print(params_out[0].detach(), params[0])
    print(F.mse_loss(out, y_target))
    # from lafomo.utilities.torch import smse, q2
    # print(y.shape, f_mean_test.shape)
    # print(smse(y_target, f_mean_test).shape)
[46]:
times = list()
[52]:
plotter = Plotter2d(model, np.array(['']))

show_result(model, test_loader, plotter=plotter)
torch.Size([50, 41, 41, 3])
t 0.00033342440923055013
tensor([ 0.3696,  0.3898,  0.7117,  0.2007, -0.0008]) tensor([0.3000, 0.4000, 1.0000, 0.2650, 0.0208])
tensor(0.1168)
../../_images/notebooks_nn_lfo_2d_9_1.png
../../_images/notebooks_nn_lfo_2d_9_2.png
[55]:
print(np.array(times).mean(), np.array(times).std())
0.00032507777214050296 1.5959767695305815e-05
[10]:
i = 10
show_result(model, subsampled_loader, plotter=plotter, index=i)
show_result(model, high_res_loader, plotter=plotter, index=i)
tensor([ 0.5178,  0.6491,  1.7789,  0.3636, -0.0308]) tensor([0.3000, 0.3000, 1.0000, 0.1000, 0.0100])
tensor(0.4199)
tensor([ 1.8254,  2.4872,  7.0254,  1.3237, -0.1865]) tensor([0.3000, 0.3000, 1.0000, 0.1000, 0.0100])
tensor(0.4132)
../../_images/notebooks_nn_lfo_2d_11_1.png
../../_images/notebooks_nn_lfo_2d_11_2.png
../../_images/notebooks_nn_lfo_2d_11_3.png
../../_images/notebooks_nn_lfo_2d_11_4.png
[206]:
print(trainer.data_loader.dataset.__len__())
3950
[199]:
from lafomo.datasets import DrosophilaSpatialTranscriptomics
from scipy.interpolate import interp1d

weird = True
if weird:
    from lafomo.utilities.data import generate_neural_dataset_2d
    temp = DrosophilaSpatialTranscriptomics(gene='kr', data_dir='../../../data', scale=True, scale_tx=True, nn_format=False)
    params = torch.tensor([-1.]*5).unsqueeze(0)
    train, test = generate_neural_dataset_2d(temp.orig_data.unsqueeze(0), params, 1, 0)
    train[0][0][..., 0] /= 6
    print('train', train[0][0][..., 0].shape)
    class Fuckery():
        def __init__(self):
            self.train_data = train
    dataset = Fuckery()
else:
    dataset = DrosophilaSpatialTranscriptomics(gene='kr', data_dir='../../../data', scale=True, scale_tx=True, nn_format=True)


x = interp1d(np.linspace(0, 1, 8), dataset.train_data[0][0], axis=0)(np.linspace(0, 1, 64))
x1 = interp1d(np.linspace(0, 1, 8), dataset.train_data[0][1], axis=0)(np.linspace(0, 1, 64))
dataset.train_data[0] = (
    torch.tensor(x, dtype=torch.float32),#dataset.train_data[0][0],#.permute(1, 0, 2),
    torch.tensor(x1, dtype=torch.float32),
    dataset.train_data[0][2]
)

dros_loader = DataLoader(dataset.train_data)
show_result(model, dros_loader, plotter=plotter)
C:\Users\Jacob\Documents\proj\lafomo\lafomo\utilities\data.py:44: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  grid = torch.tensor(grid.reshape(1, s1, s2, 2), dtype=torch.float)
C:\Users\Jacob\Documents\proj\lafomo\lafomo\utilities\data.py:48: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
  data = torch.tensor(data.reshape(data.shape[0], s1, s2, 4), dtype=torch.float)
params torch.Size([1, 5])
train torch.Size([8, 64])
torch.Size([1, 64, 64]) torch.Size([1, 64, 64])
torch.Size([1, 64, 64]) torch.Size([1, 64, 64])
tensor([0.3168, 0.3087, 1.3749, 0.2380, 0.0128]) tensor([-1., -1., -1., -1., -1.])
tensor(0.3743)
../../_images/notebooks_nn_lfo_2d_13_2.png
../../_images/notebooks_nn_lfo_2d_13_3.png
[201]:
# CALCULATING SMSE
import time

t0 = time.time()
x, y, params = next(iter(dros_loader))
p_y_pred, params_out = model.predict_f(x)
t1 = time.time()
mean = p_y_pred.mean[0]
std  = p_y_pred.variance.sqrt()[0]
y = y.squeeze()
from alfi.utilities.torch import smse, q2
print(smse(y.view(-1), mean.view(-1)).mean())
print(q2(y.view(-1), mean.view(-1)).mean())
print(mean.shape, y.shape)
print(t1 - t0)
torch.Size([1, 64, 64]) torch.Size([1, 64, 64])
torch.Size([1, 64, 64]) torch.Size([1, 64, 64])
tensor(0.3647)
tensor(0.6352)
torch.Size([64, 64]) torch.Size([64, 64])
0.033008575439453125
[56]:
torch.save(model.state_dict(), './saved_model0406.pt')
[167]:
# 0, 3, 6, 8, 48
i = 48
# i = torch.randint(50, torch.Size([1]))[0]
print(i)

x_sub, y_sub, params_sub = subsampled_loader.dataset[i]
x, y, params = high_res_loader.dataset[i]
out, _ = model(x.unsqueeze(0))
out_sub, _ = model(x_sub.unsqueeze(0))

num_t = x.shape[0]
num_x = x.shape[1]
num_t_sub = x_sub.shape[0]
num_x_sub = x_sub.shape[1]

plot_spatiotemporal_data(
    [
        x[..., 0].view(num_t, num_x).t(),
        out_sub[0, ..., 0].detach().view(num_t_sub, num_x_sub).t(),
        out[0, ..., 0].detach().view(num_t, num_x).t(),
        y[..., 0].view(num_t, num_x).t(),
        # y_sub[..., 0].view(num_t_sub, num_x_sub).t(),
    ],
    extent, nrows=1, ncols=4, figsize=(12, 4),
    clim=[(x[...,0].min(), x[...,0].max())] + [(out[0].min(), out[0].max())] * 3,
    titles=['Test input', 'Low-resolution Output', 'Super-resolution Output', 'Target Output']
)
plt.tight_layout()
out = out.squeeze()
y_target = y.squeeze()
plt.savefig('toy48.pdf', **tight_kwargs)
48
../../_images/notebooks_nn_lfo_2d_16_1.png
[19]:
dataset = HomogeneousReactionDiffusion('../../../data', one_fixed_sample=True, nn_format=True, ntest=0)

toy_loader = DataLoader(dataset.train_data)
show_result(model, toy_loader)
params torch.Size([1, 4])
torch.Size([1, 4])
torch.Size([1, 41, 41, 2])
41 41 torch.Size([1, 41, 41, 3]) torch.Size([41, 41])
tensor([0.1800, 0.2280, 0.1824, 0.0113]) tensor([0.3000, 0.3000, 0.1000, 0.0100])
tensor(0.1076)
../../_images/notebooks_nn_lfo_2d_17_1.png
[ ]:

[ ]: