Previously when codegening ops like `zeros_` or `ones_` we'd hit a `Code below assumes there is at least one tensor arg error`. This check is not entirely correct which is what is causing the error to be thrown. There are ops like the ones mentioned that pass in a `device` parameter that can be used in place of the "first tensor".
CC: @wconstab @desertfire @henrytwo @ke1337
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76917
Approved by: https://github.com/desertfire
Partially fix#69813
This PR does mainly 3 things:
1. Introduces new methods for the `MetaBase` API:
- `set_output_strided`: creates proxy tensors with exact strides, if strides don't match
- `set_output_contiguous`: alias for `set_output_strided` with contiguous strides
- `set_output_raw_strided`: does not create proxy tensors
2. Modifies codegen for handling proxy tensors:
- Creates a new field for out-of-place kernels: `proxy_output_`
- Implements `set_output_strided` by creating a proxy tensor if necessary
- Passes the proxy tensor to them `IMPL` function
- Copy the result back to the real output, in the end, whenever a proxy was created
3. Replace `set_output` by `set_output_raw_strided` for `TensorIterator*`
- Needed, since it overrides `set_output`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76096
Approved by: https://github.com/ezyang
Unfortunately the built-in pprint module support pretty-print of dataclasses only from python 3.10. The code that I wrote in method `__str__` of OpInfo should do the same job and should also work for any dataclass. For now I've put it there but we can create a function and put it somewhere where is accessible also for other dataclasses. Also the max width (80) is now hardcode but it would ideally be the parameter of the function.
when you call print on an OpInfo you get:
```
OpInfo(name = '__getitem__',
ref = None,
aliases = (),
variant_test_name = '',
op = <slot wrapper '__getitem__' of 'torch._C._TensorBase' objects>,
method_variant = <slot wrapper '__getitem__' of 'torch._C._TensorBase' objects>,
inplace_variant = None,
skips = (<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbca90>,
<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbcae0>),
decorators = (<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbca90>,
<torch.testing._internal.common_methods_invocations.DecorateInfo object at 0x7f463acbcae0>),
sample_inputs_func = <function sample_inputs_getitem at 0x7f463acc6af0>,
reference_inputs_func = None,
error_inputs_func = None,
sample_inputs_sparse_coo_func = <function _DecoratorContextManager.__call__.<locals>.decorate_context at 0x7f463acc6b80>,
sample_inputs_sparse_csr_func = <function _DecoratorContextManager.__call__.<locals>.decorate_context at 0x7f463acc6c10>,
dtypes = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
dtypesIfCUDA = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
dtypesIfROCM = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypes = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypesIfCUDA = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
backward_dtypesIfROCM = {torch.int16,
torch.float64,
torch.int32,
torch.int64,
torch.complex64,
torch.float16,
torch.bfloat16,
torch.uint8,
torch.complex128,
torch.bool,
torch.float32,
torch.int8},
supports_out = False,
supports_autograd = True,
supports_gradgrad = True,
supports_fwgrad_bwgrad = True,
supports_inplace_autograd = False,
supports_forward_ad = True,
gradcheck_wrapper = <function OpInfo.<lambda> at 0x7f463a7a40d0>,
check_batched_grad = True,
check_batched_gradgrad = True,
check_batched_forward_grad = True,
check_inplace_batched_forward_grad = True,
gradcheck_nondet_tol = 0.0,
gradcheck_fast_mode = None,
aten_name = '__getitem__',
decomp_aten_name = None,
aten_backward_name = None,
assert_autodiffed = False,
autodiff_nonfusible_nodes = ['aten::__getitem__'],
autodiff_fusible_nodes = [],
supports_sparse = False,
supports_scripting = False,
supports_sparse_csr = False,
test_conjugated_samples = True,
test_neg_view = True,
assert_jit_shape_analysis = False,
supports_expanded_weight = False)
```
cc @ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76810
Approved by: https://github.com/ezyang
Summary: Currently OpKind is stored as an object field called op_ for each IR
node, and one usage of op_ is to avoid dynamic_cast in NodeCast when we
need to downcast a base-node pointer into a concrete sub-node pointer.
As a result, we need to construct and pass in an op when downcasting
nodes, and this becomes quite anonnying when we start to implement the
trie-based IR node reusing. More importantly, the op for each subclass
should be unique for that subclass and thus making it a const static field
is a more logical design.
In this PR, we still keep the object-level op_ for easier XLA adoption. As
furture work, we can come back to remove op_, make the op() method
virtual, and get rid of OpKind in all the node constructors.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76711
Approved by: https://github.com/wconstab, https://github.com/JackCaoG
This PR turns the previously introduced `ITensorList` into a more general `IList`
class. It is a container wrapper for arbitrary types (given their appropriate
implementations).
In summary, I have:
- Renamed `ITensorList` (its iterators and macros, for consistency) to `IList`
- Made `IList` a templated function (for an arbitrary type `T`), given that they:
- Specialize `IListTagImpl<T, Tag>`, for all `IListTag`
- Introduced type aliases (for both list and iterator types):
- `at::ITensorList` -> `c10::IList<at::Tensor>`
- `at::IOptTensorRefList` -> `c10::IList<at::OptionalTensorRef>`
- Added support for `Tensor?[]` in the structured codegen
Pull Request resolved: https://github.com/pytorch/pytorch/pull/69606
Approved by: https://github.com/ezyang
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76203
Request for comments:
This change adds extra code generator support to generate out variant wrappers for operators with unstructured kernels.
The current version generates 105 new out variant wrappers in addition to the existing 136 auto-generated out variants wrappers.
This change shows that a simple tweak can increase the generated op coverage to 16% (241/1559) among all native ops described in native_functions.yaml no. matter if they are structured or not.
Command to generate out variant wrappers.
```
buck run //caffe2/torch/fb/jit:gen_static_runtime_ops
```
- AFTER this change
```
total grouped native ops: 1559
structured grouped native ops: 545
generated grouped native ops: 241
```
- BEFORE this change
```
total grouped native ops: 1503
structured grouped native ops: 540
generated grouped native ops: 136
```
To enable CI tests and make it easier to review, the generated ops are added in a separate diff: D35945633
More details:
We added a block list to remove the generation of around 10 operations that are deprecated or for which the unit test would fail. All generated ops are well *compiled* but the compiled unittest may not pass due to the lack of hand-picked test input values for certain ops. Among the 42 ops whose unittest does not pass, 1 (op "index_select") is repeated from the existing ops; 32 ops are fixed; and 9 ops are removed and blocked from generation because either it is not being commonly used in internal models such as "cholesky", "linalg_householder_product", sparse kernel "sspaddmm", or it causes some errors in static runtime such as "conj_physical" leads to an error in memory planner, and "binary_cross_entropy".
Test Plan:
OP generation:
```buck run //caffe2/torch/fb/jit:gen_static_runtime_ops```
Test generated ops:
```buck run mode/opt //caffe2/benchmarks/static_runtime:static_runtime_cpptest```
Reviewed By: tenpercent
Differential Revision: D34913736
fbshipit-source-id: a6f408321653c3589ae1c76826177fc403d59c44
(cherry picked from commit 6f4501730478dbaeeea7f3ad4f9d29bf6787e7c1)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76471
Make node_base_ctor_call produce the entire node_bace_ctor_call.
Previously it was only producing the beginning of the call, which was unintended.
Addresses part of https://github.com/pytorch/xla/issues/3472
Test Plan: Imported from OSS
Reviewed By: qihqi, ngimel
Differential Revision: D35980436
Pulled By: wconstab
fbshipit-source-id: a443cf593ac7c35b2b65e72b82907e88e1e71c7a
(cherry picked from commit 360ad6d82a7e8303b8a60e61b177dabf0131ea8b)
This **roughly** corresponds to Goal 3.2 in https://docs.google.com/document/d/1iiLNwR5ohAsw_ymfnOpDsyF6L9RTUaHMpD8YLw-jxEw/edit#
Namely:
It adds the following:
* SymbolicIntNode interface
* LazySymbolicIntNode implementation
* Lazy `narrow_copy` implementation
* Need add support for SymInt in codegen
* Test (below)
```cpp
TEST(LazyDynamicOpsTest, NarrowCopy) {
auto x = torch::rand({5, 10, 10}).to(kLazy);
const size_t Y_DIM = 3;
const size_t X_DIM_INDEX = 2;
auto y = torch::rand({Y_DIM}).to(kLazy);
auto ly = torch::lazy::TryGetLtcTensor(y);
auto dim_node = MakeNode<SizeNode>(ly->GetIrValue(), 0);
auto lmn = new torch::lazy::SymbolicIntNode(dim_node);
auto z = x.narrow_copy(X_DIM_INDEX, 0, lmn->toSymInt());
AllClose(z.cpu(), x.cpu().narrow_copy(X_DIM_INDEX, 0, Y_DIM));
}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/75759
Approved by: https://github.com/wconstab