Fixes#123698
This PR makes TensorImpl::has_symbolic_sizes_strides return false for NestedTensors.
1. It passes in the actual sizes when we call `_make_wrapper_subclass` - this is the change that makes the subclass register as `has_symbolic_sizes_strides() == True`
2. It adds a field to `_make_wrapper_subclass` where an explicit `numel` can be provided. This allows us to skip the numel computation for the storage, which previously fails due to arithmetic on NestedInts.
3. Implements `aten::numel` for NJT - this is separate from the overridden numel in `make_wrapper_subclass` for now. Note also that this means that we leave `dispatch_sizes_strides_policy="sizes"`, so that we call into the custom `numel` implementation (as well as `sizes` and `strides`), because `numel` cannot currently be computed from `sizes` for NJT.
Note also that this depends on #121361, because calling TensorImpl::set_sizes_and_strides() tries to clone the sizes into the tensor, which means that we need `clone` to be implemented on NestedInt.
Differential Revision: [D57225736](https://our.internmc.facebook.com/intern/diff/D57225736)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124687
Approved by: https://github.com/albanD
Fix part of https://github.com/pytorch/pytorch/issues/123603.
Example traceback on branch https://github.com/pytorch/vision/compare/main...wwen/custom_ops_test:
```
running my_custom_op!
Traceback (most recent call last):
File "/data/users/williamwen/torchvision/playground.py", line 13, in <module>
print(opt_fn1(torch.randn(3, 3)))
File "/data/users/williamwen/pytorch2/torch/_dynamo/eval_frame.py", line 387, in _fn
return fn(*args, **kwargs)
File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 977, in catch_errors
return callback(frame, cache_entry, hooks, frame_state, skip=1)
File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 818, in _convert_frame
result = inner_convert(
File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 411, in _convert_frame_assert
return _compile(
File "/data/users/williamwen/pytorch2/torch/_utils_internal.py", line 70, in wrapper_function
return function(*args, **kwargs)
File "/data/users/williamwen/py310-env/lib/python3.10/contextlib.py", line 79, in inner
return func(*args, **kwds)
File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 700, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/data/users/williamwen/pytorch2/torch/_dynamo/utils.py", line 266, in time_wrapper
r = func(*args, **kwargs)
File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 568, in compile_inner
out_code = transform_code_object(code, transform)
File "/data/users/williamwen/pytorch2/torch/_dynamo/bytecode_transformation.py", line 1116, in transform_code_object
transformations(instructions, code_options)
File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 173, in _fn
return fn(*args, **kwargs)
File "/data/users/williamwen/pytorch2/torch/_dynamo/convert_frame.py", line 515, in transform
tracer.run()
File "/data/users/williamwen/pytorch2/torch/_dynamo/symbolic_convert.py", line 2237, in run
super().run()
File "/data/users/williamwen/pytorch2/torch/_dynamo/symbolic_convert.py", line 875, in run
while self.step():
File "/data/users/williamwen/pytorch2/torch/_dynamo/symbolic_convert.py", line 790, in step
self.dispatch_table[inst.opcode](self, inst)
File "/data/users/williamwen/pytorch2/torch/_dynamo/symbolic_convert.py", line 492, in wrapper
return inner_fn(self, inst)
File "/data/users/williamwen/pytorch2/torch/_dynamo/symbolic_convert.py", line 1260, in CALL_FUNCTION
self.call_function(fn, args, {})
File "/data/users/williamwen/pytorch2/torch/_dynamo/symbolic_convert.py", line 730, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/data/users/williamwen/pytorch2/torch/_dynamo/variables/torch.py", line 747, in call_function
tensor_variable = wrap_fx_proxy(
File "/data/users/williamwen/pytorch2/torch/_dynamo/variables/builder.py", line 1425, in wrap_fx_proxy
return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs)
File "/data/users/williamwen/pytorch2/torch/_dynamo/variables/builder.py", line 1510, in wrap_fx_proxy_cls
example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True)
File "/data/users/williamwen/pytorch2/torch/_dynamo/utils.py", line 1804, in get_fake_value
raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None
File "/data/users/williamwen/pytorch2/torch/_dynamo/utils.py", line 1736, in get_fake_value
ret_val = wrap_fake_exception(
File "/data/users/williamwen/pytorch2/torch/_dynamo/utils.py", line 1251, in wrap_fake_exception
return fn()
File "/data/users/williamwen/pytorch2/torch/_dynamo/utils.py", line 1737, in <lambda>
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
File "/data/users/williamwen/pytorch2/torch/_dynamo/utils.py", line 1872, in run_node
raise RuntimeError(make_error_message(e)).with_traceback(
File "/data/users/williamwen/pytorch2/torch/_dynamo/utils.py", line 1854, in run_node
return node.target(*args, **kwargs)
File "/data/users/williamwen/pytorch2/torch/_ops.py", line 870, in __call__
return self_._op(*args, **(kwargs or {}))
torch._dynamo.exc.TorchRuntimeError: Failed running call_function torchvision.my_custom_op1(*(FakeTensor(..., size=(3, 3)),), **{}):
The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.
If you're using torch.compile/export/fx, it is likely that we are erroneously tracing into a custom kernel. To fix this, please wrap the custom kernel into an opaque custom op. Please see the following for details: https://docs.google.com/document/d/1W--T6wz8IY8fOI0Vm8BF44PdBgs283QvpelJZWieQWQ
from user code:
File "/data/users/williamwen/torchvision/playground.py", line 5, in fn1
return torch.ops.torchvision.my_custom_op1(x)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/124240
Approved by: https://github.com/zou3519
Previously, we'd just check `has_symbolic_sizes_strides()` to know whether a tensor has symbolic sizes or strides; if does, we skip some profiler logic. But sometimes `has_symbolic_sizes_strides()` returns false, but we do actually have symbolic sizes or strides.
So in this change, we add `may_have_symbolic_sizes_strides()` - which should never return false if the tensor has symbolic sizes and strides
Why not change `has_symbolic_sizes_strides()`? It seems like there's preexisting logic that assumes that "if has_symbolic_sizes_strides(), then we can assume that this tensor is guaranteed to have symbolic sizes or strides". In this case, we have python-implemented sizes or strides, which should follow a different code path.
Differential Revision: [D55947660](https://our.internmc.facebook.com/intern/diff/D55947660/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123696
Approved by: https://github.com/aaronenyeshi, https://github.com/soulitzer
Fixes#115331.
This PR increases the number of valid GPU devices to 512 (from 64) in order to future-proof PyTorch for providers that offer [single nodes with a large device count](https://www.tensorwave.com/). Until now, `DeviceIndex` was an `int8_t`, thus multiple changes were necessary:
- `DeviceIndex` changed to `int16_t`. Updated consumers that assume it to be an `int8_t`.
- Updated bounds checking for `torch.device()` in the Python frontend. Right now, we allow funny things like `torch.device('cpu', 200).index == -56`, which is undefined behavior. I inserted some checks to only allow values between 0 and `c10::Device::MAX_NUM_DEVICES - 1`.
- Updated the `ArgumentInfo` struct as it hardcodes the device index as 8 bit field [^1]. Might be a breaking change, not sure if users rely on this.
- Introduced `c10::Device::MAX_NUM_DEVICES` as a replacement for the old `C10_COMPILE_TIME_MAX_GPUS`
[^1]: This field was unsigned, so I guess this has also been undef behavior the whole time? Our default device index is -1, so this always wrapped around to 255 when written to the `ArgumentInfo` struct. When I switched the `DeviceIndex` to `int16_t`, it actually stayed 255 after unpacking from `ArgumentInfo` again, as the `DeviceIndex` was now wide enough that it didn't wrap back to -1.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119639
Approved by: https://github.com/cyyever, https://github.com/albanD, https://github.com/huydhn
Fixes#115331.
This PR increases the number of valid GPU devices to 512 (from 64) in order to future-proof PyTorch for providers that offer [single nodes with a large device count](https://www.tensorwave.com/). Until now, `DeviceIndex` was an `int8_t`, thus multiple changes were necessary:
- `DeviceIndex` changed to `int16_t`. Updated consumers that assume it to be an `int8_t`.
- Updated bounds checking for `torch.device()` in the Python frontend. Right now, we allow funny things like `torch.device('cpu', 200).index == -56`, which is undefined behavior. I inserted some checks to only allow values between 0 and `c10::Device::MAX_NUM_DEVICES - 1`.
- Updated the `ArgumentInfo` struct as it hardcodes the device index as 8 bit field [^1]. Might be a breaking change, not sure if users rely on this.
- Introduced `c10::Device::MAX_NUM_DEVICES` as a replacement for the old `C10_COMPILE_TIME_MAX_GPUS`
[^1]: This field was unsigned, so I guess this has also been undef behavior the whole time? Our default device index is -1, so this always wrapped around to 255 when written to the `ArgumentInfo` struct. When I switched the `DeviceIndex` to `int16_t`, it actually stayed 255 after unpacking from `ArgumentInfo` again, as the `DeviceIndex` was now wide enough that it didn't wrap back to -1.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119639
Approved by: https://github.com/cyyever, https://github.com/albanD
We found a scenario where ``tensor.device().is_privateuseone()`` is used to determine whether a tensor is privateuse1 but fails.
In the code of ``Autograd``, for example:
```
::std::tuple<at::Tensor,at::Tensor,at::Tensor> native_batch_norm(c10::DispatchKeySet ks, const at::Tensor & input, const c10::optional<at::Tensor> & weight, const c10::optional<at::Tensor> & bias, const c10::optional<at::Tensor> & running_mean, const c10::optional<at::Tensor> & running_var, bool training, double momentum, double eps) {
auto& input_ = unpack(input, "input", 0);
[[maybe_unused]] auto _any_requires_grad = compute_requires_grad( input, weight, bias );
[[maybe_unused]] auto _any_has_forward_grad_result0 = (isFwGradDefined(input) || isFwGradDefined(weight) || isFwGradDefined(bias));
check_no_requires_grad(running_mean, "running_mean", "native_batch_norm");
check_no_requires_grad(running_var, "running_var", "native_batch_norm");
std::shared_ptr<NativeBatchNormBackward0> grad_fn;
if (_any_requires_grad) {
grad_fn = std::shared_ptr<NativeBatchNormBackward0>(new NativeBatchNormBackward0(), deleteNode);
grad_fn->set_next_edges(collect_next_edges( input, weight, bias ));
grad_fn->eps = eps;
grad_fn->input_ = SavedVariable(input, false);
grad_fn->running_mean_ = SavedVariable(running_mean, false);
grad_fn->running_var_ = SavedVariable(running_var, false);
grad_fn->training = training;
grad_fn->weight_ = SavedVariable(weight, false);
}
...
}
```
When ``weight`` is ``None``, an empty tensor is automatically generated and will be transferred to the backward calculation:
c7e12c7427/torch/csrc/autograd/saved_variable.cpp (L121-L128)
At the beginning of the backward calculation in our scenario, we need to determine whether the input tensor is ``PrivateUse1`` . However, if we use ``tensor.device().is_privateuseone()``, we will get an error ``"tensor does not have a device"``:
c7e12c7427/c10/core/TensorImpl.h (L1223-L1235)
I think this part of the code can be optimized, what do you think?
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113421
Approved by: https://github.com/albanD
Currently whenever the sizes or strides are modified for a `TensorImpl` we
eagerly recompute the numel and memory format flags. This is fine for static
shapes as it's all fast C++ code, but for symbolic shapes it runs slow python code.
This instead changes the `SymbolicShapeMeta` object to compute the derived
quantities lazily at the first request. This has the added benefit that we can
now pass assumptions in `empty_tensor_restride` which remove the need to compute
some contiguity flags at all.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/112785
Approved by: https://github.com/ezyang
ghstack dependencies: #112689, #112890
In cudagraph trees, we invalidate tensors at some point and drop their storage. Then, when they are accessed with .data_ptr(), a custom error message is thrown. Previously, this invalidation didn't also make untyped_storage()/storage() error which could result in a segfault.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109750
Approved by: https://github.com/zou3519
Compiler behavior when non-zero offset is added to a null pointer is undefined and is a bad habit.
- When `lapackEig` is called with to estimate a workspace size, do not add matrix size to the W pointer.
- When `unpack_pivots_cpu_kernel` with zero `dim_size` exit early.
- When `topk_impl_loop` is called with `k` is zero, exit right away as output tensors are empty anyway.
- Ignore adding non-zero storage-offset in `TensorImpl::data_ptr_impl_impl`, which can be the case if tensor is created as `torch.empty(3)[4:]`.
- In `s_addmm_out_sparse_dense_worker` do not call `axpy` over an empty vector.
- In `_sparse_binary_op_intersection_kernel_impl` do skip computing `ptr_indices_dim` when `sparse_dim` is empty.
- Exit `grid_sample` forward/backward kernels earlier if either `input` or `grid` are empty tensors.
Found by asan in clang-12
Before the change UBSan report looks as follows:
```
ASAN_SYMBOLIZER_PATH=/usr/lib/llvm-12/bin/llvm-symbolizer UBSAN_OPTIONS=print_stacktrace=1 LD_PRELOAD=/usr/lib/llvm-12/lib/clang/12.0.1/lib/linux/libclang_rt.asan-x86_64.so python test_fx_experimental.py -v -k test_normalize_operator_exhaustive_linalg_eig_cpu_float32
Test results will be stored in test-reports/python-unittest/test_fx_experimental
Running tests...
----------------------------------------------------------------------
test_normalize_operator_exhaustive_linalg_eig_cpu_float32 (__main__.TestNormalizeOperatorsCPU) ... /opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/overrides.py:111: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'
torch.has_cuda,
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/overrides.py:112: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'
torch.has_cudnn,
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/overrides.py:118: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
torch.has_mps,
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/overrides.py:119: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'
torch.has_mkldnn,
/var/lib/jenkins/workspace/aten/src/ATen/native/BatchLinearAlgebra.cpp:937:17: runtime error: applying non-zero offset 20 to null pointer
#0 0x7f2025794888 in void at::native::lapackEig<float, float>(char, char, int, float*, int, float*, float*, int, float*, int, float*, int, float*, int*) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x9945888)
#1 0x7f20257da256 in void at::native::(anonymous namespace)::apply_linalg_eig<float>(at::Tensor&, at::Tensor&, at::Tensor&, at::Tensor&, bool) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x998b256)
#2 0x7f20257d902d in at::native::(anonymous namespace)::linalg_eig_kernel(at::Tensor&, at::Tensor&, at::Tensor&, at::Tensor const&, bool) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x998a02d)
#3 0x7f20257b5b3d in at::native::linalg_eig_out_info(at::Tensor const&, at::Tensor&, at::Tensor&, at::Tensor&, bool) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x9966b3d)
#4 0x7f20257b4770 in at::native::linalg_eig_out(at::Tensor const&, at::Tensor&, at::Tensor&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x9965770)
#5 0x7f20280710e6 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::tuple<at::Tensor&, at::Tensor&> (at::Tensor const&, at::Tensor&, at::Tensor&), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CPU_out_linalg_eig_out(at::Tensor const&, at::Tensor&, at::Tensor&))>, std::tuple<at::Tensor&, at::Tensor&>, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor&, at::Tensor&> >, std::tuple<at::Tensor&, at::Tensor&> (at::Tensor const&, at::Tensor&, at::Tensor&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor&, at::Tensor&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0xc2220e6)
#6 0x7f202727a045 in at::_ops::linalg_eig_out::call(at::Tensor const&, at::Tensor&, at::Tensor&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0xb42b045)
#7 0x7f20257b7e29 in at::native::linalg_eig(at::Tensor const&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x9968e29)
#8 0x7f2028070bf0 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::tuple<at::Tensor, at::Tensor> (at::Tensor const&), &(at::(anonymous namespace)::(anonymous namespace)::wrapper_CPU__linalg_eig(at::Tensor const&))>, std::tuple<at::Tensor, at::Tensor>, c10::guts::typelist::typelist<at::Tensor const&> >, std::tuple<at::Tensor, at::Tensor> (at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0xc221bf0)
#9 0x7f2026b1f787 in std::tuple<at::Tensor, at::Tensor> c10::Dispatcher::redispatch<std::tuple<at::Tensor, at::Tensor>, at::Tensor const&>(c10::TypedOperatorHandle<std::tuple<at::Tensor, at::Tensor> (at::Tensor const&)> const&, c10::DispatchKeySet, at::Tensor const&) const (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0xacd0787)
#10 0x7f20273230a7 in at::_ops::linalg_eig::redispatch(c10::DispatchKeySet, at::Tensor const&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0xb4d40a7)
#11 0x7f202c3cc32d in torch::autograd::VariableType::(anonymous namespace)::linalg_eig(c10::DispatchKeySet, at::Tensor const&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x1057d32d)
#12 0x7f202c3cba96 in c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<std::tuple<at::Tensor, at::Tensor> (c10::DispatchKeySet, at::Tensor const&), &(torch::autograd::VariableType::(anonymous namespace)::linalg_eig(c10::DispatchKeySet, at::Tensor const&))>, std::tuple<at::Tensor, at::Tensor>, c10::guts::typelist::typelist<c10::DispatchKeySet, at::Tensor const&> >, std::tuple<at::Tensor, at::Tensor> (c10::DispatchKeySet, at::Tensor const&)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0x1057ca96)
#13 0x7f20272798e0 in at::_ops::linalg_eig::call(at::Tensor const&) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so+0xb42a8e0)
#14 0x7f2043d97ae3 in torch::autograd::THPVariable_linalg_eig(_object*, _object*, _object*) (/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/lib/libtorch_python.so+0x23feae3)
#15 0x5072d6 in cfunction_call /usr/local/src/conda/python-3.9.17/Objects/methodobject.c:543:19
...
SUMMARY: UndefinedBehaviorSanitizer: undefined-behavior /var/lib/jenkins/workspace/aten/src/ATen/native/BatchLinearAlgebra.cpp:937:17 in
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106354
Approved by: https://github.com/huydhn, https://github.com/lezcano