[1]:
import torch
from torch.nn import Parameter
from gpytorch.distributions import MultitaskMultivariateNormal, MultivariateNormal

from alfi.models import OrdinaryLFM, generate_multioutput_rbf_gp
from alfi.trainers import VariationalTrainer
from alfi.utilities.torch import softplus
from alfi.configuration import VariationalConfiguration
from alfi.datasets import ToyTranscriptomicGenerator
from alfi.plot import Plotter1d

from matplotlib import pyplot as plt

import numpy as np

[2]:
    # def predict_f(self, t_predict):
    #     # Sample from the latent distribution
    #     q_f = super().predict_f(t_predict)
    #     f = q_f.sample(torch.Size([500])).permute(0, 2, 1)  # (S, I, T)
    #     print(f.shape)
    #     # This is a hack to wrap the latent function with the nonlinearity. Note we use the same variance.
    #     f = torch.mean(self.G(f), dim=0)[0].unsqueeze(0)
    #     print(f.shape, q_f.mean.shape, q_f.scale_tril.shape)
    #     batch_mvn = MultivariateNormal(f, q_f.covariance_matrix.unsqueeze(0))
    #     print(batch_mvn)
    #
    #     return MultitaskMultivariateNormal.from_batch_mvn(batch_mvn, task_dim=0)
[19]:
print(torch.finfo(torch.float32).tiny)
config = VariationalConfiguration(
    num_samples=70,
    initial_conditions=False # TODO
)

num_latents = 3
num_outputs = 30
dataset = ToyTranscriptomicGenerator(
    num_outputs, num_latents, softplus=True)
dataset.generate_single()#basal_rate, sensitivity, decay_rate, lengthscale=lengthscale)
print(dataset.t_observed)
# plt.plot(dataset[])
1.1754943508222875e-38
tensor([0.0000, 1.0000, 2.0000, 3.0000, 4.0000, 5.0000, 6.0000, 7.0000, 8.0000,
        9.0000])

Now let’s try to train a model to get the same results

[21]:
class TranscriptionLFM(OrdinaryLFM):
    def __init__(self, num_outputs, gp_model, config: VariationalConfiguration):
        super().__init__(num_outputs, gp_model, config)
        num_latents = gp_model.variational_strategy.num_tasks
        self.decay_rate = Parameter(0.1 + torch.rand(torch.Size([self.num_outputs, 1]), dtype=torch.float32))
        self.basal_rate = Parameter(0.1 + torch.rand(torch.Size([self.num_outputs, 1]), dtype=torch.float32))
        # self.sensitivity = Parameter(0.1 + torch.rand(torch.Size([self.num_outputs, 1]), dtype=torch.float32))
        self.weight = Parameter(0.5 * torch.randn(torch.Size([self.num_outputs, num_latents]), dtype=torch.float32))
        self.weight_bias = Parameter(torch.randn(torch.Size([self.num_outputs, 1]), dtype=torch.float32))

    def initial_state(self):
        return self.basal_rate / self.decay_rate

    def odefunc(self, t, h):
        """h is of shape (num_samples, num_outputs, 1)"""
        self.nfe += 1
        # if (self.nfe % 100) == 0:
        #     print(t)
        # print(t, self.t_index)
        # print(self.f.shape)
        decay = self.decay_rate * h

        f = self.f[:, :, self.t_index].unsqueeze(2)
        # print(f.shape, self.sensitivity.shape)
        h = self.basal_rate + f - decay
        if t > self.last_t:
            self.t_index += 1
        self.last_t = t
        return h

    def mix(self, f):
        # return f.repeat(1, 10, 1)
        f = softplus(f)
        interactions = torch.matmul(self.weight, torch.log(f+1e-100)) + self.weight_bias
        f = torch.sigmoid(interactions) # TF Activation Function (sigmoid)
        return f
[30]:
num_inducing = 12  # (I x m x 1)
inducing_points = torch.linspace(0, 12, num_inducing).repeat(num_latents, 1).view(num_latents, num_inducing, 1)
t_predict = torch.linspace(0, 13, 80, dtype=torch.float32)

gp_model = generate_multioutput_rbf_gp(num_latents, inducing_points,
                                       initial_lengthscale=1,
                                       gp_kwargs=dict(natural=False))
lfm = TranscriptionLFM(num_outputs, gp_model, config)
plotter = Plotter1d(lfm, np.arange(num_outputs))

optimizer = torch.optim.Adam(lfm.parameters(), lr=0.01)

trainer = VariationalTrainer(lfm, [optimizer], dataset)
[54]:
times = trainer.train(100, report_interval=10, step_size=5e-1)
C:\Users\Jacob\miniconda3\envs\wishart\lib\site-packages\gpytorch\distributions\multivariate_normal.py:230: NumericalWarning: Negative variance values detected. This is likely due to numerical instabilities. Rounding negative variances up to 1e-06.
  NumericalWarning,
Epoch 1001/1100 - Loss: -82.46 (-84.33 1.87) kernel: [1.5689356 1.7532138 1.2205318]
Epoch 1011/1100 - Loss: -83.12 (-85.00 1.88) kernel: [1.5706058 1.7531269 1.2204925]
Epoch 1021/1100 - Loss: -83.18 (-85.06 1.88) kernel: [1.5693965 1.7527    1.2198731]
Epoch 1031/1100 - Loss: -83.47 (-85.36 1.89) kernel: [1.571539  1.7524838 1.2200401]
Epoch 1041/1100 - Loss: -83.99 (-85.90 1.91) kernel: [1.5717928 1.7539551 1.2210684]
Epoch 1051/1100 - Loss: -84.22 (-86.13 1.91) kernel: [1.5707824 1.7538483 1.2202107]
Epoch 1061/1100 - Loss: -84.64 (-86.55 1.91) kernel: [1.5726596 1.7529309 1.2205873]
Epoch 1071/1100 - Loss: -85.09 (-87.01 1.92) kernel: [1.5756844 1.7545503 1.2216972]
Epoch 1081/1100 - Loss: -85.13 (-87.06 1.92) kernel: [1.5742133 1.7510046 1.2214695]
Epoch 1091/1100 - Loss: -85.06 (-86.99 1.93) kernel: [1.5766728 1.7536612 1.2214465]
[55]:
lfm.eval()
# labels = ['Basal rates', 'Sensitivities', 'Decay rates']
# keys = ['raw_basal', 'raw_sensitivity', 'raw_decay']
# constraints = [lfm.positivity, lfm.positivity, lfm.positivity]
# kinetics = list()
# for i, key in enumerate(keys):
#     kinetics.append(
#         constraints[i].transform(trainer.parameter_trace[key][-1].squeeze()).numpy())
#
# print(kinetics[0].shape)
# plotter.plot_double_bar(kinetics, labels, figsize=(10, 3), ground_truths=hafner_ground_truth())
q_m = lfm.predict_m(t_predict, step_size=1e-1)
q_f = lfm.predict_f(t_predict)

plotter.plot_gp(q_m, t_predict, replicate=0,
                t_scatter=dataset.t_observed,
                y_scatter=dataset.m_observed, num_samples=0)
plotter.plot_gp(q_f, t_predict,
                t_scatter=dataset.t_observed,
                y_scatter=dataset.f_observed, num_samples=5)
plt.title('Latent')
fig, axes = plt.subplots(ncols=3)
for i in range(3):
    axes[i].plot(dataset.t_observed_highres, dataset.f_observed_highres[0, i])
dataset.f_observed.shape
[55]:
torch.Size([1, 3, 10])
../../_images/notebooks_nonlinear_multitf_7_1.png
../../_images/notebooks_nonlinear_multitf_7_2.png
../../_images/notebooks_nonlinear_multitf_7_3.png
[39]:
gp_model.covar_module.lengthscale
[39]:
tensor([[[1.3871]],

        [[1.6633]],

        [[0.7071]]], grad_fn=<SoftplusBackward>)