LFO 2D reaction-diffusion
For this example, we use a dataset generated by the reaction-diffusion equation LFM.
[1]:
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from alfi.datasets import ReactionDiffusion, HomogeneousReactionDiffusion
from alfi.models import NeuralOperator, NeuralLFM
from alfi.trainers import NeuralOperatorTrainer
from alfi.plot import plot_spatiotemporal_data, tight_kwargs, Plotter2d
[11]:
ntest = 50
dataset = ReactionDiffusion('../../../data',
nn_format=True,
max_n=4000,
ntest=ntest,
sub=1)
batch_size = 50
train_loader = DataLoader(dataset.train_data, batch_size=batch_size, shuffle=True)
if ntest > 0:
test_loader = DataLoader(dataset.test_data, batch_size=ntest, shuffle=True)
i = 3
tx = dataset.data[i][0]
lf = dataset.data[i][1]
ts = tx[:, :, 1].unique().sort()[0].numpy()
xs = tx[:, :, 2].unique().sort()[0].numpy()
extent = [ts[0], ts[-1], xs[0], xs[-1]]
plot_spatiotemporal_data(
[
tx[:, :, 0].t(),
lf[:, :, 0].t(),
],
extent, nrows=1, ncols=2
)
[11]:
<mpl_toolkits.axes_grid1.axes_grid.ImageGrid at 0x1e787cc8048>
[12]:
dataset = HomogeneousReactionDiffusion('../../../data',
one_fixed_sample=False,
highres=True,
nn_format=True,
sub=1, ntest=0)
high_res_loader = DataLoader(dataset.train_data, batch_size=batch_size, shuffle=False)
dataset = HomogeneousReactionDiffusion('../../../data',
one_fixed_sample=False,
highres=True,
nn_format=True,
sub=2, ntest=0)
subsampled_loader = DataLoader(dataset.train_data, batch_size=batch_size, shuffle=False)
i = torch.randint(50, torch.Size([1]))[0]
tx = dataset.data[i][0]
plot_spatiotemporal_data(
[torch.tensor(tx[:, :, 0])],
extent, nrows=1, ncols=1
)
block_dim = 2
learning_rate = 1e-3
modes = 10
width = 38
in_channels = 3
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\ipykernel_launcher.py:18: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
[13]:
model = NeuralOperator(block_dim, in_channels, 2, modes, width, num_layers=4)
# r_dim = z_dim = 16
# model = NeuralLFM(block_dim, in_channels,
# modes, width, r_dim, z_dim)
print(model.count_params())
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
trainer = NeuralOperatorTrainer(model, [optimizer], train_loader, test_loader)
print(train_loader.dataset[0][0].shape)
25487735
torch.Size([41, 41, 3])
[5]:
state_dict = torch.load('../nn/saved_model1505.pt')
model = NeuralOperator(block_dim, in_channels, 2, modes, width, num_layers=4)
print(model.count_params())
model.load_state_dict(state_dict)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
trainer = NeuralOperatorTrainer(model, [optimizer], train_loader, test_loader)
---------------------------------------------------------------------------
FileNotFoundError Traceback (most recent call last)
<ipython-input-5-8f76473f04dc> in <module>
----> 1 state_dict = torch.load('../nn/saved_model1505.pt')
2 model = NeuralOperator(block_dim, in_channels, 2, modes, width, num_layers=4)
3 print(model.count_params())
4 model.load_state_dict(state_dict)
5 optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, weight_decay=1e-4)
~\miniconda3\envs\wishart\lib\site-packages\torch\serialization.py in load(f, map_location, pickle_module, **pickle_load_args)
577 pickle_load_args['encoding'] = 'utf-8'
578
--> 579 with _open_file_like(f, 'rb') as opened_file:
580 if _is_zipfile(opened_file):
581 # The zipfile reader is going to advance the current file position.
~\miniconda3\envs\wishart\lib\site-packages\torch\serialization.py in _open_file_like(name_or_buffer, mode)
228 def _open_file_like(name_or_buffer, mode):
229 if _is_path(name_or_buffer):
--> 230 return _open_file(name_or_buffer, mode)
231 else:
232 if 'w' in mode:
~\miniconda3\envs\wishart\lib\site-packages\torch\serialization.py in __init__(self, name, mode)
209 class _open_file(_opener):
210 def __init__(self, name, mode):
--> 211 super(_open_file, self).__init__(open(name, mode))
212
213 def __exit__(self, *args):
FileNotFoundError: [Errno 2] No such file or directory: '../nn/saved_model1505.pt'
[14]:
trainer.train(10);
Epoch 001/010 - Loss: 3.06 (1.65 1.46 1.44 0.16 0.02)
Epoch 002/010 - Loss: 1.59 (1.62 1.40 1.38 0.02 0.02)
Epoch 003/010 - Loss: 1.49 (1.50 1.32 1.21 0.02 0.03)
Epoch 004/010 - Loss: 1.19 (1.26 1.04 0.91 0.02 0.03)
Epoch 005/010 - Loss: 0.85 (0.99 0.71 0.70 0.01 0.03)
Epoch 006/010 - Loss: 0.67 (0.94 0.53 0.61 0.01 0.03)
Epoch 007/010 - Loss: 0.55 (0.82 0.44 0.51 0.01 0.03)
Epoch 008/010 - Loss: 0.49 (0.82 0.38 0.47 0.01 0.04)
Epoch 009/010 - Loss: 0.46 (0.74 0.36 0.44 0.01 0.03)
Epoch 010/010 - Loss: 0.39 (0.68 0.30 0.40 0.01 0.03)
[45]:
import time
def show_result(model, loader, plot_uncertainty=False, plotter=None, index=0):
x, y, params = next(iter(loader))
print(x.shape)
t0 = time.time()
p_y_pred, params_out = model.predict_f(x[0:1])
tend = time.time()
times.append((tend - t0)/60)
print('t', (tend - t0)/60)
mean = p_y_pred.mean[index]
std = p_y_pred.variance.sqrt()[index]
num_t = x.shape[1]
num_x = x.shape[2]
x = x[index][...,0]
y = y[index]
if not plot_uncertainty:
plot_spatiotemporal_data(
[
mean.detach().view(num_t, num_x).t(),
y.view(num_t, num_x).t(), #num_t
x.view(num_t, num_x).t()
],
extent, nrows=1, ncols=3,
titles=['Latent (Prediction)', 'Latent (Target)', 'Test input'],
cticks=None, # [0, 100, 200]
clim=[(y.min(), y.max())] * 2 + [(x.min(), x.max())],
)
else:
a = np.zeros((41, 41, 4))
a[:, :, 0] = mean
a[:, :, 0] = (a[:, :, 0] - a[:, :, 0].min()) / (a[:, :, 0].max() - a[:, :, 0].min())
a[:, :, 3] = std
a[:, :, 3] = 1-(a[:, :, 3] - a[:, :, 3].min()) / (a[:, :, 3].max() - a[:, :, 3].min())
plt.imshow(a, origin='lower', extent=extent)
plt.colorbar()
plotter.plot_double_bar(
params_out[0:1],
['l1', 'l2', 'sens', 'decay', 'diff'],
ground_truths=params[0:1],
figsize=(5, 3)
)
out = mean.squeeze()
y_target = y.squeeze()
print(params_out[0].detach(), params[0])
print(F.mse_loss(out, y_target))
# from lafomo.utilities.torch import smse, q2
# print(y.shape, f_mean_test.shape)
# print(smse(y_target, f_mean_test).shape)
[46]:
times = list()
[52]:
plotter = Plotter2d(model, np.array(['']))
show_result(model, test_loader, plotter=plotter)
torch.Size([50, 41, 41, 3])
t 0.00033342440923055013
tensor([ 0.3696, 0.3898, 0.7117, 0.2007, -0.0008]) tensor([0.3000, 0.4000, 1.0000, 0.2650, 0.0208])
tensor(0.1168)
[55]:
print(np.array(times).mean(), np.array(times).std())
0.00032507777214050296 1.5959767695305815e-05
[10]:
i = 10
show_result(model, subsampled_loader, plotter=plotter, index=i)
show_result(model, high_res_loader, plotter=plotter, index=i)
tensor([ 0.5178, 0.6491, 1.7789, 0.3636, -0.0308]) tensor([0.3000, 0.3000, 1.0000, 0.1000, 0.0100])
tensor(0.4199)
tensor([ 1.8254, 2.4872, 7.0254, 1.3237, -0.1865]) tensor([0.3000, 0.3000, 1.0000, 0.1000, 0.0100])
tensor(0.4132)
[206]:
print(trainer.data_loader.dataset.__len__())
3950
[199]:
from lafomo.datasets import DrosophilaSpatialTranscriptomics
from scipy.interpolate import interp1d
weird = True
if weird:
from lafomo.utilities.data import generate_neural_dataset_2d
temp = DrosophilaSpatialTranscriptomics(gene='kr', data_dir='../../../data', scale=True, scale_tx=True, nn_format=False)
params = torch.tensor([-1.]*5).unsqueeze(0)
train, test = generate_neural_dataset_2d(temp.orig_data.unsqueeze(0), params, 1, 0)
train[0][0][..., 0] /= 6
print('train', train[0][0][..., 0].shape)
class Fuckery():
def __init__(self):
self.train_data = train
dataset = Fuckery()
else:
dataset = DrosophilaSpatialTranscriptomics(gene='kr', data_dir='../../../data', scale=True, scale_tx=True, nn_format=True)
x = interp1d(np.linspace(0, 1, 8), dataset.train_data[0][0], axis=0)(np.linspace(0, 1, 64))
x1 = interp1d(np.linspace(0, 1, 8), dataset.train_data[0][1], axis=0)(np.linspace(0, 1, 64))
dataset.train_data[0] = (
torch.tensor(x, dtype=torch.float32),#dataset.train_data[0][0],#.permute(1, 0, 2),
torch.tensor(x1, dtype=torch.float32),
dataset.train_data[0][2]
)
dros_loader = DataLoader(dataset.train_data)
show_result(model, dros_loader, plotter=plotter)
C:\Users\Jacob\Documents\proj\lafomo\lafomo\utilities\data.py:44: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
grid = torch.tensor(grid.reshape(1, s1, s2, 2), dtype=torch.float)
C:\Users\Jacob\Documents\proj\lafomo\lafomo\utilities\data.py:48: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
data = torch.tensor(data.reshape(data.shape[0], s1, s2, 4), dtype=torch.float)
params torch.Size([1, 5])
train torch.Size([8, 64])
torch.Size([1, 64, 64]) torch.Size([1, 64, 64])
torch.Size([1, 64, 64]) torch.Size([1, 64, 64])
tensor([0.3168, 0.3087, 1.3749, 0.2380, 0.0128]) tensor([-1., -1., -1., -1., -1.])
tensor(0.3743)
[201]:
# CALCULATING SMSE
import time
t0 = time.time()
x, y, params = next(iter(dros_loader))
p_y_pred, params_out = model.predict_f(x)
t1 = time.time()
mean = p_y_pred.mean[0]
std = p_y_pred.variance.sqrt()[0]
y = y.squeeze()
from alfi.utilities.torch import smse, q2
print(smse(y.view(-1), mean.view(-1)).mean())
print(q2(y.view(-1), mean.view(-1)).mean())
print(mean.shape, y.shape)
print(t1 - t0)
torch.Size([1, 64, 64]) torch.Size([1, 64, 64])
torch.Size([1, 64, 64]) torch.Size([1, 64, 64])
tensor(0.3647)
tensor(0.6352)
torch.Size([64, 64]) torch.Size([64, 64])
0.033008575439453125
[56]:
torch.save(model.state_dict(), './saved_model0406.pt')
[167]:
# 0, 3, 6, 8, 48
i = 48
# i = torch.randint(50, torch.Size([1]))[0]
print(i)
x_sub, y_sub, params_sub = subsampled_loader.dataset[i]
x, y, params = high_res_loader.dataset[i]
out, _ = model(x.unsqueeze(0))
out_sub, _ = model(x_sub.unsqueeze(0))
num_t = x.shape[0]
num_x = x.shape[1]
num_t_sub = x_sub.shape[0]
num_x_sub = x_sub.shape[1]
plot_spatiotemporal_data(
[
x[..., 0].view(num_t, num_x).t(),
out_sub[0, ..., 0].detach().view(num_t_sub, num_x_sub).t(),
out[0, ..., 0].detach().view(num_t, num_x).t(),
y[..., 0].view(num_t, num_x).t(),
# y_sub[..., 0].view(num_t_sub, num_x_sub).t(),
],
extent, nrows=1, ncols=4, figsize=(12, 4),
clim=[(x[...,0].min(), x[...,0].max())] + [(out[0].min(), out[0].max())] * 3,
titles=['Test input', 'Low-resolution Output', 'Super-resolution Output', 'Target Output']
)
plt.tight_layout()
out = out.squeeze()
y_target = y.squeeze()
plt.savefig('toy48.pdf', **tight_kwargs)
48
[19]:
dataset = HomogeneousReactionDiffusion('../../../data', one_fixed_sample=True, nn_format=True, ntest=0)
toy_loader = DataLoader(dataset.train_data)
show_result(model, toy_loader)
params torch.Size([1, 4])
torch.Size([1, 4])
torch.Size([1, 41, 41, 2])
41 41 torch.Size([1, 41, 41, 3]) torch.Size([41, 41])
tensor([0.1800, 0.2280, 0.1824, 0.0113]) tensor([0.3000, 0.3000, 0.1000, 0.0100])
tensor(0.1076)
[ ]:
[ ]: