diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp index 41f24436530..00ec4395b0c 100644 --- a/c10/core/DispatchKeySet.cpp +++ b/c10/core/DispatchKeySet.cpp @@ -46,8 +46,8 @@ bool isBackendDispatchKey(DispatchKey t) { // math_dispatch_keyset contains all keys in backend_dispatch_keyset and // autograd_dispatch_keyset Alias key DispatchKey::CompositeImplicitAutograd // maps to [math_dispatch_keyset x full_backend_mask] -constexpr DispatchKeySet math_dispatch_keyset = backend_dispatch_keyset | - autograd_dispatch_keyset | DispatchKeySet(DispatchKey::Python); +constexpr DispatchKeySet math_dispatch_keyset = + backend_dispatch_keyset | autograd_dispatch_keyset; DispatchKeySet getRuntimeDispatchKeySet(DispatchKey t) { TORCH_INTERNAL_ASSERT(t != DispatchKey::Undefined); diff --git a/test/test_decomp.py b/test/test_decomp.py index 00be5ff4b73..16f64a0229d 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -279,8 +279,6 @@ CROSS_REF_EXCLUDE_SET = { # CompositeAutogradImplicit # See https://github.com/pytorch/pytorch/issues/81669 (None, None, "nn.functional.relu6"), - (None, None, "nn.functional.mish"), - (None, None, "nn.functional.silu"), } diff --git a/test/test_functionalization.py b/test/test_functionalization.py index 719634d3dbc..d6be39c2063 100644 --- a/test/test_functionalization.py +++ b/test/test_functionalization.py @@ -102,9 +102,10 @@ class TestFunctionalization(TestCase): def forward(self, a_1): - ones = torch.ops.aten.ones.default([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None - add_tensor = torch.ops.aten.add.Tensor(view_copy_default, ones); view_copy_default = ones = None + add_tensor = torch.ops.aten.add.Tensor(view_copy_default, fill_scalar); view_copy_default = fill_scalar = None view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]) mul_tensor = torch.ops.aten.mul.Tensor(view_copy_default_1, view_copy_default_1); view_copy_default_1 = None return add_tensor @@ -126,10 +127,11 @@ def forward(self, a_1): def forward(self, a_1): - ones = torch.ops.aten.ones.default([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None - empty = torch.ops.aten.empty.memory_format([], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) - add_tensor = torch.ops.aten.add.Tensor(view_copy_default, ones); view_copy_default = ones = None + empty_1 = torch.ops.aten.empty.memory_format([], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + add_tensor = torch.ops.aten.add.Tensor(view_copy_default, fill_scalar); view_copy_default = fill_scalar = None mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, add_tensor); add_tensor = None return mul_tensor """) @@ -180,9 +182,10 @@ def forward(self, a_1): def forward(self, a_1): - ones = torch.ops.aten.ones.default([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]) - add_tensor = torch.ops.aten.add.Tensor(a_1, ones); a_1 = ones = None + add_tensor = torch.ops.aten.add.Tensor(a_1, fill_scalar); a_1 = fill_scalar = None view_copy_default_1 = torch.ops.aten.view_copy.default(add_tensor, [4, 2]); add_tensor = None return view_copy_default_1 """) @@ -275,9 +278,10 @@ def forward(self, a_1): def forward(self, a_1): - ones = torch.ops.aten.ones.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + empty = torch.ops.aten.empty.memory_format([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None diagonal_copy_default = torch.ops.aten.diagonal_copy.default(a_1) - add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, ones); diagonal_copy_default = ones = None + add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, fill_scalar); diagonal_copy_default = fill_scalar = None diagonal_scatter_default = torch.ops.aten.diagonal_scatter.default(a_1, add_tensor); a_1 = add_tensor = None mul_tensor = torch.ops.aten.mul.Tensor(diagonal_scatter_default, diagonal_scatter_default); diagonal_scatter_default = None return mul_tensor @@ -309,12 +313,13 @@ def forward(self, a_1): def forward(self, a_1): - ones = torch.ops.aten.ones.default([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + empty = torch.ops.aten.empty.memory_format([2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None split_copy_tensor = torch.ops.aten.split_copy.Tensor(a_1, 2) getitem = split_copy_tensor[0] getitem_1 = split_copy_tensor[1]; split_copy_tensor = None diagonal_copy_default = torch.ops.aten.diagonal_copy.default(getitem_1); getitem_1 = None - add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, ones); diagonal_copy_default = ones = None + add_tensor = torch.ops.aten.add.Tensor(diagonal_copy_default, fill_scalar); diagonal_copy_default = fill_scalar = None split_copy_tensor_1 = torch.ops.aten.split_copy.Tensor(a_1, 2) getitem_2 = split_copy_tensor_1[0] getitem_3 = split_copy_tensor_1[1]; split_copy_tensor_1 = None @@ -339,10 +344,11 @@ def forward(self, a_1): def forward(self, a_1): - ones = torch.ops.aten.ones.default([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + empty = torch.ops.aten.empty.memory_format([4], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None transpose_copy_int = torch.ops.aten.transpose_copy.int(a_1, 1, 0) select_copy_int = torch.ops.aten.select_copy.int(transpose_copy_int, 0, 0); transpose_copy_int = None - add_tensor = torch.ops.aten.add.Tensor(select_copy_int, ones); select_copy_int = ones = None + add_tensor = torch.ops.aten.add.Tensor(select_copy_int, fill_scalar); select_copy_int = fill_scalar = None transpose_copy_int_1 = torch.ops.aten.transpose_copy.int(a_1, 1, 0); a_1 = None select_scatter_default = torch.ops.aten.select_scatter.default(transpose_copy_int_1, add_tensor, 0, 0); transpose_copy_int_1 = add_tensor = None transpose_copy_int_2 = torch.ops.aten.transpose_copy.int(select_scatter_default, 1, 0); select_scatter_default = None @@ -367,8 +373,10 @@ def forward(self, a_1): def forward(self, a_1): view_copy_default = torch.ops.aten.view_copy.default(a_1, [8]); a_1 = None - arange = torch.ops.aten.arange.default(4, layout = torch.strided, device = device(type='cpu'), pin_memory = False) - arange_1 = torch.ops.aten.arange.default(4, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) + empty = torch.ops.aten.empty.memory_format([0], dtype = torch.int64, layout = torch.strided, device = device(type='cpu'), pin_memory = False) + arange = torch.ops.aten.arange.start_step(0, 4, 1, dtype = torch.int64, layout = torch.strided, device = device(type='cpu')) + empty_1 = torch.ops.aten.empty.memory_format([0], dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) + arange_1 = torch.ops.aten.arange.start_step(0, 4, 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu')) index_put_default = torch.ops.aten.index_put.default(view_copy_default, [arange], arange_1); view_copy_default = arange = arange_1 = None view_copy_default_1 = torch.ops.aten.view_copy.default(index_put_default, [4, 2]) return index_put_default @@ -390,7 +398,8 @@ def forward(self, a_1): def forward(self, a_1): - ones = torch.ops.aten.ones.default([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None view_copy_default = torch.ops.aten.view_copy.default(a_1, [4, 2]); a_1 = None add_tensor = torch.ops.aten.add.Tensor(view_copy_default, 1); view_copy_default = None mul_tensor = torch.ops.aten.mul.Tensor(add_tensor, 2) @@ -413,8 +422,8 @@ def forward(self, a_1): def forward(self, a_1): ge_scalar = torch.ops.aten.ge.Scalar(a_1, 0); a_1 = None - _to_copy_default = torch.ops.aten._to_copy.default(ge_scalar, dtype = torch.float32, layout = torch.strided); ge_scalar = None - return _to_copy_default + to_dtype_layout = torch.ops.aten.to.dtype_layout(ge_scalar, dtype = torch.float32, layout = torch.strided); ge_scalar = None + return to_dtype_layout """) @skipIfTorchDynamo("Test does not work with TorchDynamo") @@ -472,7 +481,8 @@ def forward(self, a_1): def forward(self, a_1): - ones = torch.ops.aten.ones.default([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None view_copy_default = torch.ops.aten.view_copy.default(add_tensor, [8]) _reshape_alias_copy_default = torch.ops.aten._reshape_alias_copy.default(view_copy_default, [2, 4], [4, 1]); view_copy_default = None @@ -482,7 +492,7 @@ def forward(self, a_1): split_copy_tensor = torch.ops.aten.split_copy.Tensor(squeeze_copy_default, 2); squeeze_copy_default = None getitem = split_copy_tensor[0] getitem_1 = split_copy_tensor[1]; split_copy_tensor = None - add_tensor_1 = torch.ops.aten.add.Tensor(getitem, ones); getitem = ones = None + add_tensor_1 = torch.ops.aten.add.Tensor(getitem, fill_scalar); getitem = fill_scalar = None select_copy_int = torch.ops.aten.select_copy.int(_reshape_alias_copy_default, 0, 0); _reshape_alias_copy_default = None clone_default = torch.ops.aten.clone.default(add_tensor_1, memory_format = torch.contiguous_format) _unsafe_view_default = torch.ops.aten._unsafe_view.default(clone_default, [4]); clone_default = None @@ -518,9 +528,10 @@ def forward(self, a_1): def forward(self, a_1): - ones = torch.ops.aten.ones.default([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + empty = torch.ops.aten.empty.memory_format([4, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + fill_scalar = torch.ops.aten.fill.Scalar(empty, 1.0); empty = None view_default = torch.ops.aten.view.default(a_1, [4, 2]); a_1 = None - add_tensor = torch.ops.aten.add.Tensor(view_default, ones); view_default = ones = None + add_tensor = torch.ops.aten.add.Tensor(view_default, fill_scalar); view_default = fill_scalar = None view_default_1 = torch.ops.aten.view.default(add_tensor, [4, 2]) mul_tensor = torch.ops.aten.mul.Tensor(view_default_1, view_default_1); view_default_1 = None return add_tensor @@ -569,9 +580,12 @@ def forward(self, a_1): def forward(self, a_1): - zeros = torch.ops.aten.zeros.default([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) - diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None - add_tensor = torch.ops.aten.add.Tensor(a_1, a_1); a_1 = None + empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + zero_default = torch.ops.aten.zero.default(empty); empty = None + diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zero_default) + diagonal_copy_default_1 = torch.ops.aten.diagonal_copy.default(zero_default); zero_default = None + copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None + add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None return add_tensor """) @@ -583,10 +597,12 @@ def forward(self, a_1): def forward(self, a_1): - zeros = torch.ops.aten.zeros.default([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) - diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None - expand_copy_default = torch.ops.aten.expand_copy.default(a_1, [2]) - add_tensor = torch.ops.aten.add.Tensor(expand_copy_default, a_1); expand_copy_default = a_1 = None + empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + zero_default = torch.ops.aten.zero.default(empty); empty = None + diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zero_default) + diagonal_copy_default_1 = torch.ops.aten.diagonal_copy.default(zero_default); zero_default = None + copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None + add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None return add_tensor """) @@ -598,12 +614,14 @@ def forward(self, a_1): def forward(self, a_1): - zeros = torch.ops.aten.zeros.default([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) - diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None - _to_copy_default = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) - add_tensor = torch.ops.aten.add.Tensor(_to_copy_default, a_1); _to_copy_default = a_1 = None + empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + zero_default = torch.ops.aten.zero.default(empty); empty = None + diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zero_default) + diagonal_copy_default_1 = torch.ops.aten.diagonal_copy.default(zero_default); zero_default = None + copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None + add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None return add_tensor - """) # noqa: B950 + """) # Test 4: copy_() with different dtype, different shape self.assert_functionalization(f, torch.ones(1, dtype=torch.long)) @@ -613,13 +631,14 @@ def forward(self, a_1): def forward(self, a_1): - zeros = torch.ops.aten.zeros.default([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) - diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zeros); zeros = None - _to_copy_default = torch.ops.aten._to_copy.default(a_1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) - expand_copy_default = torch.ops.aten.expand_copy.default(_to_copy_default, [2]); _to_copy_default = None - add_tensor = torch.ops.aten.add.Tensor(expand_copy_default, a_1); expand_copy_default = a_1 = None + empty = torch.ops.aten.empty.memory_format([2, 2], dtype = torch.float32, device = device(type='cpu'), pin_memory = False) + zero_default = torch.ops.aten.zero.default(empty); empty = None + diagonal_copy_default = torch.ops.aten.diagonal_copy.default(zero_default) + diagonal_copy_default_1 = torch.ops.aten.diagonal_copy.default(zero_default); zero_default = None + copy_default = torch.ops.aten.copy.default(diagonal_copy_default_1, a_1); diagonal_copy_default_1 = None + add_tensor = torch.ops.aten.add.Tensor(copy_default, a_1); copy_default = a_1 = None return add_tensor - """) # noqa: B950 + """) def test_expand_symint(self): # Once some existing SymInt bugs are ironed out, we should update