mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Sympy's implementation of Min/Max displays asymptotically bad behavior on `TORCH_COMPILE_CPROFILE=1 python torchrec/distributed/tests/test_pt2_multiprocess.py TestPt2Train.test_compile_multiprocess`. Evidence profile:  On this test case, we spend 42% of all time compiling the network on ShapeEnv.replace, which in turn spends all of its time in xreplace. The problem appears to be find_localzeros call. By vendoring the implementations of Min/Max, we can potentially reduce the cost of this operation. The implementation is copy-pasted sympy/functions/elementary/miscellaneous.py but with some adjustments: * I deleted logic related to differentatiation, evalf and heaviside, as it's not relevant to PyTorch reasoning * There's some massaging to appease PyTorch's linters, including a lot of noqa and type: ignore (which I could potentially refactor away with substantive changes, but that's better as its own change) * I deleted the second loop iteration for is_connected, as an attempt at initial optimization (this also simplifies the port, since I can omit some code). I'll comment at that point what the exact difference is. Before this change, the test in question takes 100s with 40 features; post this change, afterwards, it takes only 69s. Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/133319 Approved by: https://github.com/Skylion007 |
||
|---|---|---|
| .. | ||
| _strobelight | ||
| _sympy | ||
| backcompat | ||
| benchmark | ||
| bottleneck | ||
| data | ||
| hipify | ||
| jit | ||
| model_dump | ||
| tensorboard | ||
| viz | ||
| __init__.py | ||
| _backport_slots.py | ||
| _config_module.py | ||
| _config_typing.pyi | ||
| _content_store.py | ||
| _contextlib.py | ||
| _cpp_extension_versioner.py | ||
| _cxx_pytree.py | ||
| _device.py | ||
| _exposed_in.py | ||
| _foreach_utils.py | ||
| _freeze.py | ||
| _get_clean_triton.py | ||
| _import_utils.py | ||
| _mode_utils.py | ||
| _ordered_set.py | ||
| _python_dispatch.py | ||
| _pytree.py | ||
| _stats.py | ||
| _thunk.py | ||
| _traceback.py | ||
| _triton.py | ||
| _typing_utils.py | ||
| _zip.py | ||
| backend_registration.py | ||
| bundled_inputs.py | ||
| checkpoint.py | ||
| collect_env.py | ||
| cpp_backtrace.py | ||
| cpp_extension.py | ||
| deterministic.py | ||
| dlpack.py | ||
| file_baton.py | ||
| flop_counter.py | ||
| hooks.py | ||
| mkldnn.py | ||
| mobile_optimizer.py | ||
| model_zoo.py | ||
| module_tracker.py | ||
| show_pickle.py | ||
| throughput_benchmark.py | ||
| weak.py | ||