Historically, we work out `size_hint` by working it out on the fly by doing a substitution on the sympy expression with the `var_to_val` mapping. With this change, we also maintain the hint directly on SymNode (in `expr._hint`) and use it in lieu of Sympy substitution when it is available (mostly guards on SymInt, etc; in particular, in idiomatic Inductor code, we typically manipulate Sympy expressions directly and so do not have a way to conveniently maintain hints.)
While it's possible this will give us modest performance improvements, this is not the point of this PR; the goal is to make it easier to carefully handle unbacked SymInts, where hints are expected not to be available. You can now easily test if a SymInt is backed or not by checking `symint.node.hint is None`.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94201
Approved by: https://github.com/voznesenskym
Supports the following with dynamic shapes:
```python
for element in tensor:
# do stuff with element
```
Approach follows what's done when `call_range()` is invoked with dynamic shape inputs: guard on tensor size and continue tracing with a real size value from `dyn_dim0_size.evaluate_expr()`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94326
Approved by: https://github.com/ezyang
**Problem**: For a tensor `x`, you can assign `x.my_attr = 3.14` and then later access it. Dynamo does not support this right now; it errors out with an AttributError (it was broken in #91840).
**Fix**: This fixes the problem by catching AttributeErrors in dynamo if we try to access an attr that does not exist on a standard torch.Tensor.
**Tests**: Added tests for accessing and setting attributes to make sure dynamo does not error out.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94332
Approved by: https://github.com/yanboliang
**Background:** Before this PR, support in dynamo for tensor attributes (e.g. `x.H`, `x.T`, ...) need to be individually implemented one-by-one. This could potentially lead to errors, e.g. if the implementation in [variables/tensor.py](21c7c7c72f/torch/_dynamo/variables/tensor.py (L160)) differs from the implementation from a direct call to the attribute. For attributes that were not special-cased in tensor.py, dynamo tracing would fail. This PR adds generic support for tensor attributes that return tensors without needing to specially handle them. (Notably, for x.real and x.imag, which previously weren't supported).
**In this PR:** This directly creates a proxy node for a `"call_function"` node with `target=getattr`, and feeds it into wrap_fx_proxy. This will produce a TensorVariable for the attribute returned.
This also removes the implementations for H, T, mH, mT which were broken (previously `torch.relu(x.T)` would fail). They now fall back to this default implementation (for which `torch.relu(x.T)` passes).
**Further context**:
* Ed's original suggestion in [90463](https://github.com/pytorch/pytorch/pull/90463#discussion_r1043398340) is to use `torch.Tensor.H.__get__(x)`. I wasn't able to get this to work; fx compilation fails with `getset_descriptor does not have attribute __module__`. Basically, the `__module__` attribute which is available on most python attributes, is not available on `getset_descriptor` objects. (i.e., these are implemented in C++ as attributes on torch.Tensor, so they don't obey some assumptions made by fx)
* Although both tensor attributes and methods (like `x.relu()`) both go through this, this PR should only handle attributes (e.g. see the `"getset_descriptor"` in variables/tensor.py). Methods are handled already by by GetAttrVariable.
* Prior to this PR, we already returned GetAttrVariables for unsupported attrs: the parent caller would catch the NotImplementedError and fallback to returning a GetAttrVariable. But if this GetAttrVariable was ever passed into a torch.\* function (as it could quite possibly be, since most of these attrs are tensors), it would fail because its proxy node would be missing an [example_value](https://github.com/pytorch/pytorch/blob/master/torch/_dynamo/utils.py#L1017). So: before, for some tensor x, `x.real` would work fine; but `torch.relu(x.real)` would fail.
**Testing**: added tests in test_misc.py for x.real, x.imag, x.T, x.real.T.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/91840
Approved by: https://github.com/ezyang
Previously, Dynamo faked support for item() when `capture_scalar_outputs` was True by representing it internally as a Tensor. With dynamic shapes, this is no longer necessary; we can represent it directly as a SymInt/SymFloat. Do so. Doing this requires you to use dynamic shapes; in principle we could support scalar outputs WITHOUT dynamic shapes but I won't do this unless someone hollers for it.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Differential Revision: [D42885775](https://our.internmc.facebook.com/intern/diff/D42885775)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93150
Approved by: https://github.com/voznesenskym
for some tensor x, x.type(torch.FloatTensor) will essentially do the same thing as x.to(torch.float). x.type can be called with at least 3 types of inputs:
* a string "torch.FloatTensor"
* a dtype torch.float
* a tensor type torch.FloatTensor
the third option (torch.FloatTensor) fails in fx, because fx cannot trace torch.FloatTensor objects. So this PR will replace the torch.FloatTensor type with a string "torch.FloatTensor"
Why not fix this in fx? Well, it's possible, but I'm not sure a nice way to do it. We would want to update [torch.fx.node.BaseArgumentTypes](d88bc38b0c/torch/fx/node.py (L17)) to contain torch.FloatTensor etc. We could hard-code a list of tensor types there (the types vary depending on build type, e.g. whether or not cuda tensors are available), but that's not great in case our hardcoded list differs from the actual list registered by python_tensor.cpp. Another option is to dynamically populate the list of types with `Union[tuple(...)])`, and fill the tuple with `torch._tensor_classes` (which is directly populated by python_tensor.cpp), but apparently this breaks most typecheckers.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93043
Approved by: https://github.com/jansel
This reverts commit 9945a78a94.
Reverted https://github.com/pytorch/pytorch/pull/90463 on behalf of https://github.com/ZainRizvi due to This is causing test failures: FAILED inductor/test_torchinductor_opinfo.py::TestInductorOpInfoCUDA::test_comprehensive_linalg_pinv_singular_cuda_float64 - RuntimeError: unexpected success linalg.pinv.singular, torch.float64, cuda
Rewrite inplace addcdiv to a div, mul and inplace add to avoid graph break
Rewrite inplace add to a mul and inplace add to avoid graph break
Needed to close optimizer graph breaks
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90330
Approved by: https://github.com/jansel
The old code didn't actually fakeify traceable tensor subclasses at the
time they are added as a GraphArg to the module; now we do, by ignoring
the subclass during fakeification and relying on Dynamo to simulate
the subclass on top. See comments for more details.
BTW, this codepath is super broken, see filed issues linked on the
inside.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90009
Approved by: https://github.com/wconstab, https://github.com/voznesenskym
Fix errors from [7k github models](https://github.com/pytorch/torchdynamo/issues/1884)
```
Traceback (most recent call last):
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 1062, in get_fake_value
return wrap_fake_exception(
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 739, in wrap_fake_exception
return fn()
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 1063, in <lambda>
lambda: run_node(tx.output, node, args, kwargs, nnmodule)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 1112, in run_node
raise RuntimeError(
RuntimeError: Failed running call_function <function einsum at 0x7fd8f246a4c0>(*('i,j->ij', FakeTensor(FakeTensor(..., device='meta', size=(4,)), cpu), FakeTensor(FakeTensor(..., device='meta', size=(2,)), cuda:0)), **{}):
Unhandled FakeTensor Device Propagation for aten.mul.Tensor, found two different devices cpu, cuda:0
(scroll up for backtrace)
```
The root cause is: ```tensor.type()``` should return ```torch.cuda.FloatTensor``` rather than ```torch.FloatTensor``` if it's on GPU.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/90021
Approved by: https://github.com/jansel
This is a group of bug fixes for [7k github models](https://github.com/pytorch/torchdynamo/issues/1884), it would fix 30+ model tests.
* Support ```tensor.type()```.
* Support ```tensor.get_device()```.
* Support ```torch.nn.functional._Reduction.get_enum```.
* Support ```torch._utils._get_device_index()```.
* Fallback ```tensor.data_ptr()```.
* ```FakeTensor``` always returns 0
* For no fake tensor propagation, we ```clone``` the input tensor, which makes no sense to track the original ```data_ptr```. And I don't think this is a very popular API.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89486
Approved by: https://github.com/jansel
Fix bugs in [7k github models](https://github.com/pytorch/torchdynamo/issues/1884).
* Legacy code still use ```tensor.data```, I think we can use ```tensor.detach``` to rewrite, not sure if there is anything I didn't anticipate.
* Support ```tensor.layout```.
The root cause of these issues are: dynamo wraps unimplemented ```tensor.x``` call into ```GetAttrVariable(TensorVariable, x)```, but this op was not inserted into FX graph. Hence, during the fake tensor propagation, it throws ```KeyError: 'example_value` ```.
For these two popular attributes, Dynamo should support them anyway. However, if dynamo should support ___all___ ```tensor.x``` call and not fallback to ```GetAttrVariable```, I think it's debatable.
If I turn off fake tensor propagation, it works well even not including this fix. So I'm curious if we should improve the fake propagation to cover similar cases. cc @mlazos @soumith @voznesenskym @penguinwu @anijain2305 @EikanWang @jgong5 @Guobing-Chen @chunyuan-w @XiaobingSuper @zhuhaozhe @blzheng @Xia-Weiwen @wenzhe-nrv @jiayisunx @desertfire @jansel @eellison
```
Traceback (most recent call last):
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 404, in _compile
out_code = transform_code_object(code, transform)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/bytecode_transformation.py", line 341, in transform_code_object
transformations(instructions, code_options)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/convert_frame.py", line 392, in transform
tracer.run()
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 1523, in run
super().run()
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 389, in run
and self.step()
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 359, in step
getattr(self, inst.opname)(inst)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 193, in wrapper
return inner_fn(self, inst)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 865, in CALL_FUNCTION_KW
self.call_function(fn, args, kwargs)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/symbolic_convert.py", line 301, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/variables/torch.py", line 407, in call_function
tensor_variable = wrap_fx_proxy(
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/variables/builder.py", line 636, in wrap_fx_proxy
return wrap_fx_proxy_cls(
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/variables/builder.py", line 676, in wrap_fx_proxy_cls
example_value = get_fake_value(proxy.node, tx)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 1024, in get_fake_value
args, kwargs = torch.fx.node.map_arg((node.args, node.kwargs), visit)
File "/scratch/ybliang/work/repos/pytorch/torch/fx/node.py", line 613, in map_arg
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
File "/scratch/ybliang/work/repos/pytorch/torch/fx/node.py", line 621, in map_aggregate
t = tuple(map_aggregate(elem, fn) for elem in a)
File "/scratch/ybliang/work/repos/pytorch/torch/fx/node.py", line 621, in <genexpr>
t = tuple(map_aggregate(elem, fn) for elem in a)
File "/scratch/ybliang/work/repos/pytorch/torch/fx/node.py", line 627, in map_aggregate
return immutable_dict((k, map_aggregate(v, fn)) for k, v in a.items())
File "/scratch/ybliang/work/repos/pytorch/torch/fx/node.py", line 627, in <genexpr>
return immutable_dict((k, map_aggregate(v, fn)) for k, v in a.items())
File "/scratch/ybliang/work/repos/pytorch/torch/fx/node.py", line 631, in map_aggregate
return fn(a)
File "/scratch/ybliang/work/repos/pytorch/torch/fx/node.py", line 613, in <lambda>
return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x)
File "/scratch/ybliang/work/repos/pytorch/torch/_dynamo/utils.py", line 1022, in visit
return n.meta["example_value"]
KeyError: 'example_value\n\nfrom user code:\n File "./generated/test_BayesWatch_pytorch_prunes.py", line 108, in forward\n return torch.zeros([x.size()[0], self.channels, x.size()[2] // self.spatial, x.size()[3] // self.spatial], dtype=x.dtype, layout=x.layout, device=x.device)\n\nSet torch._dynamo.config.verbose=True for more information\n\n\nYou can suppress this exception and fall back to eager by setting:\n torch._dynamo.config.suppress_errors = True\n'
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89257
Approved by: https://github.com/jansel
This refactor was prompted by challenges handling mixed int/float
operations in C++. A previous version of this patch
added overloads for each permutation of int/float and was unwieldy
https://github.com/pytorch/pytorch/pull/87722/ This PR takes a different
approach.
The general outline of the patch is to combine the C++ types SymIntNode
and SymFloatNode into a single type, SymNode. This is type erased; we
no longer know statically at C++ if we have an int/float and have to test
it with the is_int()/is_float() virtual methods. This has a number of
knock on effects.
- We no longer have C++ classes to bind to Python. Instead, we take an
entirely new approach to our Python API, where we have a SymInt/SymFloat
class defined entirely in Python, which hold a SymNode (which corresponds
to the C++ SymNode). However, SymNode is not pybind11-bound; instead,
it lives as-is in Python, and is wrapped into C++ SymNode using PythonSymNode
when it goes into C++. This implies a userland rename.
In principle, it is also possible for the canonical implementation of SymNode
to be written in C++, and then bound to Python with pybind11 (we have
this code, although it is commented out.) However, I did not implement
this as we currently have no C++ implementations of SymNode.
Because we do return SymInt/SymFloat from C++ bindings, the C++ binding
code needs to know how to find these classes. Currently, this is done
just by manually importing torch and getting the attributes.
- Because SymInt/SymFloat are easy Python wrappers, __sym_dispatch__ now
takes SymInt/SymFloat, rather than SymNode, bringing it in line with how
__torch_dispatch__ works.
Some miscellaneous improvements:
- SymInt now has a constructor that takes SymNode. Note that this
constructor is ambiguous if you pass in a subclass of SymNode,
so an explicit downcast is necessary. This means toSymFloat/toSymInt
are no more. This is a mild optimization as it means rvalue reference
works automatically.
- We uniformly use the caster for c10::SymInt/SymFloat, rather than
going the long way via the SymIntNode/SymFloatNode.
- Removed some unnecessary toSymInt/toSymFloat calls in normalize_*
functions, pretty sure this doesn't do anything.
- guard_int is now a free function, since to guard on an int you cannot
assume the method exists. A function can handle both int and SymInt
inputs.
- We clean up the magic method definition code for SymInt/SymFloat/SymNode.
ONLY the user classes (SymInt/SymFloat) get magic methods; SymNode gets
plain methods; this is to help avoid confusion between the two types.
Signed-off-by: Edward Z. Yang <ezyang@fb.com>
cc @jansel @mlazos @soumith @voznesenskym @yanboliang @penguinwu @anijain2305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87817
Approved by: https://github.com/albanD, https://github.com/anjali411
**Introduces symbolic shape guards into dynamo.**
In this PR, we take the existing fake tensor infra and plumbing in dynamo and we start passing a shape_env around. This shape_env does not get plumbed down to middle layers / backend yet - it only collects expressions from frontend invocations at the moment. We then translate these expressions into guards at the point where we take other guards installed throughout dynamo - and add them to check_fn.
Part 1 of https://docs.google.com/document/d/1QJ-M4zfMkD-fjHIqW089RptjLl9EgozZGCceUbvmgfY/edit#
cc @jansel @lezcano @fdrocha @mlazos @soumith @yanboliang @penguinwu @anijain2305
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87570
Approved by: https://github.com/ezyang
Right now, example_value is doing two jobs:
- We use it to propagate metadata (e.g. return type, shapes, etc.)
throughout the graph
- We use it to satisfy queries for the actual value (e.g. torch.cond,
`assume_constant_result`)
This is further complicated by the fact that we have two modes, one
where `example_value` is a fake tensor, and one where it is a real
tensor (this is the `fake_tensor_propagation` config flag).
This leads to scenarios where we don't support every combination of
job + mode,
e.g. if `fake_tensor_propagation=False`, `assume_constant_result` is
broken.
This is made worse by the fact that "fake tensor mode" is the default
and is required if you want dynamic shapes to work.
So, this PR introduces a `get_real_value` API that just runs the graph
up to `node` in order to get a concrete value. This API is orthogonal
to
`example_value`, so it doesn't care about `fake_tensor_propagation`.
When `fake_tensor_propagation=True`: `example_value` is a fake tensor,
you must use the `get_real_value` API to get a concrete value. This
will
be the only configuration in the future.
When `fake_tensor_propagation=False`: `example_value` and
`get_real_value` will produce the same value. This is redundant but we
will be removing this config soon.
To support this, I introduce a cache for computed real values, to
memoize the work involved if we're asking for real values a lot.
I attached this state to `OutputGraph` because it seems to be what
historically managed `example_value` lifetimes, but idk.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87091
Approved by: https://github.com/wconstab