This pr expose torch._higher_order_ops.cond as torch.cond.
1. Need to add #noqa: F811 to the _check calls in torch/__init__.py to address some confusing linter error "Redefinition of unused 'cond'" but only one cond is imported and for these lines that have this error, they don't define the cond but just use it as an argument.
2. Also add cond to the list that allows it to be traced through so as dynamo could trigger the CondHigherOrder logic instead of creating a TorchVariable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110293
Approved by: https://github.com/zou3519
This pr expose torch._higher_order_ops.cond as torch.cond.
1. Need to add #noqa: F811 to the _check calls in torch/__init__.py to address some confusing linter error "Redefinition of unused 'cond'" but only one cond is imported and for these lines that have this error, they don't define the cond but just use it as an argument.
2. Also add cond to the list that allows it to be traced through so as dynamo could trigger the CondHigherOrder logic instead of creating a TorchVariable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110293
Approved by: https://github.com/zou3519
Triplet Margin Loss takes in a Callable `distance_function` parameter which is not supported as an argument on the fx graph. See previous error:
> File "/scratch/eellison/work/pytorch/torch/_dynamo/symbolic_convert.py", line 562, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/scratch/eellison/work/pytorch/torch/_dynamo/variables/torch.py", line 723, in call_function
*proxy_args_kwargs(args, kwargs),
File "/scratch/eellison/work/pytorch/torch/_dynamo/utils.py", line 504, in proxy_args_kwargs
f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}"
File "/scratch/eellison/work/pytorch/torch/_dynamo/exc.py", line 143, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_function args: TensorVariable() TensorVariable() TensorVariable() ConstantVariable(float) NNModuleVariable()
This is fixable by just inlining into `triplet_margin_loss` and continuing to compile it. This required support for `has_torch_function_variadic`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110302
Approved by: https://github.com/mlazos
Fix: #107315
This PR enables dynamo to trace through the `pytree` API by inlining its functions. In
order to do so, a few details of `pytree` had to be changed.
In summary, this PR:
- Introduces `TreeSpecVariable` for representing `TreeSpec` instances
- Specializes `<type>.__bases__` call, returning a `TupleVariable`
- Enables the call to `id` builtin function for every variable that implements
`as_python_constant` method
- Specializes `ConstantVariable.call_method` for its (un)flatten functions
- Implements `UserDefinedObjectVariable.as_python_constant`
- Modifies `pytree` by:
- Make `SUPPORTED_NODES` a map of ids (instead of types) to `NodeDef`
- Removed `functools.wraps` function, since it can't be inlined
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108533
Approved by: https://github.com/ezyang, https://github.com/voznesenskym
ghstack dependencies: #109201
RFC: https://github.com/pytorch/rfcs/pull/54
First commit is the contents of https://github.com/Quansight-Labs/numpy_pytorch_interop/
We have already been using this in core for the last few months as a external dependency. This PR pulls all these into core.
In the next commits, I do a number of things in this order
- Fix a few small issues
- Make the tests that this PR adds pass
- Bend backwards until lintrunner passes
- Remove the optional dependency on `torch_np` and simply rely on the upstreamed code
- Fix a number dynamo tests that were passing before (they were not tasting anything I think) and are not passing now.
Missing from this PR (but not blocking):
- Have a flag that deactivates tracing NumPy functions and simply breaks. There used to be one but after the merge stopped working and I removed it. @lezcano to investigate.
- https://github.com/pytorch/pytorch/pull/106431#issuecomment-1667079543. @voznesenskym to submit a fix after we merge.
All the tests in `tests/torch_np` take about 75s to run.
This was a work by @ev-br, @rgommers @honno and I. I did not create this PR via ghstack (which would have been convenient) as this is a collaboration, and ghstack doesn't allow for shared contributions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106211
Approved by: https://github.com/ezyang
This PR adds initial dynamo support for DTensor, in particular, it:
- allows DTensor be passed into a compiled function, and allow fakify
DTensor during dynamo tracing by turning the inner local tensor to meta
tensor.
- We use `allow_in_graph` to include `DTensor` and `DTensor.from_local` to be represented as `TorchVariable`
- The dtensor created becomes a normal `TensorVariable` and it would insert any tensor operations to the output graph just like torch.Tensor
- note that dtensor have a new instance method `redistribute` compare to plain tensor, and we currently special handle it in `TensorVariable`
`from_local` and `redistribute` both accepts some non-trival metadata as arguments (i.e. DeviceMesh, Placement) which fx.Graph does not support. In order to let these two APIs appear in the dynamo captured graph, we encoded the metadata into a new_function (like `functools.partial`) and the new function only accepts prim args (i.e. tensor), then we put `call_function` with this new_function to the graph. This is suggested by @ezyang. The underlying rationale here is that the metadata will not change across the graph invocations so it's safe to encode them.
Captured graph:
```
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
# File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:685, code: dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
prim_from_local = torch__dynamo_variables_torch_prim_from_local(l_x_, run_check = False); l_x_ = None
# File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:686, code: return dt.redistribute(mesh, [Replicate()]).to_local() + 2
prim_redistribute = torch__dynamo_variables_tensor_prim_redistribute(prim_from_local); prim_from_local = None
to_local = prim_redistribute.to_local(); prim_redistribute = None
add = to_local + 2; to_local = None
return (add,)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103146
Approved by: https://github.com/voznesenskym
Original PR #99988
The problem was that we added `wrap` to torch._ops which actually puts
it on `torch.ops.wrap` which is a namespace that can be open-registered
to. The fix is that we now shove `wrap` into a new file
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100544
Approved by: https://github.com/voznesenskym
Summary:
This diff is reverting D45387167
D45387167: Basic dynamo support for traceable collectives (#94440) by wconstab has been identified to be causing the following test or build failures (internal)
If you believe this diff has been generated in error you may Commandeer and Abandon it.
Test Plan: NA
Reviewed By: s4ayub
Differential Revision: D45448312
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100424
Approved by: https://github.com/rohan-varma, https://github.com/kumpera
This PR introduces a `wrap(body_fn, *args)` higher order operator
The semantics of `wrap(body_fn, *args)` is to just run `body_fn(*args)`
Underneath Dynamo, this PR makes it so that we rewrite calls to
`wrap(body_fn, *args)` with `wrap(new_fn, *new_args)` where `new_fn` has
no free variables. This PR does not update cond/map to use the new
mechanism yet (we do not support nn.Modues yet, will come in the future).
The design we take is:
- OutputGraph represents the graph being built by Dynamo that may be
compiled and executed.
- OutputGraph owns a root SubgraphTracer, where it builds the FX graph.
- OutputGraph may own multiple nested SubgraphTracers.
- When we need to trace the body function of a HigherOrderOperator, we
construct a new SubgraphTracer to build the graph of the body function.
Mechanically, when Dynamo sees a new `wrap` HigherOrderOperator with a
body function, it:
- Creates a new SubgraphTracer via OutputGraph.new_subtracer
- Executes the body function
This captures the body function into the graph on the new
SubgraphTracer while modifying the state of the OutputGraph. For
example, the OutputGraph may receive new GraphArgs, new guards, and new
side effects.
If capture of the body function fails, then Dynamo graph breaks on the
HigherOrderOperator.
Test Plan:
- added test/dynamo/test_higher_order_ops.py
Future:
- We're not actually able to tell Dynamo to completely graph break on the
HigherOrderOperator. Instead, when we do graph break, Dynamo begins
introspecting `HigherOrderOperator.__call__`. It should probably not do
this.
- Ideally we would error out on new SideEffects. I don't know how to do
this yet.
- We don't support dealing with nn.Modules yet (e.g. calling nn.Modules
or accessing attributes of tracked nn.Modules from a body_fn). There's
an open question on what should actually happen here
- Ideally we would rewrite map/cond to use the new mechanism but we need
to fix the previous bullet point before we can get there.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99988
Approved by: https://github.com/voznesenskym, https://github.com/anijain2305
Make traceable collectives work with torchdynamo,
bypassing problems with tracing the AsyncTensor subclass.
Accept a suboptimal solution for now, and optimize it later.
For now, wait happens immediately, which generally forces an early sync.
Later, find a way either in dynamo or AOT stack to handle
AsyncCollectiveTensor to get the wait in the optimal place.
Note on implementation:
- Dynamo traces 'user-level' fc apis that are designed to behave differently
in eager vs compiled. In eager, there will be work-obj registration and
a wrapper subclass will insert a 'wait' call at the appropriate time.
In compile/trace mode, wait will be immetiately called, and work obj
registration is required to be handled by the compile backend at runtime.
- Dynamo needs to trace into some of the helper functions in the 'user-level'
api, such as '_expand_group' which is essentially a constant transformation.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94440
Approved by: https://github.com/kumpera
Testing if the minor change breaks other test cases.
For the added test case, TorchDynamo causes graph break on `torch.ops.foo.custom` but then again starts running on the recursively invoked frame - `foo_cpu` on L48 in testfile. This raises assertion like this
~~~
Traceback (most recent call last):
File "/scratch/anijain/work/pytorch/test/dynamo/test_decorators.py", line 65, in test_disallow_in_graph_for_custom_op
res = opt_fn(x)
File "/scratch/anijain/work/pytorch/torch/_dynamo/eval_frame.py", line 252, in _fn
return fn(*args, **kwargs)
File "/scratch/anijain/work/pytorch/test/dynamo/test_decorators.py", line 56, in fn
b = torch.ops.foo.custom(a)
File "/scratch/anijain/work/pytorch/torch/_ops.py", line 646, in __call__
return self._op(*args, **kwargs or {})
File "/scratch/anijain/work/pytorch/torch/_dynamo/eval_frame.py", line 401, in catch_errors
return callback(frame, cache_size, hooks, frame_state)
File "/scratch/anijain/work/pytorch/torch/_dynamo/convert_frame.py", line 495, in _convert_frame
result = inner_convert(frame, cache_size, hooks, frame_state)
File "/scratch/anijain/work/pytorch/torch/_dynamo/convert_frame.py", line 122, in _fn
return fn(*args, **kwargs)
File "/scratch/anijain/work/pytorch/torch/_dynamo/convert_frame.py", line 331, in _convert_frame_assert
return _compile(
File "/scratch/anijain/work/pytorch/torch/_dynamo/utils.py", line 169, in time_wrapper
r = func(*args, **kwargs)
File "/scratch/anijain/work/pytorch/torch/_dynamo/convert_frame.py", line 401, in _compile
out_code = transform_code_object(code, transform)
File "/scratch/anijain/work/pytorch/torch/_dynamo/bytecode_transformation.py", line 1000, in transform_code_object
transformations(instructions, code_options)
File "/scratch/anijain/work/pytorch/torch/_dynamo/convert_frame.py", line 371, in transform
tracer = InstructionTranslator(
File "/scratch/anijain/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1890, in __init__
self.symbolic_locals = collections.OrderedDict(
File "/scratch/anijain/work/pytorch/torch/_dynamo/symbolic_convert.py", line 1893, in <genexpr>
VariableBuilder(
File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 165, in __call__
return self._wrap(value).clone(**self.options())
File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 290, in _wrap
return type_dispatch(self, value)
File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 776, in wrap_tensor
tensor_variable = wrap_fx_proxy(
File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 923, in wrap_fx_proxy
return wrap_fx_proxy_cls(
File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 983, in wrap_fx_proxy_cls
example_value = wrap_to_fake_tensor_and_record(
File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 1213, in wrap_to_fake_tensor_and_record
fake_e = wrap_fake_exception(
File "/scratch/anijain/work/pytorch/torch/_dynamo/utils.py", line 835, in wrap_fake_exception
return fn()
File "/scratch/anijain/work/pytorch/torch/_dynamo/variables/builder.py", line 1214, in <lambda>
lambda: tx.fake_mode.from_tensor(
File "/scratch/anijain/work/pytorch/torch/_subclasses/fake_tensor.py", line 1434, in from_tensor
return self.fake_tensor_converter(
File "/scratch/anijain/work/pytorch/torch/_subclasses/fake_tensor.py", line 329, in __call__
return self.from_real_tensor(
File "/scratch/anijain/work/pytorch/torch/_subclasses/fake_tensor.py", line 283, in from_real_tensor
out = self.meta_converter(
File "/scratch/anijain/work/pytorch/torch/_subclasses/meta_utils.py", line 531, in __call__
r = self.meta_tensor(
File "/scratch/anijain/work/pytorch/torch/_subclasses/meta_utils.py", line 184, in meta_tensor
assert not torch._C._dispatch_tls_local_exclude_set().has(
AssertionError:
~~~
It seems `_dynamo.disable` is the right option for custom ops added by `torch.library`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99600
Approved by: https://github.com/jansel
Summary:
Replace _dynamo.config with an object instead of module
Current usage patterns of setting and reading fields on config will work
unchanged.
Only changes needed going forward:
1. import torch._dynamo.config will not work. However, just doing
import torch._dynamo is sufficient to access dynamo config
as torch._dynamo.config.
2. Files inside of _dynamo folder need to access config via
from torch._dynamo.config_util import config instead of
from torch._dynamo import config. Because _dynamo/__init__.py
imports some of the files so it would be circular import.
Test Plan:
Reviewers:
Subscribers:
Tasks:
Tags:
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96455
Approved by: https://github.com/williamwen42
As found in #92709, thanks to @ngimel and @jansel, currently `torch.Tensor.fn` points to `UserDefinedObjectVariable` rather than `TorchVariable`. The root cause is due to https://github.com/pytorch/pytorch/pull/92709#pullrequestreview-1273357406. To prevent this, build `TorchVariable` of `torch.Tensor.fn` pointing to `torch.ops.aten.fn`.
This issue propagates to `torch.Tensor.fn` causing graph break with `nopython=True`.
```python
import torch
import torch._dynamo as dynamo
#op = torch.ops.aten.abs_ # no graph break
op = torch.Tensor.abs_ # graph break
args = torch.empty(10)
def foo(args):
return op(args)
opt_foo = dynamo.optimize("inductor", nopython=True)(foo)
y_ = opt_foo(args)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/93243
Approved by: https://github.com/jansel