pytorch/torch/fx
Wei-Sheng Chin 86b7aa26f0 Fix FakeTensorProp on Module with Parameters or Buffers (#88700)
In `FakeTensorMode.__torch_dispatch__`, the output is now always computed by meta kernels in
```python
        try:
            with in_kernel_invocation_manager(self):
                r = func(*args, **kwargs)  # <----- "r" can be a real tensor.
        except NotImplementedError as not_implemented_error:
            # no meta kernel registered, fallback to kernel for the device
            if not self.allow_fallback_kernels:
                raise not_implemented_error
            return run_fallback_kernel(self, func, args, kwargs, not_implemented_error)

        return self.wrap_meta_outputs_with_default_device_logic(r, func, args, kwargs)
```
For example, I observed a CPU tensor is generated when executing `aten.addmm` when running `FakeTensorProp`. Therefore, I'd like to allow `FakeTensorMode` to wrap real tensor as `FakeTensor` during the computation. Does this PR look a good direction to fix this problem? If yes, I can go ahead and add some tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/88700
Approved by: https://github.com/eellison, https://github.com/ezyang
2022-11-11 03:49:29 +00:00
..
experimental Symbolic shape: sym_floor , sym_sqrt, sym_int (#88760) 2022-11-10 23:41:33 +00:00
passes Fix FakeTensorProp on Module with Parameters or Buffers (#88700) 2022-11-11 03:49:29 +00:00
__init__.py Refactor FX codegen into extensible Codegen object (#72566) 2022-02-11 18:13:29 +00:00
__init__.pyi
_compatibility.py
_pytree.py
_symbolic_trace.py [torch.fx.wrap] Use callable / function.__name__ instead of function.__code__.co_name (#84373) 2022-09-09 05:44:29 +00:00
annotate.py
graph_module.py [fx] Fix GraphModule.print_readable() (#88730) 2022-11-09 21:39:48 +00:00
graph.py Handle case when candidate is empty (#88359) 2022-11-05 17:19:40 +00:00
immutable_collections.py Add __all__ to torch.{fx, distributed, backends} submodules (#85079) 2022-09-20 12:51:08 +00:00
interpreter.py prepare removal of deprecated functionality in torch.testing (#87969) 2022-11-02 14:04:48 +00:00
node.py propagate .meta info when replacing subgraphs in fx (#87255) 2022-11-02 14:36:46 +00:00
operator_schemas.py Add __all__ to torch.{autograd, fx, cuda} submodules (#85343) 2022-10-09 14:46:54 +00:00
OVERVIEW.md prepare removal of deprecated functionality in torch.testing (#87969) 2022-11-02 14:04:48 +00:00
proxy.py better error message fix (#86422) 2022-10-08 00:06:05 +00:00
subgraph_rewriter.py PatternMatcher supports matching list-typed args (#88656) 2022-11-08 21:05:18 +00:00
tensor_type.py Improve getitem syntax for TensorType (#84555) 2022-09-06 18:36:24 +00:00
traceback.py Preserve stack trace for backward nodes over AOTAutograd (#83558) 2022-08-18 22:13:04 +00:00