diff --git a/test/expect/HasDecompTest.test_has_decomposition.expect b/test/expect/HasDecompTest.test_has_decomposition.expect index 74ead8e1773..042959c22cd 100644 --- a/test/expect/HasDecompTest.test_has_decomposition.expect +++ b/test/expect/HasDecompTest.test_has_decomposition.expect @@ -823,6 +823,8 @@ aten::from_file aten::from_file.out aten::full.names aten::full.names_out +aten::full_like +aten::full_like.out aten::gather aten::gather.out aten::geqrf diff --git a/test/export/test_experimental.py b/test/export/test_experimental.py index 168a5846338..641dd586edb 100644 --- a/test/export/test_experimental.py +++ b/test/export/test_experimental.py @@ -52,8 +52,8 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x): sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None neg = torch.ops.aten.neg.default(sum_1); sum_1 = None div = torch.ops.aten.div.Scalar(neg, 1); neg = None - full = torch.ops.aten.full.default([], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) - div_1 = torch.ops.aten.div.Scalar(full, 1); full = None + full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format) + div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None @@ -98,8 +98,8 @@ def forward(self, p_linear_weight, p_linear_bias, c_lifted_tensor_0, x): sum_1 = torch.ops.aten.sum.dim_IntList(mul, []); mul = None neg = torch.ops.aten.neg.default(sum_1); sum_1 = None div = torch.ops.aten.div.Scalar(neg, 1); neg = None - full = torch.ops.aten.full.default([], 1, dtype = torch.float32, layout = torch.strided, device = device(type='cpu'), pin_memory = False) - div_1 = torch.ops.aten.div.Scalar(full, 1); full = None + full_like = torch.ops.aten.full_like.default(div, 1, pin_memory = False, memory_format = torch.preserve_format) + div_1 = torch.ops.aten.div.Scalar(full_like, 1); full_like = None neg_1 = torch.ops.aten.neg.default(div_1); div_1 = None expand = torch.ops.aten.expand.default(neg_1, [3]); neg_1 = None mul_1 = torch.ops.aten.mul.Tensor(expand, clone); expand = clone = None diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index 9e761e82ed8..17a0da43a2a 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -432,8 +432,6 @@ def check_model( check_gradient=False, check_has_compiled=True, output_process_fn_grad=lambda x: x, - # TODO: enable this for all tests - exact_stride=False, ): kwargs = kwargs or {} torch._dynamo.reset() @@ -546,7 +544,6 @@ def check_model( rtol=rtol, equal_nan=True, exact_dtype=exact_dtype, - exact_stride=exact_stride, ) # In case of input mutations, check that inputs are the same self.assertEqual( @@ -557,7 +554,6 @@ def check_model( equal_nan=True, # our testing sometimes uses higher precision inputs for the reference exact_dtype=False, - exact_stride=exact_stride, ) else: for correct_val, actual_val in zip(correct_flat, actual_flat): @@ -571,8 +567,6 @@ def check_model( assert correct_val.layout == actual_val.layout if exact_dtype: assert correct_val.dtype == actual_val.dtype - if exact_stride: - assert correct_val.stride() == actual_val.stride() if check_gradient: actual = output_process_fn_grad(actual) @@ -626,7 +620,6 @@ def check_model( rtol=grad_rtol or rtol, equal_nan=True, exact_dtype=exact_dtype, - exact_stride=exact_stride, ) torch._dynamo.reset() @@ -652,8 +645,6 @@ def check_model_gpu( check_gradient=False, check_has_compiled=True, output_process_fn_grad=lambda x: x, - # TODO: enable this for all tests - exact_stride=False, ): kwargs = kwargs or {} if hasattr(model, "to"): @@ -680,7 +671,6 @@ def check_model_gpu( check_gradient=check_gradient, check_has_compiled=check_has_compiled, output_process_fn_grad=output_process_fn_grad, - exact_stride=exact_stride, ) if check_lowp: @@ -713,7 +703,6 @@ def check_model_gpu( check_gradient=check_gradient, check_has_compiled=check_has_compiled, output_process_fn_grad=output_process_fn_grad, - exact_stride=exact_stride, ) @@ -6971,12 +6960,6 @@ def forward(self, arg0_1: "Sym(s77)", arg1_1: "Sym(s27)", arg2_1: "Sym(s53)", ar self.common(fn, (torch.randn(8),)) - def test_full_like_stride(self): - def fn(a): - return torch.full_like(a, 3) - - self.common(fn, (torch.randn(4, 5, 6).transpose(1, -1),), exact_stride=True) - def test_full_truncation(self): def fn(a): return a + torch.full_like(a, 7.777) diff --git a/test/test_decomp.py b/test/test_decomp.py index 53ef92dba61..07dcd8252c5 100644 --- a/test/test_decomp.py +++ b/test/test_decomp.py @@ -545,11 +545,6 @@ comprehensive_failures = { xfail( "nn.functional.upsample_bilinear", "", dtypes=(torch.uint8,) ), # off by one error - skip( - "nn.functional.nll_loss", - "", - dtypes=(torch.float64, torch.float32, torch.bfloat16, torch.float16), - ), # non-deterministic } @@ -866,16 +861,7 @@ def forward(self, scores_1, mask_1, value_1): assert len(real_out) == len(decomp_out) if do_relative_check: - device_arg = kwargs.get("device", None) - - def upcast(x): - if (isinstance(x, Tensor) and x.device.type == "mps") or ( - device_arg and torch.device(device_arg).type == "mps" - ): - return upcast_tensor(x, dtype=torch.float32) - else: - return upcast_tensor(x, dtype=torch.float64) - + upcast = partial(upcast_tensor, dtype=torch.float64) real_out_double, _ = tree_flatten( func(*tree_map(upcast, args), **tree_map(upcast, kwargs)) ) diff --git a/test/test_nestedtensor.py b/test/test_nestedtensor.py index f3ea420c814..f53268cb24d 100644 --- a/test/test_nestedtensor.py +++ b/test/test_nestedtensor.py @@ -8530,6 +8530,14 @@ BACKWARD_SKIPS_AND_XFAILS = [ COMPILE_FORWARD_SKIPS_AND_XFAILS = [ *FORWARD_SKIPS_AND_XFAILS, + # Needs investigation in AOTAutograd: len(unwrapped_args) == num_args_tallied assertion fails + # e.g. Expected 5 == 4 + XFailRule( + error_type=AssertionError, + op_match_fn=lambda device, op: (op.full_name == "fill"), + sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name), + name="fill_aot_autograd_bug_with_transposed_input", + ), # Bug: cross-device conversions with to() result in new nested ints within compile only XFailRule( error_type=AssertionError, @@ -8573,6 +8581,12 @@ COMPILE_FORWARD_SKIPS_AND_XFAILS = [ sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name), name="crazy_aot_autograd_bug1", ), + # Bug: also no idea what's going on here: needs investigation within AOTAutograd + XFailRule( + op_match_fn=lambda device, op: (op.full_name == "isreal"), + sample_match_fn=lambda device, sample: ("noncontig_transposed" in sample.name), + name="crazy_aot_autograd_bug2", + ), ] COMPILE_BACKWARD_SKIPS_AND_XFAILS = [ diff --git a/test/test_ops.py b/test/test_ops.py index c4b257ef138..0f079e5c45e 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -2294,6 +2294,7 @@ class TestRefsOpsInfo(TestCase): "_refs.empty_strided", "_refs.equal", "_refs.full", + "_refs.full_like", "_refs.is_complex", "_refs.to", "_refs.mvlgamma", @@ -2408,6 +2409,7 @@ class TestRefsOpsInfo(TestCase): "_refs.unflatten", "_refs.sum_to_size", # ref implementation missing kwargs + "_refs.full_like", # missing "layout" "_refs.scalar_tensor", # missing "layout" # other "_refs.block_diag", # only refs._block_diag_iterable is in decomposition table diff --git a/torch/_decomp/__init__.py b/torch/_decomp/__init__.py index 8f61fa15f9b..abb94b109cc 100644 --- a/torch/_decomp/__init__.py +++ b/torch/_decomp/__init__.py @@ -346,7 +346,6 @@ def _core_aten_decompositions_post_autograd() -> dict[ aten.floor_divide, aten.frac, aten.frac_, - aten.full_like, aten._fused_moving_avg_obs_fq_helper, aten.gelu_, aten.gelu_backward, diff --git a/torch/_inductor/decomposition.py b/torch/_inductor/decomposition.py index 3f75a7ab6a9..08c3abc9f23 100644 --- a/torch/_inductor/decomposition.py +++ b/torch/_inductor/decomposition.py @@ -625,6 +625,28 @@ def randn_like( ).to(memory_format=get_like_layout(self, memory_format)) +@register_decomposition(aten.full_like) +def full_like( + self: torch.Tensor, + fill_value: Union[int, float], + *, + dtype: Optional[torch.dtype] = None, + layout: Optional[torch.layout] = None, + device: Optional[torch.device] = None, + pin_memory: bool = False, + requires_grad: bool = False, + memory_format: torch.memory_format = torch.preserve_format, +) -> torch.Tensor: + return torch.full( + [*self.size()], + fill_value, + dtype=dtype or self.dtype, + layout=layout or self.layout, + device=device or self.device, + requires_grad=requires_grad, + ).to(memory_format=get_like_layout(self, memory_format)) + + @register_decomposition(aten.randint_like.default) def randint_like( self: torch.Tensor, diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index 600df3233e7..5db712372a1 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -3177,6 +3177,7 @@ def _full(fill_value, device, dtype, size): ) +@register_lowering(aten.full_like, type_promotion_kind=None) def full_like(x, fill_value, **kwargs): return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs) @@ -6120,17 +6121,6 @@ def fill_(x, fill_value): return mutate_to(x, full_like(x, fill_value)) -@register_lowering(prims.fill, type_promotion_kind=None) -def prims_fill(x, fill_value): - dtype = x.get_dtype() - return Pointwise.create( - device=x.get_device(), - dtype=dtype, - inner_fn=lambda _: ops.constant(fill_value, dtype), - ranges=list(x.get_size()), - ) - - @register_lowering(aten.copy_, type_promotion_kind=None) def copy_(dst, src, non_blocking=False): if dst is src: diff --git a/torch/_refs/__init__.py b/torch/_refs/__init__.py index d627aab5827..8fd234c8a0e 100644 --- a/torch/_refs/__init__.py +++ b/torch/_refs/__init__.py @@ -5588,26 +5588,9 @@ def full( pin_memory=pin_memory, requires_grad=requires_grad, ) - return prims.fill(e, fill_value) # type: ignore[arg-type] + return torch.fill(e, fill_value) # type: ignore[arg-type] -def _get_shape_permutation_like( - a: TensorLikeType, layout: torch.layout -) -> tuple[ShapeType, StrideType]: - assert layout == torch.strided - - physical_layout = utils.compute_elementwise_output_logical_to_physical_perm(a) - shape = [a.shape[l] for l in physical_layout] - - permutation = [0] * len(shape) - for p, l in enumerate(physical_layout): - permutation[l] = p - - return (shape, permutation) - - -@register_decomposition(aten.full_like) -@out_wrapper() def full_like( a: TensorLikeType, fill_value: NumberType, @@ -5619,36 +5602,16 @@ def full_like( requires_grad: bool = False, memory_format: torch.memory_format = torch.preserve_format, ) -> TensorLikeType: - dtype = a.dtype if dtype is None else dtype - layout = a.layout if layout is None else layout - device = a.device if device is None else device - - if memory_format != torch.preserve_format: - result = torch.full( - a.shape, - fill_value, - dtype=dtype, - layout=layout, - device=device, - pin_memory=pin_memory, - requires_grad=requires_grad, - ) - return result.to(memory_format=memory_format) - - else: - shape, permutation = _get_shape_permutation_like(a, layout) - result = torch.full( - shape, - fill_value, - dtype=dtype, - layout=layout, - device=device, - pin_memory=pin_memory, - requires_grad=requires_grad, - ) - if permutation == list(range(len(permutation))): - return result - return result.permute(permutation).clone() + e = torch.empty_like( + a, + dtype=dtype, + layout=layout, + device=device, + pin_memory=pin_memory, + requires_grad=requires_grad, + memory_format=memory_format, + ) + return fill(e, fill_value) @register_decomposition(aten.zeros_like) diff --git a/torch/testing/_internal/common_methods_invocations.py b/torch/testing/_internal/common_methods_invocations.py index 55c1780961c..daf42f4bba5 100644 --- a/torch/testing/_internal/common_methods_invocations.py +++ b/torch/testing/_internal/common_methods_invocations.py @@ -1923,7 +1923,7 @@ def sample_inputs_full_like(self, device, dtype, requires_grad, **kwargs): def get_val(dtype): return make_tensor([], dtype=dtype, device="cpu").item() - double_dtype = torch.double if torch.device(device).type != "mps" else torch.float + double_dtype = torch.double if device != "mps:0" else torch.float inputs = [ ((), get_val(dtype), {}), ((S, S), get_val(dtype), {}), @@ -24603,10 +24603,6 @@ python_ref_db = [ DecorateInfo(unittest.skip('output is non-deterministic'), 'TestCommon', 'test_compare_cpu'), ), ), - PythonRefInfo( - "_refs.full_like", - torch_opinfo_name="full_like", - ), PythonRefInfo( "_refs.randn", torch_opinfo_name="randn",