mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Using hasattr for _boxed_call is asking for trouble (#151130)
Summary: There are a number of places in the code checking for the existence of `_boxed_call` instead of checking for a `True` value. This is somewhat dangerous because one would assume that setting it to `None` or `False` would be the same as not setting it (output_code.py does this, for example). Change `hasattr()` to `getattr(..., False)` for these cases. Test Plan: unit tests pass Differential Revision: D72806693 Pull Request resolved: https://github.com/pytorch/pytorch/pull/151130 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
6dddd6520d
commit
1f5af12cd9
|
|
@ -370,7 +370,7 @@ def run_fwd_maybe_bwd(gm, args, only_fwd=False, disable_clone=False):
|
|||
gm.zero_grad(True)
|
||||
|
||||
# TorchInductor returned callable expects lists. So, may need a boxed calling convention.
|
||||
out = gm(args) if hasattr(gm, "_boxed_call") else gm(*args)
|
||||
out = gm(args) if getattr(gm, "_boxed_call", False) else gm(*args)
|
||||
|
||||
if only_fwd:
|
||||
return out
|
||||
|
|
|
|||
|
|
@ -226,7 +226,7 @@ def aot_dispatch_base(
|
|||
# However, RuntimeWrapper does not expect the rng offsets in the
|
||||
# output. So, we have to create another wrapper and take out the offset. As
|
||||
# a result, we have to account for not boxed_call compilers as well.
|
||||
if not hasattr(compiled_fw, "_boxed_call"):
|
||||
if not getattr(compiled_fw, "_boxed_call", False):
|
||||
compiled_fw = make_boxed_func(compiled_fw)
|
||||
|
||||
# Create a wrapper to set up the rng functionalize and fakified out bits
|
||||
|
|
@ -282,7 +282,7 @@ def aot_dispatch_base(
|
|||
runtime_metadata=fw_metadata,
|
||||
)
|
||||
|
||||
if not hasattr(compiled_fw, "_boxed_call"):
|
||||
if not getattr(compiled_fw, "_boxed_call", False):
|
||||
compiled_fw = make_boxed_func(compiled_fw)
|
||||
|
||||
compiled_fn = RuntimeWrapper(
|
||||
|
|
@ -1107,7 +1107,7 @@ def aot_dispatch_autograd(
|
|||
with TracingContext.report_output_strides() as fwd_output_strides:
|
||||
compiled_fw_func = aot_config.fw_compiler(fw_module, adjusted_flat_args)
|
||||
|
||||
if not hasattr(compiled_fw_func, "_boxed_call"):
|
||||
if not getattr(compiled_fw_func, "_boxed_call", False):
|
||||
compiled_fw_func = make_boxed_func(compiled_fw_func)
|
||||
|
||||
if fakified_out_wrapper.needs_post_compile:
|
||||
|
|
|
|||
|
|
@ -258,7 +258,7 @@ def _create_runtime_wrapper(
|
|||
keep_input_mutations: bool,
|
||||
disable_amp: bool,
|
||||
):
|
||||
if not hasattr(compiled_fn, "_boxed_call"):
|
||||
if not getattr(compiled_fn, "_boxed_call", False):
|
||||
compiled_fn = make_boxed_func(compiled_fn)
|
||||
|
||||
# Note [Inputs needed in runtime epilogue after list clearing]
|
||||
|
|
|
|||
|
|
@ -122,7 +122,7 @@ def call_func_at_runtime_with_args(
|
|||
|
||||
context = torch._C._DisableAutocast if disable_amp else nullcontext
|
||||
with context():
|
||||
if hasattr(f, "_boxed_call"):
|
||||
if getattr(f, "_boxed_call", False):
|
||||
out = normalize_as_list(f(args))
|
||||
else:
|
||||
# TODO: Please remove soon
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user