gramt-mono / patching_utils.py
GokseninYuksel's picture
Upload model
24e9ac9 verified
import torch
from torch import nn
def generate_patches(input, fstride, tstride, fshape, tshape):
r"""Function that extract patches from tensors and stacks them.
See :class:`~kornia.contrib.ExtractTensorPatches` for details.
Args:
input: tensor image where to extract the patches with shape :math:`(B, C, H, W)`.
Returns:
the tensor with the extracted patches with shape :math:`(B, N, C, H_{out}, W_{out})`.
Examples:
>>> input = torch.arange(9.).view(1, 1, 3, 3)
>>> patches = extract_tensor_patches(input, (2, 3))
>>> input
tensor([[[[0., 1., 2.],
[3., 4., 5.],
[6., 7., 8.]]]])
>>> patches[:, -1]
tensor([[[[3., 4., 5.],
[6., 7., 8.]]]])
"""
batch_size, num_channels = input.size()[:2]
dims = range(2, input.dim())
for dim, patch_size, stride in zip(dims, (fshape, tshape), (fstride, tstride)):
input = input.unfold(dim, patch_size, stride)
input = input.permute(0, *dims, 1, *(dim + len(dims) for dim in dims)).contiguous()
return input.view(batch_size, -1, num_channels, fshape, tshape)
def combine_patches(
patches,
original_size,
fstride,
tstride,
fshape,
tshape,
eps: float = 1e-8,
):
r"""Restore input from patches.
See :class:`~kornia.contrib.CombineTensorPatches` for details.
Args:
patches: patched tensor with shape :math:`(B, N, C, H_{out}, W_{out})`.
Return:
The combined patches in an image tensor with shape :math:`(B, C, H, W)`.
Example:
>>> out = extract_tensor_patches(torch.arange(16).view(1, 1, 4, 4), window_size=(2, 2), stride=(2, 2))
>>> combine_tensor_patches(out, original_size=(4, 4), window_size=(2, 2), stride=(2, 2))
tensor([[[[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11],
[12, 13, 14, 15]]]])
.. note::
This function is supposed to be used in conjunction with :func:`extract_tensor_patches`.
"""
if patches.ndim != 5:
raise ValueError(
f"Invalid input shape, we expect BxNxCxHxW. Got: {patches.shape}"
)
ones = torch.ones(
patches.shape[0],
patches.shape[2],
original_size[0],
original_size[1],
device=patches.device,
dtype=patches.dtype,
)
restored_size = ones.shape[2:]
patches = patches.permute(0, 2, 3, 4, 1)
patches = patches.reshape(patches.shape[0], -1, patches.shape[-1])
int_flag = 0
if not torch.is_floating_point(patches):
int_flag = 1
dtype = patches.dtype
patches = patches.float()
ones = ones.float()
# Calculate normalization map
unfold_ones = torch.nn.functional.unfold(
ones, kernel_size=(fshape, tshape), stride=(fstride, tstride)
)
norm_map = torch.nn.functional.fold(
input=unfold_ones,
output_size=restored_size,
kernel_size=(fshape, tshape),
stride=(fstride, tstride),
)
# Restored tensor
saturated_restored_tensor = torch.nn.functional.fold(
input=patches,
output_size=restored_size,
kernel_size=(fshape, tshape),
stride=(fstride, tstride),
)
# Remove satuation effect due to multiple summations
restored_tensor = saturated_restored_tensor / (norm_map + eps)
if int_flag:
restored_tensor = restored_tensor.to(dtype)
return restored_tensor
# get the shape of intermediate representation.
def get_shape(fstride, tstride, input_fdim, input_tdim, fshape, tshape):
test_input = torch.randn(1, 2, input_fdim, input_tdim)
test_proj = nn.Conv2d(
2,
2,
kernel_size=(fshape, tshape),
stride=(fstride, tstride),
)
test_out = test_proj(test_input)
f_dim = test_out.shape[2]
t_dim = test_out.shape[3]
return f_dim, t_dim