This word appears often in class descriptions and is not consistently spelled. Update comments and some function names to use the correct spelling consistently. Facilitates searching the codebase.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155944
Approved by: https://github.com/Skylion007
Motivation: for the following script:
```
// demo.py
import torch
import json
from transformers import BertModel, BertConfig
CONFIG = """
{
"architectures": [
"BertForMaskedLM"
],
"attention_probs_dropout_prob": 0.1,
"gradient_checkpointing": false,
"hidden_act": "gelu",
"hidden_dropout_prob": 0.1,
"hidden_size": 768,
"initializer_range": 0.02,
"intermediate_size": 3072,
"layer_norm_eps": 1e-12,
"max_position_embeddings": 512,
"model_type": "bert",
"num_attention_heads": 12,
"num_hidden_layers": 12,
"pad_token_id": 0,
"position_embedding_type": "absolute",
"transformers_version": "4.6.0.dev0",
"type_vocab_size": 2,
"use_cache": true,
"vocab_size": 30522
}
"""
config = json.loads(CONFIG)
bloom_config = BertConfig(**config)
model = BertModel(bloom_config).half().cuda()
torch.compiler.reset()
torch.cuda.empty_cache()
compiled_fn = torch.compile(model)
vocab_size = 30522
for b in range(1, 3):
for s in range(1, 10):
print(f"🚀 {b} {s}")
input_ids = torch.randint(0, vocab_size, (b, s)).cuda()
attention_mask = torch.ones(b, s).cuda()
with torch.no_grad():
out = compiled_fn(input_ids, attention_mask).last_hidden_state
```
when we run it with:
```
time TORCH_LOGS=recompiles python demo.py
```
We can see there are 7 recompilations and it takes 2 mins (fresh build) or 1 min (cached build) in my machine.
One root cause of the recompilations is, there are guards to check the alignments of the inputs (see the patch). So there are unexpected recompilations for `(1, 4)`, `(1, 8)`, `(2, 4)` and `(2, 8)` inputs.
In this patch, we always try to always pad the inputs if we don't know its shape at compilation to avoid the guards on alignment. It is fine to always pad the tensor. It won't change the semantics.
Now there are only 3 recompilations and it takes 1 min (fresh build) and 17s (cached build) in my machine.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150403
Approved by: https://github.com/drisspg
**MOTIVATION**
We recently integrated support for Intel Gaudi devices (identified as 'hpu') into the common_device_type framework via the pull request at https://github.com/pytorch/pytorch/pull/126970. This integration allows tests to be automatically instantiated for Gaudi devices upon loading the relevant library. Building on this development, the current pull request extends the utility of these hooks by adapting selected CUDA tests to operate on Gaudi devices. Additionally, we have confirmed that these modifications do not interfere with the existing tests on CUDA devices.
Other accelerators can also extend the functionality by adding the device in the devices list. ( For eg: xpu )
**CHANGES**
Create a separate class for test functions running on CUDA devices
Extend the functionality of these tests to include HPUs
Use instantiate_device_type_tests with targeted attributes to generate device-specific test instances within the new classes
Apply skipIfHPU decorator to bypass tests that are not yet compatible with HPU devices
Previously we had submitted some changes in https://github.com/pytorch/pytorch/pull/140131 . However, deleted that PR due to merge conflicts and other issues.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144387
Approved by: https://github.com/ankurneog, https://github.com/EikanWang, https://github.com/yanboliang, https://github.com/guangyey
Fixes a bunch of benchmarks that failed with cudagraph errors including `tlp python benchmarks/dynamo/timm_models.py --device cuda --inductor --accuracy --amp --training --only resmlp_12_224` when `specialize_float=False`
Also brings down number of overall failures (with keep-going) from 108 => 62. I'd estimate >80% of those 62 are wobbly expect tests.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140346
Approved by: https://github.com/ezyang
ghstack dependencies: #140983, #141003
As discussed with @ezyang, this set of diffs are extracting fixes to problems discovered to flipping `specialize_float=False` in https://github.com/pytorch/pytorch/pull/137782. Since these codepaths are exercised in existing tests, I'm going to bias towards shipping speed and put these up with the primary test plan as the global CI. These code paths are all tested via existing tests when `specialize_float=False` and it feels a bit wonky to add more gated tests that only test behavior when this flag is True, especially since these code paths are already covered. That being said, I'm happy to add individual tests if reviewers insist or have a different POV.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/138598
Approved by: https://github.com/ezyang
ghstack dependencies: #138595
Fixes the observed graph breaks in https://github.com/pytorch/pytorch/issues/121349 and https://github.com/pytorch/pytorch/issues/121350.
But there are still graph breaks since a random output is being used as a seed, e.g.
```python
import random
import torch
def fn(x):
seed = random.randint(0, 100)
rand = random.Random(seed)
return x + rand.randrange(10)
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
opt_fn(torch.ones(1))
```
fails with
```
torch._dynamo.exc.InternalTorchDynamoError: UnspecializedPythonVariable() is not a constant
```
when tracing the line
```
rand = random.Random(seed)
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/133725
Approved by: https://github.com/jansel
Fixes https://github.com/pytorch/pytorch/issues/130456
When we mark_unbacked a size, we actually DO have a hint for it
(because we have a real, input tensor) for it, and previously, we were
accidentally putting it into the hint field of SymNode. If marked
unbacked size is zero or one, this can lead to inconsistency between
hint compute and static evaluation compute under guard size oblivious,
since that's the whole point of size oblivious. Answer is to scrub out
hints on mark unbacked ints.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/130483
Approved by: https://github.com/lezcano
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126
Approved by: https://github.com/kit1980
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/127126
Approved by: https://github.com/kit1980
ghstack dependencies: #127122, #127123, #127124, #127125
The big idea is that floats are treated as Tensors on input/output to the FX graph, but on the inside, we immediately call item() on the synthetic Tensor and record regular float operations on it. Canonicalization to Tensor operations will happen in a standalone FX pass. This behavior is controlled by `specialize_float` config variable when set to False.
The generated graph looks like this for the test `test_unspec_float_output`:
```
def forward(self, L_x_: "f32[3]", L_y_: "f32[]"):
l_x_ = L_x_
l_y_ = L_y_
# File: /data/users/ezyang/a/pytorch/test/dynamo/test_unspec.py:511 in f, code: return x + 1, y * 2
add: "f32[3]" = l_x_ + 1; l_x_ = None
item: "Sym(zf0)" = l_y_.item(); l_y_ = None
mul: "Sym(2*zf0)" = item * 2; item = None
scalar_tensor: "f32[]" = torch.scalar_tensor(mul); mul = None
return (add, scalar_tensor)
```
The ingredients:
* **torch/_dynamo/variables/builder.py** When `specialize_float` is False, we wrap float literals with `wrap_symfloat`. This is an unholy mashup of `wrap_symint` and `wrap_unspecialized_primitive`. The overall strategy is that we first generate a tensor argument (because that's what we want to show up into the FX graph), but then immediately call item() on the tensor argument to get a SymNodeVariable, which we will do the rest of the tracing with. Importantly, this SymNodeVariable is backed with the source of the original float: this means we can guard on the resulting value (something we could NOT do with UnspecializedPythonVariable). This has to be done manually, because if you literally call item() on the tensor, you will end up with an unbacked float. There is a bit of copy paste from wrap_symint and wrap_unspecialized_primitive which we can try to factor out, but this really is its own thing and you should review every line of code in the function.
* **torch/fx/experimental/symbolic_shapes.py** We now can generate guards on float inputs, and these guards are handled inside of ShapeEnv. So we need to be able to allocate (backed!) float symbols, and produce guards for them. Fairly straightforward generalization.
* **torch/_dynamo/codegen.py** I also need to maintain the invariant that there are no float outputs to the FX graph. I chose to do this at codegen time. When we detect a SymNodeVariable on the return stack for a float, we on the fly convert it (via `as_tensor`) to a TensorVariable, which is the true output. We then special case the output bytecode to call item() on it again. The tensor conversion is memoized on SymNodeVariable since we typically run the code generation process twice.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125325
Approved by: https://github.com/lezcano, https://github.com/jansel
Fixes https://github.com/pytorch/pytorch/issues/117268; check this issue for background.
This PR does the following:
* Do not perform a replacement if the expression we're replacing the symbol with has a less refined value range than the original. There's a little bit of trickiness around the handling for values close to INT64_MAX; when checking if a range refines another, I *only* consider the range representable in 64-bit integers. This is enough to prevent us from doing a substitution like `i0 = 10 - i1`, but it appears to still let us do the other substitutions we like, such as `i0 = i1` or `i0 = 12 * i1`
* The test above is order dependent: if we assert an equality BEFORE we have refined a range, we might be willing to do the replacement because there isn't a meaningful range. This means that it's important to mark things as sizes, before you start doing other error checking. `split_with_sizes` is adjusted accordingly. It would be good to raise an error if you get the ordering wrong, but I leave this to future work.
* It turns out this is not enough to fix AOTAutograd, because we lose the size-ness of unbacked SymInts when AOTAutograd retraces the Dynamo graph. So update deferred runtime assert insertion to also insert size-ness and value ranges annotations. Note that, in principle, it shouldn't be necessary to explicitly do the latter; these should just show up as deferred runtime asserts. That's some extra refactoring for a later day.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117356
Approved by: https://github.com/lezcano
Fixes https://github.com/pytorch/pytorch/issues/117268; check this issue for background.
This PR does the following:
* Do not perform a replacement if the expression we're replacing the symbol with has a less refined value range than the original. There's a little bit of trickiness around the handling for values close to INT64_MAX; when checking if a range refines another, I *only* consider the range representable in 64-bit integers. This is enough to prevent us from doing a substitution like `i0 = 10 - i1`, but it appears to still let us do the other substitutions we like, such as `i0 = i1` or `i0 = 12 * i1`
* The test above is order dependent: if we assert an equality BEFORE we have refined a range, we might be willing to do the replacement because there isn't a meaningful range. This means that it's important to mark things as sizes, before you start doing other error checking. `split_with_sizes` is adjusted accordingly. It would be good to raise an error if you get the ordering wrong, but I leave this to future work.
* It turns out this is not enough to fix AOTAutograd, because we lose the size-ness of unbacked SymInts when AOTAutograd retraces the Dynamo graph. So update deferred runtime assert insertion to also insert size-ness and value ranges annotations. Note that, in principle, it shouldn't be necessary to explicitly do the latter; these should just show up as deferred runtime asserts. That's some extra refactoring for a later day.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/117356
Approved by: https://github.com/lezcano
Add a flag setting that controls a threshold of guards involving a symbol, after which we force a symbol to be specialized. The roll out plan is to enable this on OSS but not fbcode, and then roll out to fbcode after we get some telemetry from the previous PR.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/119347
Approved by: https://github.com/lezcano
Fixes#109604
Resubmit gh-109715 + several skips and small fixes to make tests pass.
The main fix here is by @ysiraichi : previously, dynamo did not resume tracing numpy ndarrays after a graph break.
While at it, fix several small issues Yukio's fix uncovers:
- graph break gracefully on numpy dtypes which do not map to torch.dtypes (uint16 etc)
- recognize array scalars in dynamo, treat them as 0D ndarrays
- make sure that iterating over torch.ndarray generates arrays not bare tensors
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110512
Approved by: https://github.com/lezcano
I just ported the C++ torch.tensor implementation to Python, swapping out the inner bits to successively stack tensors together, so that we can trace through `scalar_tensor`.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/109515
Approved by: https://github.com/voznesenskym
ghstack dependencies: #109513