Spaces:
Sleeping
Sleeping
switching to conditional gaussian
Browse files- flowutils.py +63 -7
flowutils.py
CHANGED
|
@@ -5,6 +5,61 @@ import numpy as np
|
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
from einops import rearrange, repeat
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
|
| 10 |
def build_flows(
|
|
@@ -26,13 +81,14 @@ def build_flows(
|
|
| 26 |
|
| 27 |
# Set base distribution
|
| 28 |
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
|
|
|
| 34 |
|
| 35 |
-
q0 =
|
| 36 |
|
| 37 |
# Construct flow model
|
| 38 |
model = nf.ConditionalNormalizingFlow(q0, flows)
|
|
@@ -239,7 +295,7 @@ class PatchFlow(torch.nn.Module):
|
|
| 239 |
context=context_vector,
|
| 240 |
)
|
| 241 |
|
| 242 |
-
loss = -torch.mean(flow_model.flow.q0.log_prob(z) + ldj)
|
| 243 |
loss *= n_patches
|
| 244 |
|
| 245 |
if train:
|
|
|
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
from einops import rearrange, repeat
|
| 8 |
+
from normflows.distributions import BaseDistribution
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ConditionalDiagGaussian(BaseDistribution):
|
| 12 |
+
"""
|
| 13 |
+
Conditional multivariate Gaussian distribution with diagonal
|
| 14 |
+
covariance matrix, parameters are obtained by a context encoder,
|
| 15 |
+
context meaning the variable to condition on
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
def __init__(self, shape, context_encoder):
|
| 19 |
+
"""Constructor
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
shape: Tuple with shape of data, if int shape has one dimension
|
| 23 |
+
context_encoder: Computes mean and log of the standard deviation
|
| 24 |
+
of the Gaussian, mean is the first half of the last dimension
|
| 25 |
+
of the encoder output, log of the standard deviation the second
|
| 26 |
+
half
|
| 27 |
+
"""
|
| 28 |
+
super().__init__()
|
| 29 |
+
if isinstance(shape, int):
|
| 30 |
+
shape = (shape,)
|
| 31 |
+
if isinstance(shape, list):
|
| 32 |
+
shape = tuple(shape)
|
| 33 |
+
self.shape = shape
|
| 34 |
+
self.n_dim = len(shape)
|
| 35 |
+
self.d = np.prod(shape)
|
| 36 |
+
self.context_encoder = context_encoder
|
| 37 |
+
|
| 38 |
+
def forward(self, num_samples=1, context=None):
|
| 39 |
+
encoder_output = self.context_encoder(context)
|
| 40 |
+
split_ind = encoder_output.shape[-1] // 2
|
| 41 |
+
mean = encoder_output[..., :split_ind]
|
| 42 |
+
log_scale = encoder_output[..., split_ind:]
|
| 43 |
+
eps = torch.randn(
|
| 44 |
+
(num_samples,) + self.shape, dtype=mean.dtype, device=mean.device
|
| 45 |
+
)
|
| 46 |
+
z = mean + torch.exp(log_scale) * eps
|
| 47 |
+
log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum(
|
| 48 |
+
log_scale + 0.5 * torch.pow(eps, 2), list(range(1, self.n_dim + 1))
|
| 49 |
+
)
|
| 50 |
+
return z, log_p
|
| 51 |
+
|
| 52 |
+
def log_prob(self, z, context=None):
|
| 53 |
+
encoder_output = self.context_encoder(context)
|
| 54 |
+
split_ind = encoder_output.shape[-1] // 2
|
| 55 |
+
mean = encoder_output[..., :split_ind]
|
| 56 |
+
log_scale = encoder_output[..., split_ind:]
|
| 57 |
+
log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum(
|
| 58 |
+
log_scale + 0.5 * torch.pow((z - mean) / torch.exp(log_scale), 2),
|
| 59 |
+
list(range(1, self.n_dim + 1)),
|
| 60 |
+
)
|
| 61 |
+
return log_p
|
| 62 |
+
|
| 63 |
|
| 64 |
|
| 65 |
def build_flows(
|
|
|
|
| 81 |
|
| 82 |
# Set base distribution
|
| 83 |
|
| 84 |
+
context_encoder = nn.Sequential(
|
| 85 |
+
nn.Linear(context_size, context_size),
|
| 86 |
+
nn.SiLU(),
|
| 87 |
+
# output mean and scales for K=latent_size dimensions
|
| 88 |
+
nn.Linear(context_size, latent_size * 2)
|
| 89 |
+
)
|
| 90 |
|
| 91 |
+
q0 = ConditionalDiagGaussian(latent_size, context_encoder)
|
| 92 |
|
| 93 |
# Construct flow model
|
| 94 |
model = nf.ConditionalNormalizingFlow(q0, flows)
|
|
|
|
| 295 |
context=context_vector,
|
| 296 |
)
|
| 297 |
|
| 298 |
+
loss = -torch.mean(flow_model.flow.q0.log_prob(z, context_vector) + ldj)
|
| 299 |
loss *= n_patches
|
| 300 |
|
| 301 |
if train:
|