mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
dynamic shape support for interpolate(antialias=True) backward (#141198)
Fixes https://github.com/pytorch/pytorch/issues/141187 Pull Request resolved: https://github.com/pytorch/pytorch/pull/141198 Approved by: https://github.com/ezyang, https://github.com/Chillee ghstack dependencies: #141161
This commit is contained in:
parent
4831f89790
commit
d7f45fc575
|
|
@ -6501,9 +6501,6 @@ symbolic_aot_autograd_failures = {
|
|||
"nn.functional.nll_loss", ""
|
||||
), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail("trace", ""), # Cannot call sizes() on tensor with symbolic sizes/strides
|
||||
xfail(
|
||||
"_upsample_bilinear2d_aa"
|
||||
), # RuntimeError: isIntList() INTERNAL ASSERT FAILED Expected IntList but got GenericList
|
||||
decorate(
|
||||
"linalg.householder_product",
|
||||
decorator=unittest.skipIf(IS_MACOS and IS_X86, "flaky"),
|
||||
|
|
|
|||
|
|
@ -1954,7 +1954,6 @@ class FakeTensorDispatchCache(TestCase):
|
|||
extract_tensor_metadata(res4),
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(not RUN_CUDA, "requires cuda")
|
||||
def test_wrapper_tensor_subclass_different_device(self):
|
||||
class DifferentDeviceTensor(torch.Tensor):
|
||||
|
|
@ -2007,6 +2006,29 @@ class FakeTensorDispatchCache(TestCase):
|
|||
assert isinstance(fake_wrapped_a, DifferentDeviceTensor)
|
||||
self.assertFalse(fake_wrapped_a.inner_tensor.is_cpu)
|
||||
|
||||
def test__upsample_bilinear2d_aa_backward_dynamic_shapes(self):
|
||||
def f(x):
|
||||
return torch.nn.functional.interpolate(
|
||||
x,
|
||||
size=[256, 256],
|
||||
mode='bilinear',
|
||||
align_corners=False,
|
||||
antialias=True,
|
||||
)
|
||||
|
||||
shape_env = ShapeEnv()
|
||||
fake_m = FakeTensorMode(shape_env=shape_env)
|
||||
x = fake_m.from_tensor(
|
||||
torch.randn(1, 3, 2005, 1920, requires_grad=True),
|
||||
symbolic_context=StatelessSymbolicContext(
|
||||
dynamic_sizes=[DimDynamic.STATIC, DimDynamic.STATIC, DimDynamic.DYNAMIC, DimDynamic.DYNAMIC],
|
||||
constraint_sizes=[None, None, None, None]
|
||||
),
|
||||
)
|
||||
with fake_m, enable_python_dispatcher():
|
||||
out = f(x)
|
||||
out.sum().backward()
|
||||
self.assertEqual(x.shape, x.grad.shape)
|
||||
|
||||
def test_cache_tuple_outputs(self):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -6535,6 +6535,34 @@ def meta_upsample_bimode2d_aa(
|
|||
)
|
||||
|
||||
|
||||
@register_meta([aten._upsample_bilinear2d_aa_backward.default])
|
||||
def meta_upsample_bimode2d_aa_backward(
|
||||
grad_output,
|
||||
output_size,
|
||||
input_size,
|
||||
align_corners,
|
||||
scales_h=None,
|
||||
scales_w=None,
|
||||
):
|
||||
full_output_size = upsample_common_check(
|
||||
input_size, output_size, num_spatial_dims=2
|
||||
)
|
||||
torch._check(
|
||||
grad_output.ndim == 4,
|
||||
lambda: f"Expected grad_output to be a tensor of dimension 4 but got: dimension {grad_output.ndim}",
|
||||
)
|
||||
for i in range(4):
|
||||
torch._check(
|
||||
grad_output.shape[i] == full_output_size[i],
|
||||
lambda: f"""
|
||||
Expected grad_output to have the same shape as output; output.size({i}) = {full_output_size[i]}
|
||||
but got grad_output_size({i}) = {grad_output.size(i)}""",
|
||||
)
|
||||
return grad_output.new_empty(input_size).to(
|
||||
memory_format=utils.suggest_memory_format(grad_output)
|
||||
)
|
||||
|
||||
|
||||
# From aten/src/ATen/native/cuda/AmpKernels.cu
|
||||
@register_meta(aten._amp_foreach_non_finite_check_and_unscale_.default)
|
||||
def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale):
|
||||
|
|
|
|||
|
|
@ -15660,7 +15660,6 @@ op_db: List[OpInfo] = [
|
|||
skips=(
|
||||
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestEagerFusionOpInfo', 'test_aot_autograd_symbolic_exhaustive'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'),
|
||||
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
|
||||
)),
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user