mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add decomposition for unsqueeze_copy (#130942)
* Extracted from #128416 Pull Request resolved: https://github.com/pytorch/pytorch/pull/130942 Approved by: https://github.com/peterbell10
This commit is contained in:
parent
3c1562158e
commit
bdf5a6dca9
|
|
@ -1316,8 +1316,6 @@ aten::unique_dim_consecutive.out
|
|||
aten::unsafe_split.Tensor_out
|
||||
aten::unsafe_split_with_sizes.out
|
||||
aten::unsqueeze_
|
||||
aten::unsqueeze_copy
|
||||
aten::unsqueeze_copy.out
|
||||
aten::upsample_bicubic2d_backward
|
||||
aten::upsample_bicubic2d_backward.grad_input
|
||||
aten::upsample_bilinear2d_backward
|
||||
|
|
|
|||
|
|
@ -1429,6 +1429,7 @@ class TestOperators(TestCase):
|
|||
xfail("masked.cumprod", ""),
|
||||
xfail("renorm"), # hit vmap fallback, which is disabled
|
||||
xfail("t_copy"),
|
||||
xfail("unsqueeze_copy"),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
|
@ -1566,6 +1567,7 @@ class TestOperators(TestCase):
|
|||
"index_fill"
|
||||
), # aten::_unique hit the vmap fallback which is currently disabled
|
||||
xfail("t_copy"),
|
||||
xfail("unsqueeze_copy"),
|
||||
}
|
||||
),
|
||||
)
|
||||
|
|
|
|||
|
|
@ -4363,6 +4363,7 @@ class TestVmapOperatorsOpInfo(TestCase):
|
|||
xfail("as_strided"),
|
||||
xfail("as_strided_copy"),
|
||||
xfail("t_copy"),
|
||||
xfail("unsqueeze_copy"),
|
||||
xfail("istft"),
|
||||
xfail("nonzero"),
|
||||
xfail("nn.functional.fractional_max_pool2d"),
|
||||
|
|
|
|||
|
|
@ -1263,6 +1263,11 @@ EXPECTED_SKIPS_OR_FAILS_WITH_DTYPES: Tuple[onnx_test_common.DecorateMeta, ...] =
|
|||
dtypes=onnx_test_common.BOOL_TYPES + onnx_test_common.INT_TYPES,
|
||||
reason=onnx_test_common.reason_onnx_script_does_not_support("Floor", "bool, int"),
|
||||
),
|
||||
xfail(
|
||||
"unsqueeze_copy",
|
||||
reason="OnnxExporterError: Failed to export model",
|
||||
dtypes=(torch.int8, torch.uint8, torch.int16),
|
||||
),
|
||||
xfail(
|
||||
"where",
|
||||
dtypes=onnx_test_common.BOOL_TYPES,
|
||||
|
|
|
|||
|
|
@ -352,6 +352,7 @@ def mps_ops_modifier(ops):
|
|||
'unsafe_chunk',
|
||||
'unsafe_split',
|
||||
'unsqueeze',
|
||||
'unsqueeze_copy',
|
||||
'view_as',
|
||||
'view_as_real',
|
||||
'view',
|
||||
|
|
|
|||
|
|
@ -200,6 +200,7 @@ GRADIENT_IMPLEMENTED_FOR_COMPLEX = {
|
|||
"permute",
|
||||
"squeeze",
|
||||
"unsqueeze",
|
||||
"unsqueeze_copy",
|
||||
"resize",
|
||||
"resize_as",
|
||||
"tril",
|
||||
|
|
|
|||
|
|
@ -460,6 +460,7 @@ def core_aten_decompositions() -> Dict[torch._ops.OperatorBase, Callable]:
|
|||
aten._unsafe_masked_index_put_accumulate,
|
||||
aten.unsafe_split.Tensor,
|
||||
aten.unsafe_split_with_sizes,
|
||||
aten.unsqueeze_copy,
|
||||
aten._unsafe_view,
|
||||
aten.upsample_linear1d,
|
||||
aten.upsample_bilinear2d,
|
||||
|
|
|
|||
|
|
@ -294,6 +294,7 @@ __all__ = [
|
|||
"unfold",
|
||||
"unfold_copy",
|
||||
"unsqueeze",
|
||||
"unsqueeze_copy",
|
||||
"view",
|
||||
"view_as",
|
||||
"view_copy",
|
||||
|
|
@ -6321,6 +6322,7 @@ expand_copy = _make_copy_from_view(aten.expand)
|
|||
# no sparse support. See narrow_copy_sparse in core.
|
||||
narrow_copy = _make_copy_from_view(aten.narrow)
|
||||
t_copy = _make_copy_from_view(aten.t)
|
||||
unsqueeze_copy = _make_copy_from_view(aten.unsqueeze)
|
||||
view_copy = _make_copy_from_view(aten.view)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -19653,6 +19653,29 @@ op_db: List[OpInfo] = [
|
|||
autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused
|
||||
autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused
|
||||
sample_inputs_func=sample_unsqueeze),
|
||||
OpInfo('unsqueeze_copy',
|
||||
dtypes=all_types_and_complex_and(torch.bool, torch.float16, torch.bfloat16, torch.chalf),
|
||||
supports_out=True,
|
||||
supports_forward_ad=True,
|
||||
supports_fwgrad_bwgrad=True,
|
||||
# See https://github.com/pytorch/pytorch/pull/78358
|
||||
check_batched_forward_grad=False,
|
||||
# vmap does not support inplace views
|
||||
check_inplace_batched_forward_grad=False,
|
||||
assert_jit_shape_analysis=True,
|
||||
assert_autodiffed=True,
|
||||
autodiff_fusible_nodes=[], # aliases inputs, shouldn't be fused
|
||||
autodiff_nonfusible_nodes=[], # aliases inputs, shouldn't be fused
|
||||
sample_inputs_func=sample_unsqueeze,
|
||||
skips=(
|
||||
DecorateInfo(unittest.expectedFailure, 'TestDTensorOps', 'test_dtensor_op_db'),
|
||||
DecorateInfo(
|
||||
unittest.expectedFailure,
|
||||
'TestJit',
|
||||
'test_variant_consistency_jit',
|
||||
dtypes=(torch.float32,),
|
||||
),
|
||||
)),
|
||||
BinaryUfuncInfo('xlogy',
|
||||
aliases=('special.xlogy',),
|
||||
dtypes=all_types_and(torch.bool, torch.half, torch.bfloat16),
|
||||
|
|
@ -23946,6 +23969,11 @@ python_ref_db = [
|
|||
"_refs.unsqueeze",
|
||||
torch_opinfo_name="unsqueeze",
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.unsqueeze_copy",
|
||||
torch_opinfo_name="unsqueeze_copy",
|
||||
supports_out=True,
|
||||
),
|
||||
PythonRefInfo(
|
||||
"_refs.view",
|
||||
torch_opinfo_name="view",
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user