mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Implement col2im decomposition and fix im2col and add a few preconditions (#85541)
As per title Pull Request resolved: https://github.com/pytorch/pytorch/pull/85541 Approved by: https://github.com/jansel
This commit is contained in:
parent
1f38abb5d2
commit
787028cadb
|
|
@ -36,6 +36,13 @@ static inline void col2im_shape_check(
|
|||
dilation_height,
|
||||
" dilation_width: ",
|
||||
dilation_width);
|
||||
TORCH_CHECK(
|
||||
pad_width >= 0 && pad_height >= 0,
|
||||
"padding should be non-negative, but got pad_height: ",
|
||||
pad_height,
|
||||
" pad_width: ",
|
||||
pad_width);
|
||||
|
||||
|
||||
int64_t ndim = input.ndimension();
|
||||
// allow dim=0 only the batch dimension.
|
||||
|
|
@ -218,7 +225,7 @@ static inline void im2col_shape_check(
|
|||
output_height,
|
||||
", ",
|
||||
output_width,
|
||||
"), which is too small (non-positive).");
|
||||
"), but its components must be at least one.");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1225,7 +1225,7 @@ TEST_F(ModulesTest, Unfold) {
|
|||
model(torch::randn({1, 2, 2, 2})),
|
||||
"Given input with spatial size (2, 2), kernel_size=(2, 3), "
|
||||
"dilation=(1, 1), padding=(0, 0), calculated shape of the array of "
|
||||
"sliding blocks as (1, 0), which is too small (non-positive).");
|
||||
"sliding blocks as (1, 0), but its components must be at least one.");
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -3931,6 +3931,8 @@ class TestFunctionalTracing(JitTestCase):
|
|||
"max_unpool1d": PROXY_ITERATED,
|
||||
"max_unpool2d": PROXY_ITERATED,
|
||||
"max_unpool3d": PROXY_ITERATED,
|
||||
"fold": PROXY_ITERATED,
|
||||
"unfold": PROXY_ITERATED,
|
||||
|
||||
"adaptive_max_pool1d_with_indices": ARG_TYPE_MISMATCH,
|
||||
"fractional_max_pool2d_with_indices": ARG_TYPE_MISMATCH,
|
||||
|
|
@ -3955,7 +3957,6 @@ class TestFunctionalTracing(JitTestCase):
|
|||
"embedding": CONTROL_FLOW,
|
||||
"embedding_bag": CONTROL_FLOW,
|
||||
"feature_alpha_dropout": CONTROL_FLOW,
|
||||
"fold": CONTROL_FLOW,
|
||||
"gaussian_nll_loss": CONTROL_FLOW,
|
||||
"glu": CONTROL_FLOW,
|
||||
"grid_sample": CONTROL_FLOW,
|
||||
|
|
@ -3992,7 +3993,6 @@ class TestFunctionalTracing(JitTestCase):
|
|||
"threshold": CONTROL_FLOW,
|
||||
"triplet_margin_loss": CONTROL_FLOW,
|
||||
"triplet_margin_with_distance_loss": CONTROL_FLOW,
|
||||
"unfold": CONTROL_FLOW,
|
||||
"upsample": CONTROL_FLOW,
|
||||
|
||||
"upsample_bilinear": INTERPOLATE_ARGS_CONFLICT,
|
||||
|
|
|
|||
|
|
@ -693,7 +693,6 @@ meta_dispatch_expected_failures = {
|
|||
aten.bincount.default : {i64, i8, i32, i16, u8},
|
||||
aten.bucketize.Tensor : {f16, i8, f64, i64, bf16, f32, i32, i16, u8},
|
||||
aten.bucketize.Tensor_out : {f16, i8, f64, i64, bf16, f32, i32, i16, u8},
|
||||
aten.col2im.default : {c64, f32, f64, c128},
|
||||
aten.equal.default : {c64, f16, i8, f64, c128, i64, bf16, f32, i32, b8, i16, u8},
|
||||
aten.frexp.Tensor : {bf16, f32, f16, f64},
|
||||
aten.grid_sampler_3d.default : {f32, f64},
|
||||
|
|
|
|||
|
|
@ -10783,20 +10783,17 @@ tensor(..., device='meta', size=(1,), requires_grad=True)""")
|
|||
# input wrong dimension
|
||||
|
||||
unfold = nn.Unfold(kernel_size=(2, 3))
|
||||
with self.assertRaisesRegex(NotImplementedError, r"Only 4D input Tensors are supported"):
|
||||
unfold(torch.randn(1, 5, 2))
|
||||
|
||||
# calculated output shape is too small
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"too small \(non-positive\)"):
|
||||
with self.assertRaisesRegex(RuntimeError, r"its components must be at least one"):
|
||||
unfold = nn.Unfold(kernel_size=(2, 3))
|
||||
unfold(torch.randn(1, 2, 2, 2))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"too small \(non-positive\)"):
|
||||
with self.assertRaisesRegex(RuntimeError, r"its components must be at least one"):
|
||||
unfold = nn.Unfold(kernel_size=(5, 3), padding=(1, 1))
|
||||
unfold(torch.randn(1, 2, 2, 3))
|
||||
|
||||
with self.assertRaisesRegex(RuntimeError, r"too small \(non-positive\)"):
|
||||
with self.assertRaisesRegex(RuntimeError, r"its components must be at least one"):
|
||||
unfold = nn.Unfold(kernel_size=(1, 3), padding=(1, 1), dilation=(1, 2))
|
||||
unfold(torch.randn(1, 2, 2, 2))
|
||||
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@ import functools
|
|||
import operator
|
||||
import sys
|
||||
from enum import Enum
|
||||
from functools import partial, reduce
|
||||
from itertools import product
|
||||
from typing import Callable, cast, Iterable, List, Optional, Tuple
|
||||
|
||||
|
|
@ -68,18 +69,18 @@ def type_casts(
|
|||
return inner
|
||||
|
||||
|
||||
compute_only_pw_cast_for_opmath = functools.partial(
|
||||
compute_only_pw_cast_for_opmath = partial(
|
||||
type_casts,
|
||||
type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
compute_dtype_only=True,
|
||||
)
|
||||
pw_cast_for_opmath = functools.partial(
|
||||
pw_cast_for_opmath = partial(
|
||||
type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
||||
)
|
||||
reduction_complex_to_real = functools.partial(
|
||||
reduction_complex_to_real = partial(
|
||||
type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.COMPLEX_TO_FLOAT
|
||||
)
|
||||
pw_cast_for_int_to_real = functools.partial(
|
||||
pw_cast_for_int_to_real = partial(
|
||||
type_casts, type_promotion=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
|
||||
)
|
||||
|
||||
|
|
@ -693,7 +694,28 @@ def _log_softmax_backward_data(
|
|||
return _cast_grad_to_input_dtype(grad_output, grad_input, input_dtype)
|
||||
|
||||
|
||||
def _im2col_col2im_indices_along_dim(
|
||||
input_d, kernel_d, dilation_d, padding_d, stride_d, device
|
||||
):
|
||||
"""Utility function to implement im2col and col2im"""
|
||||
blocks_d = input_d + padding_d * 2 - dilation_d * (kernel_d - 1)
|
||||
|
||||
arange_kw = partial(torch.arange, dtype=torch.int64, device=device)
|
||||
|
||||
# Stride kernel over input and find starting indices along dim d
|
||||
blocks_d_indices = arange_kw(0, blocks_d, stride_d).unsqueeze(0)
|
||||
|
||||
# Apply dilation on kernel and find its indices along dim d
|
||||
kernel_grid = arange_kw(0, kernel_d * dilation_d, dilation_d).unsqueeze(-1)
|
||||
|
||||
# Broadcast and add kernel staring positions (indices) with
|
||||
# kernel_grid along dim d, to get block indices along dim d
|
||||
return blocks_d_indices + kernel_grid
|
||||
|
||||
|
||||
@register_decomposition(aten.im2col)
|
||||
@out_wrapper()
|
||||
@pw_cast_for_opmath
|
||||
def im2col(
|
||||
input: Tensor,
|
||||
kernel_size: List[int],
|
||||
|
|
@ -701,60 +723,175 @@ def im2col(
|
|||
padding: List[int],
|
||||
stride: List[int],
|
||||
) -> Tensor:
|
||||
utils.check(input.dim() == 4, lambda: "im2col(): only 4D input supported")
|
||||
utils.check(len(kernel_size) == 2, lambda: "im2col(): only 2D kernel supported")
|
||||
utils.check(len(dilation) == 2, lambda: "im2col(): only 2D dilation supported")
|
||||
utils.check(len(padding) == 2, lambda: "im2col(): only 2D padding supported")
|
||||
utils.check(len(stride) == 2, lambda: "im2col(): only 2D stride supported")
|
||||
|
||||
batch_dim = input.size(0)
|
||||
channel_dim = input.size(1)
|
||||
input_h = input.size(2)
|
||||
input_w = input.size(3)
|
||||
def check_positive(param, param_name, strict=True):
|
||||
cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
|
||||
utils.check(
|
||||
cond, lambda: "{param_name} should be greater {'than' zero, but got {param}"
|
||||
)
|
||||
|
||||
stride_h, stride_w = stride[0], stride[1]
|
||||
padding_h, padding_w = padding[0], padding[1]
|
||||
dilation_h, dilation_w = dilation[0], dilation[1]
|
||||
kernel_h, kernel_w = kernel_size[0], kernel_size[1]
|
||||
check_positive(kernel_size, "kernel_size")
|
||||
check_positive(dilation, "dilation")
|
||||
check_positive(dilation, "padding", strict=False)
|
||||
check_positive(stride, "stride")
|
||||
|
||||
def _get_im2col_indices_along_dim(
|
||||
input_d, kernel_d, dilation_d, padding_d, stride_d
|
||||
):
|
||||
blocks_d = input_d + padding_d * 2 - dilation_d * (kernel_d - 1)
|
||||
|
||||
# Stride kernel over input and find starting indices along dim d
|
||||
blocks_d_indices = torch.arange(
|
||||
0, blocks_d, stride_d, dtype=torch.int64, device=input.device
|
||||
).unsqueeze(0)
|
||||
num_blocks = (blocks_d - 1) // stride_d + 1
|
||||
|
||||
# Apply dilation on kernel and find its indices along dim d
|
||||
kernel_grid = torch.arange(
|
||||
0, kernel_d * dilation_d, dilation_d, dtype=torch.int64, device=input.device
|
||||
).unsqueeze(-1)
|
||||
|
||||
# Broadcast and add kernel staring positions (indices) with
|
||||
# kernel_grid along dim d, to get block indices along dim d
|
||||
block_mask = blocks_d_indices + kernel_grid
|
||||
|
||||
return block_mask, num_blocks
|
||||
|
||||
blocks_row_indices, num_blocks_row = _get_im2col_indices_along_dim(
|
||||
input_h, kernel_h, dilation_h, padding_h, stride_h
|
||||
shape = input.shape
|
||||
ndim = len(shape)
|
||||
utils.check(
|
||||
ndim in (3, 4) and all(d != 0 for d in shape[-3:]),
|
||||
lambda: "Expected 3D or 4D (batch mode) tensor for input with possible 0 batch size "
|
||||
f"and non-zero dimensions, but got: {tuple(shape)}",
|
||||
)
|
||||
blocks_col_indices, num_blocks_col = _get_im2col_indices_along_dim(
|
||||
input_w, kernel_w, dilation_w, padding_w, stride_w
|
||||
output_size = tuple(
|
||||
1 + (out + 2 * pad - dil * (ker - 1) - 1) // st
|
||||
for out, pad, dil, ker, st in zip(
|
||||
shape[-2:], padding, dilation, kernel_size, stride
|
||||
)
|
||||
)
|
||||
utils.check(
|
||||
all(c > 0 for c in output_size),
|
||||
lambda: f"Given an input with spacial size {tuple(shape[-2:])}, "
|
||||
f"kernel_size={kernel_size}, dilation={dilation}, "
|
||||
f"padding={padding}, stride={stride}, "
|
||||
"the calculated shape of the array of sliding blocks "
|
||||
f"is {output_size}, but its components must be at least one.",
|
||||
)
|
||||
batched_input = ndim == 4
|
||||
if not batched_input:
|
||||
input = input.unsqueeze(0)
|
||||
|
||||
batch_dim, channel_dim, input_h, input_w = input.shape
|
||||
|
||||
stride_h, stride_w = stride
|
||||
padding_h, padding_w = padding
|
||||
dilation_h, dilation_w = dilation
|
||||
kernel_h, kernel_w = kernel_size
|
||||
|
||||
blocks_row_indices = _im2col_col2im_indices_along_dim(
|
||||
input_h, kernel_h, dilation_h, padding_h, stride_h, input.device
|
||||
)
|
||||
blocks_col_indices = _im2col_col2im_indices_along_dim(
|
||||
input_w, kernel_w, dilation_w, padding_w, stride_w, input.device
|
||||
)
|
||||
|
||||
padded_input = F.pad(input, (padding_w, padding_w, padding_h, padding_h))
|
||||
padded_input = F.pad(input, (padding_h, padding_h, padding_w, padding_w))
|
||||
|
||||
blocks_row_indices = blocks_row_indices.unsqueeze(-1).unsqueeze(-1)
|
||||
output = padded_input[:, :, blocks_row_indices, blocks_col_indices]
|
||||
output = output.permute(0, 1, 2, 4, 3, 5)
|
||||
return output.reshape(
|
||||
num_blocks_row = blocks_row_indices.size(1)
|
||||
num_blocks_col = blocks_col_indices.size(1)
|
||||
output = output.reshape(
|
||||
batch_dim, channel_dim * kernel_h * kernel_w, num_blocks_row * num_blocks_col
|
||||
)
|
||||
|
||||
if not batched_input:
|
||||
output = output.squeeze(0)
|
||||
return output
|
||||
|
||||
|
||||
@register_decomposition(aten.col2im)
|
||||
@out_wrapper()
|
||||
@pw_cast_for_opmath
|
||||
def col2im(
|
||||
input: Tensor,
|
||||
output_size: List[int],
|
||||
kernel_size: List[int],
|
||||
dilation: List[int],
|
||||
padding: List[int],
|
||||
stride: List[int],
|
||||
) -> Tensor:
|
||||
utils.check(len(output_size) == 2, lambda: "only 2D output_size supported")
|
||||
utils.check(len(kernel_size) == 2, lambda: "only 2D kernel supported")
|
||||
utils.check(len(dilation) == 2, lambda: "only 2D dilation supported")
|
||||
utils.check(len(padding) == 2, lambda: "only 2D padding supported")
|
||||
utils.check(len(stride) == 2, lambda: "only 2D stride supported")
|
||||
|
||||
def check_positive(param, param_name, strict=True):
|
||||
cond = all(p > 0 for p in param) if strict else all(p >= 0 for p in param)
|
||||
utils.check(
|
||||
cond, lambda: "{param_name} should be greater than zero, but got {param}"
|
||||
)
|
||||
|
||||
check_positive(kernel_size, "kernel_size")
|
||||
check_positive(dilation, "dilation")
|
||||
check_positive(padding, "padding", strict=False)
|
||||
check_positive(stride, "stride")
|
||||
check_positive(output_size, "output_size")
|
||||
|
||||
shape = input.shape
|
||||
ndim = len(shape)
|
||||
utils.check(
|
||||
ndim in (2, 3) and all(d != 0 for d in shape[-2:]),
|
||||
lambda: "Expected 2D or 3D (batch mode) tensor for input with possible 0 batch size "
|
||||
f"and non-zero dimensions, but got: {tuple(shape)}",
|
||||
)
|
||||
prod_kernel_size = kernel_size[0] * kernel_size[1]
|
||||
utils.check(
|
||||
shape[-2] % prod_kernel_size == 0,
|
||||
lambda: "Expected size of input's first non-batch dimension to be divisible by the "
|
||||
f"product of kernel_size, but got input.shape[-2] = {shape[-2]} and "
|
||||
f"kernel_size={kernel_size}",
|
||||
)
|
||||
col = [
|
||||
1 + (out + 2 * pad - dil * (ker - 1) - 1) // st
|
||||
for out, pad, dil, ker, st in zip(
|
||||
output_size, padding, dilation, kernel_size, stride
|
||||
)
|
||||
]
|
||||
L = col[0] * col[1]
|
||||
utils.check(
|
||||
shape[-1] == L,
|
||||
lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
|
||||
f"dilation={dilation}, padding={padding}, stride={stride}, "
|
||||
f"expected input.size(-1) to be {L} but got {shape[-1]}.",
|
||||
)
|
||||
utils.check(
|
||||
L > 0,
|
||||
lambda: f"Given output_size={output_size}, kernel_size={kernel_size}, "
|
||||
f"dilation={dilation}, padding={padding}, stride={stride}, "
|
||||
f"expected input.size(-1) to be {L} but got {shape[-1]}.",
|
||||
)
|
||||
batched_input = ndim == 3
|
||||
if not batched_input:
|
||||
input = input.unsqueeze(0)
|
||||
|
||||
shape = input.shape
|
||||
|
||||
out_h, out_w = output_size
|
||||
stride_h, stride_w = stride
|
||||
padding_h, padding_w = padding
|
||||
dilation_h, dilation_w = dilation
|
||||
kernel_h, kernel_w = kernel_size
|
||||
|
||||
# col2im is defined as the backwards of im2col, so we differentiate its decomposition by hand
|
||||
input = input.reshape([shape[0], shape[1] // prod_kernel_size] + kernel_size + col)
|
||||
input = input.permute(0, 1, 2, 4, 3, 5)
|
||||
|
||||
indices_row = _im2col_col2im_indices_along_dim(
|
||||
out_h, kernel_h, dilation_h, padding_h, stride_h, input.device
|
||||
)
|
||||
indices_row = _unsqueeze_to_dim(indices_row, 4)
|
||||
indices_col = _im2col_col2im_indices_along_dim(
|
||||
out_w, kernel_w, dilation_w, padding_w, stride_w, input.device
|
||||
)
|
||||
|
||||
output_padded_size = [o + 2 * p for o, p in zip(output_size, padding)]
|
||||
output = input.new_zeros(
|
||||
[shape[0], shape[1] // prod(kernel_size)] + output_padded_size
|
||||
)
|
||||
idx = (None, None, indices_row, indices_col)
|
||||
output = torch.ops.aten.index_put(output, idx, input, accumulate=True)
|
||||
output = F.pad(output, (-padding_h, -padding_h, -padding_w, -padding_w))
|
||||
|
||||
if not batched_input:
|
||||
output = output.squeeze(0)
|
||||
return output
|
||||
|
||||
|
||||
# TODO: the type annotations on arguments are not quite right
|
||||
|
||||
|
|
@ -1892,7 +2029,7 @@ def _upsample_cubic_interp1d(coeffs: TensorSequenceType, ts: Tensor) -> Tensor:
|
|||
|
||||
# Need this instead of just sum() to keep mypy happy
|
||||
def _sum_tensors(ts: Iterable[Tensor]) -> Tensor:
|
||||
return functools.reduce(torch.add, ts)
|
||||
return reduce(torch.add, ts)
|
||||
|
||||
|
||||
@register_decomposition(aten.grid_sampler_2d)
|
||||
|
|
@ -2174,7 +2311,7 @@ def matmul(tensor1, tensor2):
|
|||
# This can happen in e.g. [3, 5, 0] @ [0, 0].
|
||||
sizes_1 = t1.shape
|
||||
output_shape = list(sizes_1[:-1])
|
||||
folded_dim1 = functools.reduce(operator.mul, output_shape)
|
||||
folded_dim1 = reduce(operator.mul, output_shape)
|
||||
|
||||
# Readjust output_shape if we are multiplying by a matrix
|
||||
t2_is_matrix = t2.dim() == 2
|
||||
|
|
|
|||
|
|
@ -4662,16 +4662,7 @@ def unfold(
|
|||
return handle_torch_function(
|
||||
unfold, (input,), input, kernel_size, dilation=dilation, padding=padding, stride=stride
|
||||
)
|
||||
if input.dim() == 4:
|
||||
msg = "{} must be int or 2-tuple for 4D input"
|
||||
assert_int_or_pair(kernel_size, "kernel_size", msg)
|
||||
assert_int_or_pair(dilation, "dilation", msg)
|
||||
assert_int_or_pair(padding, "padding", msg)
|
||||
assert_int_or_pair(stride, "stride", msg)
|
||||
|
||||
return torch._C._nn.im2col(input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
|
||||
else:
|
||||
raise NotImplementedError("Input Error: Only 4D input Tensors are supported (got {}D)".format(input.dim()))
|
||||
return torch._C._nn.im2col(input, _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride))
|
||||
|
||||
|
||||
def fold(
|
||||
|
|
@ -4693,20 +4684,9 @@ def fold(
|
|||
return handle_torch_function(
|
||||
fold, (input,), input, output_size, kernel_size, dilation=dilation, padding=padding, stride=stride
|
||||
)
|
||||
if input.dim() == 3 or input.dim() == 2:
|
||||
msg = "{} must be int or 2-tuple for 3D input"
|
||||
assert_int_or_pair(output_size, "output_size", msg)
|
||||
assert_int_or_pair(kernel_size, "kernel_size", msg)
|
||||
assert_int_or_pair(dilation, "dilation", msg)
|
||||
assert_int_or_pair(padding, "padding", msg)
|
||||
assert_int_or_pair(stride, "stride", msg)
|
||||
|
||||
return torch._C._nn.col2im(
|
||||
input, _pair(output_size), _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError("Input Error: Only unbatched (2D) or batched (3D) input Tensors"
|
||||
f"are supported (got {input.dim()}D)")
|
||||
return torch._C._nn.col2im(
|
||||
input, _pair(output_size), _pair(kernel_size), _pair(dilation), _pair(padding), _pair(stride)
|
||||
)
|
||||
|
||||
#
|
||||
# multihead attention
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user