Navier-Stokes

[21]:
import numpy as np
import torch

from torch.nn import Parameter
from torch.optim import Adam
from gpytorch.optim import NGD
from gpytorch.constraints import Interval
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
from os import path

from lafomo.utilities.torch import get_image
from lafomo.datasets import DrosophilaSpatialTranscriptomics
from lafomo.models import MultiOutputGP, PartialLFM
from lafomo.models.pdes import ReactionDiffusion
from lafomo.datasets import ToySpatialTranscriptomics, P53Data
from lafomo.configuration import VariationalConfiguration
from lafomo.plot import Plotter, plot_spatiotemporal_data
from lafomo.trainers import PDETrainer
from lafomo.utilities.torch import discretise, softplus
from lafomo.utilities.fenics import interval_mesh
[22]:
from fenics import *
from dolfin import *
from mshr import *
[31]:
# Parameters
T = 5.0/20            # final time
num_steps = 5000//20   # number of time steps
dt = T / num_steps # time step size
mu = 0.001         # dynamic viscosity
rho = 1            # density
[32]:
# Generate Mesh
channel = Rectangle(Point(0, 0), Point(2.2, 0.41))
cylinder = Circle(Point(0.2, 0.2), 0.05)
domain = channel - cylinder
mesh = generate_mesh(domain, 32)
plot(mesh)
[32]:
[<matplotlib.lines.Line2D at 0x7fa22a2336d0>,
 <matplotlib.lines.Line2D at 0x7fa22a233a10>]
../../_images/notebooks_pde_pde_stokes_4_1.png
[33]:
# Define function spaces
V = VectorFunctionSpace(mesh, 'P', 2)
Q = FunctionSpace(mesh, 'P', 1)
[34]:
# Define boundaries
inflow   = 'near(x[0], 0)'
outflow  = 'near(x[0], 2.2)'
walls    = 'near(x[1], 0) || near(x[1], 0.41)'
cylinder = 'on_boundary && x[0]>0.1 && x[0]<0.3 && x[1]>0.1 && x[1]<0.3'

inflow_profile = ('4.0*1.5*x[1]*(0.41 - x[1]) / pow(0.41, 2)', '0')
# Define boundary conditions
bcu_inflow = DirichletBC(V, Expression(inflow_profile, degree=2), inflow)
bcu_walls = DirichletBC(V, Constant((0, 0)), walls)
bcu_cylinder = DirichletBC(V, Constant((0, 0)), cylinder)
bcp_outflow = DirichletBC(Q, Constant(0), outflow)
bcu = [bcu_inflow, bcu_walls, bcu_cylinder]
bcp = [bcp_outflow]

[35]:
# Generate dataset
# Define trial and test functions
u = TrialFunction(V)
v = TestFunction(V)
p = TrialFunction(Q)
q = TestFunction(Q)

# Define functions for solutions at previous and current time steps
u_n = Function(V)
u_  = Function(V)
p_n = Function(Q)
p_  = Function(Q)

# Define expressions used in variational forms
U  = 0.5*(u_n + u)
n  = FacetNormal(mesh)
f  = Constant((0, 0))
k  = Constant(dt)
mu = Constant(mu)
rho = Constant(rho)
[36]:
# Define symmetric gradient
def epsilon(u):
    return sym(nabla_grad(u))

# Define stress tensor
def sigma(u, p):
    return 2*mu*epsilon(u) - p*Identity(len(u))
[37]:
# Define variational forms:
# Define variational problem for step 1
F1 = rho*dot((u - u_n) / k, v)*dx \
   + rho*dot(dot(u_n, nabla_grad(u_n)), v)*dx \
   + inner(sigma(U, p_n), epsilon(v))*dx \
   + dot(p_n*n, v)*ds - dot(mu*nabla_grad(U)*n, v)*ds \
   - dot(f, v)*dx
a1 = lhs(F1)
L1 = rhs(F1)

# Define variational problem for step 2
a2 = dot(nabla_grad(p), nabla_grad(q))*dx
L2 = dot(nabla_grad(p_n), nabla_grad(q))*dx - (1/k)*div(u_)*q*dx

# Define variational problem for step 3
a3 = dot(u, v)*dx
L3 = dot(u_, v)*dx - k*dot(nabla_grad(p_ - p_n), v)*dx

# Assemble matrices
A1 = assemble(a1)
A2 = assemble(a2)
A3 = assemble(a3)

# Apply boundary conditions to matrices
[bc.apply(A1) for bc in bcu]
[bc.apply(A2) for bc in bcp]
[37]:
[None]
[38]:
from tqdm import tqdm
# Create progress bar
progress = Progress('Time-stepping')

xdmffile_u = XDMFFile('navier_stokes_cylinder/velocity.xdmf')
xdmffile_p = XDMFFile('navier_stokes_cylinder/pressure.xdmf')
xdmffile_u.parameters['flush_output'] = True
xdmffile_p.parameters['flush_output'] = True

# Time-stepping
t = 0
for n in tqdm(range(num_steps)):

    # Update current time
    t += dt

    # Step 1: Tentative velocity step
    b1 = assemble(L1)
    [bc.apply(b1) for bc in bcu]
    solve(A1, u_.vector(), b1, 'bicgstab', 'hypre_amg')

    # Step 2: Pressure correction step
    b2 = assemble(L2)
    [bc.apply(b2) for bc in bcp]
    solve(A2, p_.vector(), b2, 'bicgstab', 'hypre_amg')

    # Step 3: Velocity correction step
    b3 = assemble(L3)
    solve(A3, u_.vector(), b3, 'cg', 'sor')

    # Plot solution
    plot(u_, title='Velocity')
    plot(p_, title='Pressure')

    # Save solution to file (XDMF/HDF5)
    xdmffile_u.write(u_, t)
    xdmffile_p.write(p_, t)

    # Save nodal values to file
    # timeseries_u.store(u_.vector(), t)
    # timeseries_p.store(p_.vector(), t)

    # Update previous solution
    u_n.assign(u_)
    p_n.assign(p_)

    # Update progress bar
    # print('u max:', u_.vector().max())
100%|██████████| 250/250 [00:29<00:00,  8.38it/s]
../../_images/notebooks_pde_pde_stokes_10_1.png
[39]:
print(xdmffile_u)
print(T)
<dolfin.cpp.io.XDMFFile object at 0x7fa22815e570>
0.25
[ ]:
drosophila = True

if drosophila:
    filepath = path.join('../../../experiments', 'dros-kr', 'partial', 'savedmodel')

    dataset = DrosophilaSpatialTranscriptomics(gene='kr', data_dir='../../../data')

    data = next(iter(dataset))
    tx, y_target = data
    lengthscale = 10
    images = [get_image(dataset.orig_data, i)
          for i in range(2, 4)
          for dataset in [kr_dataset, kni_dataset, gt_dataset]
          ]

else:
    filepath = path.join('../../../experiments', 'toy-spatial', 'partial', 'savedmodel')

    dataset = ToySpatialTranscriptomics(data_dir='../../../data/')
    data = next(iter(dataset))
    tx, y_target = data
    lengthscale = 0.2

num_inducing = int(tx.shape[1] * 5/6)

ts = tx[0, :].unique().sort()[0].numpy()
xs = tx[1, :].unique().sort()[0].numpy()
t_diff = ts[-1]-ts[0]
x_diff = xs[-1]-xs[0]
extent = [ts[0], ts[-1], xs[0], xs[-1]]

if drosophila: plot_spatiotemporal_data(images, extent, nrows=2, ncols=3)

Set up GP model

[5]:
inducing_points = torch.stack([
    tx[0, torch.randperm(tx.shape[1])[:int(0.3 * tx.shape[1])]],
    tx[1, torch.randperm(tx.shape[1])[:int(0.3 * tx.shape[1])]]
], dim=1).unsqueeze(0)

gp_kwargs = dict(use_ard=True,
                 use_scale=False,
                 # lengthscale_constraint=Interval(0.1, 0.3),
                 learn_inducing_locations=False,
                 initial_lengthscale=lengthscale)
gp_model = MultiOutputGP(inducing_points, 1, **gp_kwargs)
gp_model.double();

print(inducing_points.shape)
plt.scatter(inducing_points[0,:,0], inducing_points[0, :, 1])
torch.Size([1, 504, 2])
[5]:
<matplotlib.collections.PathCollection at 0x7fb5280eb450>
../../_images/notebooks_pde_pde_stokes_14_2.png

Set up PDE

[6]:
t_range = (ts[0], ts[-1])
print(t_range)
time_steps = dataset.num_discretised
print(time_steps)

fenics_model = ReactionDiffusion(t_range, time_steps, mesh)

config = VariationalConfiguration(
    initial_conditions=False,
    num_samples=25
)

sensitivity = Parameter(torch.ones((1, 1), dtype=torch.float64), requires_grad=False)
decay = Parameter(0.1*torch.ones((1, 1), dtype=torch.float64), requires_grad=False)
diffusion = Parameter(0.01*torch.ones((1, 1), dtype=torch.float64), requires_grad=False)
fenics_params = [sensitivity, decay, diffusion]

lfm = PartialLFM(1, gp_model, fenics_model, fenics_params, config)
(0.0, 1.0)
40
[9]:
train_mask = torch.zeros_like(tx[0,:])
train_mask[torch.randperm(tx.shape[1])[:int(0.3 * tx.shape[1])]] = 1

num_training = tx.shape[1]
variational_optimizer = NGD(lfm.variational_parameters(), num_data=num_training, lr=0.1)
parameter_optimizer = Adam(lfm.nonvariational_parameters(), lr=0.07)
optimizers = [variational_optimizer, parameter_optimizer]

trainer = PDETrainer(lfm,
                     optimizers,
                     dataset,
                     track_parameters=list(lfm.fenics_named_parameters.keys()),
                     train_mask=train_mask.bool(),
                     warm_variational=1)
t_sorted, dp [0.    0.025 0.05  0.075 0.1   0.125 0.15  0.175 0.2   0.225 0.25  0.275
 0.3   0.325 0.35  0.375 0.4   0.425 0.45  0.475 0.5   0.525 0.55  0.575
 0.6   0.625 0.65  0.675 0.7   0.725 0.75  0.775 0.8   0.825 0.85  0.875
 0.9   0.925 0.95  0.975 1.   ] 0.025
x dp is set to 0.025
t_sorted, dp [0.    0.025 0.05  0.075 0.1   0.125 0.15  0.175 0.2   0.225 0.25  0.275
 0.3   0.325 0.35  0.375 0.4   0.425 0.45  0.475 0.5   0.525 0.55  0.575
 0.6   0.625 0.65  0.675 0.7   0.725 0.75  0.775 0.8   0.825 0.85  0.875
 0.9   0.925 0.95  0.975 1.   ] 0.025

Now let’s see some samples from the GP and corresponding LFM output

[10]:
num_t = trainer.tx[0, :].unique().shape[0]
num_x = trainer.tx[1, :].unique().shape[0]

# gp_model.covar_module.lengthscale = 0.3*0.3 * 2
out = gp_model(trainer.tx.transpose(0, 1))

sample = out.sample(torch.Size([lfm.config.num_samples])).permute(0, 2, 1)
real = torch.tensor(dataset.orig_data[trainer.t_sorted, 2]).unsqueeze(0)


plot_spatiotemporal_data(
    [sample.mean(0)[0].detach().view(num_t, num_x).transpose(0, 1),
    real.squeeze().view(num_t, num_x).transpose(0, 1)],
    extent,
    titles=['Prediction', 'Ground truth']
)

sample = sample.view(lfm.config.num_samples, 1, num_t, num_x)
real = real.repeat(lfm.config.num_samples, 1, 1)
real = real.view(lfm.config.num_samples, 1, num_t, num_x)

out = lfm.solve_pde(sample)
real_out = lfm.solve_pde(real)

plot_spatiotemporal_data(
    [out.mean(0).detach().transpose(0, 1),
    real_out[0].detach().transpose(0, 1)],
    extent,
    titles=['Prediction', 'Ground truth']
)
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
<ipython-input-10-d904d801b3d5> in <module>
     27     real_out[0].detach().transpose(0, 1)],
     28     extent,
---> 29     titles=['Prediction', 'Ground truth']
     30 )
     31

~/Documents/proj/reggae/lafomo/plot/misc.py in plot_spatiotemporal_data(images, extent, nrows, ncols, titles)
     22                      cbar_mode="each",
     23                      cbar_size="7%",
---> 24                      cbar_pad="2%",
     25                  )
     26     aspect = (extent[1]-extent[0])/ (extent[3]-extent[2])

~/miniconda3/envs/wishart/lib/python3.7/site-packages/matplotlib/cbook/deprecation.py in wrapper(*inner_args, **inner_kwargs)
    409                          else deprecation_addendum,
    410                 **kwargs)
--> 411         return func(*inner_args, **inner_kwargs)
    412
    413     return wrapper

~/miniconda3/envs/wishart/lib/python3.7/site-packages/mpl_toolkits/axes_grid1/axes_grid.py in __init__(self, fig, rect, nrows_ncols, ngrids, direction, axes_pad, add_all, share_all, aspect, label_mode, cbar_mode, cbar_location, cbar_pad, cbar_size, cbar_set_cax, axes_class)
    434                 direction=direction, axes_pad=axes_pad,
    435                 share_all=share_all, share_x=True, share_y=True, aspect=aspect,
--> 436                 label_mode=label_mode, axes_class=axes_class)
    437         else:  # Only show deprecation in that case.
    438             super().__init__(

~/miniconda3/envs/wishart/lib/python3.7/site-packages/matplotlib/cbook/deprecation.py in wrapper(*inner_args, **inner_kwargs)
    409                          else deprecation_addendum,
    410                 **kwargs)
--> 411         return func(*inner_args, **inner_kwargs)
    412
    413     return wrapper

~/miniconda3/envs/wishart/lib/python3.7/site-packages/mpl_toolkits/axes_grid1/axes_grid.py in __init__(self, fig, rect, nrows_ncols, ngrids, direction, axes_pad, add_all, share_all, share_x, share_y, label_mode, axes_class, aspect)
    206         self.axes_llc = self.axes_column[0][-1]
    207
--> 208         self._init_locators()
    209
    210         if add_all:

~/miniconda3/envs/wishart/lib/python3.7/site-packages/mpl_toolkits/axes_grid1/axes_grid.py in _init_locators(self)
    474                 self.axes_all[0].figure, self._divider.get_position(),
    475                 orientation=self._colorbar_location)
--> 476             for _ in range(self.ngrids)]
    477
    478         cb_mode = self._colorbar_mode

~/miniconda3/envs/wishart/lib/python3.7/site-packages/mpl_toolkits/axes_grid1/axes_grid.py in <listcomp>(.0)
    474                 self.axes_all[0].figure, self._divider.get_position(),
    475                 orientation=self._colorbar_location)
--> 476             for _ in range(self.ngrids)]
    477
    478         cb_mode = self._colorbar_mode

~/miniconda3/envs/wishart/lib/python3.7/site-packages/mpl_toolkits/axes_grid1/axes_grid.py in __init__(self, orientation, *args, **kwargs)
     25         self._default_label_on = True
     26         self._locator = None  # deprecated.
---> 27         super().__init__(*args, **kwargs)
     28
     29     @cbook._rename_parameter("3.2", "locator", "ticks")

~/miniconda3/envs/wishart/lib/python3.7/site-packages/matplotlib/axes/_base.py in __init__(self, fig, rect, facecolor, frameon, sharex, sharey, label, xscale, yscale, box_aspect, **kwargs)
    509
    510         self._rasterization_zorder = None
--> 511         self.cla()
    512
    513         # funcs used to format x and y - fall back on major formatters

~/miniconda3/envs/wishart/lib/python3.7/site-packages/mpl_toolkits/axes_grid1/axes_grid.py in cla(self)
     80     def cla(self):
     81         super().cla()
---> 82         self._config_axes()
     83
     84

~/miniconda3/envs/wishart/lib/python3.7/site-packages/mpl_toolkits/axes_grid1/axes_grid.py in _config_axes(self)
     69         ax = self
     70         ax.set_navigate(False)
---> 71         ax.axis[:].toggle(all=False)
     72         b = self._default_label_on
     73         ax.axis[self.orientation].toggle(all=b)

~/miniconda3/envs/wishart/lib/python3.7/site-packages/mpl_toolkits/axes_grid1/mpl_axes.py in __call__(self, *args, **kwargs)
     14     def __call__(self, *args, **kwargs):
     15         for m in self._objects:
---> 16             m(*args, **kwargs)
     17
     18

~/miniconda3/envs/wishart/lib/python3.7/site-packages/mpl_toolkits/axes_grid1/mpl_axes.py in toggle(self, all, ticks, ticklabels, label)
    124         if _ticklabels is not None:
    125             tickparam = {labelOn: _ticklabels}
--> 126             self._axis.set_tick_params(**tickparam)
    127
    128         if _label is not None:

~/miniconda3/envs/wishart/lib/python3.7/site-packages/matplotlib/axis.py in set_tick_params(self, which, reset, **kw)
    837                 self._major_tick_kw.update(kwtrans)
    838                 for tick in self.majorTicks:
--> 839                     tick._apply_params(**kwtrans)
    840             if which in ['minor', 'both']:
    841                 self._minor_tick_kw.update(kwtrans)

~/miniconda3/envs/wishart/lib/python3.7/site-packages/matplotlib/axis.py in _apply_params(self, **kw)
    381         label_kw = {k[5:]: v for k, v in kw.items()
    382                     if k in ['labelsize', 'labelcolor']}
--> 383         self.label1.set(**label_kw)
    384         self.label2.set(**label_kw)
    385         for k, v in label_kw.items():

~/miniconda3/envs/wishart/lib/python3.7/site-packages/matplotlib/artist.py in set(self, **kwargs)
   1086     def set(self, **kwargs):
   1087         """A property batch setter.  Pass *kwargs* to set properties."""
-> 1088         kwargs = cbook.normalize_kwargs(kwargs, self)
   1089         move_color_to_start = False
   1090         if "color" in kwargs:

~/miniconda3/envs/wishart/lib/python3.7/site-packages/matplotlib/cbook/deprecation.py in wrapper(*inner_args, **inner_kwargs)
    409                          else deprecation_addendum,
    410                 **kwargs)
--> 411         return func(*inner_args, **inner_kwargs)
    412
    413     return wrapper

~/miniconda3/envs/wishart/lib/python3.7/site-packages/matplotlib/cbook/deprecation.py in wrapper(*inner_args, **inner_kwargs)
    409                          else deprecation_addendum,
    410                 **kwargs)
--> 411         return func(*inner_args, **inner_kwargs)
    412
    413     return wrapper

~/miniconda3/envs/wishart/lib/python3.7/site-packages/matplotlib/cbook/deprecation.py in wrapper(*inner_args, **inner_kwargs)
    409                          else deprecation_addendum,
    410                 **kwargs)
--> 411         return func(*inner_args, **inner_kwargs)
    412
    413     return wrapper

~/miniconda3/envs/wishart/lib/python3.7/site-packages/matplotlib/cbook/__init__.py in normalize_kwargs(kw, alias_mapping, required, forbidden, allowed)
   1738     elif (isinstance(alias_mapping, type) and issubclass(alias_mapping, Artist)
   1739           or isinstance(alias_mapping, Artist)):
-> 1740         alias_mapping = getattr(alias_mapping, "_alias_map", {})
   1741
   1742     to_canonical = {alias: canonical

KeyboardInterrupt:
../../_images/notebooks_pde_pde_stokes_19_1.png
<Figure size 432x288 with 0 Axes>
[ ]:
#print(hihi)

trainer.train(2)
[24]:
print(sample.shape)

plt.imshow(sample.mean(0)[0].transpose(0, 1))
plt.colorbar()
plt.figure()
plt.imshow(out.mean(0).detach())
# lfm.save(filepath)
torch.Size([25, 1, 512])
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-24-d175470bbe4e> in <module>
      1 print(sample.shape)
      2
----> 3 plt.imshow(sample.mean(0)[0].transpose(0, 1))
      4 plt.colorbar()
      5 plt.figure()

IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
[25]:
lfm = PartialLFM.load(filepath,
                      gp_cls=MultiOutputGP,
                      gp_args=[inducing_points, 1],
                      gp_kwargs=gp_kwargs,
                      lfm_args=[1, fenics_model, fenics_params, config])
# lfm = PartialLFM(gp_model, fenics_model, fenics_params, config)

gp_model = lfm.gp_model
optimizer = torch.optim.Adam(lfm.parameters(), lr=0.07)
trainer = PDETrainer(lfm, optimizer, dataset, track_parameters=list(lfm.fenics_named_parameters.keys()))
t_sorted, dp [53.925 60.175 66.425 72.675 78.925 85.175 91.425 97.675] 6.25
x dp is set to 1.0
t_sorted, dp [25.5 26.5 27.5 28.5 29.5 30.5 31.5 32.5 33.5 34.5 35.5 36.5 37.5 38.5
 39.5 40.5 41.5 42.5 43.5 44.5 45.5 46.5 47.5 48.5 49.5 50.5 51.5 52.5
 53.5 54.5 55.5 56.5 57.5 58.5 59.5 60.5 61.5 62.5 63.5 64.5 65.5 66.5
 67.5 68.5 69.5 70.5 71.5 72.5 73.5 74.5 75.5 76.5 77.5 78.5 79.5 80.5
 81.5 82.5 83.5 84.5 85.5 86.5 87.5 88.5] 1.0
[26]:
from lafomo.utilities.torch import smse, cia, q2

tx = trainer.tx
num_t = tx[0, :].unique().shape[0]
num_x = tx[1, :].unique().shape[0]
# f_mean = lfm(tx).mean.detach()
# f_var = lfm(tx).variance.detach()
y_target = trainer.y_target[0]
ts = tx[0, :].unique().sort()[0].numpy()
xs = tx[1, :].unique().sort()[0].numpy()
t_diff = ts[-1] - ts[0]
x_diff = xs[-1] - xs[0]
extent = [ts[0], ts[-1], xs[0], xs[-1]]
print(y_target.shape, f_mean.squeeze().shape)
f_mean_test = f_mean.squeeze()
f_var_test = f_var.squeeze()

print(q2(y_target, f_mean.squeeze()))
print(cia(y_target, f_mean_test, f_var_test).item())
print(smse(y_target, f_mean_test).mean().item())
---------------------------------------------------------------------------
NameError                                 Traceback (most recent call last)
<ipython-input-26-2d90327e70b2> in <module>
     12 x_diff = xs[-1] - xs[0]
     13 extent = [ts[0], ts[-1], xs[0], xs[-1]]
---> 14 print(y_target.shape, f_mean.squeeze().shape)
     15 f_mean_test = f_mean.squeeze()
     16 f_var_test = f_var.squeeze()

NameError: name 'f_mean' is not defined
[32]:
plotter = Plotter(lfm, np.arange(1))

labels = ['Sensitivity', 'Decay', 'Diffusion']
kinetics = list()
for key in lfm.fenics_named_parameters.keys():
    kinetics.append(softplus(trainer.parameter_trace[key][-1]).squeeze().numpy())

plotter.plot_double_bar(kinetics, labels)

# plotter.plot_latents()
../../_images/notebooks_pde_pde_stokes_24_0.png
[ ]: