[1]:
from experiments.partial import build_partial, plot_partial, pretrain_partial
from pathlib import Path
import numpy as np

from alfi.datasets import DrosophilaSpatialTranscriptomics, HomogeneousReactionDiffusion
from alfi.trainers import PartialPreEstimator
from alfi.plot import plot_spatiotemporal_data
from alfi.plot.misc import plot_variational_dist
from alfi.utilities.torch import spline_interpolate_gradient, softplus
from alfi.utilities.data import dros_ground_truth
from alfi.models import TrainMode

from matplotlib import pyplot as plt
import torch
[2]:
drosophila = True
if drosophila:
    gene = 'gt'
    dataset = DrosophilaSpatialTranscriptomics(
        gene=gene, data_dir='../../../data', scale=True, disc=1)
    params = dict(lengthscale=[20, 10],
                  **dros_ground_truth(gene),
                  parameter_grad=False,
                  warm_epochs=-1,
                  natural=True,
                  zero_mean=True,
                  clamp=True)
    disc = dataset.disc
else:
    data = 'toy-spatial'
    dataset = HomogeneousReactionDiffusion(data_dir='../../../data')
    params = dict(lengthscale=0.2,
                  sensitivity=1,
                  decay=0.1,
                  diffusion=0.01,
                  parameter_grad=False,
                  warm_epochs=-1,
                  natural=False,
                  clamp=False)
    disc = 1
model_name = '0savedmodel'
# model_name = 'epoch165'
# model_name = 'model_with_3'
lfm, trainer, plotter = build_partial(
    dataset,
    params)#,
    # reload=f'../../../experiments/{data}/partial/{model_name}')
lfm.gp_model.covar_module.lengthscale
tx torch.Size([2, 512])
x dp is set to tensor(1., dtype=torch.float64)
[2]:
tensor([[[20.0000, 10.0000]]], dtype=torch.float64, grad_fn=<SoftplusBackward>)
[3]:
plot_partial(dataset, lfm, trainer, plotter, Path('./'), params)
'cmunrm.otf' can not be subsetted into a Type 3 font. The entire font will be embedded in the output.
../../_images/notebooks_pde_pde_minimal_2_1.png
[4]:
pretrain_partial(dataset, lfm, trainer, params);
/Users/jacob/Documents/proj/torchcubicspline/torchcubicspline/interpolate.py:277: UserWarning: input value tensor is non-contiguous, this will lower the performance due to extra data copy when converting non-contiguous tensor to contiguous, please use contiguous input value tensor if possible (Triggered internally at  ../aten/src/ATen/native/BucketizationUtils.h:25.)
  index = torch.bucketize(t.detach(), self._t) - 1
num training 153
Epoch 001/080 - Loss: 0.33 (0.33 0.00) kernel: [[[19.97002649  9.97001707]]]
Epoch 011/080 - Loss: 0.29 (0.29 0.00) kernel: [[[19.83236085  9.88488189]]]
Epoch 021/080 - Loss: 0.26 (0.26 0.00) kernel: [[[19.84878729  9.88612257]]]
Epoch 031/080 - Loss: 0.22 (0.22 0.00) kernel: [[[19.84549775  9.88469033]]]
Epoch 041/080 - Loss: 0.18 (0.18 0.00) kernel: [[[19.81529531  9.8859996 ]]]
Epoch 051/080 - Loss: 0.15 (0.15 0.00) kernel: [[[19.80628765  9.84551196]]]
Epoch 061/080 - Loss: 0.11 (0.11 0.00) kernel: [[[19.95258065  9.70803974]]]
Epoch 071/080 - Loss: 0.07 (0.07 0.00) kernel: [[[20.17451815  9.52854874]]]
[5]:
orig_data = dataset.orig_data.squeeze().t()
num_t_orig = orig_data[:, 0].unique().shape[0]
num_x_orig = orig_data[:, 1].unique().shape[0]

tx = trainer.tx
num_t = tx[0, :].unique().shape[0]
num_x = tx[1, :].unique().shape[0]
y_target = trainer.y_target[0]
ind = lfm.inducing_points[0]
print(ind.shape)
plt.scatter(ind[:, 0], ind[:, 1])
print(trainer.train_mask.shape)
torch.Size([426, 2])
torch.Size([512])
../../_images/notebooks_pde_pde_minimal_4_1.png
[7]:
lfm.config.num_samples = 15
trainer.plot_outputs = True
lfm.set_mode(TrainMode.GRADIENT_MATCH)
trainer.train(10, report_interval=1);
Mean output variance: 0.03320894951911319
Test loss: 1.4794096017496434
prot Q2: 0.472
prot CA: 0.125
mrna Q2: 0.364
mrna CA: 0.260
Epoch 031/040 - Loss: 1.36 (0.45 0.45) kernel: [[[20.42225991  9.73779057]]] s: 0.10020655891674712 dec: 0.10020655891674712 diff: 0.015900012764781094
Mean output variance: 0.03952920682363361
Test loss: 1.4994630875902755
prot Q2: 0.471
prot CA: 0.088
mrna Q2: 0.367
mrna CA: 0.260
Epoch 032/040 - Loss: 1.36 (0.45 0.46) kernel: [[[20.40797207  9.73344118]]] s: 0.10020655891674712 dec: 0.10020655891674712 diff: 0.015900012764781094
Mean output variance: 0.028518628460504944
Test loss: 1.4642968688948381
prot Q2: 0.472
prot CA: 0.096
mrna Q2: 0.363
mrna CA: 0.252
Epoch 033/040 - Loss: 1.34 (0.44 0.45) kernel: [[[20.39254882  9.72880339]]] s: 0.10020655891674712 dec: 0.10020655891674712 diff: 0.015900012764781094
Mean output variance: 0.03655852999519911
Test loss: 1.5135292376463787
prot Q2: 0.442
prot CA: 0.092
mrna Q2: 0.366
mrna CA: 0.240
Epoch 034/040 - Loss: 1.38 (0.47 0.45) kernel: [[[20.37689017  9.72370622]]] s: 0.10020655891674712 dec: 0.10020655891674712 diff: 0.015900012764781094
Mean output variance: 0.0342211338421374
Test loss: 1.4491800478537793
prot Q2: 0.484
prot CA: 0.104
mrna Q2: 0.370
mrna CA: 0.229
Epoch 035/040 - Loss: 1.31 (0.43 0.45) kernel: [[[20.36080458  9.71814265]]] s: 0.10020655891674712 dec: 0.10020655891674712 diff: 0.015900012764781094
Mean output variance: 0.03486932254055179
Test loss: 1.453298349301025
prot Q2: 0.477
prot CA: 0.113
mrna Q2: 0.372
mrna CA: 0.219
Epoch 036/040 - Loss: 1.33 (0.44 0.45) kernel: [[[20.34414317  9.71221889]]] s: 0.10020655891674712 dec: 0.10020655891674712 diff: 0.015900012764781094
Mean output variance: 0.03088667803164573
Test loss: 1.4245052201165138
prot Q2: 0.497
prot CA: 0.115
mrna Q2: 0.376
mrna CA: 0.213
Epoch 037/040 - Loss: 1.30 (0.42 0.46) kernel: [[[20.32676365  9.70607156]]] s: 0.10020655891674712 dec: 0.10020655891674712 diff: 0.015900012764781094
Mean output variance: 0.03064064191488642
Test loss: 1.4275427093133957
prot Q2: 0.495
prot CA: 0.111
mrna Q2: 0.379
mrna CA: 0.217
Epoch 038/040 - Loss: 1.31 (0.42 0.47) kernel: [[[20.30919435  9.69944014]]] s: 0.10020655891674712 dec: 0.10020655891674712 diff: 0.015900012764781094
Mean output variance: 0.03864007083503944
Test loss: 1.4195081004421124
prot Q2: 0.510
prot CA: 0.182
mrna Q2: 0.379
mrna CA: 0.211
Epoch 039/040 - Loss: 1.29 (0.41 0.47) kernel: [[[20.29108383  9.69284169]]] s: 0.10020655891674712 dec: 0.10020655891674712 diff: 0.015900012764781094
Mean output variance: 0.027265226956193447
Test loss: 1.4030025474400256
prot Q2: 0.512
prot CA: 0.135
mrna Q2: 0.380
mrna CA: 0.217
Epoch 040/040 - Loss: 1.29 (0.40 0.48) kernel: [[[20.27248088  9.68592694]]] s: 0.10020655891674712 dec: 0.10020655891674712 diff: 0.015900012764781094
../../_images/notebooks_pde_pde_minimal_5_1.png
[15]:
# for key in trainer.parameter_trace.keys():
#     params = torch.stack(trainer.parameter_trace[key])
#     for i in range(1, params.ndim):
#         params = params.mean(-1)
    # plt.figure()
    # plt.plot(params)
[softplus(param) for param in lfm.fenics_parameters]
[15]:
[tensor([[0.0970]], dtype=torch.float64),
 tensor([[0.0764]], dtype=torch.float64),
 tensor([[0.0015]], dtype=torch.float64)]
[15]:
from alfi.plot import tight_kwargs
plot_partial(dataset, lfm, trainer, plotter, Path('./'), params)

# plt.savefig(Path('./') / f'kinetics-{gene}.pdf', **tight_kwargs)
'cmunrm.otf' can not be subsetted into a Type 3 font. The entire font will be embedded in the output.
../../_images/notebooks_pde_pde_minimal_7_1.png
[ ]:
from alfi.utilities.torch import q2, cia
f = lfm(tx)
f_mean = f.mean.detach()
f_var = f.variance.detach()
y_target = trainer.y_target[0]
def cia(y_test, f_mean, f_var):
    return ((y_test >= (f_mean - 1 * f_var.sqrt())) &
            (y_test <= (f_mean + 1 * f_var.sqrt()))).double().mean()

print(f_mean.shape, y_target.shape, f_var.shape)
print('Q2', q2(y_target.squeeze(), f_mean.squeeze()))
print('CA', cia(y_target.squeeze(), f_mean.squeeze(), f_var.squeeze()))
[ ]:
gp = lfm.gp_model(tx.t())
lf_target = orig_data[trainer.t_sorted, 2]
f_mean = gp.mean.detach().view(num_t, num_x)[::disc].reshape(-1)
f_var = gp.variance.detach().view(num_t, num_x)[::disc].reshape(-1)

print('Q2', q2(lf_target.squeeze(), f_mean.squeeze()))
print('CA', cia(lf_target.squeeze(), f_mean.squeeze(), f_var.squeeze()))

[16]:
lfm.save('./kr-2005')
[ ]:
lfm2, trainer2, plotter2 = build_partial(
    dataset,
    params,
    reload='./kr-1205')
[ ]:
gp = lfm2.gp_model(tx.t())
lf_target = orig_data[trainer.t_sorted, 2]
f_mean = gp.mean.detach()
f_var = gp.variance.detach()

print('Q2', q2(lf_target.squeeze(), f_mean.squeeze()))
print('CA', cia(lf_target.squeeze(), f_mean.squeeze(), f_var.squeeze()))
[71]:
to_plot = list()
clims = list()
means = {'kr': False, 'kni': True, 'gt': False}
for gene in ['kr', 'kni', 'gt']:
    dataset = DrosophilaSpatialTranscriptomics(
        gene=gene, data_dir='../../../data', scale=True)
    params = dict(lengthscale=10,
                  **dros_params[gene],
                  parameter_grad=False,
                  warm_epochs=-1,
                  natural=False,
                  zero_mean=means[gene],
                  clamp=True)
    disc = dataset.disc
    lfm, trainer, plotter = build_partial(
        dataset,
        params,
        reload=f'../../../experiments/{gene}0')

    lfm.eval()
    tx = trainer.tx
    num_t = tx[0, :].unique().shape[0]
    num_x = tx[1, :].unique().shape[0]
    orig_data = dataset.orig_data.squeeze().t()
    num_t_orig = orig_data[:, 0].unique().shape[0]
    num_x_orig = orig_data[:, 1].unique().shape[0]
    disc = dataset.disc if hasattr(dataset, 'disc') else 1

    f = lfm(tx, step=disc)
    f_mean = f.mean.detach()
    f_var = f.variance.detach()
    y_target = trainer.y_target[0]
    ts = tx[0, :].unique().sort()[0].numpy()
    xs = tx[1, :].unique().sort()[0].numpy()
    extent = [ts[0], ts[-1], xs[0], xs[-1]]

    l_target = orig_data[trainer.t_sorted, 2]
    l = lfm.gp_model(tx.t())
    l_mean = l.mean.detach()
    to_plot.append([
            l_mean.view(num_t, num_x).t(),
            l_target.view(num_t_orig, num_x_orig).t(),
            f_mean.view(num_t_orig, num_x_orig).t(),
            y_target.view(num_t_orig, num_x_orig).detach().t(),
    ])

    clims.append(([(l_target.min(), l_target.max())] * 2 + [(y_target.min(), y_target.max())] * 2))
tx torch.Size([2, 512])
x dp is set to tensor(1., dtype=torch.float64)
tx torch.Size([2, 456])
x dp is set to tensor(1., dtype=torch.float64)
tx torch.Size([2, 512])
x dp is set to tensor(1., dtype=torch.float64)
[72]:
clim = [*clims[0], *clims[1], *clims[2]]
plots = [*to_plot[0], *to_plot[1], *to_plot[2]]

grid = plot_spatiotemporal_data(
    plots,
    extent,
    nrows=3,
    ncols=4,
    titles=[
        'Latent (Prediction)', 'Latent (Target)',
        'Output (Prediction)', 'Output (Target)',] + ['']*8,
    cticks=None,  # [0, 100, 200]
    clim=clim
)
plt.gca().get_figure().set_size_inches(15, 7)
for i in range(8):
    grid[i].set_xticks([])
    grid[i].set_xlabel(None)

for i in range(8, 12):
    grid[i].set_xticks([np.ceil(extent[0]), np.floor(extent[1])])
plt.savefig('dros_combined.pdf', **tight_kwargs)
'cmunrm.otf' can not be subsetted into a Type 3 font. The entire font will be embedded in the output.
../../_images/notebooks_pde_pde_minimal_14_1.png
[58]:

[58]:
<mpl_toolkits.axes_grid1.mpl_axes.Axes at 0x7fc687e41e10>
[46]:

[ ]: