At high current implementation of constrains functions (constrain_as_**) will raise exception for the following code snippets:
```
def f(x):
a = x.item()
constrain_as_size(a, 4, 7)
return torch.empty((a, 4))
inp = torch.tensor([5])
ep = torch._export.export(f, (inp,))
```
The reason is because current constrain logic is:
1) Purely python so it won't survive AOT export (the full node is gone after AOT export since AOT export only maintains aten level op).
2) Utilize side effect to add range constraints for traced symbol's shape env ([code](9591e52880/torch/fx/experimental/symbolic_shapes.py (L370-L372))).
3) If runtime assertion is turned on (by default). [`_AddRuntimeAssertionsForConstraintsPass`](9591e52880/torch/_export/passes/add_runtime_assertions_for_constraints_pass.py (L98-L100)) will try to append assertion node based on range constrains extracted from shape env of symbol during another interpretation round.
4). However, since 1), in the round of AOT export, range constraints logic won't run for symbols generated during this round. And later there is no range constrains information available for assertion round and caused issue.
5) As a result of above, it will failure at `torch.empty((a, 4))` (there is no constrains for `a` that it must be positive).
The fix here is just to implement range constrain logic as a native aten op (CPU implementation as no-op) to make it be able to survive AOT export.
**NOTE:**
[Logic](2d745b95d7/torch/fx/experimental/symbolic_shapes.py (L350-L365C15)) within [`constrain_range`](2d745b95d7/torch/fx/experimental/symbolic_shapes.py (LL313C74-L313C74)) is split out as `constrain_range_int` to capture case when non `SymInt` is passed in and reused in the new `_constrain_range`. The reason is when non `SymInt` is provided:
* If it directly calls `sym_constrain_range`, the C++ version will be called which will be no-op.
* So in this case it calls `constrain_range_int` instead to be able to capture issue like user provides a input whose tensor's shape could be out of range during exporting, like the following for above code example:
```
...
inp = torch.tensor([10])
ep = torch._export.export(f, (inp,)) # immediately raise error
```
Differential Revision: [D46734204](https://our.internmc.facebook.com/intern/diff/D46734204)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103346
Approved by: https://github.com/tugsbayasgalan
Fixes#95900
Using the following repro as guide:
```python
import torch
import torch._dynamo
from torch._subclasses import fake_tensor
from torch.fx.experimental.symbolic_shapes import ShapeEnv
from torch._dynamo.output_graph import config
class Model(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = torch.nn.Linear(2, 2)
self.linear2 = torch.nn.Linear(2, 2)
def forward(self, x):
out = self.linear(x)
out = self.linear2(out)
return out
fake_mode = fake_tensor.FakeTensorMode(allow_non_fake_inputs=False,
allow_fallback_kernels=True,
shape_env=ShapeEnv(
allow_scalar_outputs=config.capture_scalar_outputs,
allow_dynamic_output_shape_ops=config.capture_dynamic_output_shape_ops,
frame_id=0
),
)
# Fakefying input/model before calling torch._dynamo.export
with fake_mode:
fake_x = torch.rand(5, 2, 2)
model = Model()
# Calling torch._dynamo.export without active fake mode
graph_module, guards = torch._dynamo.export(
model,
fake_x,
aten_graph=True,
fake_mode=fake_mode
)
graph_module.print_readable()
graph_module.graph.print_tabular()
```
Summary of changes:
* Plumb fake_mode through torch.export API. When specified, it
replaces the creation of a new FaketendorMode at InstructionTranslator on behalf of OutputGraph
Hacks FakeTensor.__new__ to prevent a
torch.tensor._make_subclass call for inputs that are already fakefied by
user. This probably need to be fixed in a nicer way. Any idea?
* Removed a few asserts that didn't want faked tensors coming
from user script
* Added torch._subclasses.fake_tensor.FakeTensor to type list on a few
asserts check to allow fake inputs
The changes above allowed symbolic tracing with both static and dynamic shapes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100017
Approved by: https://github.com/ezyang
Workaround for https://github.com/pytorch/pytorch/issues/102886
related to: https://github.com/pytorch/pytorch/issues/102476https://github.com/pytorch/pytorch/issues/102475https://github.com/pytorch/pytorch/issues/102474https://github.com/pytorch/pytorch/issues/102473https://github.com/pytorch/pytorch/issues/102473https://github.com/pytorch/pytorch/issues/102472
Since 9aaa12e328 the first inductor (CPU) UT fails until the GPU context is correct initialised and the subsequent UTs pass. CUDA observes the same issue and a workaround was pushed to force initialisation of cuda context by declaring an empty tensor https://github.com/pytorch/pytorch/issues/92627, we have adopted the same approach but have opted for `torch.zeros` which correctly activates the HIP context after the kernel launch.
**Reproducer:**
```
import torch
from torch._subclasses.fake_tensor import FakeTensorMode
import argparse
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='Swap between torch.empty and torch.randn operations.')
parser.add_argument('--empty', action='store_true', help='Use torch.empty operation')
parser.add_argument('--rand', action='store_true', help='Use torch.randn operation')
args = parser.parse_args()
torch.cuda.set_device(0)
if args.empty:
torch.empty(1, device="cuda")
elif args.rand:
torch.rand(1, device="cuda")
print(f": hasPrimaryContext: {torch._C._cuda_hasPrimaryContext(0)")
with FakeTensorMode():
p = torch.randn(4, 2, requires_grad=True, device='cuda')
x = torch.randn(8, 4, device='cuda')
y = torch.mm(x, p).square().sum()
y.backward()
```
**ROCm python repro.py --empty**
0: hasPrimaryContext: False
**ROCm python repro.py --rand**
0: hasPrimaryContext: True
**CUDA python repro.py --empty**
0: hasPrimaryContext: True
**CUDA python repro.py --rand**
0: hasPrimaryContext: True
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103149
Approved by: https://github.com/eellison
- Add get_printoptions and printoptions context manager
- Improve edgeitems handling when it is zero
- Add render_call which can be used to conveniently print command
line arguments of a function call, while suppressing actual
tensor data
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102623
Approved by: https://github.com/albanD
FakeTensor doesn't normalize device_idx and failed with below testcase.
import torch
import habana_frameworks.torch.hpu
from torch._subclasses.fake_tensor import FakeTensorMode
with FakeTensorMode.push():
a = torch.empty(1, device="hpu")
b = torch.empty(1, device="hpu:0")
result = a + b
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102512
Approved by: https://github.com/albanD
When investigating failures in https://github.com/pytorch/pytorch/pull/100017 I realized that we were reentering FakeTensorMode even though there was already one on the stack. Although we have attempted assert for these cases in the past, e.g., as in https://github.com/pytorch/pytorch/pull/97186 it seems that the existing protections were insufficient.
In this particular case, the reapplication of FakeTensorMode was due to an interaction with NotImplemented multiple dispatch handling. If proxy tensor mode detects an unrecognized tensor type (this includes FakeTensor, if it is not tracked with a proxy), it will return NotImplemented to give this tensor a chance to unpack itself into proxyable operation. However, this is never the right thing for FakeTensor, where no unpacking is possible. However, today, FakeTensor attempts to reapply the FakeTensorMode, resulting in FakeTensorMode being twice on the stack.
This PR does a number of things:
* It adds an assert in `FakeTensorMode.__torch_dispatch__` that you must not already have this mode on the stack, this is ALWAYS an error
* It modifies `FakeTensor.__torch_dispatch__` to return `NotImplemented` if the mode is already active. This prevents us from readding the mode on the stack
* It adds a new logging artifact `not_implemented` which you can use to get debug logs about all of the times a `__torch_dispatch__` handler returned NotImplemented and why it did so. Your subclass has to manually opt into this logging, but I inserted the necessary logs for ProxyTensorMode and FakeTensor(Mode)
* `with fake_mode` now no-ops if the fake mode is already on the stack, which is what users want anyway
* I am BREAKING pre-autograd tracing, because it is currently doing something weird with the original C++ mode stack. Brian is going to follow up with a fix next week.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102091
Approved by: https://github.com/thiagocrepaldi, https://github.com/eellison, https://github.com/wanchaol, https://github.com/bdhirsh
torch/custom_op.py is getting long, and the autograd pieces are going to
make it even longer. I'm planning on just organizing the files under
a torch/_custom_op folder.
Note that the imports now look a bit crazy (from torch._custom_op.impl
import...) but they will look more OK when we figure out the plan to
make custom_op public (coming later).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101823
Approved by: https://github.com/ezyang, https://github.com/albanD, https://github.com/bdhirsh
FakeTensor has a default device logic that wraps meta tensors to the right device after running meta kernels and throws on multiple devices. This logic was only running on the wrapping from meta kernels -> fake. For out variants, where the output of the meta kernel was already a fake tensor because it was an input, the device logic wasn't running.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/101807
Approved by: https://github.com/ngimel
Previously the error message went through torch.library. This PR changes
it so that on each custom_op.impl_* call:
- we store a (function, location) tuple
- if a (function, location) tuple exists already, then we raise an
error.
This logic already existed for the abstract impl (the impl for meta and
fake tensors), so this PR just extends it to the others.
Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/100979
Approved by: https://github.com/bdhirsh, https://github.com/soulitzer
This PR:
- adds an abstract registration API for CustomOp (CustomOp.impl_abstract)
that is used for both FakeTensor and meta tensors
- deletes CustomOp.impl_meta
The user story behind this API is that it is the one-stop shop for
registering implementations for data-less Tensors, i.e. FakeTensor and
Meta tensor.
The abstract implementation provided by the user:
- gets registered as the FakeTensor implementation AND the meta formula
- can be written like a regular meta formula. If the user decides that
they need something more special (i.e. data-dependent output shape),
then they are able to query a current context object (FakeTensorImplCtx)
that has methods to construct new unbacked symints.
Caveats:
- we really need to make FakeTensor/FakeTensorMode public. Otherwise,
there isn't a way for the user to interactively test that their abstract
implementation is correct without running through large pieces of the
PT2 stack (make_fx or torch.compile).
- We do not memoize the symints produced by
ctx.create_unbacked_symint(). It is possible to do this in the
future, but it is difficult to do soundly and I am not convinced of
the utility outside of the nonzero() usecase mentioned in #95399
Public API:
- More docs will come when we actually expose this API to users by
putting it in a public namespace, unless you folks want it now.
- The APIs mentioned in `__all__` are the ones that are intended to be
public.
Test Plan:
- Updated existing custom_op_db operators
- Added new numpy_nonzero and numpy_nms operations that test operations
that have data-dependendent output shape.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99439
Approved by: https://github.com/ezyang
I got too confused by the FakeTensor printing, so this PR fixes it to
print normally.
Before:
```
with FakeTensorMode():
x = torch.empty(2, 2, device="cpu")
print(x)
# FakeTensor(FakeTensor(..., device='meta', shape=(2, 2)), cpu)
```
After (Tensor printing doesn't print the default device):
```
FakeTensor(..., shape=(2, 2))
```
Test Plan:
- new test
Pull Request resolved: https://github.com/pytorch/pytorch/pull/99205
Approved by: https://github.com/eellison
Summary: This fixes the case when some of the input tensors were
real tensors and fakified in `validate_and_convert_non_fake_tensors`,
but `flat_arg_fake_tensors` would not contain all the inputs
because it was computed before the fakification. We fix this by
recomputing `flat_arg_fake_tensors` after fakification as well.
Test Plan:
python test/dynamo/test_export.py ExportTests.test_mixed_real_and_fake_inputs
Reviewers: Chillee, voznesenskym
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98769
Approved by: https://github.com/voznesenskym
This was leftover for when we had more logic in the FakeTensor and not FakeTensorMode, and wasn't firing correctly. It also makes more sense for it to be in the other validation function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97186
Approved by: https://github.com/bdhirsh
This was leftover for when we had more logic in the FakeTensor and not FakeTensorMode, and wasn't firing correctly. It also makes more sense for it to be in the other validation function.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/97186
Approved by: https://github.com/bdhirsh
This replaces fake_mode_from_tensors but it preferentially looks for
fake_mode in TracingContext and also if there is an active fake mode
on the dispatch stack, before groveling in tensors to find it.
This advances PegasusForCausalLM, which was previously failing because
we generated a graph that had a parameter (non-fake) and a SymInt,
and thus previously we failed to detect the correct fake mode.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/98321
Approved by: https://github.com/voznesenskym
The purpose of this API is to execute a few large components of work:
1) Refactor all the internals of plumbing dynamic dimension information after dynamo to be stateless
2) Decouple allocation controls around dynamic dimensions from verification
3) For (2), for allocation, create an enum that dictates whether we are in DUCK (default today), STATIC (aka assume_static_default in the past), or DYNAMIC (aka user constrained, do not duck shape)
4) For (2), for verification, we separate out the list of dynamic ranges entirely from allocation. This means shape_env does not tracking for what we verify on, and instead, it is the callers job to invoke produce_guards() with the various things they want verified, specifically, with the valid ranges. We do use constrain ranges to refine value ranges when doing analysis.
5) We have decided, therefore, as an extension of (4) to double down on "late" checks versus "eager" checks, primarily because the mechanisms for gathering what actually matters happens during guards, and should be a purview of the caller seeking guards, not the shape env. However, for dynamo, these structures are essentially one and the same.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/96699
Approved by: https://github.com/avikchaudhuri, https://github.com/ezyang
This removes the need to explicitly constrain_unify `x[mask]` and `y[mask]` when mask is a boolean tensor. It's very narrow but it seems to work in practice.
To invalidate the nonzero call when mutation occurs, I use version counter. I know there are ways to bypass this but I think it's good enough for now.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95399
Approved by: https://github.com/eellison
This takes the strategy described in https://docs.google.com/document/d/1lFRYAJo5nrfxRhwIzGnfi2pbLpU6T4ytSRSuLJ5qebI/edit#
It is essentially https://github.com/pytorch/pytorch/pull/95222 but squashed and with changes that are unnecessary given that we assume nonzero returns > 1.
What's in the PR:
* nonzero now supports meta propagation. When `capture_dynamic_output_shape_ops`, it will return a tensor with an unbacked SymInt representing the size in question.
* The unbacked SymInt is UNSOUNDLY assumed to be not equal to 0/1. We will still error if you guard otherwise.
* PrimTorch pointwise operators are updated to use empty_permuted, to avoid guarding on unbacked SymInt from empty_strided (tested in `test_dynamic_pointwise_scalar`)
* Convolution is updated to skip backend selection if batch is unbacked, to avoid guarding on unbacked SymInt (tested in `test_unbacked_batch_resnet`)
* I kept the helper utilities like `definitely_true` for working with possibly unbacked SymInts. They're not used right now but maybe someone will find them useful.
* Added `constrain_unify` to let you specify two unbacked SymInts must have the same value
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/95387
Approved by: https://github.com/voznesenskym
I believe this fixes the AllenaiLongformerBase problem in periodic.
The longer version of the problem is here is we are currently optimistically converting all item() calls into unbacked SymInt/SymFloat, but sometimes this results in a downstream error due to a data-dependent guard. Fallbacks for this case are non-existent; this will just crash the model. This is bad. So we flag guard until we get working fallbacks.
What could these fallbacks look like? One idea I have is to optimistically make data-dependent calls unbacked, but then if it results in a crash, restart Dynamo analysis with the plan of graph breaking when the item() call immediately happened.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94987
Approved by: https://github.com/Skylion007, https://github.com/malfet
This PR removes the unnecessary == 0 guard when constructing empty tensors, by ensuring that when we create a contiguous tensor we go directly to the C++ torch.empty implementation (instead of indirecting through empty_strided), where we can bypass doing zero tests when computing the size of the storage. This probably also speeds up trace time.
When I did this, I found out that `empty_tensor_restride_symint` was flagrantly wrong (we had never exercised it before because we redirected to `empty_strided` in PrimTorch decomp, which doesn't hit this codepath.) The bugs:
* Stride computation was wrong (only `last_idx` was ever written to)
* Using set_sizes_and_strides with `sym_sizes` input doesn't work, because there is some sort of ordering problem where `clone_symvec` isn't safe when you clone a vector into itself. Probably should fix this.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94512
Approved by: https://github.com/ngimel
Fast path execution of a few binary ops in fake tensor, to speed up trace time. When testing `python benchmarks/dynamo/timm_models.py --accuracy --timing --backend aot_eager --dynamic-shapes --float32 --only hrnet_w18`, I get the following trace speedup.
Before:
```
cuda eval hrnet_w18 PASS
TIMING: entire_frame_compile:53.97591 backend_compile:33.60832
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:89985 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```
After:
```
cuda eval hrnet_w18 PASS
TIMING: entire_frame_compile:40.18931 backend_compile:25.28828
STATS: call_* op count: 1369 | FakeTensor.__torch_dispatch__:4995 | FakeTensorMode.__torch_dispatch__:69478 | attempt fast:4399 | fast is_contiguous:4399 | ProxyTorchDispatchMode.__torch_dispatch__:3010
```
My experiment notebook can be found at https://docs.google.com/document/d/1_dTIQUwjIVnEWmiFAavJQYVF8uzXqD9Dk6b9gGQLF_U/edit#
This is not the "most" optimized version of the code; compared with Horace/Voz roofline experiment:
```
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index e3bf545f3b8..395942c6ffe 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -774,6 +774,10 @@ class FakeTensorMode(TorchDispatchMode):
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
kwargs = kwargs if kwargs else {}
+ with no_dispatch():
+ if func in {aten.mul.Tensor, aten.add.Tensor, aten.sub.Tensor, aten.relu.default}:
+ return FakeTensor(self, torch.empty(args[0].shape, device='meta'), device='cuda')
+
if func == torch.ops.prim.device.default:
assert len(args) == 1 and isinstance(args[0], FakeTensor)
if args[0].fake_mode.in_kernel_invocation:
```
I am still leaving about 5s of trace time improvement on the table (3s of which is attributable to not yet handling relu.)
The implementation here is based off of https://github.com/pytorch/pytorch/pull/93118/ but I modeled the short circuit logic off of TensorIterator's implementation, for ease of code review and correctness verification. However, there are some important divergences:
* Traditional fast setup in TensorIterator only short circuits if the shapes of all input elements are equal. On hrnet_w18, only 5% of fastpath'ed binary operators actually satisfy this. So instead, I compute the broadcasted shape, but then I only allow the fast path if (1) at least one input tensor has a shape that is exactly the output size, and (2) all the tensors are contiguous (or if all the tensors are channels last).
* I had to manually adjust the logic to handle wrapped numbers (which ordinarily are handled by wrapping into tensors). I think I got this right.
Some evidence that this heuristic is correct is here in: https://gist.github.com/ezyang/b22fa7b72b7349137211d8dc7041f758 I exhaustively test all dim=3 tensors with sizes [1, 2] and show that we get the same significant strides between PrimTorch and the new algorithm. In fact, there ARE differences between this algorithm and PrimTorch, but in fact this algorithm agrees with TensorIterator where PrimTorch is wrong (sample case: size=(1, 1, 2), stride=(1, 1, 1), stride=(1, 1, 1))
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/94047
Approved by: https://github.com/eellison