Variational Inference: Cascaded ODE
In this example, we use a cascaded system of ODEs.
The dataset required is small and is available preprocessed here:
[1]:
import torch
import numpy as np
from gpytorch.optim import NGD
from torch.optim import Adam
from torch.nn import Parameter
from matplotlib import pyplot as plt
from alfi.datasets import P53Data
from alfi.configuration import VariationalConfiguration
from alfi.models import OrdinaryLFM, generate_multioutput_rbf_gp
from alfi.plot import Plotter1d, Colours
from alfi.trainers import VariationalTrainer
Let’s start by importing our dataset…
[2]:
dataset = P53Data(replicate=0, data_dir='../../../data')
num_genes = 5
num_tfs = 1
plt.figure(figsize=(4, 2))
for i in range(5):
plt.plot(dataset[i][1])
We use the cascaded ordinary differential equation (ODE) model:
dy/dt = b + sf(t) - dy
df/dt = g(t) - λf(t)
g(t) ~ GP(0, k(t, t'))
Since this is an ODE, we inherit from the OrdinaryLFM class.
[3]:
from gpytorch.constraints import Positive
class TranscriptionLFM(OrdinaryLFM):
def __init__(self, num_outputs, gp_model, config: VariationalConfiguration, **kwargs):
super().__init__(num_outputs, gp_model, config, **kwargs)
self.positivity = Positive()
self.raw_decay = Parameter(0.1 + torch.rand(torch.Size([self.num_outputs, 1]), dtype=torch.float64))
self.raw_basal = Parameter(torch.rand(torch.Size([self.num_outputs, 1]), dtype=torch.float64))
self.raw_sensitivity = Parameter(0.2 + torch.rand(torch.Size([self.num_outputs, 1]), dtype=torch.float64))
self.raw_protein_decay = Parameter(0.1 + torch.rand(torch.Size([1, 1]), dtype=torch.float64))
@property
def decay_rate(self):
return self.positivity.transform(self.raw_decay)
@decay_rate.setter
def decay_rate(self, value):
self.raw_decay = self.positivity.inverse_transform(value)
@property
def protein_decay_rate(self):
return self.positivity.transform(self.raw_protein_decay)
@protein_decay_rate.setter
def protein_decay_rate(self, value):
self.raw_protein_decay = self.positivity.inverse_transform(value)
@property
def basal_rate(self):
return self.positivity.transform(self.raw_basal)
@basal_rate.setter
def basal_rate(self, value):
self.raw_basal = self.positivity.inverse_transform(value)
@property
def sensitivity(self):
return self.positivity.transform(self.raw_sensitivity)
@sensitivity.setter
def sensitivity(self, value):
self.raw_sensitivity = self.decay_constraint.inverse_transform(value)
def initial_state(self):
h0 = torch.cat([
torch.zeros_like(self.basal_rate),
self.basal_rate / self.decay_rate
], -1)
return h0
def odefunc(self, t, h):
"""h is of shape (num_samples, num_outputs, 1)"""
self.nfe += 1
# if (self.nfe % 100) == 0:
# print(t)
f = h[:, :, 0].unsqueeze(-1)
y = h[:, :, 1].unsqueeze(-1)
g = self.f[:, :, self.t_index].unsqueeze(-1)
df = g - self.protein_decay_rate * f
dy = self.basal_rate + self.sensitivity * f - self.decay_rate * y
# print('df, dy', df.shape, dy.shape)
if t > self.last_t:
self.t_index += 1
self.last_t = t
return torch.cat([df, dy], -1)
def decode(self, h_out):
return h_out[:, :, 1]
[4]:
config = VariationalConfiguration(
preprocessing_variance=dataset.variance,
num_samples=80,
initial_conditions=False
)
num_inducing = 12 # (I x m x 1)
inducing_points = torch.linspace(0, 12, num_inducing).repeat(num_tfs, 1).view(num_tfs, num_inducing, 1)
t_predict = torch.linspace(0, 14, 80, dtype=torch.float32)
step_size = 5e-1
num_training = dataset.m_observed.shape[-1]
use_natural = True
gp_model = generate_multioutput_rbf_gp(num_tfs, inducing_points, gp_kwargs=dict(natural=use_natural))
lfm = TranscriptionLFM(num_genes, gp_model, config, num_training_points=num_training)
plotter = Plotter1d(lfm, dataset.gene_names, style='seaborn')
[5]:
class P53ConstrainedTrainer(VariationalTrainer):
def after_epoch(self):
with torch.no_grad():
sens = torch.tensor(1.)
dec = torch.tensor(0.8)
self.lfm.raw_sensitivity[3] = self.lfm.positivity.inverse_transform(sens)
self.lfm.raw_decay[3] = self.lfm.positivity.inverse_transform(dec)
super().after_epoch()
track_parameters = [
'raw_basal',
'raw_decay',
'raw_sensitivity',
'gp_model.covar_module.raw_lengthscale',
]
if use_natural:
variational_optimizer = NGD(lfm.variational_parameters(), num_data=num_training, lr=0.1)
parameter_optimizer = Adam(lfm.nonvariational_parameters(), lr=0.02)
optimizers = [variational_optimizer, parameter_optimizer]
else:
optimizers = [Adam(lfm.parameters(), lr=0.03)]
trainer = P53ConstrainedTrainer(lfm, optimizers, dataset, track_parameters=track_parameters)
Outputs prior to training:
[6]:
titles = ['Basal rates', 'Sensitivities', 'Decay rates']
kinetics = list()
for key in ['raw_basal', 'raw_sensitivity', 'raw_decay']:
kinetics.append(
lfm.positivity.transform(trainer.parameter_trace[key][-1].squeeze()).numpy())
kinetics = np.array(kinetics)
plotter.plot_double_bar(kinetics, titles=titles, ground_truths=dataset.params_ground_truth())
q_m = lfm.predict_m(t_predict, step_size=1e-1)
q_f = lfm.predict_f(t_predict)
plotter.plot_gp(q_m, t_predict, replicate=0,
t_scatter=dataset.t_observed,
y_scatter=dataset.m_observed, num_samples=0)
plotter.plot_gp(q_f, t_predict, ylim=(-1, 3))
plt.title('Latent')
[6]:
Text(0.5, 1.0, 'Latent')
[7]:
lfm.train()
step_size = 5e-1
trainer.train(400, report_interval=10, step_size=step_size)
Epoch 001/400 - Loss: 7.81 (7.81 0.00) kernel: [0.6831972]
Epoch 011/400 - Loss: 6.77 (6.54 0.23) kernel: [0.6247893]
Epoch 021/400 - Loss: 6.60 (6.27 0.33) kernel: [0.6334702]
Epoch 031/400 - Loss: 6.20 (5.79 0.41) kernel: [0.6627324]
Epoch 041/400 - Loss: 5.98 (5.49 0.49) kernel: [0.7031785]
Epoch 051/400 - Loss: 5.57 (4.96 0.61) kernel: [0.7404556]
Epoch 061/400 - Loss: 5.30 (4.60 0.70) kernel: [0.75663084]
Epoch 071/400 - Loss: 5.00 (4.20 0.80) kernel: [0.7715795]
Epoch 081/400 - Loss: 4.65 (3.75 0.91) kernel: [0.79156274]
Epoch 091/400 - Loss: 4.35 (3.34 1.02) kernel: [0.8082584]
Epoch 101/400 - Loss: 3.99 (2.87 1.13) kernel: [0.8171648]
Epoch 111/400 - Loss: 3.69 (2.47 1.22) kernel: [0.8188295]
Epoch 121/400 - Loss: 3.33 (2.01 1.32) kernel: [0.83573043]
Epoch 131/400 - Loss: 3.02 (1.59 1.43) kernel: [0.8459381]
Epoch 141/400 - Loss: 2.75 (1.23 1.52) kernel: [0.85402286]
Epoch 151/400 - Loss: 2.45 (0.84 1.61) kernel: [0.8643153]
Epoch 161/400 - Loss: 2.28 (0.58 1.71) kernel: [0.85946304]
Epoch 171/400 - Loss: 2.09 (0.29 1.80) kernel: [0.87005776]
Epoch 181/400 - Loss: 1.85 (-0.03 1.88) kernel: [0.8876577]
Epoch 191/400 - Loss: 1.76 (-0.19 1.94) kernel: [0.8924065]
Epoch 201/400 - Loss: 1.59 (-0.42 2.01) kernel: [0.9023677]
Epoch 211/400 - Loss: 1.58 (-0.51 2.08) kernel: [0.90342885]
Epoch 221/400 - Loss: 1.55 (-0.56 2.11) kernel: [0.9137664]
Epoch 231/400 - Loss: 1.47 (-0.69 2.16) kernel: [0.9276629]
Epoch 241/400 - Loss: 1.44 (-0.76 2.19) kernel: [0.9313885]
Epoch 251/400 - Loss: 1.49 (-0.73 2.22) kernel: [0.9431207]
Epoch 261/400 - Loss: 1.35 (-0.88 2.23) kernel: [0.93753386]
Epoch 271/400 - Loss: 1.40 (-0.86 2.26) kernel: [0.94270945]
Epoch 281/400 - Loss: 1.46 (-0.81 2.27) kernel: [0.9521061]
Epoch 291/400 - Loss: 1.33 (-0.96 2.29) kernel: [0.94980067]
Epoch 301/400 - Loss: 1.39 (-0.91 2.30) kernel: [0.9627767]
Epoch 311/400 - Loss: 1.33 (-1.00 2.32) kernel: [0.9621694]
Epoch 321/400 - Loss: 1.37 (-0.96 2.33) kernel: [0.96069795]
Epoch 331/400 - Loss: 1.33 (-1.01 2.33) kernel: [0.962077]
Epoch 341/400 - Loss: 1.28 (-1.06 2.34) kernel: [0.97064215]
Epoch 351/400 - Loss: 1.34 (-1.01 2.35) kernel: [0.97254443]
Epoch 361/400 - Loss: 1.38 (-1.00 2.38) kernel: [0.9724194]
Epoch 371/400 - Loss: 1.28 (-1.11 2.39) kernel: [0.9802332]
Epoch 381/400 - Loss: 1.27 (-1.14 2.40) kernel: [0.98273104]
Epoch 391/400 - Loss: 1.25 (-1.15 2.40) kernel: [0.9832086]
[7]:
[(1622792396.2409189, 7.810469442678664),
(1622792396.2859287, 7.722534062386705),
(1622792396.3259373, 7.398306377674728),
(1622792396.370948, 7.257347347480059),
(1622792396.415958, 7.119780946068084),
(1622792396.457942, 7.064290477408223),
(1622792396.500944, 7.026900966042142),
(1622792396.551956, 6.996700383673343),
(1622792396.5999668, 6.899742292973044),
(1622792396.6399758, 6.911148677267815),
(1622792396.6879861, 6.769309378743254),
(1622792396.7349975, 6.794865152364443),
(1622792396.8160148, 6.76424181000604),
(1622792396.8870313, 6.696304387691425),
(1622792396.9280405, 6.709920429917044),
(1622792396.9700496, 6.612977246812174),
(1622792397.0150592, 6.585465976874482),
(1622792397.0550685, 6.573126315856527),
(1622792397.0970783, 6.5649836912677975),
(1622792397.1390877, 6.5576576595796086),
(1622792397.1870983, 6.596777220322733),
(1622792397.2271068, 6.50644800822066),
(1622792397.2751179, 6.491191412634177),
(1622792397.315127, 6.48330820329089),
(1622792397.3591363, 6.37825715525253),
(1622792397.3971455, 6.373820345980527),
(1622792397.4391553, 6.333896034426532),
(1622792397.4851646, 6.349614311224442),
(1622792397.5301778, 6.328920212166866),
(1622792397.5771883, 6.251250411884913),
(1622792397.627199, 6.198825037434884),
(1622792397.6742098, 6.225188279257334),
(1622792397.715219, 6.176983763000798),
(1622792397.7582288, 6.12173835718629),
(1622792397.8052394, 6.10957779932243),
(1622792397.8482487, 6.156112609852071),
(1622792397.8922584, 6.006385037739532),
(1622792397.9372692, 6.001296989927484),
(1622792397.9792788, 5.9460588092688536),
(1622792398.023288, 5.928102923660155),
(1622792398.0702984, 5.976421754788883),
(1622792398.1143086, 5.978031444192469),
(1622792398.1583183, 5.840013156998244),
(1622792398.2023282, 5.880331772474916),
(1622792398.2463388, 5.767639950364697),
(1622792398.2953484, 5.871817890083739),
(1622792398.3433597, 5.749826779673855),
(1622792398.3813684, 5.702740358331703),
(1622792398.4233778, 5.710641499637438),
(1622792398.4683878, 5.702217585589966),
(1622792398.5124223, 5.5732923375305825),
(1622792398.5579855, 5.588537567460422),
(1622792398.5995228, 5.52160358283785),
(1622792398.6465476, 5.535704507428586),
(1622792398.69459, 5.515822612457824),
(1622792398.7346234, 5.481936631722472),
(1622792398.7756708, 5.500847144142474),
(1622792398.8206992, 5.378803157303145),
(1622792398.8667357, 5.3372900770679195),
(1622792398.9122746, 5.345692844809461),
(1622792398.9530225, 5.302138276526713),
(1622792398.9940312, 5.275547192204573),
(1622792399.0400417, 5.215067407977587),
(1622792399.079051, 5.212116137391296),
(1622792399.1230607, 5.189692553415449),
(1622792399.1680782, 5.119261263791048),
(1622792399.214081, 5.1467588113315745),
(1622792399.2590911, 5.022098330148928),
(1622792399.301102, 5.018753236005936),
(1622792399.3461108, 5.024220009161199),
(1622792399.392121, 4.997020054218018),
(1622792399.437131, 4.945596068261978),
(1622792399.481141, 4.932723018306411),
(1622792399.5271509, 4.942201310742497),
(1622792399.5701609, 4.816319284965871),
(1622792399.6181715, 4.739373545629615),
(1622792399.6651824, 4.789211613166458),
(1622792399.7121925, 4.806235728931799),
(1622792399.7592037, 4.7011913675387955),
(1622792399.8032131, 4.676069934919278),
(1622792399.8482234, 4.653602875861772),
(1622792399.8892329, 4.626095580451826),
(1622792399.9792528, 4.560761622529058),
(1622792400.029263, 4.546357760152443),
(1622792400.0792754, 4.465599188393799),
(1622792400.1262856, 4.504438059005176),
(1622792400.1732962, 4.55634971841312),
(1622792400.219307, 4.370042605849587),
(1622792400.262316, 4.463026303123067),
(1622792400.3093266, 4.339514638656357),
(1622792400.3533363, 4.353435829107653),
(1622792400.3983464, 4.2935560165902755),
(1622792400.4373553, 4.2551166033076875),
(1622792400.4783647, 4.275973176145521),
(1622792400.5223742, 4.194263024589574),
(1622792400.5713856, 4.173688599972174),
(1622792400.614395, 4.1939800177132724),
(1622792400.6624055, 4.145986604643372),
(1622792400.7074156, 4.047420720845542),
(1622792400.7554266, 4.056022204910594),
(1622792400.8024366, 3.9935388624048445),
(1622792400.8524482, 4.047170548958951),
(1622792400.900459, 3.9034971168234565),
(1622792400.9374673, 3.861862816755632),
(1622792400.983478, 3.9196889307795564),
(1622792401.023486, 3.8225266854760775),
(1622792401.0644956, 3.765592954180119),
(1622792401.1085052, 3.8210189652225526),
(1622792401.1525152, 3.7507762979036605),
(1622792401.1955252, 3.6952725977906007),
(1622792401.2385347, 3.691293495764154),
(1622792401.2785437, 3.646284458368437),
(1622792401.326554, 3.5880057338630835),
(1622792401.3685637, 3.567872284661067),
(1622792401.4055717, 3.5729036711274254),
(1622792401.4425802, 3.5102972511826906),
(1622792401.4865901, 3.520263886594748),
(1622792401.5316007, 3.4646226023819082),
(1622792401.5796108, 3.479272735935482),
(1622792401.61962, 3.3687251798345748),
(1622792401.6636298, 3.329549047089673),
(1622792401.7036386, 3.330858112637696),
(1622792401.7456484, 3.367511284535597),
(1622792401.7896585, 3.255614373581845),
(1622792401.8336678, 3.257312797582638),
(1622792401.8766775, 3.250458302561756),
(1622792401.9206874, 3.2433332194695774),
(1622792401.9656973, 3.1981645741445206),
(1622792402.0107074, 3.1274381957776445),
(1622792402.0537171, 3.1141382300886233),
(1622792402.0997276, 3.0177243974656243),
(1622792402.1487386, 3.0456876683596787),
(1622792402.195749, 3.0230344813221164),
(1622792402.2477605, 2.966101777263199),
(1622792402.2947712, 2.928255278886037),
(1622792402.3437817, 2.9834528919312064),
(1622792402.3927932, 2.9134416506546197),
(1622792402.4358027, 2.8052101972028307),
(1622792402.4788125, 2.836371481292547),
(1622792402.5228221, 2.708779128605089),
(1622792402.5718331, 2.7453748383591234),
(1622792402.6318467, 2.742326652517084),
(1622792402.7148657, 2.707078422912713),
(1622792402.7688775, 2.717595317190858),
(1622792402.821889, 2.682461292123943),
(1622792402.8709004, 2.63864546324101),
(1622792402.9159105, 2.617359500567666),
(1622792402.9619226, 2.635344825598976),
(1622792403.016935, 2.508751872121775),
(1622792403.0669532, 2.496646873817456),
(1622792403.1189773, 2.4502709383971357),
(1622792403.169989, 2.429224321937066),
(1622792403.2210073, 2.4877366931448863),
(1622792403.271022, 2.380741495561961),
(1622792403.3150356, 2.450460363053271),
(1622792403.3607378, 2.3749833579371407),
(1622792403.4117491, 2.3482944943201316),
(1622792403.456759, 2.314039226467266),
(1622792403.499769, 2.3279575231591974),
(1622792403.5437777, 2.2791369203075282),
(1622792403.5915976, 2.2841304289609905),
(1622792403.6386085, 2.3303726816214194),
(1622792403.687619, 2.1643503435576283),
(1622792403.7346299, 2.2377756499350774),
(1622792403.7816436, 2.1818628440672003),
(1622792403.8286538, 2.23553896955875),
(1622792403.8776693, 2.185416980530934),
(1622792403.9185817, 2.1468807426554415),
(1622792403.9625785, 2.044671921650984),
(1622792404.005581, 2.019868885708958),
(1622792404.0485868, 2.085030926830768),
(1622792404.091585, 2.013209409269687),
(1622792404.1315796, 2.063094329773466),
(1622792404.174422, 2.0335695820781203),
(1622792404.2205825, 1.955220352902154),
(1622792404.266575, 1.9694197268612192),
(1622792404.3135931, 1.9436064577409142),
(1622792404.3605978, 1.9293849445986389),
(1622792404.4050012, 1.9405601930058567),
(1622792404.4530122, 1.907231061180713),
(1622792404.4960215, 1.8470633546270827),
(1622792404.538031, 1.8371063426447225),
(1622792404.5780401, 1.887100677868052),
(1622792404.6200497, 1.8453825951236082),
(1622792404.66706, 1.843914168550808),
(1622792404.7130706, 1.749024969885976),
(1622792404.7540796, 1.7926374004627001),
(1622792404.7950885, 1.7707712084983662),
(1622792404.8390982, 1.7657462925086411),
(1622792404.8821082, 1.7203397154454443),
(1622792404.9231176, 1.7550330317297245),
(1622792404.9651275, 1.8014258333639397),
(1622792405.0041354, 1.753788294338959),
(1622792405.0551467, 1.6098042217222936),
(1622792405.1091588, 1.7565213493050829),
(1622792405.1491678, 1.6164422056212593),
(1622792405.1951783, 1.6047030462829315),
(1622792405.2411888, 1.6871075706736978),
(1622792405.287199, 1.6957417154421657),
(1622792405.3332093, 1.665213689678046),
(1622792405.3772197, 1.5902755935278723),
(1622792405.4192293, 1.6419948871321526),
(1622792405.4652395, 1.7059216498609913),
(1622792405.5192506, 1.619671274741496),
(1622792405.5742633, 1.585051029271406),
(1622792405.6212738, 1.7542632045960191),
(1622792405.6632833, 1.5994076320237858),
(1622792405.7052925, 1.6157299996312835),
(1622792405.7443013, 1.6571872656239133),
(1622792405.7873116, 1.611670225920378),
(1622792405.8293202, 1.5750391343170462),
(1622792405.87233, 1.5598591411640383),
(1622792405.9203408, 1.6009370342120541),
(1622792405.9613507, 1.5056866514789187),
(1622792406.00536, 1.521219812843933),
(1622792406.0503705, 1.509855911497243),
(1622792406.0893788, 1.5268011266566557),
(1622792406.133388, 1.517462199836627),
(1622792406.1713972, 1.5330228656797824),
(1622792406.2224085, 1.4994344850579355),
(1622792406.2704191, 1.5546180431055552),
(1622792406.322431, 1.5572758204297363),
(1622792406.371442, 1.4654591737967788),
(1622792406.4194531, 1.551648555623078),
(1622792406.4594617, 1.5161239272739737),
(1622792406.5034716, 1.507695205214195),
(1622792406.5484815, 1.4916692396187652),
(1622792406.5894904, 1.520439818251527),
(1622792406.6355011, 1.4878439388617926),
(1622792406.6845117, 1.4837815903769973),
(1622792406.7345233, 1.470777222978365),
(1622792406.7815337, 1.5035259634860454),
(1622792406.830545, 1.4800877254525218),
(1622792406.897559, 1.393652841727036),
(1622792406.9505713, 1.4636246573124003),
(1622792406.9985824, 1.3674659023123137),
(1622792407.0435927, 1.4044623908468614),
(1622792407.0926034, 1.5385931728072313),
(1622792407.139614, 1.4761902035019254),
(1622792407.1856244, 1.511570938508206),
(1622792407.2276337, 1.4352019690325961),
(1622792407.2716434, 1.4470581071829907),
(1622792407.3086517, 1.4351508699578575),
(1622792407.3526618, 1.4468198300556967),
(1622792407.3976715, 1.4037910274418395),
(1622792407.4426818, 1.3504410785968366),
(1622792407.4806905, 1.4728579242523638),
(1622792407.5216997, 1.5350199754355225),
(1622792407.563709, 1.4042340440803236),
(1622792407.6067195, 1.4116634431256847),
(1622792407.6467268, 1.491173670311345),
(1622792407.6907365, 1.41533450884296),
(1622792407.731747, 1.4296386308916773),
(1622792407.7757566, 1.4184775780099619),
(1622792407.8217673, 1.4638782508973136),
(1622792407.8657765, 1.4544299782690517),
(1622792407.909787, 1.3867874421942932),
(1622792407.9547973, 1.3762121090203392),
(1622792408.0028071, 1.363342060977767),
(1622792408.0478175, 1.39820636924627),
(1622792408.0988288, 1.347111077618097),
(1622792408.1378376, 1.4210920649348346),
(1622792408.1768463, 1.3601145561631076),
(1622792408.2148588, 1.4371988990908422),
(1622792408.2578642, 1.3811829798564563),
(1622792408.3058753, 1.4128144428462472),
(1622792408.3538861, 1.3951030527317536),
(1622792408.3998966, 1.4353300465540395),
(1622792408.4439054, 1.3689912858077977),
(1622792408.4859154, 1.4996317714652148),
(1622792408.5309258, 1.4023544620317443),
(1622792408.5779357, 1.4223985387026508),
(1622792408.6169448, 1.3403842473063998),
(1622792408.661955, 1.327651446946489),
(1622792408.7089655, 1.3606335205151754),
(1622792408.7549758, 1.3551011501814871),
(1622792408.8019862, 1.342416579968305),
(1622792408.8529975, 1.4277249708893551),
(1622792408.8960078, 1.3652063597163808),
(1622792408.9360166, 1.3896875096011851),
(1622792408.9810264, 1.461973489293837),
(1622792409.0270367, 1.3552811814594188),
(1622792409.0710466, 1.4507618329102154),
(1622792409.1150563, 1.3447757889421856),
(1622792409.1540656, 1.3221074514948645),
(1622792409.1990752, 1.4457273871214218),
(1622792409.2490866, 1.4054765722124754),
(1622792409.2980974, 1.3677928396768895),
(1622792409.3401072, 1.3314222059239427),
(1622792409.3841166, 1.3522082111552434),
(1622792409.436128, 1.3280482582780684),
(1622792409.4821389, 1.336597460357555),
(1622792409.5231478, 1.3806130298130528),
(1622792409.5671575, 1.4739302777535706),
(1622792409.6111674, 1.3101451570025153),
(1622792409.653177, 1.360406591688501),
(1622792409.6981874, 1.2934755286399977),
(1622792409.7441976, 1.3336078572448562),
(1622792409.7922084, 1.3637843069202822),
(1622792409.8412194, 1.2720613751186216),
(1622792409.8912306, 1.3883163328664514),
(1622792409.9402413, 1.3686020958544578),
(1622792409.9852517, 1.3807423760661464),
(1622792410.0272608, 1.3336120520080756),
(1622792410.0772717, 1.3870431241198864),
(1622792410.125283, 1.3640096130491965),
(1622792410.1722937, 1.3744702554069148),
(1622792410.215303, 1.4072365009721095),
(1622792410.2563126, 1.3643074459689943),
(1622792410.2953207, 1.370000380062255),
(1622792410.3453324, 1.3252084628646914),
(1622792410.3893428, 1.2961688156277564),
(1622792410.4313512, 1.2925723078131433),
(1622792410.4793622, 1.3510152670253623),
(1622792410.5253727, 1.3270231610606205),
(1622792410.5683823, 1.3892599969249178),
(1622792410.6163929, 1.354077446081817),
(1622792410.6634033, 1.3394753606401535),
(1622792410.7054129, 1.3349436337769935),
(1622792410.749422, 1.3972823240473566),
(1622792410.7964332, 1.3700049176399294),
(1622792410.8384428, 1.3336954938801484),
(1622792410.8914545, 1.305353708865597),
(1622792410.9574695, 1.4509153376444819),
(1622792411.0164826, 1.319343635849295),
(1622792411.0614924, 1.3108632430699239),
(1622792411.103502, 1.3453691918974533),
(1622792411.1505127, 1.305185281501575),
(1622792411.1975224, 1.30755693249121),
(1622792411.2375329, 1.3034640263754056),
(1622792411.2815418, 1.3263642527678001),
(1622792411.3425553, 1.3450011877996113),
(1622792411.4125714, 1.31000033263009),
(1622792411.4605823, 1.3145376536955296),
(1622792411.5165944, 1.3566229789978506),
(1622792411.5626051, 1.40645459280146),
(1622792411.6046143, 1.418742223107511),
(1622792411.647624, 1.2872054664256833),
(1622792411.6956344, 1.2974690846826935),
(1622792411.7366436, 1.3081964216070419),
(1622792411.7806535, 1.2847995802958156),
(1622792411.8266642, 1.2667147120134061),
(1622792411.871674, 1.3438459695948484),
(1622792411.9186845, 1.3365877197715137),
(1622792411.9666953, 1.3448870119424852),
(1622792412.0117054, 1.2425419761174792),
(1622792412.0597165, 1.285616144830837),
(1622792412.0997252, 1.3770289577968164),
(1622792412.1467357, 1.296103328431034),
(1622792412.1887453, 1.2449703459434167),
(1622792412.2357554, 1.343651544839467),
(1622792412.280766, 1.2626080495697916),
(1622792412.3267763, 1.3062090990541058),
(1622792412.3707857, 1.3435344863760212),
(1622792412.410795, 1.3246337089632452),
(1622792412.4518113, 1.2778671711491048),
(1622792412.4958138, 1.3130521128994141),
(1622792412.5418248, 1.285391594629331),
(1622792412.5828338, 1.3603052196599426),
(1622792412.6208422, 1.363727411684868),
(1622792412.659851, 1.3808199612277092),
(1622792412.70386, 1.2768024650854686),
(1622792412.750871, 1.2898625050972439),
(1622792412.7888796, 1.3651422649398868),
(1622792412.8328893, 1.2747978532964568),
(1622792412.8799005, 1.2580876827415988),
(1622792412.9279106, 1.286051130312631),
(1622792412.9739208, 1.3084449890811232),
(1622792413.0179307, 1.3172923682265802),
(1622792413.0559397, 1.3475561107611742),
(1622792413.0999496, 1.276188655339684),
(1622792413.14596, 1.3194223613981342),
(1622792413.1889694, 1.3036770623824134),
(1622792413.2269778, 1.2983364165111402),
(1622792413.2699876, 1.2798270614999432),
(1622792413.3169982, 1.3396883394421824),
(1622792413.356006, 1.3063873500520542),
(1622792413.394015, 1.3170352739714744),
(1622792413.4360247, 1.236997136231875),
(1622792413.474033, 1.2840675837945636),
(1622792413.5150423, 1.2667029673045371),
(1622792413.5590522, 1.2486533032819263),
(1622792413.6040623, 1.2432947936936167),
(1622792413.652073, 1.2320725637724277),
(1622792413.6950824, 1.23025617910125),
(1622792413.7350917, 1.2374983655509129),
(1622792413.7791018, 1.2509548476072676),
(1622792413.8251119, 1.2172973421624065),
(1622792413.8721218, 1.2677252774404186),
(1622792413.9141316, 1.2363287613736793),
(1622792413.9621427, 1.2468976867223913),
(1622792414.0091534, 1.3249285309953858),
(1622792414.050162, 1.2535808323214),
(1622792414.0921729, 1.189377058635781),
(1622792414.1361816, 1.3026121572698008),
(1622792414.181192, 1.2854512193402192),
(1622792414.222201, 1.2561926838542377),
(1622792414.2592094, 1.2920772417033697),
(1622792414.2962172, 1.290760137227305),
(1622792414.3352258, 1.2490235288217144)]
[10]:
tight_kwargs = dict(bbox_inches='tight', pad_inches=0)
t_predict = torch.linspace(0, 15, 80, dtype=torch.float32)
lfm.eval()
q_m = lfm.predict_m(t_predict, step_size=step_size)
q_f = lfm.predict_f(t_predict)
plotter.plot_losses(trainer, last_x=200)
fig, axes = plt.subplots(nrows=2, ncols=3, figsize=(10, 4))
row = 0
col = 0
for i in range(num_genes):
if i == 3:
row += 1
col = 0
ax = axes[row, col]
plotter.plot_gp(q_m, t_predict, replicate=0, ax=ax,
color=Colours.line_color, shade_color=Colours.shade_color,
t_scatter=dataset.t_observed, y_scatter=dataset.m_observed, num_samples=0)
col += 1
ax.set_title(dataset.gene_names[i])
plotter.plot_gp(q_f, t_predict, ax=axes[1, 2],
ylim=(-2, 3.2),
num_samples=3,
color=Colours.line2_color,
shade_color=Colours.shade2_color)
axes[1, 2].set_title('Latent force (p53)')
plt.savefig('./combined.pdf', **tight_kwargs)
titles = ['Basal rates', 'Sensitivities', 'Decay rates']
kinetics = list()
for key in ['raw_basal', 'raw_sensitivity', 'raw_decay']:
kinetics.append(
lfm.positivity.transform(trainer.parameter_trace[key][-1].squeeze()).numpy())
kinetics = np.array(kinetics)
plotter.plot_double_bar(kinetics, titles=titles, ground_truths=dataset.params_ground_truth(),
figsize=(9, 3),
yticks=[
np.linspace(0, 0.12, 5),
np.linspace(0, 1.2, 4),
np.arange(0, 1.1, 0.2),
])
plt.tight_layout()
plt.savefig('./kinetics.pdf', **tight_kwargs)
[ ]: