Using hasattr for _boxed_call is asking for trouble (#151130)

Summary:
There are a number of places in the code checking for the existence of `_boxed_call` instead of checking for a `True` value. This is somewhat dangerous because one would assume that setting it to `None` or `False` would be the same as not setting it (output_code.py does this, for example).

Change `hasattr()` to `getattr(..., False)` for these cases.

Test Plan: unit tests pass

Differential Revision: D72806693

Pull Request resolved: https://github.com/pytorch/pytorch/pull/151130
Approved by: https://github.com/Skylion007
This commit is contained in:
Aaron Orenstein 2025-04-14 18:36:30 +00:00 committed by PyTorch MergeBot
parent 6dddd6520d
commit 1f5af12cd9
4 changed files with 6 additions and 6 deletions

View File

@ -370,7 +370,7 @@ def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False):
gm.zero_grad(True)
# TorchInductor returned callable expects lists. So, may need a boxed calling convention.
out = gm(args) if hasattr(gm, "_boxed_call") else gm(*args)
out = gm(args) if getattr(gm, "_boxed_call", False) else gm(*args)
if only_fwd:
return out

View File

@ -226,7 +226,7 @@ def aot_dispatch_base(
# However, RuntimeWrapper does not expect the rng offsets in the
# output. So, we have to create another wrapper and take out the offset. As
# a result, we have to account for not boxed_call compilers as well.
if not hasattr(compiled_fw, "_boxed_call"):
if not getattr(compiled_fw, "_boxed_call", False):
compiled_fw = make_boxed_func(compiled_fw)
# Create a wrapper to set up the rng functionalize and fakified out bits
@ -282,7 +282,7 @@ def aot_dispatch_base(
runtime_metadata=fw_metadata,
)
if not hasattr(compiled_fw, "_boxed_call"):
if not getattr(compiled_fw, "_boxed_call", False):
compiled_fw = make_boxed_func(compiled_fw)
compiled_fn = RuntimeWrapper(
@ -1107,7 +1107,7 @@ def aot_dispatch_autograd(
with TracingContext.report_output_strides() as fwd_output_strides:
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
if not hasattr(compiled_fw_func, "_boxed_call"):
if not getattr(compiled_fw_func, "_boxed_call", False):
compiled_fw_func = make_boxed_func(compiled_fw_func)
if fakified_out_wrapper.needs_post_compile:

View File

@ -258,7 +258,7 @@ def _create_runtime_wrapper(
keep_input_mutations: bool,
disable_amp: bool,
):
if not hasattr(compiled_fn, "_boxed_call"):
if not getattr(compiled_fn, "_boxed_call", False):
compiled_fn = make_boxed_func(compiled_fn)
# Note [Inputs needed in runtime epilogue after list clearing]

View File

@ -122,7 +122,7 @@ def call_func_at_runtime_with_args(
context = torch._C._DisableAutocast if disable_amp else nullcontext
with context():
if hasattr(f, "_boxed_call"):
if getattr(f, "_boxed_call", False):
out = normalize_as_list(f(args))
else:
# TODO: Please remove soon