This PR adds a parametrized test for cond. It tests cond can be traced with valid inputs. Specifically valid inputs is combination of:
- pred (python boolean, boolean tensor, int tensor, scalar tensor)
- true_fn/false_fn (func, obj, nn_module)
- Operands (0 or more tensor inputs), tested with 0 and 2
- closures (0 or more tensor closures), tested with 0 and 2
- nested_level (no nesting or level-2 nested cond)
What this test doesn't cover:
- pred: symbolic boolean expression as predicate
- true_fn/false_fn: that mutates indermediate tensors
- operands: non-tensor operands such as float, int
- closures: nn_module attribute closures, python constant closures
- nested_level: 3+
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110727
Approved by: https://github.com/zou3519
In this PR:
- Adds support for strides for jagged tensor (design doc for this coming soon)
- NestedTensor skips automatic dynamic
- Make use of @bdhirsh's subclass fakification logic by adding the __tensor_{un,}flatten__ functions.
- Additional logic for fakification: since existing subclass fakification logic does not handle the case where the outer tensor has an additional dimension. We insert one-off logic to (1) insert an extra SingletonSymInt onto the fakified NestedTensor. (2) make sure we call track_symint on both the sizes on the inner and outer tensor during guard creation.
Remaining things that are weird:
- Still need to skip some logic in meta utils for some reason (I was going to write this up more, but decided not to since we're not able to do this anyway for a immediate reason: we cannot arbitrarily compare singleton ints. For now I'm just following Brian's advise from [here](https://github.com/pytorch/pytorch/pull/109171#discussion_r1328137070) )
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109171
Approved by: https://github.com/ezyang, https://github.com/bdhirsh
The main thrust of the initial effort here was to capture `register_hook` calls on tensors in compile regions. The first part of this was done in https://github.com/pytorch/pytorch/pull/108903 wherein we added support for register_hook input tensors.
The distinction between input and intermediary is due to implementation differences.
There are 2 kinds of hooks:
1) Hooks on objects with sources (inputs, params)
2) Hooks on objects w/o sources (intermediaries, and outputs).
Note: As outputs can be made simple by how dynamo handles residuals, they could actually be handled as if they were inputs, but, for the sake of this PR, we will refer to hooks as either hooks on inputs (sourced), or hooks on intermediaries (not sourced).
**The plan:**
For tensors w/ a source: (The PR above)
We record registered hooks, store them as a global, and associate them with the tensor in residuals. This means that when dynamo goes to create the frame, where we produce bytecode to stitch together our PT2 modified bytecode with the original eager code, we call register_hook. This registration of hooks in residuals is sound because (a) it happens right after a Pt2 frame region ends and (b) we know that the tensor is alive in f_locals, f_globals, or a module in the users invoking frame. This means we can soundly know it will be around to invoke register_hook on. As long as we guard on the identity of the lifted function, this is sound to do.
For tensors w/o a source: (This PR)
Ostensibly, the most correct and complete solution would be to smuggle hooks into a runtime wrapper in aot_autograd, where all the items the hooks close over are lifted to inputs as necessary and passed alongside the user provided function. This is necessary so that we can properly trace out and capture all the mutations within the user defined hook at backwards time.
This is too complicated - so, we limited the scope of this initial PR to a simple subset of hooks:
- Hooks must have a source (be known to us already, not a lambda or intermediary defined function)
- We must be tracing under compiled autograd
**The flow**:
We use the HOP added in https://github.com/pytorch/pytorch/pull/109690/files, referred to as the HOP below.
1) We intercept register_hook calls and wrap the user defined fn in the HOP
2) We write a `_register_hook_trampoline` to the graph that is a local no-arg function that is invoked as a call_function in the dynamo graph
3) aot_autograd inlines through it during its trace, and sees the HOP
4) the HOP preserves itself in the graph - it does not get traced into
5) During backwards, compiled_autograd installs the HOP under a hook call
6) When compiled_autograd enters compilation over its generated graph, dynamo traces the contents of the hook
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109537
Approved by: https://github.com/ezyang
We want to get to a point where most UserErrors link to exportdb examples. This PR makes passing case names non-optional to make this intent clearer and encourage developers who raise UserErrors to make or point to examples that make fixing such errors more obvious for users.
In addition, sometimes there are multiple examples that are relevant to an error. Thus this PR also enables passing multiple case names.
Retry of #110733 which was reverted due to a landrace.
Differential Revision: [D50087148](https://our.internmc.facebook.com/intern/diff/D50087148/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110878
Approved by: https://github.com/gmagogsfm, https://github.com/tugsbayasgalan
The motivation for removing this is already present in the pre-PR comments. Copying it
~~~
# NB - SuperSource is a weird one.
# it is our only source with 2 bases, so we use the objec
# as the base, rather than the type, since an invocation
# like super(Foo, foo) is represented here, the source object base is more spiritually
# aligned with the instance, rather than the type.
# This whole construction is questionable tho, and we should probably find a way to
# avoid this exception to our otherwise nice source parentage invariant.
~~~
Instead of using super(a, b), we can use `type(b).__mro__[index]`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110475
Approved by: https://github.com/jansel
This pr expose torch._higher_order_ops.cond as torch.cond.
1. Need to add #noqa: F811 to the _check calls in torch/__init__.py to address some confusing linter error "Redefinition of unused 'cond'" but only one cond is imported and for these lines that have this error, they don't define the cond but just use it as an argument.
2. Also add cond to the list that allows it to be traced through so as dynamo could trigger the CondHigherOrder logic instead of creating a TorchVariable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110293
Approved by: https://github.com/zou3519
We want to get to a point where most `UserError`s link to `exportdb` examples. This PR makes passing case names non-optional to make this intent clearer and encourage developers who raise `UserError`s to make or point to examples that make fixing such errors more obvious for users.
In addition, sometimes there are multiple examples that are relevant to an error. Thus this PR also enables passing multiple case names.
Differential Revision: [D50020465](https://our.internmc.facebook.com/intern/diff/D50020465/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110733
Approved by: https://github.com/zhxchen17
Summary: The runtime assertions inserted in the `torch._export.export` by the `_AddRuntimeAssertionsForInlineConstraintsPass` lead to errors in AOT Inductor like #109884. In `torch._export.aot_compile` export and AOT compilation are run consecutively which would lead to the above issue if any assertions are inserted.
In this PR, we're adding a new parameter / flag to `torch._export.aot_compile`, `remove_runtime_assertions`, to remove the assertions inserted during export before AOT compilation. The flag is set to `False` for BC.
Additionally, we remove the flag `add_runtime_assertions_for_inline_constraints` recently added to `torch._dynamo.config`, as it can lead to undesirable `torch._export` behavior and is 's no longer required for the AOT Inductor testing purposes.
Test Plan: CI
Reviewers:
Subscribers:
Tasks:
Tags:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110710
Approved by: https://github.com/zhxchen17, https://github.com/chenyang78
This pr expose torch._higher_order_ops.cond as torch.cond.
1. Need to add #noqa: F811 to the _check calls in torch/__init__.py to address some confusing linter error "Redefinition of unused 'cond'" but only one cond is imported and for these lines that have this error, they don't define the cond but just use it as an argument.
2. Also add cond to the list that allows it to be traced through so as dynamo could trigger the CondHigherOrder logic instead of creating a TorchVariable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110293
Approved by: https://github.com/zou3519
Ideally all `_dynamo.exc.UserError`s should have "case names", i.e., link to examples in `exportdb`.
This PR adds case names to several instances of `_dynamo.exc.UserError`. In particular, looking at coverage based on `UserErrorType`:
* `DYNAMIC_CONTROL_FLOW`, `ANTI_PATTERN`, and `STANDARD_LIBRARY` are fully covered.
* `CONSTRAINT_VIOLATION` and `DYNAMIC_DIM` have no coverage. We don't seem to have any dedicated examples of specifying dynamic shapes in `exportdb` (although they are used in some other examples without explanation, to avoid some specialization that would make such examples moot).
* `INVALID_INPUT` is only partly covered. Frankly this is tedious to cover via examples.
Differential Revision: [D49928518](https://our.internmc.facebook.com/intern/diff/D49928518/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110555
Approved by: https://github.com/angelayi, https://github.com/ydwu4
When mapping between the original signature of a program and the graph-captured signature of its exported program, we emit errors when we see unexpected original or graph-captured inputs or outputs.
These errors can arise because of various reasons, e.g.:
1. some input or output has been lifted because of mutation
2. some type is not pytree-registered for flattening / unflattening
3. some type cannot be realized with graph operations
(This is probably not an exhaustive list.)
Previously we used to emit errors based on a vanilla id-based membership check between the two sides, mostly anticipating (1) as the reason for errors. But this does not do justice to errors because of (2) or (3).
This PR emits a different error when it finds (3) to be a probable cause. Specifically, it considers only Tensor and Sym* types to be "supported": no other type seems to be realizable by graph operations.
When (2) is a probable cause, we sometimes also hit the same error because we would expect the supported types to show through upon registration. But this kind of error may need some more work in the future.
Differential Revision: [D49885828](https://our.internmc.facebook.com/intern/diff/D49885828/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110472
Approved by: https://github.com/ydwu4
Triplet Margin Loss takes in a Callable `distance_function` parameter which is not supported as an argument on the fx graph. See previous error:
> File "/scratch/eellison/work/pytorch/torch/_dynamo/symbolic_convert.py", line 562, in call_function
self.push(fn.call_function(self, args, kwargs))
File "/scratch/eellison/work/pytorch/torch/_dynamo/variables/torch.py", line 723, in call_function
*proxy_args_kwargs(args, kwargs),
File "/scratch/eellison/work/pytorch/torch/_dynamo/utils.py", line 504, in proxy_args_kwargs
f"call_function args: {typestr(*args)} {typestr(*list(kwargs.values()))}"
File "/scratch/eellison/work/pytorch/torch/_dynamo/exc.py", line 143, in unimplemented
raise Unsupported(msg)
torch._dynamo.exc.Unsupported: call_function args: TensorVariable() TensorVariable() TensorVariable() ConstantVariable(float) NNModuleVariable()
This is fixable by just inlining into `triplet_margin_loss` and continuing to compile it. This required support for `has_torch_function_variadic`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110302
Approved by: https://github.com/mlazos