pytorch/torch/_subclasses
Colin Peppler 13a25c647f [export] improve binary op fast path broadcast check (#121546)
# Context
I believe we have an incorrect guard being created during FakeTensor's binary op fast path.

Consider this case
```
# op.shape: (10, 192); final_shape: (s0, 10, 192)
# Guard Ne(s0, 10) is created when we create SymBool(10 == s0)
if isinstance(op, torch.Tensor) and op.shape == final_shape:
    break
```

As of right now, `op.shape == final_shape` checks whether one of the binary op's operands is the same as the binay op's output shape.
* If one of them is a dynamic shape, then we'll create a guard via`SymBool` creation (i.e. `s0 == 10`).
* If the `SymBool` expr resolves to `false`, then we'll create the guard `Ne(s0, 10)`.

This is a problem when the # of dimensions aren't the same between `op.shape` & `final_shape`. Take the case above for example, `op.shape: (10, 192); final_shape: (s0, 10, 192)`. Although, the shapes aren't the same, it doesn't necessarily mean that `s0 != 10`.

Some thoughts (feel free to ignore). What if the # of dimensions are equal but one of the shapes has symbols. Here's three cases:
  1. `op.shape: (9000, 10, 192); final_shape: (s0, 10, 192)` -- not broadcastable.
  2. `op.shape: (1, 10, 192); final_shape: (s0, 10, 192)` -- 0/1 specialization wins?
  3. `op.shape: (100, 10, 192); final_shape: (s0, 10, 192) where s0 = 100` -- Ask user to mark `s0` as a constant.

# Test
```
$ TORCHDYNAMO_VERBOSE=1 PYTORCH_TEST_WITH_DYNAMO=1 pytest -s test/dynamo/test_dynamic_shapes.py -k test_export_fast_binary_broadcast_check_dynamic_shapes

torch.fx.experimental.symbolic_shapes.ConstraintViolationError: Constraints violated (dim0)! For more information, run with TORCH_LOGS="+dynamic".
  - Not all values of dim0 = L['a'].size()[0] in the specified range 3 <= dim0 <= 1024 satisfy the generated guard Ne(L['a'].size()[0], 3).
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/121546
Approved by: https://github.com/aakhundov
2024-03-09 01:49:42 +00:00
..
__init__.py Add Fake Cross Ref Mode, migrate sparse to it (#85382) 2022-09-21 17:15:47 +00:00
fake_impls.py [export] improve binary op fast path broadcast check (#121546) 2024-03-09 01:49:42 +00:00
fake_tensor.py Change default torch_function behavior to be disabled when torch_dispatch is defined (take 2) (#120632) 2024-03-09 01:08:37 +00:00
fake_utils.py Revert "[fake_impls] Fix seed/offset device for attention kernels (#120839)" (#121447) 2024-03-08 01:48:23 +00:00
functional_tensor.py Change default torch_function behavior to be disabled when torch_dispatch is defined (take 2) (#120632) 2024-03-09 01:08:37 +00:00
meta_utils.py Subclass view fake-ification via reified ViewFuncs (#118405) 2024-03-07 19:56:16 +00:00
schema_check_mode.py Replace follow_imports = silent with normal (#118414) 2024-01-27 02:44:11 +00:00