pytorch/torch/nested
David Berard d548417d95 [NJT] throw an exception if nested_tensor_from_jagged is fx-traced without being fx.wrapped (#130702)
The NJT constructor can't be fx-traced safely due to the dummy nt used:

774ca93fd2/torch/nested/_internal/nested_tensor.py (L501-L508)

The error doesn't appear immediately, but appears if you try to move a module with an fx-traced NJT constructor onto a different device, or try to serialize it. Let's throw an error if we try to fx-trace the NJT constructor so users know to wrap the call.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/130702
Approved by: https://github.com/jbschlosser, https://github.com/soulitzer
2024-07-16 19:21:10 +00:00
..
_internal [Nested Tensor][easy] Add softmax backward support (#130602) 2024-07-16 00:07:42 +00:00
__init__.py [NJT] throw an exception if nested_tensor_from_jagged is fx-traced without being fx.wrapped (#130702) 2024-07-16 19:21:10 +00:00