Our experience using `constraints` / `dynamic_dim` with the existing export API has found it to be (subjectively) clunky and (objectively) verbose in common cases.
This PR implements a new design for the export API that replaces the use of `constraints` / `dynamic_dim` with a new way of specifying dynamic shapes, involving the following concepts:
* a constructor `Dim` for first-class named dynamic dimensions with ranges (similar to `functorch.dim`, and analogous to internal symbolic sizes)
* a mechanism that uses the above in `export` calls to associate inputs to their dynamic shape specifications (`dynamic_shapes`)
Design doc: https://docs.google.com/presentation/d/168U7XK72C_WSsZpGESP6Cho9udh193fi0gfjxCNcJ4E/edit#slide=id.p (Meta-only). Note that we only implement Option 1 in that doc. An older version of this PR also implemented Option 3, which is an alternative way of specifying dynamic shapes using tensor type annotations on the exported callable; but we have moved that to future work for now.
See docs for these new features in `torch.export`. The existing `torch.export.export` is modified to use the new API, `torch._export.export__RC__`, whenever `constraints=None`. We have not deprecated the existing API yet, but will do in a follow-up.
Constraint violation errors arising through use of the new API will now contain suggested fixes using the new API. No longer do we need to report all specializations for static dimensions and suggest all constraints over dynamic dimensions to fix such errors. Instead, due to the redesign, the suggested fixes are much more concise, only involving modifying the definitions of relevant `Dim`s.
Differential Revision: [D48919204](https://our.internmc.facebook.com/intern/diff/D48919204/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108448
Approved by: https://github.com/suo, https://github.com/gmagogsfm
Fix: #107315
This PR enables dynamo to trace through the `pytree` API by inlining its functions. In
order to do so, a few details of `pytree` had to be changed.
In summary, this PR:
- Introduces `TreeSpecVariable` for representing `TreeSpec` instances
- Specializes `<type>.__bases__` call, returning a `TupleVariable`
- Enables the call to `id` builtin function for every variable that implements
`as_python_constant` method
- Specializes `ConstantVariable.call_method` for its (un)flatten functions
- Implements `UserDefinedObjectVariable.as_python_constant`
- Modifies `pytree` by:
- Make `SUPPORTED_NODES` a map of ids (instead of types) to `NodeDef`
- Removed `functools.wraps` function, since it can't be inlined
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108533
Approved by: https://github.com/ezyang, https://github.com/voznesenskym
ghstack dependencies: #109201
Summary:
The basic concept behind this diff is to modify Dynamo's tracing behavior when it encounters a KeyedJaggedTensor that is synced (aka has `_length_per_key` and `_offset_per_key` populated). These fields are lists of integers; ordinarily, Dynamo will optimistically try to specialize on integers, however, for KJTs, we know that these integers will definitely vary from run-to-run. Furthermore, ordinarily, we would also specialize these integers if they are 0/1, but we will frequently expect features in KJTs to be 0/1.
The fix is to detect KJTs and treat these integers as *unbacked integers*. This is NOT a universally sound optimization: when treating these integers as unbacked, we never report them as equal to zero or one. In return, we always generate graphs that generalize no matter the length of values on features. This is enough to trace through APS sparse arch, torchrec_dlrm and some small split-cat examples.
The special integer behavior is triggered by a dynamically scoped `force_unspec_int_unbacked_size_like` variable on TracingContext, which we trigger when we wrap a KJT. There probably are other ways to do this, but this was simple and worked.
Test Plan:
```
buck2 test mode/dev-nosan //pytorch/benchmark/fb/test_gpu:run_test_gpu
```
from aakhundov
1. first build feed_lower_benchmark:
```
buck2 build --show-output mode/opt -c python.package_style=inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.split-dwarf=true hpc/new/models/feed/benchmark:feed_lower_benchmark
```
2. then run the lowering of the model with it:
```
TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 TORCH_LOGS="output_code,graph_code" TORCH_COMPILE_DEBUG=1 ../buck-out/v2/gen/fbcode/79c6b019ee0f9469/hpc/new/models/feed/benchmark/__feed_lower_benchmark__/feed_lower_benchmark.par --load=manifold://ig_inference_model/tree/user/facebook/fblearner/predictor/960999465/60/gpu_lowering/input.predictor --skip-trt --skip-ait --sync-mode=0 --enable-aot-inductor --lower-presets="ig_stories" --gpu-trace
```
cf https://docs.google.com/document/d/1yD30xYrdmM8r2HTdmXnZTg0-MHVexfVrAa0294m1AUE/edit?pli=1#heading=h.qiv3fp7e6zg0
From torchrec: https://www.internalfb.com/intern/wiki/Torchrec/Development/Testing_production_models/
From ge0405
baseline (without your diff): f477293168
your diff: f477292363
```
buck2 test //caffe2/test/dynamo:test_dynamo_torchrec
buck2 run 'fbcode//mode/opt' fbcode//pytorch/benchmark/fb/test_gpu:run_test_gpu -- 'pytorch.benchmark.fb.test_gpu.test_gpu.TestBenchmarkFbGpu.test_train_blue_reels_vdd_v3_inductor_speedup'
```
Differential Revision: D49236757
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109216
Approved by: https://github.com/voznesenskym
Before this PR, if we run the following code:
```python
def true_fn(x):
return x - x.cos()
def false_fn(x):
return x + x.sin()
def foo(x):
return cond(x.shape[0] == 4, true_fn, false_fn, [x])
gm = make_fx(foo, tracing_mode='symbolic')(torch.ones(3, 4))
gm = make_fx(foo, tracing_mode='symbolic')(torch.ones(4, 5))
```
we'll have the following error:
```python
Traceback (most recent call last):
File "/home/yidi/local/pytorch/make_fx.py", line 16, in <module>
gm = make_fx(foo, tracing_mode='symbolic')(torch.ones(4, 5))
File "/home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 841, in wrapped
t = dispatch_trace(wrap_key(func, args, fx_tracer, pre_dispatch), tracer=fx_tracer, concrete_args=tuple(phs))
File "/home/yidi/local/pytorch/torch/_compile.py", line 24, in inner
return torch._dynamo.disable(fn, recursive)(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 461, in dispatch_trace
graph = tracer.trace(root, concrete_args)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/external_utils.py", line 17, in inner
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/fx/_symbolic_trace.py", line 817, in trace
(self.create_arg(fn(*args)),),
File "/home/yidi/local/pytorch/torch/fx/experimental/proxy_tensor.py", line 497, in wrapped
out = f(*tensors)
File "/home/yidi/local/pytorch/make_fx.py", line 13, in foo
return control_flow.cond(x.shape[0] == 4, true_fn, false_fn, [x])
File "/home/yidi/local/pytorch/torch/_higher_order_ops/cond.py", line 151, in cond
return torch.compile(cond_op, backend="eager", fullgraph=True)(
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 397, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/eval_frame.py", line 545, in catch_errors
return callback(frame, cache_entry, hooks, frame_state)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 140, in _fn
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 380, in _convert_frame_assert
return _compile(
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 561, in _compile
guarded_code = compile_inner(code, one_graph, hooks, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/utils.py", line 189, in time_wrapper
r = func(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 483, in compile_inner
out_code = transform_code_object(code, transform)
File "/home/yidi/local/pytorch/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
transformations(instructions, code_options)
File "/home/yidi/local/pytorch/torch/_dynamo/convert_frame.py", line 432, in transform
tracer = InstructionTranslator(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2032, in __init__
self.symbolic_locals = collections.OrderedDict(
File "/home/yidi/local/pytorch/torch/_dynamo/symbolic_convert.py", line 2035, in <genexpr>
VariableBuilder(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 229, in __call__
vt = self._wrap(value).clone(**self.options())
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 374, in _wrap
return type_dispatch(self, value)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 808, in wrap_listlike
output = [
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 809, in <listcomp>
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 229, in __call__
vt = self._wrap(value).clone(**self.options())
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 374, in _wrap
return type_dispatch(self, value)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 808, in wrap_listlike
output = [
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 809, in <listcomp>
VariableBuilder(self.tx, GetItemSource(self.get_source(), i))(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 229, in __call__
vt = self._wrap(value).clone(**self.options())
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 374, in _wrap
return type_dispatch(self, value)
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1040, in wrap_tensor
tensor_variable = wrap_fx_proxy(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1267, in wrap_fx_proxy
return wrap_fx_proxy_cls(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1382, in wrap_fx_proxy_cls
example_value = wrap_to_fake_tensor_and_record(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1652, in wrap_to_fake_tensor_and_record
dynamic_dims, constraint_dims = _automatic_dynamic(
File "/home/yidi/local/pytorch/torch/_dynamo/variables/builder.py", line 1550, in _automatic_dynamic
if dim is not None and e.size()[i] != dim:
File "/home/yidi/local/pytorch/torch/__init__.py", line 352, in __bool__
return self.node.bool_()
File "/home/yidi/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1019, in bool_
return self.guard_bool("", 0)
File "/home/yidi/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 1001, in guard_bool
r = self.shape_env.evaluate_expr(self.expr, self.hint, fx_node=self.fx_node)
File "/home/yidi/local/pytorch/torch/fx/experimental/recording.py", line 227, in wrapper
return fn(*args, **kwargs)
File "/home/yidi/local/pytorch/torch/fx/experimental/symbolic_shapes.py", line 3793, in evaluate_expr
assert orig_expr == hint, f"{orig_expr} != {hint}"
AssertionError: False != True
from user code:
Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
You can suppress this exception and fall back to eager by setting:
import torch._dynamo
torch._dynamo.config.suppress_errors = True
```
It's because we record the SymInt in the frame state in _automatic_dynamic the first time we compile the function. Then In the second time, when we are given a symint sized input with different hints, the comparison fails.
Implementation:
This PR returns shape dynamism according to the dynamism of inputs: if a diemsion is SymInt, return DYNAMIC else return static.
Test Plan:
Add a test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109331
Approved by: https://github.com/ezyang
We could have SymBool inputs for torch.compile, e.g. in the following situation:
```
def f(x:torch.Tensor):
pred = x.size(0) == 3
torch.compile(f)(pred, x)
make_fx(f, tracing_mode="symbolic")(x)
```
The idea of this PR (credit to @ezyang) is to support SymBool by re-using the infra we've already had for SymInt so that we don't need to replicate a lot of stuff.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107850
Approved by: https://github.com/ezyang
ghstack dependencies: #107662
The strategy in this PR is pretty straightforward.
There are 2 kinds of hooks:
1) Hooks on objects with sources (inputs, params)
2) Hooks on objects w/o sources (intermediaries, and outputs).
Note: As outputs can be made simple by how dynamo handles residuals, they could actually be handled as if they were inputs, but, for the sake of this PR, we will refer to hooks as either hooks on inputs (sourced), or hooks on intermediaries (not sourced).
The plan:
**For tensors w/ a source:**
We record registered hooks, store them as a global, and associate them with the tensor in residuals. This means that when dynamo goes to create the frame, where we produce bytecode to stitch together our PT2 modified bytecode with the original eager code, we call `register_hook`. This registration of hooks in residuals is sound because (a) it happens right after a Pt2 frame region ends and (b) we know that the tensor is alive in f_locals, f_globals, or a module in the users invoking frame. This means we can soundly know it will be around to invoke `register_hook` on. As long as we guard on the identity of the lifted function, this is sound to do.
**For tensors w/o a source:**
Graph break - we will support this in a subsequent PR
**Handles:**
An interesting new component here is the creation of a `STORE_FAST `->`LOAD_FAST` associated with the handle, the return result of `register_hook`. If the user code stored the result of `register_hook` in a handle, we need to honor that. We do so by interceding into `STORE_FAST`, and recording the name of the local variable as directed by user code. We then honor that same name in the reconstructed bytecode. If the user did not store a hook, we merely pop the produced value to preserve the stack.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108903
Approved by: https://github.com/ezyang
ghstack dependencies: #108846, #109092
We could have SymBool inputs for torch.compile, e.g. in the following situation:
```
def f(x:torch.Tensor):
pred = x.size(0) == 3
torch.compile(f)(pred, x)
make_fx(f, tracing_mode="symbolic")(x)
```
The idea of this PR (credit to @ezyang) is to support SymBool by re-using the infra we've already had for SymInt so that we don't need to replicate a lot of stuff.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107850
Approved by: https://github.com/ezyang
ghstack dependencies: #107662
Summary:
The basic concept behind this diff is to modify Dynamo's tracing behavior when it encounters a KeyedJaggedTensor that is synced (aka has `_length_per_key` and `_offset_per_key` populated). These fields are lists of integers; ordinarily, Dynamo will optimistically try to specialize on integers, however, for KJTs, we know that these integers will definitely vary from run-to-run. Furthermore, ordinarily, we would also specialize these integers if they are 0/1, but we will frequently expect features in KJTs to be 0/1.
The fix is to detect KJTs and treat these integers as *unbacked integers*. This is NOT a universally sound optimization: when treating these integers as unbacked, we never report them as equal to zero or one. In return, we always generate graphs that generalize no matter the length of values on features. This is enough to trace through APS sparse arch, torchrec_dlrm and some small split-cat examples.
The special integer behavior is triggered by a dynamically scoped `force_unspec_int_unbacked_size_like` variable on TracingContext, which we trigger when we wrap a KJT. There probably are other ways to do this, but this was simple and worked.
Test Plan:
```
buck2 test mode/dev-nosan //pytorch/benchmark/fb/test_gpu:run_test_gpu
```
from aakhundov
1. first build feed_lower_benchmark:
```
buck2 build --show-output mode/opt -c python.package_style=inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.split-dwarf=true hpc/new/models/feed/benchmark:feed_lower_benchmark
```
2. then run the lowering of the model with it:
```
TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 TORCH_LOGS="output_code,graph_code" TORCH_COMPILE_DEBUG=1 ../buck-out/v2/gen/fbcode/79c6b019ee0f9469/hpc/new/models/feed/benchmark/__feed_lower_benchmark__/feed_lower_benchmark.par --load=manifold://ig_inference_model/tree/user/facebook/fblearner/predictor/960999465/60/gpu_lowering/input.predictor --skip-trt --skip-ait --sync-mode=0 --enable-aot-inductor --lower-presets="ig_stories" --gpu-trace
```
cf https://docs.google.com/document/d/1yD30xYrdmM8r2HTdmXnZTg0-MHVexfVrAa0294m1AUE/edit?pli=1#heading=h.qiv3fp7e6zg0
From torchrec: https://www.internalfb.com/intern/wiki/Torchrec/Development/Testing_production_models/
From ge0405
baseline (without your diff): f477293168
your diff: f477292363
Differential Revision: D49019987
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108960
Approved by: https://github.com/voznesenskym
The strategy for supporting functools partials is relatively straightforward.
There are 2 cases we need to support:
**1) Functools partials as input**
In this case, we are first seeing the functools partial and it is guaranteed to have a source. As such, the args, keywords, and func of the functools partial are passed through VariableBuilder. As this is the first time we are seeing these objects (as it is an input), we re-enter VariableBuilder with a source referencing the args, keywords, and func as attributes of the input to produce:
- func: A callable VariableTracker (UDF, TorchVariable, etc) depending on the value of `func`
- args: List[VariableTracker] - note, not ListVariableTracker!
- keywords: Dict[str, VariableTracker]
A major benefit of this structure is that it very elegantly matches the args to `call_function`.
We then compose a FunctoolsPartialVariable from the VariableTrackers made above.
**2) Functools partials created within compile**
In this case, we already have all the args as known VTs, and thus just compose a FunctoolsPartialVariable as we do for case (1).
For both (1) and (2) - we propagate all guards from the func, args, and keyword VTs to the FunctoolsPartialVariable
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108846
Approved by: https://github.com/ezyang, https://github.com/jansel
Summary:
Original commit changeset: 33650f7cb0fb
Original Phabricator Diff: D48833682
Test Plan: See T162942232 for how we figured out that this diff caused significant numeric difference.
Reviewed By: voznesenskym
Differential Revision: D49082219
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108823
Approved by: https://github.com/xw285cornell
Summary:
Enables dynamo eager mode tracing for the following situation:
1. we have a torch.autograd.Function
2. the input to that function is a tensor subclass which is an intermediary
This is useful for float8 training UX.
Test Plan:
```
python test/dynamo/test_autograd_function.py -k intermediary_input
```
Reviewers:
Subscribers:
Tasks:
Tags:
Fixes #ISSUE_NUMBER
Pull Request resolved: https://github.com/pytorch/pytorch/pull/108093
Approved by: https://github.com/bdhirsh, https://github.com/wanchaol
Marks all params/optimizer state as static addresses and a finalizer which cleans up the graph attributes when the optimizer goes out of scope.
**Note: this does not mark grads as static because this will increase memory usage significantly
There are two cases:
1. The upstream graph is cudagraphed - this case will work fine OOTB
2. The upstream graph is not cudagraphed - in this case, there will be a lot of copies introduced from the upstream (to copy the grads) into cudagraphed-owned memory, unless the user explicitly marks the grads as static. If the user does this, this will also require not deallocating the grads in zero_grad() (either the mod or optimizer version) by setting them to zero vs None. There is a PR (https://github.com/pytorch/pytorch/pull/107853) in flight to throw an error if zero_grad attempts to set static grads to None.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107504
Approved by: https://github.com/eellison
This PR allows dynamo to fakify FunctionalTensorWrapper by unwrapping, replacing and wrapping again for FunctionalTensorWrapper so that FunctionalTensorWrapper can be passed in as input for dynamo.optimize and we can support something like this
```python
ff = torch.func.functionalize(f)
torch.compile(ff)(x)
```
This PR didn't follow the \_\_tensor_flatten\_\_ and \_\_tensor_unflatten\_\_ protocol right now because we're not sure the plan of doing that for FunctionalTensorWrapper (it's implemented in C++).
**Test Plan:**
Add a new test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107062
Approved by: https://github.com/zou3519
ghstack dependencies: #107042
This PR allows dynamo to fakify FunctionalTensorWrapper by unwrapping, replacing and wrapping again for FunctionalTensorWrapper so that FunctionalTensorWrapper can be passed in as input for dynamo.optimize and we can support something like this
```python
ff = torch.func.functionalize(f)
torch.compile(ff)(x)
```
This PR didn't follow the \_\_tensor_flatten\_\_ and \_\_tensor_unflatten\_\_ protocol right now because we're not sure the plan of doing that for FunctionalTensorWrapper (it's implemented in C++).
**Test Plan:**
Add a new test.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107062
Approved by: https://github.com/zou3519
ghstack dependencies: #107042
Adds API to mark tensor as a static input -
To make this trigger recompiles properly, I'll need to update tensor match checks to also check for this new attribute
Additional concern is memory - the tensors will be kept alive, but this is the current behavior for nn modules and parameters.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107154
Approved by: https://github.com/eellison
RFC: https://github.com/pytorch/rfcs/pull/54
First commit is the contents of https://github.com/Quansight-Labs/numpy_pytorch_interop/
We have already been using this in core for the last few months as a external dependency. This PR pulls all these into core.
In the next commits, I do a number of things in this order
- Fix a few small issues
- Make the tests that this PR adds pass
- Bend backwards until lintrunner passes
- Remove the optional dependency on `torch_np` and simply rely on the upstreamed code
- Fix a number dynamo tests that were passing before (they were not tasting anything I think) and are not passing now.
Missing from this PR (but not blocking):
- Have a flag that deactivates tracing NumPy functions and simply breaks. There used to be one but after the merge stopped working and I removed it. @lezcano to investigate.
- https://github.com/pytorch/pytorch/pull/106431#issuecomment-1667079543. @voznesenskym to submit a fix after we merge.
All the tests in `tests/torch_np` take about 75s to run.
This was a work by @ev-br, @rgommers @honno and I. I did not create this PR via ghstack (which would have been convenient) as this is a collaboration, and ghstack doesn't allow for shared contributions.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106211
Approved by: https://github.com/ezyang
Previously, you would get an error like
```
Dynamo input and output is a strict subset of traced input/output
```
now you get
```
Cannot export model which references tensors that are neither
buffers/parameters/constants nor are direct inputs. For each tensor, if you'd
like this tensor to be an explicit input, add it as a dummy argument
to the top-level model definition you are exporting; if you would
like its value to be embedded as an exported constant, wrap its access
in a function marked with @assume_constant_result.
G['bulbous_bouffant'], accessed at:
File "test_export.py", line N, in f
return bulbous_bouffant + y
```
This doesn't handle outputs, I'm going to hit that next.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/106403
Approved by: https://github.com/tugsbayasgalan
Fix: #105074
This PR makes dynamo handle Numpy global variables the same way as PyTorch tensor global
variables by tracking them as side-effect.
In summary, we add `NumpyNdarrayVariable` to the
`VariableBuilder._can_lift_attrs_to_inputs` function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105959
Approved by: https://github.com/ezyang
Currently, exporting a model to ONNX with fake tensor mode requires the
user to load data and model within `torch.onnx.enable_fake_mode` context,
but the actual call to `torch.onnx.dynamo_export` is done outside such
context.
With this PR, we enable `torch.onnx.dynamo_export` to be called either
within `torch.onnx.enable_fake_mode` or outside of it. This feature
required changes to the core PyTorch Dynamo, which were greatly
supported by @ezyang
In future steps we will determine which scenario we are going to
support, but for now we can use either to explore different options and
scenarios and asses their pros and cons.
This PR also creates a separate suite of tests for fake mode specific
scenarios (`TestFxToOnnxFakeTensorWithOnnxRuntime`).
It was done separately to decrease the test time, but we
could merge it with the default `TestFxToOnnxWithOnnxRuntime`. The
additional parameters are `load_checkpoint_during_init` and
`export_within_fake_mode`
With the newly added supported of nested export within fake mode, the
following scenarios are now supported:
```python
import torch
with torch.onnx.enable_fake_mode() as fake_context:
fake_args = create_args()
fake_kwargs = create_kwargs()
fake_model = create_model()
fake_model.load_state_dict(torch.load(tmp_checkpoint_file.name))
export_options = torch.onnx.ExportOptions(fake_context=fake_context)
# `torch.onnx.dynamo_export` called WITHIN `torch.onnx.enable_fake_mode`
export_output = torch.onnx.dynamo_export(
fake_model,
*fake_args,
**fake_kwargs,
export_options=export_options,
)
export_output.save("/path/to/model.onnx", model_state_dict=create_model())
```
If we decide to only support scenarios in which `torch._dynamo.export` is called within `FakeTensorMode`, then we can remove `fake_mode` argument from `torch._dynamo.export` as a follow-up task
ps: This PR is mostly Edward's https://github.com/pytorch/pytorch/pull/105468 + unit tests after an offline discussion
ps: https://github.com/pytorch/pytorch/issues/105464 tracks pending tasks/limitations from this PR
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105477
Approved by: https://github.com/ezyang, https://github.com/BowenBao
This PR adds initial dynamo support for DTensor, in particular, it:
- allows DTensor be passed into a compiled function, and allow fakify
DTensor during dynamo tracing by turning the inner local tensor to meta
tensor.
- We use `allow_in_graph` to include `DTensor` and `DTensor.from_local` to be represented as `TorchVariable`
- The dtensor created becomes a normal `TensorVariable` and it would insert any tensor operations to the output graph just like torch.Tensor
- note that dtensor have a new instance method `redistribute` compare to plain tensor, and we currently special handle it in `TensorVariable`
`from_local` and `redistribute` both accepts some non-trival metadata as arguments (i.e. DeviceMesh, Placement) which fx.Graph does not support. In order to let these two APIs appear in the dynamo captured graph, we encoded the metadata into a new_function (like `functools.partial`) and the new function only accepts prim args (i.e. tensor), then we put `call_function` with this new_function to the graph. This is suggested by @ezyang. The underlying rationale here is that the metadata will not change across the graph invocations so it's safe to encode them.
Captured graph:
```
def forward(self, L_x_ : torch.Tensor):
l_x_ = L_x_
# File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:685, code: dt = DTensor.from_local(x, mesh, [Shard(0)], run_check=False)
prim_from_local = torch__dynamo_variables_torch_prim_from_local(l_x_, run_check = False); l_x_ = None
# File: /scratch/wanchaol/work/pytorch/test/distributed/_tensor/test_dtensor.py:686, code: return dt.redistribute(mesh, [Replicate()]).to_local() + 2
prim_redistribute = torch__dynamo_variables_tensor_prim_redistribute(prim_from_local); prim_from_local = None
to_local = prim_redistribute.to_local(); prim_redistribute = None
add = to_local + 2; to_local = None
return (add,)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103146
Approved by: https://github.com/voznesenskym