Spaces:
Running
Running
| import pdb | |
| import normflows as nf | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from einops import rearrange, repeat | |
| from normflows.distributions import BaseDistribution | |
| def sanitize_locals(args_dict, ignore_keys=None): | |
| if ignore_keys is None: | |
| ignore_keys = [] | |
| if not isinstance(ignore_keys, list): | |
| ignore_keys = [ignore_keys] | |
| _dict = args_dict.copy() | |
| _dict.pop("self") | |
| class_name = _dict.pop("__class__").__name__ | |
| class_params = {k: v for k, v in _dict.items() if k not in ignore_keys} | |
| return {class_name: class_params} | |
| def build_flows( | |
| latent_size, num_flows=4, num_blocks_per_flow=2, hidden_units=128, context_size=64 | |
| ): | |
| # Define flows | |
| flows = [] | |
| flows.append( | |
| nf.flows.MaskedAffineAutoregressive( | |
| latent_size, | |
| hidden_features=hidden_units, | |
| num_blocks=num_blocks_per_flow, | |
| context_features=context_size, | |
| ) | |
| ) | |
| for i in range(num_flows): | |
| flows += [ | |
| nf.flows.CoupledRationalQuadraticSpline( | |
| latent_size, | |
| num_blocks=num_blocks_per_flow, | |
| num_hidden_channels=hidden_units, | |
| num_context_channels=context_size, | |
| ) | |
| ] | |
| flows += [nf.flows.LULinearPermute(latent_size)] | |
| # Set base distribution | |
| context_encoder = nn.Sequential( | |
| nn.Linear(context_size, context_size), | |
| nn.SiLU(), | |
| # output mean and scales for K=latent_size dimensions | |
| nn.Linear(context_size, latent_size * 2), | |
| ) | |
| q0 = ConditionalDiagGaussian(latent_size, context_encoder) | |
| # Construct flow model | |
| model = nf.ConditionalNormalizingFlow(q0, flows) | |
| return model | |
| class ConditionalDiagGaussian(BaseDistribution): | |
| """ | |
| Conditional multivariate Gaussian distribution with diagonal | |
| covariance matrix, parameters are obtained by a context encoder, | |
| context meaning the variable to condition on | |
| """ | |
| def __init__(self, shape, context_encoder): | |
| """Constructor | |
| Args: | |
| shape: Tuple with shape of data, if int shape has one dimension | |
| context_encoder: Computes mean and log of the standard deviation | |
| of the Gaussian, mean is the first half of the last dimension | |
| of the encoder output, log of the standard deviation the second | |
| half | |
| """ | |
| super().__init__() | |
| if isinstance(shape, int): | |
| shape = (shape,) | |
| if isinstance(shape, list): | |
| shape = tuple(shape) | |
| self.shape = shape | |
| self.n_dim = len(shape) | |
| self.d = np.prod(shape) | |
| self.context_encoder = context_encoder | |
| def forward(self, num_samples=1, context=None): | |
| encoder_output = self.context_encoder(context) | |
| split_ind = encoder_output.shape[-1] // 2 | |
| mean = encoder_output[..., :split_ind] | |
| log_scale = encoder_output[..., split_ind:] | |
| eps = torch.randn( | |
| (num_samples,) + self.shape, dtype=mean.dtype, device=mean.device | |
| ) | |
| z = mean + torch.exp(log_scale) * eps | |
| log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum( | |
| log_scale + 0.5 * torch.pow(eps, 2), list(range(1, self.n_dim + 1)) | |
| ) | |
| return z, log_p | |
| def log_prob(self, z, context=None): | |
| encoder_output = self.context_encoder(context) | |
| split_ind = encoder_output.shape[-1] // 2 | |
| mean = encoder_output[..., :split_ind] | |
| log_scale = encoder_output[..., split_ind:] | |
| log_p = -0.5 * self.d * np.log(2 * np.pi) - torch.sum( | |
| log_scale + 0.5 * torch.pow((z - mean) / torch.exp(log_scale), 2), | |
| list(range(1, self.n_dim + 1)), | |
| ) | |
| return log_p | |
| def get_emb(sin_inp): | |
| """ | |
| Gets a base embedding for one dimension with sin and cos intertwined | |
| """ | |
| emb = torch.stack((sin_inp.sin(), sin_inp.cos()), dim=-1) | |
| return torch.flatten(emb, -2, -1) | |
| class PositionalEncoding2D(nn.Module): | |
| def __init__(self, channels): | |
| """ | |
| :param channels: The last dimension of the tensor you want to apply pos emb to. | |
| """ | |
| super(PositionalEncoding2D, self).__init__() | |
| self.org_channels = channels | |
| channels = int(np.ceil(channels / 4) * 2) | |
| self.channels = channels | |
| inv_freq = 1.0 / (10000 ** (torch.arange(0, channels, 2).float() / channels)) | |
| self.register_buffer("inv_freq", inv_freq) | |
| self.register_buffer("cached_penc", None, persistent=False) | |
| def forward(self, tensor): | |
| """ | |
| :param tensor: A 4d tensor of size (batch_size, x, y, ch) | |
| :return: Positional Encoding Matrix of size (batch_size, x, y, ch) | |
| """ | |
| if len(tensor.shape) != 4: | |
| raise RuntimeError("The input tensor has to be 4d!") | |
| if ( | |
| self.cached_penc is not None | |
| and self.cached_penc.shape[:2] == tensor.shape[1:3] | |
| ): | |
| return self.cached_penc | |
| self.cached_penc = None | |
| batch_size, orig_ch, x, y = tensor.shape | |
| pos_x = torch.arange(x, device=tensor.device, dtype=self.inv_freq.dtype) | |
| pos_y = torch.arange(y, device=tensor.device, dtype=self.inv_freq.dtype) | |
| sin_inp_x = torch.einsum("i,j->ij", pos_x, self.inv_freq) | |
| sin_inp_y = torch.einsum("i,j->ij", pos_y, self.inv_freq) | |
| emb_x = get_emb(sin_inp_x).unsqueeze(1) | |
| emb_y = get_emb(sin_inp_y) | |
| emb = torch.zeros( | |
| (x, y, self.channels * 2), | |
| device=tensor.device, | |
| dtype=tensor.dtype, | |
| ) | |
| emb[:, :, : self.channels] = emb_x | |
| emb[:, :, self.channels : 2 * self.channels] = emb_y | |
| self.cached_penc = emb | |
| return self.cached_penc | |
| class SpatialNormer(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels, # channels will be number of sigma scales in input | |
| kernel_size=3, | |
| stride=2, | |
| padding=1, | |
| ): | |
| """ | |
| Note that the convolution will reduce the channel dimension | |
| So (b, num_sigmas, c, h, w) -> (b, num_sigmas, new_h , new_w) | |
| """ | |
| super().__init__() | |
| self.conv = nn.Conv3d( | |
| in_channels, | |
| in_channels, | |
| kernel_size, | |
| # This is the real trick that ensures each | |
| # sigma dimension is normed separately | |
| groups=in_channels, | |
| stride=(1, stride, stride), | |
| padding=(0, padding, padding), | |
| bias=False, | |
| ) | |
| self.conv.weight.data.fill_(1) # all ones weights | |
| self.conv.weight.requires_grad = False # freeze weights | |
| def forward(self, x): | |
| return self.conv(x.square()).pow_(0.5).squeeze(2) | |
| class PatchFlow(torch.nn.Module): | |
| def __init__( | |
| self, | |
| input_size, | |
| patch_size=3, | |
| context_embedding_size=128, | |
| num_flows=4, | |
| num_blocks_per_flow=2, | |
| hidden_units=128, | |
| ): | |
| super().__init__() | |
| self.config = sanitize_locals(locals(), ignore_keys="input_size") | |
| num_sigmas, c, h, w = input_size | |
| self.local_pooler = SpatialNormer( | |
| in_channels=num_sigmas, kernel_size=patch_size | |
| ) | |
| self.flows = build_flows( | |
| latent_size=num_sigmas, | |
| context_size=context_embedding_size, | |
| num_flows=num_flows, | |
| num_blocks_per_flow=num_blocks_per_flow, | |
| hidden_units=hidden_units, | |
| ) | |
| self.position_encoding = PositionalEncoding2D(channels=context_embedding_size) | |
| # caching pos encs | |
| _, _, ctx_h, ctw_w = self.local_pooler( | |
| torch.empty((1, num_sigmas, c, h, w)) | |
| ).shape | |
| self.position_encoding(torch.empty(1, 1, ctx_h, ctw_w)) | |
| assert self.position_encoding.cached_penc.shape[-1] == context_embedding_size | |
| def init_weights(self): | |
| # Initialize weights with Xavier | |
| linear_modules = list( | |
| filter(lambda m: isinstance(m, nn.Linear), self.flows.modules()) | |
| ) | |
| total = len(linear_modules) | |
| for idx, m in enumerate(linear_modules): | |
| # Last layer gets init w/ zeros | |
| if idx == total - 1: | |
| nn.init.zeros_(m.weight.data) | |
| else: | |
| nn.init.xavier_uniform_(m.weight.data) | |
| if m.bias is not None: | |
| nn.init.zeros_(m.bias.data) | |
| def forward(self, x, chunk_size=32): | |
| b, s, c, h, w = x.shape | |
| x_norm = self.local_pooler(x) | |
| _, _, new_h, new_w = x_norm.shape | |
| context = self.position_encoding(x_norm) | |
| # (Patches * batch) x channels | |
| local_ctx = rearrange(context, "h w c -> (h w) c") | |
| patches = rearrange(x_norm, "b c h w -> (h w) b c") | |
| nchunks = (patches.shape[0] + chunk_size - 1) // chunk_size | |
| patches = patches.chunk(nchunks, dim=0) | |
| ctx_chunks = local_ctx.chunk(nchunks, dim=0) | |
| patch_logpx = [] | |
| # gc = repeat(global_ctx, "b c -> (n b) c", n=self.patch_batch_size) | |
| for p, ctx in zip(patches, ctx_chunks): | |
| # num patches in chunk (<= chunk_size) | |
| n = p.shape[0] | |
| ctx = repeat(ctx, "n c -> (n b) c", b=b) | |
| p = rearrange(p, "n b c -> (n b) c") | |
| # Compute log densities for each patch | |
| logpx = self.flows.log_prob(p, context=ctx) | |
| logpx = rearrange(logpx, "(n b) -> n b", n=n, b=b) | |
| patch_logpx.append(logpx) | |
| # Convert back to image | |
| logpx = torch.cat(patch_logpx, dim=0) | |
| logpx = rearrange(logpx, "(h w) b -> b 1 h w", b=b, h=new_h, w=new_w) | |
| return logpx.contiguous() | |
| def stochastic_step( | |
| scores, x_batch, flow_model, opt=None, train=False, n_patches=32, device="cpu" | |
| ): | |
| if train: | |
| flow_model.train() | |
| opt.zero_grad(set_to_none=True) | |
| else: | |
| flow_model.eval() | |
| patches, context = PatchFlow.get_random_patches( | |
| scores, x_batch, flow_model, n_patches | |
| ) | |
| patch_feature = patches.to(device) | |
| context_vector = context.to(device) | |
| patch_feature = rearrange(patch_feature, "n b c -> (n b) c") | |
| context_vector = rearrange(context_vector, "n b c -> (n b) c") | |
| # global_pooled_image = flow_model.global_pooler(x_batch) | |
| # global_context = flow_model.global_attention(global_pooled_image) | |
| # gctx = repeat(global_context, "b c -> (n b) c", n=n_patches) | |
| # # Concatenate global context to local context | |
| # context_vector = torch.cat([context_vector, gctx], dim=1) | |
| # z, ldj = flow_model.flows.inverse_and_log_det( | |
| # patch_feature, | |
| # context=context_vector, | |
| # ) | |
| loss = flow_model.flows.forward_kld(patch_feature, context_vector) | |
| loss *= n_patches | |
| if train: | |
| loss.backward() | |
| opt.step() | |
| return loss.item() / n_patches | |
| def get_random_patches(scores, x_batch, flow_model, n_patches): | |
| b = scores.shape[0] | |
| h = flow_model.local_pooler(scores) | |
| patches = rearrange(h, "b c h w -> (h w) b c") | |
| context = flow_model.position_encoding(h) | |
| context = rearrange(context, "h w c -> (h w) c") | |
| context = repeat(context, "n c -> n b c", b=b) | |
| # conserve gpu memory | |
| patches = patches.cpu() | |
| context = context.cpu() | |
| # Get random patches | |
| total_patches = patches.shape[0] | |
| shuffled_idx = torch.randperm(total_patches) | |
| rand_idx_batch = shuffled_idx[:n_patches] | |
| return patches[rand_idx_batch], context[rand_idx_batch] | |