pytorch/torch/_dynamo
Shunting Zhang e545caa50f dynamo/torchxla integration: trace on xla rather than eager (#88904)
In #87741 we added the inference support for dynamo/torchxla integration. Later on in #88449 we attempt to add the training support. That attempt is not smooth because
- we try 2 things together
   1. let dynamo trace the model on xla rather than eager
   2. enable training
- It turns out neither of these two tasks are trivial enough.

Furthermore, item 2 (enable training) depends on item 1 (tracing on xla). We enable training via AOTAutograd. AOTAutograd lift all model parameters/buffers as graph inputs. Without item 1 being done, we would need copy all graph inputs (including model parameters/buffers) from eager device to xla devices. That hurts performance a lot. Have a cache to map eager parameter to XLA parameter does not solve the problem since the update on either will not sync automatically to the other. They will easily go out of sync.

This PR let dynamo trace the model on XLA rather than eager. This is a preparation step to enabling training.

Also, tracing on XLA makes the data movement more efficient. We see 1.5x geomean speedup compared to previous 1.38x.
```
+-------------------------+--------------------+-------------------------+
| Model                   |   XLA (trace once) |   XLA (trace everytime) |
+=========================+====================+=========================+
| resnet18                |            1.38    |                 1.008   |
+-------------------------+--------------------+-------------------------+
| resnet50                |            1.227   |                 0.998   |
+-------------------------+--------------------+-------------------------+
| resnext50_32x4d         |            1.544   |                 1.008   |
+-------------------------+--------------------+-------------------------+
| alexnet                 |            1.085   |                 1.045   |
+-------------------------+--------------------+-------------------------+
| mobilenet_v2            |            2.028   |                 1.013   |
+-------------------------+--------------------+-------------------------+
| mnasnet1_0              |            1.516   |                 0.995   |
+-------------------------+--------------------+-------------------------+
| squeezenet1_1           |            0.868   |                 1.01    |
+-------------------------+--------------------+-------------------------+
| vgg16                   |            1.099   |                 1.008   |
+-------------------------+--------------------+-------------------------+
| BERT_pytorch            |            3.26    |                 1.027   |
+-------------------------+--------------------+-------------------------+
| timm_vision_transformer |            2.182   |                 1.015   |
+-------------------------+--------------------+-------------------------+
| geomean                 |            1.50389 |                 1.01261 |
+-------------------------+--------------------+-------------------------+
```

Example command
```
GPU_NUM_DEVICES=1 python benchmarks/dynamo/torchbench.py --randomize-input --performance --trace-on-xla --only resnet18 --backend=torchxla_trace_once
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/88904
Approved by: https://github.com/wconstab, https://github.com/JackCaoG, https://github.com/jansel
2022-11-22 03:57:04 +00:00
..
optimizations dynamo/torchxla integration: trace on xla rather than eager (#88904) 2022-11-22 03:57:04 +00:00
variables [Dynamo] Fix bugs when calling tensor.data and tensor.layout (#89257) 2022-11-21 22:44:01 +00:00
__init__.py [dynamo][reland] API Support for nn.Module (#89113) 2022-11-17 02:03:48 +00:00
allowed_functions.py Graph-break on FSDP in dynamo (#87420) 2022-10-25 17:07:44 +00:00
bytecode_analysis.py Fix line numbers bug (#87247) 2022-10-19 22:44:01 +00:00
bytecode_transformation.py Fix line numbers bug (#87247) 2022-10-19 22:44:01 +00:00
codegen.py [dynamo] Port all pytorch/dynamo and test/dynamo pieces over from symbolic-shapes branch (#88768) 2022-11-13 04:50:21 +00:00
config.py Rewrite assert statement with torch._assert under config (#88246) 2022-11-17 19:49:31 +00:00
convert_frame.py [reland][dynamo] fixes dict changed during runtime error (#88877) 2022-11-13 16:20:45 +00:00
debug_utils.py [dynamo][reland] API Support for nn.Module (#89113) 2022-11-17 02:03:48 +00:00
eval_frame.py Add support for dynamic kwarg to torch._dynamo.optimize (#89290) 2022-11-19 23:51:02 +00:00
exc.py Fix all references to torchdynamo from the merge (#87731) 2022-10-31 06:51:07 +00:00
guards.py [dynamo] Port all pytorch/dynamo and test/dynamo pieces over from symbolic-shapes branch (#88768) 2022-11-13 04:50:21 +00:00
logging.py Fix CODE level usage in dynamo config.py (#87522) 2022-10-25 22:47:54 +00:00
mutation_guard.py
output_graph.py Don't iterate over graph when adding graph input (#89084) 2022-11-16 00:08:34 +00:00
profiler.py
replay_record.py
resume_execution.py
side_effects.py [dynamo] mutable local caching to make dynamo faster at tracing mutation (#89170) 2022-11-19 01:47:48 +00:00
skipfiles.py Graph-break on FSDP in dynamo (#87420) 2022-10-25 17:07:44 +00:00
source.py
symbolic_convert.py [dynamo] mutable local caching to make dynamo faster at tracing mutation (#89170) 2022-11-19 01:47:48 +00:00
test_case.py [dynamo] Unify raise_on_* config to suppress_errors and raise by default (#87440) 2022-10-21 17:03:29 +00:00
test_minifier_common.py Add comprehensive minifier tests (#88022) 2022-11-17 02:02:29 +00:00
testing.py [dashboard][huggingface] skip accuracy checks for really large models… (#89273) 2022-11-19 00:22:45 +00:00
utils.py dynamo/torchxla integration: trace on xla rather than eager (#88904) 2022-11-22 03:57:04 +00:00