[dynamo][misc] Replace UserFunctionVariable with VariableTracker build (#165707)

Audit: To prevent future issues with functools.partial or callable
objects.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/165707
Approved by: https://github.com/Lucaskabela
This commit is contained in:
Animesh Jain 2025-10-20 22:51:57 -07:00 committed by PyTorch MergeBot
parent 9f9ab881b2
commit 1290b077f2

View File

@ -200,9 +200,10 @@ class SuperVariable(VariableTracker):
and not (args or kwargs)
):
with do_not_convert_to_tracable_parameter():
return variables.UserFunctionVariable(
unpatched_nn_module_init, source=source
).call_function(tx, [self.objvar] + args, kwargs)
fn_vt = VariableTracker.build(
tx, unpatched_nn_module_init, source=source
)
return fn_vt.call_function(tx, [self.objvar] + args, kwargs)
else:
unimplemented_v2(
gb_type="Unsupported super().__init__() call",
@ -230,9 +231,8 @@ class SuperVariable(VariableTracker):
elif isinstance(inner_fn, staticmethod) and isinstance(
inner_fn.__func__, types.FunctionType
):
return variables.UserFunctionVariable(
inner_fn.__func__, source=source
).call_function(tx, args, kwargs)
fn_vt = VariableTracker.build(tx, inner_fn.__func__, source=source)
return fn_vt.call_function(tx, args, kwargs)
elif isinstance(inner_fn, classmethod) and isinstance(
inner_fn.__func__, types.FunctionType
):
@ -255,13 +255,13 @@ class SuperVariable(VariableTracker):
tx, self.objvar.value_type, cls_source
)
return variables.UserFunctionVariable(
inner_fn.__func__, source=AttrSource(source, "__func__")
).call_function(tx, [cls_variable, *args], kwargs)
fn_vt = VariableTracker.build(
tx, inner_fn.__func__, source=AttrSource(source, "__func__")
)
return fn_vt.call_function(tx, [cls_variable, *args], kwargs)
elif isinstance(inner_fn, types.FunctionType):
return variables.UserFunctionVariable(
inner_fn, source=source
).call_function(tx, [self.objvar] + args, kwargs)
fn_vt = VariableTracker.build(tx, inner_fn, source=source)
return fn_vt.call_function(tx, [self.objvar] + args, kwargs)
elif isinstance(inner_fn, types.MethodType):
return variables.UserMethodVariable(
inner_fn.__func__, self.objvar, source=source
@ -574,10 +574,8 @@ class ComptimeVariable(VariableTracker):
from ..comptime import comptime
# To support the comptime.print_graph convenience accessors
from .functions import UserFunctionVariable
return UserFunctionVariable(
getattr(comptime, name), source=AttrSource(self.source, name)
return VariableTracker.build(
tx, getattr(comptime, name), source=AttrSource(self.source, name)
)
def call_function(
@ -771,9 +769,8 @@ class AutogradFunctionVariable(VariableTracker):
sig = inspect.signature(fn)
if len(args) - 1 == len(sig._parameters):
args = args[1:] # Don't use context
return variables.UserFunctionVariable(fn, source=source).call_function(
tx, args, kwargs
)
fn_vt = VariableTracker.build(tx, fn, source=source)
return fn_vt.call_function(tx, args, kwargs)
elif isinstance(fn, types.MethodType):
return variables.UserMethodVariable(
fn.__func__,
@ -799,9 +796,8 @@ class AutogradFunctionVariable(VariableTracker):
assert isinstance(fn, types.FunctionType)
fn_source = AttrSource(self.source, "backward")
return variables.UserFunctionVariable(fn, source=fn_source).call_function(
tx, args, kwargs
)
fn_vt = VariableTracker.build(tx, fn, source=fn_source)
return fn_vt.call_function(tx, args, kwargs)
def call_function(self, tx: "InstructionTranslator", args, kwargs):
return AutogradFunctionVariable(self.fn_cls)
@ -1026,10 +1022,12 @@ class AutogradEngineVariable(UserDefinedObjectVariable):
assert tx.one_graph or tx.error_on_graph_break, (
"queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"
)
return variables.UserFunctionVariable(
fn_vt = VariableTracker.build(
tx,
torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback,
source=self.source,
).call_function(
)
return fn_vt.call_function(
tx,
(tx.output.side_effects.get_ca_final_callbacks_var(), *args),
kwargs,