dynamic shape support for interpolate(antialias=True) backward (#141198)

Fixes https://github.com/pytorch/pytorch/issues/141187

Pull Request resolved: https://github.com/pytorch/pytorch/pull/141198
Approved by: https://github.com/ezyang, https://github.com/Chillee
ghstack dependencies: #141161
This commit is contained in:
Brian Hirsh 2025-01-15 11:11:32 -08:00 committed by PyTorch MergeBot
parent 4831f89790
commit d7f45fc575
4 changed files with 51 additions and 5 deletions

View File

@ -6501,9 +6501,6 @@ symbolic_aot_autograd_failures = {
"nn.functional.nll_loss", ""
), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail("trace", ""), # Cannot call sizes() on tensor with symbolic sizes/strides
xfail(
"_upsample_bilinear2d_aa"
), # RuntimeError: isIntList() INTERNAL ASSERT FAILED Expected IntList but got GenericList
decorate(
"linalg.householder_product",
decorator=unittest.skipIf(IS_MACOS and IS_X86, "flaky"),

View File

@ -1954,7 +1954,6 @@ class FakeTensorDispatchCache(TestCase):
extract_tensor_metadata(res4),
)
@unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_wrapper_tensor_subclass_different_device(self):
class DifferentDeviceTensor(torch.Tensor):
@ -2007,6 +2006,29 @@ class FakeTensorDispatchCache(TestCase):
assert isinstance(fake_wrapped_a, DifferentDeviceTensor)
self.assertFalse(fake_wrapped_a.inner_tensor.is_cpu)
def test__upsample_bilinear2d_aa_backward_dynamic_shapes(self):
def f(x):
return torch.nn.functional.interpolate(
x,
size=[256, 256],
mode='bilinear',
align_corners=False,
antialias=True,
)
shape_env = ShapeEnv()
fake_m = FakeTensorMode(shape_env=shape_env)
x = fake_m.from_tensor(
torch.randn(1, 3, 2005, 1920, requires_grad=True),
symbolic_context=StatelessSymbolicContext(
dynamic_sizes=[DimDynamic.STATIC, DimDynamic.STATIC, DimDynamic.DYNAMIC, DimDynamic.DYNAMIC],
constraint_sizes=[None, None, None, None]
),
)
with fake_m, enable_python_dispatcher():
out = f(x)
out.sum().backward()
self.assertEqual(x.shape, x.grad.shape)
def test_cache_tuple_outputs(self):
"""

View File

@ -6535,6 +6535,34 @@ def meta_upsample_bimode2d_aa(
)
@register_meta([aten._upsample_bilinear2d_aa_backward.default])
def meta_upsample_bimode2d_aa_backward(
grad_output,
output_size,
input_size,
align_corners,
scales_h=None,
scales_w=None,
):
full_output_size = upsample_common_check(
input_size, output_size, num_spatial_dims=2
)
torch._check(
grad_output.ndim == 4,
lambda: f"Expected grad_output to be a tensor of dimension 4 but got: dimension {grad_output.ndim}",
)
for i in range(4):
torch._check(
grad_output.shape[i] == full_output_size[i],
lambda: f"""
Expected grad_output to have the same shape as output; output.size({i}) = {full_output_size[i]}
but got grad_output_size({i}) = {grad_output.size(i)}""",
)
return grad_output.new_empty(input_size).to(
memory_format=utils.suggest_memory_format(grad_output)
)
# From aten/src/ATen/native/cuda/AmpKernels.cu
@register_meta(aten._amp_foreach_non_finite_check_and_unscale_.default)
def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale):

View File

@ -15660,7 +15660,6 @@ op_db: List[OpInfo] = [
skips=(
DecorateInfo(unittest.expectedFailure, 'TestJit', 'test_variant_consistency_jit'),
DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
DecorateInfo(unittest.expectedFailure, 'TestEagerFusionOpInfo', 'test_aot_autograd_symbolic_exhaustive'),
DecorateInfo(unittest.expectedFailure, 'TestInductorOpInfo', 'test_comprehensive'),
DecorateInfo(unittest.expectedFailure, 'TestMathBits', 'test_neg_view'),
)),