Navier-Stokes
[21]:
import numpy as np
import torch
from torch.nn import Parameter
from torch.optim import Adam
from gpytorch.optim import NGD
from gpytorch.constraints import Interval
from matplotlib import pyplot as plt
from mpl_toolkits.axes_grid1 import ImageGrid
from os import path
from lafomo.utilities.torch import get_image
from lafomo.datasets import DrosophilaSpatialTranscriptomics
from lafomo.models import MultiOutputGP, PartialLFM
from lafomo.models.pdes import ReactionDiffusion
from lafomo.datasets import ToySpatialTranscriptomics, P53Data
from lafomo.configuration import VariationalConfiguration
from lafomo.plot import Plotter, plot_spatiotemporal_data
from lafomo.trainers import PDETrainer
from lafomo.utilities.torch import discretise, softplus
from lafomo.utilities.fenics import interval_mesh
[22]:
from fenics import *
from dolfin import *
from mshr import *
[31]:
# Parameters
T = 5.0/20 # final time
num_steps = 5000//20 # number of time steps
dt = T / num_steps # time step size
mu = 0.001 # dynamic viscosity
rho = 1 # density
[32]:
# Generate Mesh
channel = Rectangle(Point(0, 0), Point(2.2, 0.41))
cylinder = Circle(Point(0.2, 0.2), 0.05)
domain = channel - cylinder
mesh = generate_mesh(domain, 32)
plot(mesh)
[32]:
[<matplotlib.lines.Line2D at 0x7fa22a2336d0>,
<matplotlib.lines.Line2D at 0x7fa22a233a10>]
[33]:
# Define function spaces
V = VectorFunctionSpace(mesh, 'P', 2)
Q = FunctionSpace(mesh, 'P', 1)
[34]:
# Define boundaries
inflow = 'near(x[0], 0)'
outflow = 'near(x[0], 2.2)'
walls = 'near(x[1], 0) || near(x[1], 0.41)'
cylinder = 'on_boundary && x[0]>0.1 && x[0]<0.3 && x[1]>0.1 && x[1]<0.3'
inflow_profile = ('4.0*1.5*x[1]*(0.41 - x[1]) / pow(0.41, 2)', '0')
# Define boundary conditions
bcu_inflow = DirichletBC(V, Expression(inflow_profile, degree=2), inflow)
bcu_walls = DirichletBC(V, Constant((0, 0)), walls)
bcu_cylinder = DirichletBC(V, Constant((0, 0)), cylinder)
bcp_outflow = DirichletBC(Q, Constant(0), outflow)
bcu = [bcu_inflow, bcu_walls, bcu_cylinder]
bcp = [bcp_outflow]
[35]:
# Generate dataset
# Define trial and test functions
u = TrialFunction(V)
v = TestFunction(V)
p = TrialFunction(Q)
q = TestFunction(Q)
# Define functions for solutions at previous and current time steps
u_n = Function(V)
u_ = Function(V)
p_n = Function(Q)
p_ = Function(Q)
# Define expressions used in variational forms
U = 0.5*(u_n + u)
n = FacetNormal(mesh)
f = Constant((0, 0))
k = Constant(dt)
mu = Constant(mu)
rho = Constant(rho)
[36]:
# Define symmetric gradient
def epsilon(u):
return sym(nabla_grad(u))
# Define stress tensor
def sigma(u, p):
return 2*mu*epsilon(u) - p*Identity(len(u))
[37]:
# Define variational forms:
# Define variational problem for step 1
F1 = rho*dot((u - u_n) / k, v)*dx \
+ rho*dot(dot(u_n, nabla_grad(u_n)), v)*dx \
+ inner(sigma(U, p_n), epsilon(v))*dx \
+ dot(p_n*n, v)*ds - dot(mu*nabla_grad(U)*n, v)*ds \
- dot(f, v)*dx
a1 = lhs(F1)
L1 = rhs(F1)
# Define variational problem for step 2
a2 = dot(nabla_grad(p), nabla_grad(q))*dx
L2 = dot(nabla_grad(p_n), nabla_grad(q))*dx - (1/k)*div(u_)*q*dx
# Define variational problem for step 3
a3 = dot(u, v)*dx
L3 = dot(u_, v)*dx - k*dot(nabla_grad(p_ - p_n), v)*dx
# Assemble matrices
A1 = assemble(a1)
A2 = assemble(a2)
A3 = assemble(a3)
# Apply boundary conditions to matrices
[bc.apply(A1) for bc in bcu]
[bc.apply(A2) for bc in bcp]
[37]:
[None]
[38]:
from tqdm import tqdm
# Create progress bar
progress = Progress('Time-stepping')
xdmffile_u = XDMFFile('navier_stokes_cylinder/velocity.xdmf')
xdmffile_p = XDMFFile('navier_stokes_cylinder/pressure.xdmf')
xdmffile_u.parameters['flush_output'] = True
xdmffile_p.parameters['flush_output'] = True
# Time-stepping
t = 0
for n in tqdm(range(num_steps)):
# Update current time
t += dt
# Step 1: Tentative velocity step
b1 = assemble(L1)
[bc.apply(b1) for bc in bcu]
solve(A1, u_.vector(), b1, 'bicgstab', 'hypre_amg')
# Step 2: Pressure correction step
b2 = assemble(L2)
[bc.apply(b2) for bc in bcp]
solve(A2, p_.vector(), b2, 'bicgstab', 'hypre_amg')
# Step 3: Velocity correction step
b3 = assemble(L3)
solve(A3, u_.vector(), b3, 'cg', 'sor')
# Plot solution
plot(u_, title='Velocity')
plot(p_, title='Pressure')
# Save solution to file (XDMF/HDF5)
xdmffile_u.write(u_, t)
xdmffile_p.write(p_, t)
# Save nodal values to file
# timeseries_u.store(u_.vector(), t)
# timeseries_p.store(p_.vector(), t)
# Update previous solution
u_n.assign(u_)
p_n.assign(p_)
# Update progress bar
# print('u max:', u_.vector().max())
100%|██████████| 250/250 [00:29<00:00, 8.38it/s]
[39]:
print(xdmffile_u)
print(T)
<dolfin.cpp.io.XDMFFile object at 0x7fa22815e570>
0.25
[ ]:
drosophila = True
if drosophila:
filepath = path.join('../../../experiments', 'dros-kr', 'partial', 'savedmodel')
dataset = DrosophilaSpatialTranscriptomics(gene='kr', data_dir='../../../data')
data = next(iter(dataset))
tx, y_target = data
lengthscale = 10
images = [get_image(dataset.orig_data, i)
for i in range(2, 4)
for dataset in [kr_dataset, kni_dataset, gt_dataset]
]
else:
filepath = path.join('../../../experiments', 'toy-spatial', 'partial', 'savedmodel')
dataset = ToySpatialTranscriptomics(data_dir='../../../data/')
data = next(iter(dataset))
tx, y_target = data
lengthscale = 0.2
num_inducing = int(tx.shape[1] * 5/6)
ts = tx[0, :].unique().sort()[0].numpy()
xs = tx[1, :].unique().sort()[0].numpy()
t_diff = ts[-1]-ts[0]
x_diff = xs[-1]-xs[0]
extent = [ts[0], ts[-1], xs[0], xs[-1]]
if drosophila: plot_spatiotemporal_data(images, extent, nrows=2, ncols=3)
Set up GP model
[5]:
inducing_points = torch.stack([
tx[0, torch.randperm(tx.shape[1])[:int(0.3 * tx.shape[1])]],
tx[1, torch.randperm(tx.shape[1])[:int(0.3 * tx.shape[1])]]
], dim=1).unsqueeze(0)
gp_kwargs = dict(use_ard=True,
use_scale=False,
# lengthscale_constraint=Interval(0.1, 0.3),
learn_inducing_locations=False,
initial_lengthscale=lengthscale)
gp_model = MultiOutputGP(inducing_points, 1, **gp_kwargs)
gp_model.double();
print(inducing_points.shape)
plt.scatter(inducing_points[0,:,0], inducing_points[0, :, 1])
torch.Size([1, 504, 2])
[5]:
<matplotlib.collections.PathCollection at 0x7fb5280eb450>
Set up PDE
[6]:
t_range = (ts[0], ts[-1])
print(t_range)
time_steps = dataset.num_discretised
print(time_steps)
fenics_model = ReactionDiffusion(t_range, time_steps, mesh)
config = VariationalConfiguration(
initial_conditions=False,
num_samples=25
)
sensitivity = Parameter(torch.ones((1, 1), dtype=torch.float64), requires_grad=False)
decay = Parameter(0.1*torch.ones((1, 1), dtype=torch.float64), requires_grad=False)
diffusion = Parameter(0.01*torch.ones((1, 1), dtype=torch.float64), requires_grad=False)
fenics_params = [sensitivity, decay, diffusion]
lfm = PartialLFM(1, gp_model, fenics_model, fenics_params, config)
(0.0, 1.0)
40
[9]:
train_mask = torch.zeros_like(tx[0,:])
train_mask[torch.randperm(tx.shape[1])[:int(0.3 * tx.shape[1])]] = 1
num_training = tx.shape[1]
variational_optimizer = NGD(lfm.variational_parameters(), num_data=num_training, lr=0.1)
parameter_optimizer = Adam(lfm.nonvariational_parameters(), lr=0.07)
optimizers = [variational_optimizer, parameter_optimizer]
trainer = PDETrainer(lfm,
optimizers,
dataset,
track_parameters=list(lfm.fenics_named_parameters.keys()),
train_mask=train_mask.bool(),
warm_variational=1)
t_sorted, dp [0. 0.025 0.05 0.075 0.1 0.125 0.15 0.175 0.2 0.225 0.25 0.275
0.3 0.325 0.35 0.375 0.4 0.425 0.45 0.475 0.5 0.525 0.55 0.575
0.6 0.625 0.65 0.675 0.7 0.725 0.75 0.775 0.8 0.825 0.85 0.875
0.9 0.925 0.95 0.975 1. ] 0.025
x dp is set to 0.025
t_sorted, dp [0. 0.025 0.05 0.075 0.1 0.125 0.15 0.175 0.2 0.225 0.25 0.275
0.3 0.325 0.35 0.375 0.4 0.425 0.45 0.475 0.5 0.525 0.55 0.575
0.6 0.625 0.65 0.675 0.7 0.725 0.75 0.775 0.8 0.825 0.85 0.875
0.9 0.925 0.95 0.975 1. ] 0.025
Now let’s see some samples from the GP and corresponding LFM output
[10]:
num_t = trainer.tx[0, :].unique().shape[0]
num_x = trainer.tx[1, :].unique().shape[0]
# gp_model.covar_module.lengthscale = 0.3*0.3 * 2
out = gp_model(trainer.tx.transpose(0, 1))
sample = out.sample(torch.Size([lfm.config.num_samples])).permute(0, 2, 1)
real = torch.tensor(dataset.orig_data[trainer.t_sorted, 2]).unsqueeze(0)
plot_spatiotemporal_data(
[sample.mean(0)[0].detach().view(num_t, num_x).transpose(0, 1),
real.squeeze().view(num_t, num_x).transpose(0, 1)],
extent,
titles=['Prediction', 'Ground truth']
)
sample = sample.view(lfm.config.num_samples, 1, num_t, num_x)
real = real.repeat(lfm.config.num_samples, 1, 1)
real = real.view(lfm.config.num_samples, 1, num_t, num_x)
out = lfm.solve_pde(sample)
real_out = lfm.solve_pde(real)
plot_spatiotemporal_data(
[out.mean(0).detach().transpose(0, 1),
real_out[0].detach().transpose(0, 1)],
extent,
titles=['Prediction', 'Ground truth']
)
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
<ipython-input-10-d904d801b3d5> in <module>
27 real_out[0].detach().transpose(0, 1)],
28 extent,
---> 29 titles=['Prediction', 'Ground truth']
30 )
31
~/Documents/proj/reggae/lafomo/plot/misc.py in plot_spatiotemporal_data(images, extent, nrows, ncols, titles)
22 cbar_mode="each",
23 cbar_size="7%",
---> 24 cbar_pad="2%",
25 )
26 aspect = (extent[1]-extent[0])/ (extent[3]-extent[2])
~/miniconda3/envs/wishart/lib/python3.7/site-packages/matplotlib/cbook/deprecation.py in wrapper(*inner_args, **inner_kwargs)
409 else deprecation_addendum,
410 **kwargs)
--> 411 return func(*inner_args, **inner_kwargs)
412
413 return wrapper
~/miniconda3/envs/wishart/lib/python3.7/site-packages/mpl_toolkits/axes_grid1/axes_grid.py in __init__(self, fig, rect, nrows_ncols, ngrids, direction, axes_pad, add_all, share_all, aspect, label_mode, cbar_mode, cbar_location, cbar_pad, cbar_size, cbar_set_cax, axes_class)
434 direction=direction, axes_pad=axes_pad,
435 share_all=share_all, share_x=True, share_y=True, aspect=aspect,
--> 436 label_mode=label_mode, axes_class=axes_class)
437 else: # Only show deprecation in that case.
438 super().__init__(
~/miniconda3/envs/wishart/lib/python3.7/site-packages/matplotlib/cbook/deprecation.py in wrapper(*inner_args, **inner_kwargs)
409 else deprecation_addendum,
410 **kwargs)
--> 411 return func(*inner_args, **inner_kwargs)
412
413 return wrapper
~/miniconda3/envs/wishart/lib/python3.7/site-packages/mpl_toolkits/axes_grid1/axes_grid.py in __init__(self, fig, rect, nrows_ncols, ngrids, direction, axes_pad, add_all, share_all, share_x, share_y, label_mode, axes_class, aspect)
206 self.axes_llc = self.axes_column[0][-1]
207
--> 208 self._init_locators()
209
210 if add_all:
~/miniconda3/envs/wishart/lib/python3.7/site-packages/mpl_toolkits/axes_grid1/axes_grid.py in _init_locators(self)
474 self.axes_all[0].figure, self._divider.get_position(),
475 orientation=self._colorbar_location)
--> 476 for _ in range(self.ngrids)]
477
478 cb_mode = self._colorbar_mode
~/miniconda3/envs/wishart/lib/python3.7/site-packages/mpl_toolkits/axes_grid1/axes_grid.py in <listcomp>(.0)
474 self.axes_all[0].figure, self._divider.get_position(),
475 orientation=self._colorbar_location)
--> 476 for _ in range(self.ngrids)]
477
478 cb_mode = self._colorbar_mode
~/miniconda3/envs/wishart/lib/python3.7/site-packages/mpl_toolkits/axes_grid1/axes_grid.py in __init__(self, orientation, *args, **kwargs)
25 self._default_label_on = True
26 self._locator = None # deprecated.
---> 27 super().__init__(*args, **kwargs)
28
29 @cbook._rename_parameter("3.2", "locator", "ticks")
~/miniconda3/envs/wishart/lib/python3.7/site-packages/matplotlib/axes/_base.py in __init__(self, fig, rect, facecolor, frameon, sharex, sharey, label, xscale, yscale, box_aspect, **kwargs)
509
510 self._rasterization_zorder = None
--> 511 self.cla()
512
513 # funcs used to format x and y - fall back on major formatters
~/miniconda3/envs/wishart/lib/python3.7/site-packages/mpl_toolkits/axes_grid1/axes_grid.py in cla(self)
80 def cla(self):
81 super().cla()
---> 82 self._config_axes()
83
84
~/miniconda3/envs/wishart/lib/python3.7/site-packages/mpl_toolkits/axes_grid1/axes_grid.py in _config_axes(self)
69 ax = self
70 ax.set_navigate(False)
---> 71 ax.axis[:].toggle(all=False)
72 b = self._default_label_on
73 ax.axis[self.orientation].toggle(all=b)
~/miniconda3/envs/wishart/lib/python3.7/site-packages/mpl_toolkits/axes_grid1/mpl_axes.py in __call__(self, *args, **kwargs)
14 def __call__(self, *args, **kwargs):
15 for m in self._objects:
---> 16 m(*args, **kwargs)
17
18
~/miniconda3/envs/wishart/lib/python3.7/site-packages/mpl_toolkits/axes_grid1/mpl_axes.py in toggle(self, all, ticks, ticklabels, label)
124 if _ticklabels is not None:
125 tickparam = {labelOn: _ticklabels}
--> 126 self._axis.set_tick_params(**tickparam)
127
128 if _label is not None:
~/miniconda3/envs/wishart/lib/python3.7/site-packages/matplotlib/axis.py in set_tick_params(self, which, reset, **kw)
837 self._major_tick_kw.update(kwtrans)
838 for tick in self.majorTicks:
--> 839 tick._apply_params(**kwtrans)
840 if which in ['minor', 'both']:
841 self._minor_tick_kw.update(kwtrans)
~/miniconda3/envs/wishart/lib/python3.7/site-packages/matplotlib/axis.py in _apply_params(self, **kw)
381 label_kw = {k[5:]: v for k, v in kw.items()
382 if k in ['labelsize', 'labelcolor']}
--> 383 self.label1.set(**label_kw)
384 self.label2.set(**label_kw)
385 for k, v in label_kw.items():
~/miniconda3/envs/wishart/lib/python3.7/site-packages/matplotlib/artist.py in set(self, **kwargs)
1086 def set(self, **kwargs):
1087 """A property batch setter. Pass *kwargs* to set properties."""
-> 1088 kwargs = cbook.normalize_kwargs(kwargs, self)
1089 move_color_to_start = False
1090 if "color" in kwargs:
~/miniconda3/envs/wishart/lib/python3.7/site-packages/matplotlib/cbook/deprecation.py in wrapper(*inner_args, **inner_kwargs)
409 else deprecation_addendum,
410 **kwargs)
--> 411 return func(*inner_args, **inner_kwargs)
412
413 return wrapper
~/miniconda3/envs/wishart/lib/python3.7/site-packages/matplotlib/cbook/deprecation.py in wrapper(*inner_args, **inner_kwargs)
409 else deprecation_addendum,
410 **kwargs)
--> 411 return func(*inner_args, **inner_kwargs)
412
413 return wrapper
~/miniconda3/envs/wishart/lib/python3.7/site-packages/matplotlib/cbook/deprecation.py in wrapper(*inner_args, **inner_kwargs)
409 else deprecation_addendum,
410 **kwargs)
--> 411 return func(*inner_args, **inner_kwargs)
412
413 return wrapper
~/miniconda3/envs/wishart/lib/python3.7/site-packages/matplotlib/cbook/__init__.py in normalize_kwargs(kw, alias_mapping, required, forbidden, allowed)
1738 elif (isinstance(alias_mapping, type) and issubclass(alias_mapping, Artist)
1739 or isinstance(alias_mapping, Artist)):
-> 1740 alias_mapping = getattr(alias_mapping, "_alias_map", {})
1741
1742 to_canonical = {alias: canonical
KeyboardInterrupt:
<Figure size 432x288 with 0 Axes>
[ ]:
#print(hihi)
trainer.train(2)
[24]:
print(sample.shape)
plt.imshow(sample.mean(0)[0].transpose(0, 1))
plt.colorbar()
plt.figure()
plt.imshow(out.mean(0).detach())
# lfm.save(filepath)
torch.Size([25, 1, 512])
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-24-d175470bbe4e> in <module>
1 print(sample.shape)
2
----> 3 plt.imshow(sample.mean(0)[0].transpose(0, 1))
4 plt.colorbar()
5 plt.figure()
IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)
[25]:
lfm = PartialLFM.load(filepath,
gp_cls=MultiOutputGP,
gp_args=[inducing_points, 1],
gp_kwargs=gp_kwargs,
lfm_args=[1, fenics_model, fenics_params, config])
# lfm = PartialLFM(gp_model, fenics_model, fenics_params, config)
gp_model = lfm.gp_model
optimizer = torch.optim.Adam(lfm.parameters(), lr=0.07)
trainer = PDETrainer(lfm, optimizer, dataset, track_parameters=list(lfm.fenics_named_parameters.keys()))
t_sorted, dp [53.925 60.175 66.425 72.675 78.925 85.175 91.425 97.675] 6.25
x dp is set to 1.0
t_sorted, dp [25.5 26.5 27.5 28.5 29.5 30.5 31.5 32.5 33.5 34.5 35.5 36.5 37.5 38.5
39.5 40.5 41.5 42.5 43.5 44.5 45.5 46.5 47.5 48.5 49.5 50.5 51.5 52.5
53.5 54.5 55.5 56.5 57.5 58.5 59.5 60.5 61.5 62.5 63.5 64.5 65.5 66.5
67.5 68.5 69.5 70.5 71.5 72.5 73.5 74.5 75.5 76.5 77.5 78.5 79.5 80.5
81.5 82.5 83.5 84.5 85.5 86.5 87.5 88.5] 1.0
[26]:
from lafomo.utilities.torch import smse, cia, q2
tx = trainer.tx
num_t = tx[0, :].unique().shape[0]
num_x = tx[1, :].unique().shape[0]
# f_mean = lfm(tx).mean.detach()
# f_var = lfm(tx).variance.detach()
y_target = trainer.y_target[0]
ts = tx[0, :].unique().sort()[0].numpy()
xs = tx[1, :].unique().sort()[0].numpy()
t_diff = ts[-1] - ts[0]
x_diff = xs[-1] - xs[0]
extent = [ts[0], ts[-1], xs[0], xs[-1]]
print(y_target.shape, f_mean.squeeze().shape)
f_mean_test = f_mean.squeeze()
f_var_test = f_var.squeeze()
print(q2(y_target, f_mean.squeeze()))
print(cia(y_target, f_mean_test, f_var_test).item())
print(smse(y_target, f_mean_test).mean().item())
---------------------------------------------------------------------------
NameError Traceback (most recent call last)
<ipython-input-26-2d90327e70b2> in <module>
12 x_diff = xs[-1] - xs[0]
13 extent = [ts[0], ts[-1], xs[0], xs[-1]]
---> 14 print(y_target.shape, f_mean.squeeze().shape)
15 f_mean_test = f_mean.squeeze()
16 f_var_test = f_var.squeeze()
NameError: name 'f_mean' is not defined
[32]:
plotter = Plotter(lfm, np.arange(1))
labels = ['Sensitivity', 'Decay', 'Diffusion']
kinetics = list()
for key in lfm.fenics_named_parameters.keys():
kinetics.append(softplus(trainer.parameter_trace[key][-1]).squeeze().numpy())
plotter.plot_double_bar(kinetics, labels)
# plotter.plot_latents()
[ ]: