Add env for disabling meta reference on functionalization. (#148822)

Fix: https://github.com/pytorch/xla/issues/8755

This PR introduces `TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE`
environment variable. Setting this variable makes it so the
functionalization kernels won't run the meta reference, which is used to
propagate expected sizes and strides.

Currently, PyTorch/XLA doesn't actually propagates the correct strides
to its tensors. It was also shown that calling these meta functions may
incur in significant overhead.

Running the provided minimal reproducer (see issue), we see a speedup
close to 4.3x:

- Baseline: 0.0747s
- `XLA_DISABLE_FUNCTIONALIZATION=1`: 0.0159s
- `TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE=1`: 0.0175s

In summary, this PR:

- Creates the `disable_meta_reference()` function, which checks whether
  the environment variable is set
- Modifies codegen for functionalization kernels, adding the call to
  `disable_meta_reference()` function to the appropriate conditions
- Creates a new bash function for running `lazy/test_ts_opinfo.py` with
  the environment variable set
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148822
Approved by: https://github.com/bdhirsh
This commit is contained in:
Yukio Siraichi 2025-03-10 19:53:27 -03:00 committed by PyTorch MergeBot
parent 09029010e5
commit e0d4c43ad1
3 changed files with 18 additions and 5 deletions

View File

@ -314,6 +314,13 @@ test_python() {
assert_git_not_dirty
}
test_lazy_tensor_meta_reference_disabled() {
export TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE=1
echo "Testing lazy tensor operations without meta reference"
time python test/run_test.py --include lazy/test_ts_opinfo.py --verbose
export -n TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE
}
test_dynamo_wrapped_shard() {
if [[ -z "$NUM_TEST_SHARDS" ]]; then
@ -1627,6 +1634,7 @@ elif [[ "${BUILD_ENVIRONMENT}" == *rocm* && -n "$TESTS_TO_INCLUDE" ]]; then
test_python_shard "$SHARD_NUMBER"
test_aten
elif [[ "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHARDS -gt 1 ]]; then
test_lazy_tensor_meta_reference_disabled
test_without_numpy
install_torchvision
test_python_shard 1

View File

@ -94,6 +94,11 @@ inline c10::List<::std::optional<Tensor>> to_meta(const c10::List<::std::optiona
return outputs;
}
static bool disable_meta_reference() {
static auto env = std::getenv("TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE");
return env != nullptr && std::strcmp(env, "1") == 0;
}
${func_definitions}

View File

@ -432,7 +432,7 @@ def emit_view_functionalization_body(
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
{return_type} reference_tensor_output;
if (compute_reference_meta) {{
if (compute_reference_meta && !disable_meta_reference()) {{
{meta_conversion_str}
at::AutoDispatchSkipFunctionalize func_guard;
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
@ -447,7 +447,7 @@ def emit_view_functionalization_body(
// XLA/LTC don't implement the logic to propagate strides correctly, so we need to rely
// on a reference implementation here (instead of relying on the output from the forward lambda
// having the correct stride info)
if (compute_reference_meta) {{
if (compute_reference_meta && !disable_meta_reference()) {{
at::functionalization::impl::set_sizes_strides_offset({view_tensor_name}, reference_tensor_output);
}}
return {view_tensor_name};
@ -473,7 +473,7 @@ def emit_view_functionalization_body(
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
{return_type} reference_tensor_output;
if (compute_reference_meta) {{
if (compute_reference_meta && !disable_meta_reference()) {{
{meta_conversion_str}
at::AutoDispatchSkipFunctionalize func_guard;
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
@ -506,7 +506,7 @@ def emit_view_functionalization_body(
);
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta);
// See Note [Propagating strides in the functionalization pass]
if (compute_reference_meta) {{
if (compute_reference_meta && !disable_meta_reference()) {{
at::functionalization::impl::set_sizes_strides_offset(out, reference_tensor_output);
}}
return out;
@ -715,7 +715,7 @@ def emit_inplace_functionalization_body(
return f"""
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
if ({str(not any_storage_args and f.func.kind() == SchemaKind.inplace).lower()}) {{
if ({str(not any_storage_args and f.func.kind() == SchemaKind.inplace).lower()} && !disable_meta_reference()) {{
// Before converting the mutable op to its functional variant, run meta tensors through the original op.
// This will help us catch shape errors that apply to inplace ops that wouldn't apply to their functional variants.
// (We can only do this for inplace ops today though, because they technically all support meta tensors).