Commit Graph

113 Commits

Author SHA1 Message Date
PyTorch MergeBot
c88c0e6c65 Revert "[Dynamo] Handle torch function subclass/mode dispatch on generic tensor methods (#137119)"
This reverts commit d255b34c0a.

Reverted https://github.com/pytorch/pytorch/pull/137119 on behalf of https://github.com/malfet due to Need to revert to be able to revert https://github.com/pytorch/pytorch/pull/136910 ([comment](https://github.com/pytorch/pytorch/pull/137119#issuecomment-2400401262))
2024-10-08 17:09:26 +00:00
Michael Lazos
d255b34c0a [Dynamo] Handle torch function subclass/mode dispatch on generic tensor methods (#137119)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/137119
Approved by: https://github.com/williamwen42
ghstack dependencies: #137114, #137115, #137116, #137117, #137120, #137227
2024-10-07 18:55:26 +00:00
Michael Lazos
041960a1ce [Dynamo] Automatically in-graph traceable tensor subclass ctors (#135151)
Fixes https://github.com/pytorch/pytorch/issues/114389

Previously, dynamo would attempt to trace through the `__init__` of traceable tensor subclasses, since their constructors are AOT dispatcher traceable by definition, dynamo should automatically put these in the graph like we do for any other tensors. Not doing this is difficult because dynamo would need to apply mutations post tensor subclass creation in the graph.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135151
Approved by: https://github.com/bdhirsh
2024-09-06 12:23:38 +00:00
Avik Chaudhuri
43f4947d44 fix fake tensor tolist implementation (#135131)
Summary:
When exporting for training with `tolist`, we do not hit `FunctionalTensor.tolist` since we do not functionalize. Unfortunately, this means we hit `FakeTensor.tolist`, which creates unbacked symints that are not backed by proxies.

Rather than trying to patch up this low-level implementation, we replace it with essentially what `FunctionalTensor.tolist` does, which is higher-level: we essentially desugar to `item()` calls and let it take care of unbacked symints.

Test Plan:
Some expected failures are gone now.
Also found a test for `tolist` that was written when `FunctionalTensor.tolist` was implemented but not really doing much; repurposed it now to exercise more modes.

Differential Revision: D62197742

Pull Request resolved: https://github.com/pytorch/pytorch/pull/135131
Approved by: https://github.com/ezyang
2024-09-05 23:20:31 +00:00
Xuehai Pan
5229b52bf2 [dynamo] support cls.__base__ (#133969)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133969
Approved by: https://github.com/jansel
2024-08-20 20:03:31 +00:00
soulitzer
4af4910b1a Reland "Construct NJT without graph breaks" (#133196)
This reverts commit 154d40ca488e6979ce9c2de89d8a35b53129ebea.

and adds changes from https://github.com/pytorch/pytorch/pull/133061

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133196
Approved by: https://github.com/ezyang
ghstack dependencies: #133145
2024-08-14 01:11:13 +00:00
soulitzer
05de2b2d0f Revert "Construct NJT without graph breaks" (#133145)
This reverts commit 911154271309667b55dfb963ec6384bd0048019b.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133145
Approved by: https://github.com/YuqingJ
2024-08-10 03:11:16 +00:00
Zhengxu Chen
942ffd1b2d Make the __module__ name of HOO to be always "torch.ops.higher_order" (#132775)
Summary: It seems that we can just make this the default so that in the future all the ops printed in the graph should be like torch.ops.higher_order

Test Plan: CI

Differential Revision: D60530900

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132775
Approved by: https://github.com/ydwu4, https://github.com/zou3519
2024-08-08 16:55:09 +00:00
Edward Z. Yang
aec6332356 Only thunkify proxies in some situations (#132421)
The goal of this PR is to avoid stack overflow when we create extremely long chains of thunks, and then evaluate them (e.g., as occurs if you sum(long list of symint)). The basic idea behind this PR is to only thunkify proxies if they're being created in places where they may or may not be used--crucially, symint operations that occur in user code we are tracing are eagerly placed into the graph, even if they may eventually be dead.

I annotated the PR with explanation of changes.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132421
Approved by: https://github.com/Skylion007, https://github.com/zou3519
ghstack dependencies: #132674, #132675
2024-08-08 12:03:06 +00:00
angelayi
a270800f0b [export][reland] Add print_readable to unflattened module (#132817)
Reland https://github.com/pytorch/pytorch/pull/128617

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132817
Approved by: https://github.com/pianpwk
2024-08-08 06:05:30 +00:00
soulitzer
f50621989b Construct NJT without graph breaks (#130292)
Combines contributions from https://github.com/pytorch/pytorch/pull/130505

Some context can be found in this large comment block:

a5b64d39fd/test/dynamo/test_subclasses.py (L1667-L1681)

Changes in this PR
- For each tensor fakified, check the nested int registry in eager, and eagerly symbolicize if that tensor has already been associated with nested int in eager.
- Adds a separate counter stored on FakeTensorMode as a fake analog to _tensor_id_counter (which keeps track of unique tensors). This counter is initialized to the global eager tensor id counter upon creation of the FakeTensorMode, and needs to be reset when the same FakeTensorMode is reused to trace again (in this PR, we piggyback on the epoch incrementing logic).
- (refactor) Today, we store FakeTensor -> symbolic nested int in the global registry. With this PR, symbolic nested int is stored directly on the FakeTensor. (Eager still caches nested int in the registry, though we should avoid this at some point.)

Basically unchanged, but worth noting:
- `__tensor_unflatten__` is still responsible for determining whether we should cache for now. The logic is somewhat simplified.
- to_copy is still using the trick of updating two different tensors in the registry to point to the same nested int. This is kind of broken, but we try to leave it as is, and plan a better fix with the UnionFind stack.

Differential Revision: [D60406772](https://our.internmc.facebook.com/intern/diff/D60406772)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130292
Approved by: https://github.com/bdhirsh
ghstack dependencies: #131916, #131803
2024-08-06 17:03:39 +00:00
soulitzer
406b50835b Use FakeTensor cache for subclass inner tensors (#131803)
Rewrite of original PR in https://github.com/pytorch/pytorch/pull/130291

To answer review comments from https://github.com/pytorch/pytorch/pull/130291#pullrequestreview-2166671953:

> At a higher level, do we need this?

Today, this should not change the behavior of anything. But an invariant of "same tensor always corresponds to the same FakeTensor" is nice (from discussion with @bdhirsh).

> Why does this happen?

Today, both dynamo and meta_utils do some recursion when it comes to FakeTensors. So whenever we fakify a subclass, the process would roughly like:

```
wrap_to_fake (subclass)
   meta_utils (subclass)
      meta_utils (values) -> not cached because we use callback
      meta_utils(offsets) -> not cached because we use callback
  wrap_to_fake (values)
  wrap_to_fake (offsets) -> cached because we rely on top-level meta_utils
```

However, we know that:
- Caching only occurs at the top-level of meta_utils.
- The return value of the top-level wrap_to_fake is returned.

This means that after all of this:
- The fakified subclass holds inner FakeTensors that are NOT part of the cache
- values/offsets are Fakified a second time, and those instances are cached.

Differential Revision: [D60406773](https://our.internmc.facebook.com/intern/diff/D60406773)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131803
Approved by: https://github.com/ezyang
ghstack dependencies: #131916
2024-08-06 17:03:39 +00:00
soulitzer
a94c441e48 Fix symbolic nested int printing (#131916)
Differential Revision: [D60406775](https://our.internmc.facebook.com/intern/diff/D60406775)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/131916
Approved by: https://github.com/Skylion007, https://github.com/jbschlosser
2024-08-06 17:03:39 +00:00
David Berard
5973aec671 [fx] python_code(verbose=True): show size/strides for all tensors (#132192)
python_code(verbose=True) (or print_readable()) generates a string with the code representing the fx graph, with extra annotations indicating the size or stride of the tensor. Currently, it'll only shows sizes/strides for FakeTensors provided in metadata. For subclass tensors like NestedTensor, the outer class (provided in the node metadata) will be a non-FakeTensor and the inner tensors will be fake. This PR expands the conditional to show sizes/strides for all tensors, not just FakeTensors.

Testing: I ran this test script (below), ran it with `TORCH_LOGS=+dynamo` and found in the logs the graph shown below - we see that the input nested tensor has sizes and strides associated with it. Also, I stacked a diff on top of this one that forces the readable graph to be generated whenever PT2 is in use in tests, which should hopefully find any issues; https://github.com/pytorch/pytorch/pull/132195 shows no significant failures except for preexisting failures.

test script:
```python
import torch

def fn(x):
    return x.cos()

nt = torch.nested.nested_tensor_from_jagged(
    torch.randn(10, 10),
    torch.tensor([0, 1, 3, 6, 10]),
)

torch.compile(fn)(nt)
```

logs excerpt:
```
[0/0] [__graph_code] TRACED GRAPH
[0/0] [__graph_code]  ===== __compiled_fn_1 =====
[0/0] [__graph_code]  /data/users/dberard/pytorch/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.M

[0/0] [__graph_code]     def forward(self, L_x_: "f32[4, zf1, 10][10*zf1, 10, 1]cpu", zf1: "Sym(zf1)"):
[0/0] [__graph_code]         l_x_ = L_x_
[0/0] [__graph_code]
[0/0] [__graph_code]          # File: /data/users/dberard/scripts/nt_print_graph.py:4 in fn, code: return x.c

[0/0] [__graph_code]         cos: "f32[4, zf1, 10][10*zf1, 10, 1]cpu" = l_x_.cos();  l_x_ = None
[0/0] [__graph_code]         return (cos,)
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/132192
Approved by: https://github.com/Chillee
2024-08-03 02:54:32 +00:00
PyTorch MergeBot
3855ac5a5d Revert "[export] Add print_readable to unflattener (#128617)"
This reverts commit ab9791c0e3.

Reverted https://github.com/pytorch/pytorch/pull/128617 on behalf of https://github.com/angelayi due to never got landed internally due to weird flow... sorry ([comment](https://github.com/pytorch/pytorch/pull/128617#issuecomment-2264224466))
2024-08-01 23:47:29 +00:00
Oguz Ulgen
920f0426ae Add None return type to init -- tests rest (#132376)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/132376
Approved by: https://github.com/jamesjwu
ghstack dependencies: #132335, #132351, #132352
2024-08-01 15:44:51 +00:00
angelayi
ab9791c0e3 [export] Add print_readable to unflattener (#128617)
Taking inspiration from `GraphModule.print_readable` (aka I copied its [code](17b45e905a/torch/fx/graph_module.py (L824))), I added a `print_readable` to the unflattened module, because it's kind of nontrivial to print the contents of this module.

Example print from `python test/export/test_unflatten.py -k test_unflatten_nested`
```
class UnflattenedModule(torch.nn.Module):
    def forward(self, x: "f32[2, 3]"):
        # No stacktrace found for following nodes
        rootparam: "f32[2, 3]" = self.rootparam

        # File: /data/users/angelayi/pytorch2/test/export/test_unflatten.py:99 in forward, code: x = x * self.rootparam
        mul: "f32[2, 3]" = torch.ops.aten.mul.Tensor(x, rootparam);  x = rootparam = None

        # No stacktrace found for following nodes
        foo: "f32[2, 3]" = self.foo(mul);  mul = None
        bar: "f32[2, 3]" = self.bar(foo);  foo = None
        return (bar,)

    class foo(torch.nn.Module):
        def forward(self, mul: "f32[2, 3]"):
            # No stacktrace found for following nodes
            child1param: "f32[2, 3]" = self.child1param
            nested: "f32[2, 3]" = self.nested(mul);  mul = None

            # File: /data/users/angelayi/pytorch2/test/export/test_unflatten.py:79 in forward, code: return x + self.child1param
            add: "f32[2, 3]" = torch.ops.aten.add.Tensor(nested, child1param);  nested = child1param = None
            return add

        class nested(torch.nn.Module):
            def forward(self, mul: "f32[2, 3]"):
                # File: /data/users/angelayi/pytorch2/test/export/test_unflatten.py:67 in forward, code: return x / x
                div: "f32[2, 3]" = torch.ops.aten.div.Tensor(mul, mul);  mul = None
                return div

    class bar(torch.nn.Module):
        def forward(self, add: "f32[2, 3]"):
            # No stacktrace found for following nodes
            child2buffer: "f32[2, 3]" = self.child2buffer

            # File: /data/users/angelayi/pytorch2/test/export/test_unflatten.py:87 in forward, code: return x - self.child2buffer
            sub: "f32[2, 3]" = torch.ops.aten.sub.Tensor(add, child2buffer);  add = child2buffer = None
            return sub
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128617
Approved by: https://github.com/zhxchen17, https://github.com/pianpwk
2024-07-30 00:41:44 +00:00
Xuehai Pan
918ece4f4d [BE][Easy][11/19] enforce style for empty lines in import segments in test/dy*/ (#129762)
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter.

You can review these PRs via:

```bash
git diff --ignore-all-space --ignore-blank-lines HEAD~1
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129762
Approved by: https://github.com/anijain2305
2024-07-27 17:43:53 +00:00
PyTorch MergeBot
d6ae8bbf16 Revert "[export] Add print_readable to unflattener (#128617)"
This reverts commit 9fee87e4cd.

Reverted https://github.com/pytorch/pytorch/pull/128617 on behalf of https://github.com/clee2000 due to broke inductor/test_flex_attention https://github.com/pytorch/pytorch/actions/runs/9984688318/job/27595182606 433ef4e444 Not run on PR due to bad TD ([comment](https://github.com/pytorch/pytorch/pull/128617#issuecomment-2236867975))
2024-07-18 15:31:51 +00:00
Michael Lazos
cf3f4285a8 Add recursive metadata guard test (#131002)
Ensures that nested tensors subclasses are guarded properly. It turns out this case is already handled [here](d77af49380/torch/_dynamo/variables/builder.py (L1496)) which will recursively wrap inner tensors adding metadata guards for them.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/131002
Approved by: https://github.com/bdhirsh
2024-07-18 08:24:43 +00:00
angelayi
9fee87e4cd [export] Add print_readable to unflattener (#128617)
Taking inspiration from `GraphModule.print_readable` (aka I copied its [code](17b45e905a/torch/fx/graph_module.py (L824))), I added a `print_readable` to the unflattened module, because it's kind of nontrivial to print the contents of this module.

Example print from `python test/export/test_unflatten.py -k test_unflatten_nested`
```
class UnflattenedModule(torch.nn.Module):
    def forward(self, x: "f32[2, 3]"):
        # No stacktrace found for following nodes
        rootparam: "f32[2, 3]" = self.rootparam

        # File: /data/users/angelayi/pytorch2/test/export/test_unflatten.py:99 in forward, code: x = x * self.rootparam
        mul: "f32[2, 3]" = torch.ops.aten.mul.Tensor(x, rootparam);  x = rootparam = None

        # No stacktrace found for following nodes
        foo: "f32[2, 3]" = self.foo(mul);  mul = None
        bar: "f32[2, 3]" = self.bar(foo);  foo = None
        return (bar,)

    class foo(torch.nn.Module):
        def forward(self, mul: "f32[2, 3]"):
            # No stacktrace found for following nodes
            child1param: "f32[2, 3]" = self.child1param
            nested: "f32[2, 3]" = self.nested(mul);  mul = None

            # File: /data/users/angelayi/pytorch2/test/export/test_unflatten.py:79 in forward, code: return x + self.child1param
            add: "f32[2, 3]" = torch.ops.aten.add.Tensor(nested, child1param);  nested = child1param = None
            return add

        class nested(torch.nn.Module):
            def forward(self, mul: "f32[2, 3]"):
                # File: /data/users/angelayi/pytorch2/test/export/test_unflatten.py:67 in forward, code: return x / x
                div: "f32[2, 3]" = torch.ops.aten.div.Tensor(mul, mul);  mul = None
                return div

    class bar(torch.nn.Module):
        def forward(self, add: "f32[2, 3]"):
            # No stacktrace found for following nodes
            child2buffer: "f32[2, 3]" = self.child2buffer

            # File: /data/users/angelayi/pytorch2/test/export/test_unflatten.py:87 in forward, code: return x - self.child2buffer
            sub: "f32[2, 3]" = torch.ops.aten.sub.Tensor(add, child2buffer);  add = child2buffer = None
            return sub
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128617
Approved by: https://github.com/zhxchen17, https://github.com/pianpwk
2024-07-18 01:36:01 +00:00
Michael Lazos
470f07c840 Add guard override capability for tensor subclass metadata (#130780)
Fixes https://github.com/pytorch/pytorch/issues/114405

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130780
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
ghstack dependencies: #130779
2024-07-17 19:13:53 +00:00
Michael Lazos
bea6762c01 Add guards on subclass metadata (#130779)
This PR adds guards in dynamo which verify the equality of tensor subclass metadata along with tests verifying the expected recompile behavior. The next PR adds the capability to override the guard behavior to possibly perform the check in a less expensive manner.

Toward fixing https://github.com/pytorch/pytorch/issues/114405

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130779
Approved by: https://github.com/anijain2305, https://github.com/bdhirsh
2024-07-17 19:13:52 +00:00
Michael Lazos
415d5e53ae Propagate buffer and parameter indices through AOT (#130393)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130393
Approved by: https://github.com/bdhirsh
2024-07-16 22:12:38 +00:00
PyTorch MergeBot
cbda8be537 Revert "Propagate buffer and parameter indices through AOT (#130393)"
This reverts commit 69a77389e2.

Reverted https://github.com/pytorch/pytorch/pull/130393 on behalf of https://github.com/clee2000 due to broke lint for torch/_functorch/_aot_autograd/subclass_utils.py https://github.com/pytorch/pytorch/actions/runs/9948630877/job/27483551649 80236dca90 lint was green on PR, probably a landrace ([comment](https://github.com/pytorch/pytorch/pull/130393#issuecomment-2231263753))
2024-07-16 15:43:34 +00:00
Michael Lazos
69a77389e2 Propagate buffer and parameter indices through AOT (#130393)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130393
Approved by: https://github.com/bdhirsh
ghstack dependencies: #130391, #130392, #130503
2024-07-16 00:25:38 +00:00
PyTorch MergeBot
dff9d68f18 Revert "Fix names conflict when lifting (#129817)"
This reverts commit 53cf46b8c6.

Reverted https://github.com/pytorch/pytorch/pull/129817 on behalf of https://github.com/clee2000 due to Failing inductor/test_flex_attention.py https://github.com/pytorch/pytorch/actions/runs/9940532858/job/27478084137 74da2a467f Sorry for the churn, possibly a landrace? ([comment](https://github.com/pytorch/pytorch/pull/129817#issuecomment-2229519886))
2024-07-15 22:08:45 +00:00
Zhanghan Wang
53cf46b8c6 Fix names conflict when lifting (#129817)
## Bug description
When pending args that are potentially to be lift [here](58f346c874/torch/_dynamo/output_graph.py (L1866)) having same base name, like `contiguous` and `contiguous_1`, the call into [create_graph_input](58f346c874/torch/_dynamo/output_graph.py (L2081)) can finally create a name ([here](58f346c874/torch/fx/graph.py (L1008))) that overwrite args to lift. And thus causing a wrong output of graph.

## Reproducing
Below is an reproduceable example,
```python
import logging
from typing import List

import torch
from functorch.compile import aot_module_simplified, make_boxed_func

@torch.library.custom_op("mylib::somefunc_forward", mutates_args=())
def somefunc_forward(
    input_: torch.Tensor,
    weight: torch.Tensor,
    shape: List[int],
) -> torch.Tensor:
    return torch.ones_like(input_)

@somefunc_forward.register_fake
def _(input_, shape, weight):
    return torch.empty_like(input_)

@torch.library.custom_op("mylib::somefunc_backward", mutates_args=())
def somefunc_backward(
    grad_output: torch.Tensor,
    input_: torch.Tensor,
    weight: torch.Tensor,
    shape: List[int],
) -> torch.Tensor:
    print(f"backward.{grad_output.shape=}")
    print(f"backward.{input_.shape=}")
    print(f"backward.{weight.shape=}")
    print(f"backward.{shape=}")
    assert list(weight.shape) == shape
    return torch.ones_like(weight)

@somefunc_backward.register_fake
def _(grad_output, input_, weight, shape):
    return torch.empty_like(weight)

def a_func(grad_output, input_, weight_, shape):
    return torch.ones_like(input_.sum() * weight_)

class SomeFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, normalized_shape):
        ctx.normalized_shape = normalized_shape
        input_ = input.contiguous()
        weight_ = weight.contiguous()
        output = somefunc_forward(input_, weight_, ctx.normalized_shape)
        ctx.save_for_backward(input_, weight_)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_, weight_ = ctx.saved_tensors
        # grad_weight = a_func(grad_output, input_, weight_, ctx.normalized_shape)
        grad_weight = somefunc_backward(
            grad_output.contiguous(),
            input_,
            weight_,
            ctx.normalized_shape,
        )
        return None, grad_weight, None

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(7))

    def forward(self, x):
        return SomeFunc.apply(x, self.weight, [7])

model = MyModel()
torch._logging.set_logs(dynamo=logging.DEBUG, aot=logging.DEBUG, graph_code=True)

def aot_print_backend(gm, sample_inputs):
    # Forward compiler capture
    def fw(gm, sample_inputs):
        print(f"----- fw")
        gm.print_readable()
        return make_boxed_func(gm.forward)

    # Backward compiler capture
    def bw(gm, sample_inputs):
        print(f"----- bw")
        gm.print_readable()
        return make_boxed_func(gm.forward)

    # Call AOTAutograd
    gm_forward = aot_module_simplified(
        gm, sample_inputs, fw_compiler=fw, bw_compiler=bw
    )
    return gm_forward

model = torch.compile(
    model,
    backend=aot_print_backend,
    dynamic=False,
)
out = model(torch.rand((128, 4, 7)))
out.mean().backward()
```

I can see log that showing calling into create_graph_input like
```log
V0629 02:08:46.839914 8200981504 torch/_dynamo/output_graph.py:2042] [0/0] create_graph_input contiguous (none)
V0629 02:08:46.839998 8200981504 torch/_dynamo/output_graph.py:2042] [0/0] create_graph_input contiguous_1 (none)
```

And the backward graph generate will be like
```log
class GraphModule(torch.nn.Module):
    def forward(self, function_ctx, somefunc_forward_default: "f32[128, 4, 7]", contiguous: "f32[128, 4, 7]", contiguous_1: "f32[7]"):
        contiguous_1 = contiguous
        contiguous_2 = contiguous_1

        # No stacktrace found for following nodes
        _set_grad_enabled = torch._C._set_grad_enabled(False)

         # File: /Users/bytedance/testtorch/test_custom_op_bug.py:61 in backward, code: grad_output.contiguous(),
        contiguous: "f32[128, 4, 7]" = somefunc_forward_default.contiguous();  somefunc_forward_default = None

         # File: /opt/tiger/pytorch/torch/_library/custom_ops.py:506 in __call__, code: return self._opoverload(*args, **kwargs)
        somefunc_backward_default: "f32[7]" = torch.ops.mylib.somefunc_backward.default(contiguous, contiguous_1, contiguous_2, [7]);  contiguous = contiguous_1 = contiguous_2 = None

        # No stacktrace found for following nodes
        _set_grad_enabled_1 = torch._C._set_grad_enabled(True)
        return (None, somefunc_backward_default)
```

The original code of `somefunc_backward` takes a input list of `grad_output`, `input_`, `weight` and `shape`, where `weight` should be shape of `torch.Size([7])`. However, in the graph, `contiguous1` and `contiguous_2` are assigned with `contiguous`, this leads to assertion failure I added in `somefunc_backward`.

## Environment
```log
Collecting environment information...
PyTorch version: 2.5.0a0+git0b7e8df
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.5 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.26.4
Libc version: N/A

Python version: 3.9.19 (main, May  6 2024, 14:39:30)  [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3 Pro

Versions of relevant libraries:
[pip3] numpy==2.0.0
[pip3] optree==0.11.0
[pip3] torch==2.5.0a0+git0b7e8df
[pip3] torchgraph==0.0.1
[conda] numpy                     2.0.0                    pypi_0    pypi
[conda] optree                    0.11.0                   pypi_0    pypi
[conda] torch                     2.5.0a0+git0b7e8df           dev_0    <develop>
[conda] torchgraph                0.0.1                     dev_0    <develop>
```

## How to fix?

I put a naive fix that add the potential args to lift into the used_names. This visits private variables, will fix that if this issue makes sense to you.

@zou3519 @oulgen

Co-authored-by: rzou <zou3519@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129817
Approved by: https://github.com/zou3519
2024-07-15 18:49:12 +00:00
PyTorch MergeBot
1e897a0ca4 Revert "Fix names conflict when lifting (#129817)"
This reverts commit 74da2a467f.

Reverted https://github.com/pytorch/pytorch/pull/129817 on behalf of https://github.com/clee2000 due to broke dynamo/test_inline_inbuilt_nn_modules.py https://github.com/pytorch/pytorch/actions/runs/9940532858/job/27461141919 74da2a467f.  Test passed on PR, possibly a landrace? ([comment](https://github.com/pytorch/pytorch/pull/129817#issuecomment-2228993570))
2024-07-15 17:09:52 +00:00
Zhanghan Wang
74da2a467f Fix names conflict when lifting (#129817)
## Bug description
When pending args that are potentially to be lift [here](58f346c874/torch/_dynamo/output_graph.py (L1866)) having same base name, like `contiguous` and `contiguous_1`, the call into [create_graph_input](58f346c874/torch/_dynamo/output_graph.py (L2081)) can finally create a name ([here](58f346c874/torch/fx/graph.py (L1008))) that overwrite args to lift. And thus causing a wrong output of graph.

## Reproducing
Below is an reproduceable example,
```python
import logging
from typing import List

import torch
from functorch.compile import aot_module_simplified, make_boxed_func

@torch.library.custom_op("mylib::somefunc_forward", mutates_args=())
def somefunc_forward(
    input_: torch.Tensor,
    weight: torch.Tensor,
    shape: List[int],
) -> torch.Tensor:
    return torch.ones_like(input_)

@somefunc_forward.register_fake
def _(input_, shape, weight):
    return torch.empty_like(input_)

@torch.library.custom_op("mylib::somefunc_backward", mutates_args=())
def somefunc_backward(
    grad_output: torch.Tensor,
    input_: torch.Tensor,
    weight: torch.Tensor,
    shape: List[int],
) -> torch.Tensor:
    print(f"backward.{grad_output.shape=}")
    print(f"backward.{input_.shape=}")
    print(f"backward.{weight.shape=}")
    print(f"backward.{shape=}")
    assert list(weight.shape) == shape
    return torch.ones_like(weight)

@somefunc_backward.register_fake
def _(grad_output, input_, weight, shape):
    return torch.empty_like(weight)

def a_func(grad_output, input_, weight_, shape):
    return torch.ones_like(input_.sum() * weight_)

class SomeFunc(torch.autograd.Function):
    @staticmethod
    def forward(ctx, input, weight, normalized_shape):
        ctx.normalized_shape = normalized_shape
        input_ = input.contiguous()
        weight_ = weight.contiguous()
        output = somefunc_forward(input_, weight_, ctx.normalized_shape)
        ctx.save_for_backward(input_, weight_)
        return output

    @staticmethod
    def backward(ctx, grad_output):
        input_, weight_ = ctx.saved_tensors
        # grad_weight = a_func(grad_output, input_, weight_, ctx.normalized_shape)
        grad_weight = somefunc_backward(
            grad_output.contiguous(),
            input_,
            weight_,
            ctx.normalized_shape,
        )
        return None, grad_weight, None

class MyModel(torch.nn.Module):
    def __init__(self):
        super().__init__()
        self.weight = torch.nn.Parameter(torch.ones(7))

    def forward(self, x):
        return SomeFunc.apply(x, self.weight, [7])

model = MyModel()
torch._logging.set_logs(dynamo=logging.DEBUG, aot=logging.DEBUG, graph_code=True)

def aot_print_backend(gm, sample_inputs):
    # Forward compiler capture
    def fw(gm, sample_inputs):
        print(f"----- fw")
        gm.print_readable()
        return make_boxed_func(gm.forward)

    # Backward compiler capture
    def bw(gm, sample_inputs):
        print(f"----- bw")
        gm.print_readable()
        return make_boxed_func(gm.forward)

    # Call AOTAutograd
    gm_forward = aot_module_simplified(
        gm, sample_inputs, fw_compiler=fw, bw_compiler=bw
    )
    return gm_forward

model = torch.compile(
    model,
    backend=aot_print_backend,
    dynamic=False,
)
out = model(torch.rand((128, 4, 7)))
out.mean().backward()
```

I can see log that showing calling into create_graph_input like
```log
V0629 02:08:46.839914 8200981504 torch/_dynamo/output_graph.py:2042] [0/0] create_graph_input contiguous (none)
V0629 02:08:46.839998 8200981504 torch/_dynamo/output_graph.py:2042] [0/0] create_graph_input contiguous_1 (none)
```

And the backward graph generate will be like
```log
class GraphModule(torch.nn.Module):
    def forward(self, function_ctx, somefunc_forward_default: "f32[128, 4, 7]", contiguous: "f32[128, 4, 7]", contiguous_1: "f32[7]"):
        contiguous_1 = contiguous
        contiguous_2 = contiguous_1

        # No stacktrace found for following nodes
        _set_grad_enabled = torch._C._set_grad_enabled(False)

         # File: /Users/bytedance/testtorch/test_custom_op_bug.py:61 in backward, code: grad_output.contiguous(),
        contiguous: "f32[128, 4, 7]" = somefunc_forward_default.contiguous();  somefunc_forward_default = None

         # File: /opt/tiger/pytorch/torch/_library/custom_ops.py:506 in __call__, code: return self._opoverload(*args, **kwargs)
        somefunc_backward_default: "f32[7]" = torch.ops.mylib.somefunc_backward.default(contiguous, contiguous_1, contiguous_2, [7]);  contiguous = contiguous_1 = contiguous_2 = None

        # No stacktrace found for following nodes
        _set_grad_enabled_1 = torch._C._set_grad_enabled(True)
        return (None, somefunc_backward_default)
```

The original code of `somefunc_backward` takes a input list of `grad_output`, `input_`, `weight` and `shape`, where `weight` should be shape of `torch.Size([7])`. However, in the graph, `contiguous1` and `contiguous_2` are assigned with `contiguous`, this leads to assertion failure I added in `somefunc_backward`.

## Environment
```log
Collecting environment information...
PyTorch version: 2.5.0a0+git0b7e8df
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.5 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.26.4
Libc version: N/A

Python version: 3.9.19 (main, May  6 2024, 14:39:30)  [Clang 14.0.6 ] (64-bit runtime)
Python platform: macOS-14.5-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M3 Pro

Versions of relevant libraries:
[pip3] numpy==2.0.0
[pip3] optree==0.11.0
[pip3] torch==2.5.0a0+git0b7e8df
[pip3] torchgraph==0.0.1
[conda] numpy                     2.0.0                    pypi_0    pypi
[conda] optree                    0.11.0                   pypi_0    pypi
[conda] torch                     2.5.0a0+git0b7e8df           dev_0    <develop>
[conda] torchgraph                0.0.1                     dev_0    <develop>
```

## How to fix?

I put a naive fix that add the potential args to lift into the used_names. This visits private variables, will fix that if this issue makes sense to you.

@zou3519 @oulgen

Co-authored-by: rzou <zou3519@gmail.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129817
Approved by: https://github.com/zou3519
2024-07-15 13:41:46 +00:00
Animesh Jain
7ea8a3c9b8 [dynamo] Validate check_fn (#118448)
Fixes - https://github.com/pytorch/pytorch/issues/128090

Tracker issue here - https://github.com/pytorch/pytorch/issues/129937

Pull Request resolved: https://github.com/pytorch/pytorch/pull/118448
Approved by: https://github.com/jansel, https://github.com/ezyang
2024-07-05 18:04:12 +00:00
Joel Schlosser
6897631ceb Guard on inner tensor names for traceable wrapper subclasses (#129618)
Fixes #129601

Background: it's possible that a traceable wrapper subclass will have an optional inner tensor constituent (e.g. NJT's cached min / max sequence lengths). To specify this, the subclass's `__tensor_flatten__()` impl should leave out any unspecified optional inner tensors in the returned list of `attrs`.

This PR guards on the list of inner tensor `attrs` returned in `subclass.__tensor_flatten__()[0]`.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129618
Approved by: https://github.com/anijain2305
2024-06-28 16:30:25 +00:00
Chien-Lin Chen
5e7ac69a67 [Dynamic Shapes] fixed dynamic shape inference (#128807)
Made dynamic dimension indirectly bound to an integer constrained.
After each ShapeEnv._refine_ranges, check if the new ValueRange is singleton, if it is, replace the symbol.

Fixes #122307

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128807
Approved by: https://github.com/ezyang
2024-06-27 22:33:32 +00:00
Brian Hirsh
880e894c39 [Brian's PR #128981] fix dynamo isinstance inlining for nn.Parameter + subclasses (#129162)
This is a copy of Brian's PR https://github.com/pytorch/pytorch/pull/128981, with very small changes to work around numpy related errors.

For discussions, please see Brian's original PR.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129162
Approved by: https://github.com/bdhirsh
2024-06-21 03:48:10 +00:00
Joel Schlosser
31d5753247 Short-term fix to preserve NJT metadata cache in torch.compile (#122836)
Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122836
Approved by: https://github.com/soulitzer
2024-06-20 23:15:53 +00:00
rzou
7178b4e987 [Dynamo x torch_function] fix incorrect source (#128980)
Fixes https://github.com/pytorch/pytorch/issues/128964

The problem was that we were installing the source for a type
incorrectly.

Test Plan:
- new tests

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128980
Approved by: https://github.com/mlazos
2024-06-20 14:54:00 +00:00
PyTorch MergeBot
b0d2fe6299 Revert "Short-term fix to preserve NJT metadata cache in torch.compile (#122836)"
This reverts commit 2a41fc0390.

Reverted https://github.com/pytorch/pytorch/pull/122836 on behalf of https://github.com/jbschlosser due to internal test failures with DEBUG=1 asserts ([comment](https://github.com/pytorch/pytorch/pull/122836#issuecomment-2177298245))
2024-06-19 00:28:53 +00:00
Joel Schlosser
2a41fc0390 Short-term fix to preserve NJT metadata cache in torch.compile (#122836)
Idea: close over min / max sequence length in the main NJT view func (`_nested_view_from_jagged`) so that view replay during fake-ification propagates these correctly in torch.compile.

For dynamic shapes support for min / max sequence length, this PR uses a hack that stores the values in `(val, 0)` shaped tensors.

**NB: This PR changes SDPA to operate on real views instead of using `buffer_from_jagged()` / `ViewNestedFromBuffer`, which may impact the internal FIRST model. That is, it undoes the partial revert from #123215 alongside a fix to the problem that required the partial revert. We need to verify that there are no regressions there before landing.**

Differential Revision: [D55448636](https://our.internmc.facebook.com/intern/diff/D55448636)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/122836
Approved by: https://github.com/soulitzer
ghstack dependencies: #127007, #128057
2024-06-17 15:25:09 +00:00
David Berard
f0d68120f4 [subclasses] Handle dynamo inputs that are subclass views with (-1) in the view (#128662)
When handling an input to dynamo that's a view of a subclass, dynamo does some handling to reconstruct the view. Part of this is to construct symints for the input parameters to the view.

Previously, the code would just call `create_symbol()` which by default specifies a _positive_ symint (>= 0); this fails in the case where you have an aten::view that was called with a -1.

Fix: just specify `positive=None` when calling `create_symbol()`, to avoid restricting the symint to >= 0 or <= 0.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128662
Approved by: https://github.com/jbschlosser
2024-06-15 14:58:18 +00:00
PyTorch MergeBot
5efe71f134 Revert "[export] Add print_readable to unflattener (#128617)"
This reverts commit 5d9a609b4f.

Reverted https://github.com/pytorch/pytorch/pull/128617 on behalf of https://github.com/huydhn due to Sorry for reverting your change but another failed test shows up in trunk inductor/test_flex_attention.py where it needs to be updated 5d9a609b4f.  I guess it is easier to revert and reland this ([comment](https://github.com/pytorch/pytorch/pull/128617#issuecomment-2169030779))
2024-06-15 01:46:23 +00:00
angelayi
5d9a609b4f [export] Add print_readable to unflattener (#128617)
Taking inspiration from `GraphModule.print_readable` (aka I copied its [code](17b45e905a/torch/fx/graph_module.py (L824))), I added a `print_readable` to the unflattened module, because it's kind of nontrivial to print the contents of this module.

Example print from `python test/export/test_unflatten.py -k test_unflatten_nested`
```
class UnflattenedModule(torch.nn.Module):
    def forward(self, x: "f32[2, 3]"):
        # No stacktrace found for following nodes
        rootparam: "f32[2, 3]" = self.rootparam

        # File: /data/users/angelayi/pytorch2/test/export/test_unflatten.py:99 in forward, code: x = x * self.rootparam
        mul: "f32[2, 3]" = torch.ops.aten.mul.Tensor(x, rootparam);  x = rootparam = None

        # No stacktrace found for following nodes
        foo: "f32[2, 3]" = self.foo(mul);  mul = None
        bar: "f32[2, 3]" = self.bar(foo);  foo = None
        return (bar,)

    class foo(torch.nn.Module):
        def forward(self, mul: "f32[2, 3]"):
            # No stacktrace found for following nodes
            child1param: "f32[2, 3]" = self.child1param
            nested: "f32[2, 3]" = self.nested(mul);  mul = None

            # File: /data/users/angelayi/pytorch2/test/export/test_unflatten.py:79 in forward, code: return x + self.child1param
            add: "f32[2, 3]" = torch.ops.aten.add.Tensor(nested, child1param);  nested = child1param = None
            return add

        class nested(torch.nn.Module):
            def forward(self, mul: "f32[2, 3]"):
                # File: /data/users/angelayi/pytorch2/test/export/test_unflatten.py:67 in forward, code: return x / x
                div: "f32[2, 3]" = torch.ops.aten.div.Tensor(mul, mul);  mul = None
                return div

    class bar(torch.nn.Module):
        def forward(self, add: "f32[2, 3]"):
            # No stacktrace found for following nodes
            child2buffer: "f32[2, 3]" = self.child2buffer

            # File: /data/users/angelayi/pytorch2/test/export/test_unflatten.py:87 in forward, code: return x - self.child2buffer
            sub: "f32[2, 3]" = torch.ops.aten.sub.Tensor(add, child2buffer);  add = child2buffer = None
            return sub
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128617
Approved by: https://github.com/zhxchen17, https://github.com/pianpwk
2024-06-15 00:26:04 +00:00
David Berard
f48f7615dc [easy][subclasses] dynamo.reset() in test_subclass_views (#128659)
When we don't dynamo.reset(), we don't recompile on different dynamic shapes.

Also, some of the returned views were tuples - so when we `* 2`, we actually just copy all the inputs twice in the tuple. I changed it so that it would just return one of the values from the return tuple.

Additionally, this exposes a bug that fails with the slice operation, so I skipped it when we're testing with dynamic shapes:

```
  File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3996, in produce_guards
    sexpr = ShapeGuardPrinter(symbol_to_source, source_ref, self.var_to_sources).doprint(expr)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 292, in doprint
    return self._str(self._print(expr))
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
    return printmethod(expr, **kwargs)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 56, in _print_Add
    t = self._print(term)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
    return printmethod(expr, **kwargs)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 366, in _print_Mul
    a_str = [self.parenthesize(x, prec, strict=False) for x in a]
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 366, in <listcomp>
    a_str = [self.parenthesize(x, prec, strict=False) for x in a]
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/str.py", line 37, in parenthesize
    return self._print(item)
  File "/home/dberard/local/miniconda3/envs/pytorch/lib/python3.10/site-packages/sympy/printing/printer.py", line 331, in _print
    return printmethod(expr, **kwargs)
  File "/home/dberard/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1494, in _print_Symbol
    assert self.symbol_to_source.get(expr), (
AssertionError: s3 (could be from ['<ephemeral: symint_visitor_fn>', '<ephemeral: symint_visitor_fn>']) not in {s0: ["L['x'].a.size()[1]", "L['x'].b.size()[1]", "L['x'].size()[1]", "L['x'].a.size()[1]", "L['x'].b.size()[1]", "L['x'].a.size()[1]", "L['x'].b.size()[1]"], s1: ["L['x'].a.stride()[0]", "L['x'].b.stride()[0]", "L['x'].stride()[0]", "L['x'].a.stride()[0]", "L['x'].b.stride()[0]", "L['x'].a.stride()[0]", "L['x'].b.stride()[0]"], s2: ["L['x'].a.storage_offset()", "L['x'].b.storage_offset()", "L['x'].a.storage_offset()", "L['x'].b.storage_offset()"]}.  If this assert is failing, it could be due to the issue described in https://github.com/pytorch/pytorch/pull/90665
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/128659
Approved by: https://github.com/YuqingJ
2024-06-14 05:18:07 +00:00
Michael Lazos
2129903aa3 Properly detect nested torch function args (#127496)
Dynamo was not detecting nested torch function classes in containers. This was due to pytree compatibility for variable trackers being removed.
Fixes https://github.com/pytorch/pytorch/issues/127174

Pull Request resolved: https://github.com/pytorch/pytorch/pull/127496
Approved by: https://github.com/anijain2305
2024-06-02 03:43:22 +00:00
Edward Z. Yang
9b91c91e64 Don't add to replacements when guard is suppressed (#126210)
Also improve logging when guards are suppressed

Partially addresses https://github.com/pytorch/pytorch/issues/125641

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/126210
Approved by: https://github.com/jbschlosser
2024-05-23 20:10:29 +00:00
Edward Z. Yang
ecd62746e3 Also pull size/stride info from example_value (#125505)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/125505
Approved by: https://github.com/jansel
2024-05-05 22:27:46 +00:00
Edward Z. Yang
c511aed27f [Meta Tensor] fix meta inplace set storage (#123880)
Fixes #123879

Co-authored-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123880
Approved by: https://github.com/ezyang
2024-05-01 06:53:49 +00:00
PyTorch MergeBot
ae13c7e593 Revert "[Meta Tensor] fix meta inplace set storage (#123880)"
This reverts commit cccae93551.

Reverted https://github.com/pytorch/pytorch/pull/123880 on behalf of https://github.com/izaitsevfb due to breaks cpu_inductor_torchbench (detectron2_fasterrcnn) ([comment](https://github.com/pytorch/pytorch/pull/123880#issuecomment-2083366385))
2024-04-29 18:19:42 +00:00
Yuanjing Shi
cccae93551 [Meta Tensor] fix meta inplace set storage (#123880)
Fixes #123879

Co-authored-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123880
Approved by: https://github.com/ezyang
2024-04-28 17:01:12 +00:00
soulitzer
cf5ca58e7f [NJT] Inline through torch.nested.nested_tensor_from_jagged instead of graph break (#124343)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124343
Approved by: https://github.com/jbschlosser
2024-04-19 23:13:59 +00:00
PyTorch MergeBot
4a0900d04b Revert "[NJT] Inline through torch.nested.nested_tensor_from_jagged instead of graph break (#124343)"
This reverts commit ef93402f61.

Reverted https://github.com/pytorch/pytorch/pull/124343 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](https://github.com/pytorch/pytorch/pull/124343#issuecomment-2064937192))
2024-04-18 18:55:48 +00:00