|
|
import torch |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
def generate_patches(input, fstride, tstride, fshape, tshape): |
|
|
r"""Function that extract patches from tensors and stacks them. |
|
|
|
|
|
See :class:`~kornia.contrib.ExtractTensorPatches` for details. |
|
|
|
|
|
Args: |
|
|
input: tensor image where to extract the patches with shape :math:`(B, C, H, W)`. |
|
|
|
|
|
Returns: |
|
|
the tensor with the extracted patches with shape :math:`(B, N, C, H_{out}, W_{out})`. |
|
|
|
|
|
Examples: |
|
|
>>> input = torch.arange(9.).view(1, 1, 3, 3) |
|
|
>>> patches = extract_tensor_patches(input, (2, 3)) |
|
|
>>> input |
|
|
tensor([[[[0., 1., 2.], |
|
|
[3., 4., 5.], |
|
|
[6., 7., 8.]]]]) |
|
|
>>> patches[:, -1] |
|
|
tensor([[[[3., 4., 5.], |
|
|
[6., 7., 8.]]]]) |
|
|
|
|
|
""" |
|
|
batch_size, num_channels = input.size()[:2] |
|
|
dims = range(2, input.dim()) |
|
|
for dim, patch_size, stride in zip(dims, (fshape, tshape), (fstride, tstride)): |
|
|
input = input.unfold(dim, patch_size, stride) |
|
|
input = input.permute(0, *dims, 1, *(dim + len(dims) for dim in dims)).contiguous() |
|
|
return input.view(batch_size, -1, num_channels, fshape, tshape) |
|
|
|
|
|
|
|
|
def combine_patches( |
|
|
patches, |
|
|
original_size, |
|
|
fstride, |
|
|
tstride, |
|
|
fshape, |
|
|
tshape, |
|
|
eps: float = 1e-8, |
|
|
): |
|
|
r"""Restore input from patches. |
|
|
|
|
|
See :class:`~kornia.contrib.CombineTensorPatches` for details. |
|
|
|
|
|
Args: |
|
|
patches: patched tensor with shape :math:`(B, N, C, H_{out}, W_{out})`. |
|
|
|
|
|
Return: |
|
|
The combined patches in an image tensor with shape :math:`(B, C, H, W)`. |
|
|
|
|
|
Example: |
|
|
>>> out = extract_tensor_patches(torch.arange(16).view(1, 1, 4, 4), window_size=(2, 2), stride=(2, 2)) |
|
|
>>> combine_tensor_patches(out, original_size=(4, 4), window_size=(2, 2), stride=(2, 2)) |
|
|
tensor([[[[ 0, 1, 2, 3], |
|
|
[ 4, 5, 6, 7], |
|
|
[ 8, 9, 10, 11], |
|
|
[12, 13, 14, 15]]]]) |
|
|
|
|
|
.. note:: |
|
|
This function is supposed to be used in conjunction with :func:`extract_tensor_patches`. |
|
|
|
|
|
""" |
|
|
if patches.ndim != 5: |
|
|
raise ValueError( |
|
|
f"Invalid input shape, we expect BxNxCxHxW. Got: {patches.shape}" |
|
|
) |
|
|
ones = torch.ones( |
|
|
patches.shape[0], |
|
|
patches.shape[2], |
|
|
original_size[0], |
|
|
original_size[1], |
|
|
device=patches.device, |
|
|
dtype=patches.dtype, |
|
|
) |
|
|
restored_size = ones.shape[2:] |
|
|
|
|
|
patches = patches.permute(0, 2, 3, 4, 1) |
|
|
patches = patches.reshape(patches.shape[0], -1, patches.shape[-1]) |
|
|
int_flag = 0 |
|
|
if not torch.is_floating_point(patches): |
|
|
int_flag = 1 |
|
|
dtype = patches.dtype |
|
|
patches = patches.float() |
|
|
ones = ones.float() |
|
|
|
|
|
|
|
|
unfold_ones = torch.nn.functional.unfold( |
|
|
ones, kernel_size=(fshape, tshape), stride=(fstride, tstride) |
|
|
) |
|
|
norm_map = torch.nn.functional.fold( |
|
|
input=unfold_ones, |
|
|
output_size=restored_size, |
|
|
kernel_size=(fshape, tshape), |
|
|
stride=(fstride, tstride), |
|
|
) |
|
|
|
|
|
saturated_restored_tensor = torch.nn.functional.fold( |
|
|
input=patches, |
|
|
output_size=restored_size, |
|
|
kernel_size=(fshape, tshape), |
|
|
stride=(fstride, tstride), |
|
|
) |
|
|
|
|
|
restored_tensor = saturated_restored_tensor / (norm_map + eps) |
|
|
if int_flag: |
|
|
restored_tensor = restored_tensor.to(dtype) |
|
|
return restored_tensor |
|
|
|
|
|
|
|
|
|
|
|
def get_shape(fstride, tstride, input_fdim, input_tdim, fshape, tshape): |
|
|
test_input = torch.randn(1, 2, input_fdim, input_tdim) |
|
|
test_proj = nn.Conv2d( |
|
|
2, |
|
|
2, |
|
|
kernel_size=(fshape, tshape), |
|
|
stride=(fstride, tstride), |
|
|
) |
|
|
test_out = test_proj(test_input) |
|
|
f_dim = test_out.shape[2] |
|
|
t_dim = test_out.shape[3] |
|
|
return f_dim, t_dim |
|
|
|