mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fixes #123698 This PR makes TensorImpl::has_symbolic_sizes_strides return false for NestedTensors. 1. It passes in the actual sizes when we call `_make_wrapper_subclass` - this is the change that makes the subclass register as `has_symbolic_sizes_strides() == True` 2. It adds a field to `_make_wrapper_subclass` where an explicit `numel` can be provided. This allows us to skip the numel computation for the storage, which previously fails due to arithmetic on NestedInts. 3. Implements `aten::numel` for NJT - this is separate from the overridden numel in `make_wrapper_subclass` for now. Note also that this means that we leave `dispatch_sizes_strides_policy="sizes"`, so that we call into the custom `numel` implementation (as well as `sizes` and `strides`), because `numel` cannot currently be computed from `sizes` for NJT. Note also that this depends on #121361, because calling TensorImpl::set_sizes_and_strides() tries to clone the sizes into the tensor, which means that we need `clone` to be implemented on NestedInt. Differential Revision: [D57225736](https://our.internmc.facebook.com/intern/diff/D57225736) Pull Request resolved: https://github.com/pytorch/pytorch/pull/124687 Approved by: https://github.com/albanD |
||
|---|---|---|
| .. | ||
| benchmark | ||
| core | ||
| cuda | ||
| hip | ||
| macros | ||
| mobile | ||
| test | ||
| util | ||
| xpu | ||
| BUCK.oss | ||
| BUILD.bazel | ||
| build.bzl | ||
| CMakeLists.txt | ||
| ovrsource_defs.bzl | ||