mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Add env for disabling meta reference on functionalization. (#148822)
Fix: https://github.com/pytorch/xla/issues/8755 This PR introduces `TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE` environment variable. Setting this variable makes it so the functionalization kernels won't run the meta reference, which is used to propagate expected sizes and strides. Currently, PyTorch/XLA doesn't actually propagates the correct strides to its tensors. It was also shown that calling these meta functions may incur in significant overhead. Running the provided minimal reproducer (see issue), we see a speedup close to 4.3x: - Baseline: 0.0747s - `XLA_DISABLE_FUNCTIONALIZATION=1`: 0.0159s - `TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE=1`: 0.0175s In summary, this PR: - Creates the `disable_meta_reference()` function, which checks whether the environment variable is set - Modifies codegen for functionalization kernels, adding the call to `disable_meta_reference()` function to the appropriate conditions - Creates a new bash function for running `lazy/test_ts_opinfo.py` with the environment variable set Pull Request resolved: https://github.com/pytorch/pytorch/pull/148822 Approved by: https://github.com/bdhirsh
This commit is contained in:
parent
09029010e5
commit
e0d4c43ad1
|
|
@ -314,6 +314,13 @@ test_python() {
|
|||
assert_git_not_dirty
|
||||
}
|
||||
|
||||
test_lazy_tensor_meta_reference_disabled() {
|
||||
export TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE=1
|
||||
echo "Testing lazy tensor operations without meta reference"
|
||||
time python test/run_test.py --include lazy/test_ts_opinfo.py --verbose
|
||||
export -n TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE
|
||||
}
|
||||
|
||||
|
||||
test_dynamo_wrapped_shard() {
|
||||
if [[ -z "$NUM_TEST_SHARDS" ]]; then
|
||||
|
|
@ -1627,6 +1634,7 @@ elif [[ "${BUILD_ENVIRONMENT}" == *rocm* && -n "$TESTS_TO_INCLUDE" ]]; then
|
|||
test_python_shard "$SHARD_NUMBER"
|
||||
test_aten
|
||||
elif [[ "${SHARD_NUMBER}" == 1 && $NUM_TEST_SHARDS -gt 1 ]]; then
|
||||
test_lazy_tensor_meta_reference_disabled
|
||||
test_without_numpy
|
||||
install_torchvision
|
||||
test_python_shard 1
|
||||
|
|
|
|||
|
|
@ -94,6 +94,11 @@ inline c10::List<::std::optional<Tensor>> to_meta(const c10::List<::std::optiona
|
|||
return outputs;
|
||||
}
|
||||
|
||||
static bool disable_meta_reference() {
|
||||
static auto env = std::getenv("TORCH_DISABLE_FUNCTIONALIZATION_META_REFERENCE");
|
||||
return env != nullptr && std::strcmp(env, "1") == 0;
|
||||
}
|
||||
|
||||
|
||||
${func_definitions}
|
||||
|
||||
|
|
|
|||
|
|
@ -432,7 +432,7 @@ def emit_view_functionalization_body(
|
|||
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
|
||||
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
|
||||
{return_type} reference_tensor_output;
|
||||
if (compute_reference_meta) {{
|
||||
if (compute_reference_meta && !disable_meta_reference()) {{
|
||||
{meta_conversion_str}
|
||||
at::AutoDispatchSkipFunctionalize func_guard;
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
|
||||
|
|
@ -447,7 +447,7 @@ def emit_view_functionalization_body(
|
|||
// XLA/LTC don't implement the logic to propagate strides correctly, so we need to rely
|
||||
// on a reference implementation here (instead of relying on the output from the forward lambda
|
||||
// having the correct stride info)
|
||||
if (compute_reference_meta) {{
|
||||
if (compute_reference_meta && !disable_meta_reference()) {{
|
||||
at::functionalization::impl::set_sizes_strides_offset({view_tensor_name}, reference_tensor_output);
|
||||
}}
|
||||
return {view_tensor_name};
|
||||
|
|
@ -473,7 +473,7 @@ def emit_view_functionalization_body(
|
|||
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::XLABit) ||
|
||||
{view_tensor_name}.key_set().has_backend(c10::BackendComponent::LazyBit);
|
||||
{return_type} reference_tensor_output;
|
||||
if (compute_reference_meta) {{
|
||||
if (compute_reference_meta && !disable_meta_reference()) {{
|
||||
{meta_conversion_str}
|
||||
at::AutoDispatchSkipFunctionalize func_guard;
|
||||
c10::impl::ExcludeDispatchKeyGuard guard(exclude_keys_for_meta_dispatch);
|
||||
|
|
@ -506,7 +506,7 @@ def emit_view_functionalization_body(
|
|||
);
|
||||
auto out = at::functionalization::impl::create_functional_tensor_with_view_meta(tmp_output, {view_tensor_name}, view_meta);
|
||||
// See Note [Propagating strides in the functionalization pass]
|
||||
if (compute_reference_meta) {{
|
||||
if (compute_reference_meta && !disable_meta_reference()) {{
|
||||
at::functionalization::impl::set_sizes_strides_offset(out, reference_tensor_output);
|
||||
}}
|
||||
return out;
|
||||
|
|
@ -715,7 +715,7 @@ def emit_inplace_functionalization_body(
|
|||
|
||||
return f"""
|
||||
{dispatcher_sig.defn(name=wrapper_name(f.func), is_redispatching_fn=True)} {{
|
||||
if ({str(not any_storage_args and f.func.kind() == SchemaKind.inplace).lower()}) {{
|
||||
if ({str(not any_storage_args and f.func.kind() == SchemaKind.inplace).lower()} && !disable_meta_reference()) {{
|
||||
// Before converting the mutable op to its functional variant, run meta tensors through the original op.
|
||||
// This will help us catch shape errors that apply to inplace ops that wouldn't apply to their functional variants.
|
||||
// (We can only do this for inplace ops today though, because they technically all support meta tensors).
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user