mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Previously, when we applied a replacement, a SymInt that was previously an unbacked SymInt would then transmute into whatever we replaced it into (e.g., a constant). This has a major downside: we often look at SymInts associated with FX nodes (e.g., the meta of x.item() return) to find out where the unbacked SymInt was allocated. If we replace it, we no longer can find out where, e.g., u1 was allocated! But we need to know this so we can generate deferred runtime asserts like u1 == s0. To solve this problem, I have a special mode for replace, resolve_unbacked=False, which lets you disable substitutions on unbacked SymInts. When reporting node.expr, we preferentially avoid applying unbacked SymInt substitutions. To understand if we might accidentally reapply the substitution later, before we have reached the deferred runtime assert, we must study the calls to simplify() in ShapeEnv. My audit turns up these sites: * `produce_guards`: this is fine, deferred runtime asserts never show up here, we must NOT have unbacked SymInts show up here. Similarly `get_nontrivial_guards`. * `_maybe_evaluate_static`: this is fine, we are using this to determine if it is necessary to produce a guard/runtime assert. We don't want to reissue a runtime assert if we've already asserted on it, and replacements can help us understand if this has occurred. * `_simplify_floor_div`: this is a legitimate bug, it needs to be `resolve_unbacked=False` * `_refine_ranges`: this is fine, a refined range doesn't affect what runtime asserts we issue * `_update_divisible`: this updates the `self.divisible` set, which specifies when we can simplify away divisibility constraints. Since this affects replacements only, it won't cause us to oversimplify a user provided expression. There are some situations where we DO want to always apply the substitution, specifically when we have the duplicate symbol problem (we retrace an item call and get u0 and u1 which refer to the same thing.) I don't want two symbols in this case, so a special `rename_unbacked_to` is provided which sets up the unconditional renaming. Along the way, I make a refinement to `_update_var_to_range`: if you update a var range for a size-like unbacked SymInt, you are now no longer allowed to set its lower bound below 2. This is because if you could, then our size oblivious tests for it would be inconsistent. Actually, I think there is still some inconsistency, because if you assert `u0 == 0` we will still end up with this in deferred runtime asserts, and we will then use this to simplify these statements to be True everywhere else. Maybe we should forbid this kind of refinement; not done in this PR. Fixes https://github.com/pytorch/pytorch/issues/119689 Fixes https://github.com/pytorch/pytorch/issues/118385 Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/120816 Approved by: https://github.com/lezcano |
||
|---|---|---|
| .. | ||
| _static | ||
| _templates | ||
| community | ||
| elastic | ||
| notes | ||
| rpc | ||
| scripts | ||
| amp.rst | ||
| autograd.rst | ||
| backends.rst | ||
| benchmark_utils.rst | ||
| bottleneck.rst | ||
| checkpoint.rst | ||
| complex_numbers.rst | ||
| cond.rst | ||
| conf.py | ||
| config_mod.rst | ||
| cpp_extension.rst | ||
| cpp_index.rst | ||
| cpu.rst | ||
| cuda_environment_variables.rst | ||
| cuda._sanitizer.rst | ||
| cuda.rst | ||
| cudnn_persistent_rnn.rst | ||
| cudnn_rnn_determinism.rst | ||
| data.rst | ||
| ddp_comm_hooks.rst | ||
| debugging_environment_variables.rst | ||
| deploy.rst | ||
| deterministic.rst | ||
| distributed.algorithms.join.rst | ||
| distributed.checkpoint.rst | ||
| distributed.elastic.rst | ||
| distributed.optim.rst | ||
| distributed.rst | ||
| distributed.tensor.parallel.rst | ||
| distributions.rst | ||
| dlpack.rst | ||
| docutils.conf | ||
| export.ir_spec.rst | ||
| export.rst | ||
| fft.rst | ||
| fsdp.rst | ||
| func.api.rst | ||
| func.batch_norm.rst | ||
| func.migrating.rst | ||
| func.rst | ||
| func.ux_limitations.rst | ||
| func.whirlwind_tour.rst | ||
| future_mod.rst | ||
| futures.rst | ||
| fx.experimental.rst | ||
| fx.rst | ||
| hub.rst | ||
| index.rst | ||
| jit_builtin_functions.rst | ||
| jit_language_reference_v2.rst | ||
| jit_language_reference.rst | ||
| jit_python_reference.rst | ||
| jit_unsupported.rst | ||
| jit_utils.rst | ||
| jit.rst | ||
| library.rst | ||
| linalg.rst | ||
| logging.rst | ||
| masked.rst | ||
| math-quantizer-equation.png | ||
| meta.rst | ||
| miscellaneous_environment_variables.rst | ||
| mobile_optimizer.rst | ||
| model_zoo.rst | ||
| monitor.rst | ||
| mps.rst | ||
| multiprocessing.rst | ||
| name_inference.rst | ||
| named_tensor.rst | ||
| nested.rst | ||
| nn.attention.bias.rst | ||
| nn.attention.rst | ||
| nn.functional.rst | ||
| nn.init.rst | ||
| nn.rst | ||
| onnx_dynamo_onnxruntime_backend.rst | ||
| onnx_dynamo.rst | ||
| onnx_torchscript_supported_aten_ops.rst | ||
| onnx_torchscript.rst | ||
| onnx.rst | ||
| optim.rst | ||
| package.rst | ||
| pipeline.rst | ||
| profiler.rst | ||
| quantization-accuracy-debugging.rst | ||
| quantization-backend-configuration.rst | ||
| quantization-support.rst | ||
| quantization.rst | ||
| random.rst | ||
| rpc.rst | ||
| signal.rst | ||
| sparse.rst | ||
| special.rst | ||
| storage.rst | ||
| tensor_attributes.rst | ||
| tensor_view.rst | ||
| tensorboard.rst | ||
| tensors.rst | ||
| testing.rst | ||
| threading_environment_variables.rst | ||
| torch_cuda_memory.rst | ||
| torch_environment_variables.rst | ||
| torch.ao.ns._numeric_suite_fx.rst | ||
| torch.ao.ns._numeric_suite.rst | ||
| torch.compiler_aot_inductor.rst | ||
| torch.compiler_api.rst | ||
| torch.compiler_best_practices_for_backends.rst | ||
| torch.compiler_cudagraph_trees.rst | ||
| torch.compiler_custom_backends.rst | ||
| torch.compiler_deepdive.rst | ||
| torch.compiler_dynamic_shapes.rst | ||
| torch.compiler_fake_tensor.rst | ||
| torch.compiler_faq.rst | ||
| torch.compiler_fine_grain_apis.rst | ||
| torch.compiler_get_started.rst | ||
| torch.compiler_guards_overview.rst | ||
| torch.compiler_inductor_profiling.rst | ||
| torch.compiler_ir.rst | ||
| torch.compiler_nn_module.rst | ||
| torch.compiler_performance_dashboard.rst | ||
| torch.compiler_profiling_torch_compile.rst | ||
| torch.compiler_transformations.rst | ||
| torch.compiler_troubleshooting.rst | ||
| torch.compiler.rst | ||
| torch.overrides.rst | ||
| torch.rst | ||
| type_info.rst | ||
| utils.rst | ||
| xpu.rst | ||