Add decomposition for unsqueeze_copy (#130942)

* Extracted from #128416
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130942
Approved by: https://github.com/peterbell10
This commit is contained in:
Tom Ritchford 2024-07-29 17:32:06 +00:00 committed by PyTorch MergeBot
parent 3c1562158e
commit bdf5a6dca9
9 changed files with 41 additions and 2 deletions

View File

@ -1316,8 +1316,6 @@ aten::unique_dim_consecutive.out
aten::unsafe_split.Tensor_out
aten::unsafe_split_with_sizes.out
aten::unsqueeze_
aten::unsqueeze_copy
aten::unsqueeze_copy.out
aten::upsample_bicubic2d_backward
aten::upsample_bicubic2d_backward.grad_input
aten::upsample_bilinear2d_backward

View File

@ -1429,6 +1429,7 @@ class TestOperators(TestCase):
xfail("masked.cumprod", ""),
xfail("renorm"), # hit vmap fallback, which is disabled
xfail("t_copy"),
xfail("unsqueeze_copy"),
}
),
)
@ -1566,6 +1567,7 @@ class TestOperators(TestCase):
"index_fill"
), # aten::_unique hit the vmap fallback which is currently disabled
xfail("t_copy"),
xfail("unsqueeze_copy"),
}
),
)

View File

@ -4363,6 +4363,7 @@ class TestVmapOperatorsOpInfo(TestCase):
xfail("as_strided"),
xfail("as_strided_copy"),
xfail("t_copy"),
xfail("unsqueeze_copy"),
xfail("istft"),
xfail("nonzero"),
xfail("nn.functional.fractional_max_pool2d"),

View File

@ -1263,6 +1263,11 @@ EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES: Tuple[onnx_test_common.DecorateMeta, ...] =
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
reason=onnx_test_common.reason_onnx_script_does_not_support("Floor", "bool, int"),
),
xfail(
"unsqueeze_copy",
reason="OnnxExporterError: Failed to export model",
dtypes=(torch.int8, torch.uint8, torch.int16),
),
xfail(
"where",
dtypes=onnx_test_common.BOOL_TYPES,

View File

@ -352,6 +352,7 @@ def mps_ops_modifier(ops):
'unsafe_chunk',
'unsafe_split',
'unsqueeze',
'unsqueeze_copy',
'view_as',
'view_as_real',
'view',

View File

@ -200,6 +200,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
"permute",
"squeeze",
"unsqueeze",
"unsqueeze_copy",
"resize",
"resize_as",
"tril",

View File

@ -460,6 +460,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
aten._unsafe_masked_index_put_accumulate,
aten.unsafe_split.Tensor,
aten.unsafe_split_with_sizes,
aten.unsqueeze_copy,
aten._unsafe_view,
aten.upsample_linear1d,
aten.upsample_bilinear2d,

View File

@ -294,6 +294,7 @@ __all__ = [
"unfold",
"unfold_copy",
"unsqueeze",
"unsqueeze_copy",
"view",
"view_as",
"view_copy",
@ -6321,6 +6322,7 @@ expand_copy = _make_copy_from_view(aten.expand)
# no sparse support. See narrow_copy_sparse in core.
narrow_copy = _make_copy_from_view(aten.narrow)
t_copy = _make_copy_from_view(aten.t)
unsqueeze_copy = _make_copy_from_view(aten.unsqueeze)
view_copy = _make_copy_from_view(aten.view)

View File

@ -19653,6 +19653,29 @@ op_db: List[OpInfo] = [
autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused
autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused
sample_inputs_func=sample_unsqueeze),
OpInfo('unsqueeze_copy',
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
supports_out=True,
supports_forward_ad=True,
supports_fwgrad_bwgrad=True,
# See https://github.com/pytorch/pytorch/pull/78358
check_batched_forward_grad=False,
# vmap does not support inplace views
check_inplace_batched_forward_grad=False,
assert_jit_shape_analysis=True,
assert_autodiffed=True,
autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused
autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused
sample_inputs_func=sample_unsqueeze,
skips=(
DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
DecorateInfo(
unittest.expectedFailure,
'TestJit',
'test_variant_consistency_jit',
dtypes=(torch.float32,),
),
)),
BinaryUfuncInfo('xlogy',
aliases=('special.xlogy',),
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
@ -23946,6 +23969,11 @@ python_ref_db = [
"_refs.unsqueeze",
torch_opinfo_name="unsqueeze",
),
PythonRefInfo(
"_refs.unsqueeze_copy",
torch_opinfo_name="unsqueeze_copy",
supports_out=True,
),
PythonRefInfo(
"_refs.view",
torch_opinfo_name="view",