mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[dynamo][misc] Replace UserFunctionVariable with VariableTracker build (#165707)
Audit: To prevent future issues with functools.partial or callable objects. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165707 Approved by: https://github.com/Lucaskabela
This commit is contained in:
parent
9f9ab881b2
commit
1290b077f2
|
|
@ -200,9 +200,10 @@ class SuperVariable(VariableTracker):
|
|||
and not (args or kwargs)
|
||||
):
|
||||
with do_not_convert_to_tracable_parameter():
|
||||
return variables.UserFunctionVariable(
|
||||
unpatched_nn_module_init, source=source
|
||||
).call_function(tx, [self.objvar] + args, kwargs)
|
||||
fn_vt = VariableTracker.build(
|
||||
tx, unpatched_nn_module_init, source=source
|
||||
)
|
||||
return fn_vt.call_function(tx, [self.objvar] + args, kwargs)
|
||||
else:
|
||||
unimplemented_v2(
|
||||
gb_type="Unsupported super().__init__() call",
|
||||
|
|
@ -230,9 +231,8 @@ class SuperVariable(VariableTracker):
|
|||
elif isinstance(inner_fn, staticmethod) and isinstance(
|
||||
inner_fn.__func__, types.FunctionType
|
||||
):
|
||||
return variables.UserFunctionVariable(
|
||||
inner_fn.__func__, source=source
|
||||
).call_function(tx, args, kwargs)
|
||||
fn_vt = VariableTracker.build(tx, inner_fn.__func__, source=source)
|
||||
return fn_vt.call_function(tx, args, kwargs)
|
||||
elif isinstance(inner_fn, classmethod) and isinstance(
|
||||
inner_fn.__func__, types.FunctionType
|
||||
):
|
||||
|
|
@ -255,13 +255,13 @@ class SuperVariable(VariableTracker):
|
|||
tx, self.objvar.value_type, cls_source
|
||||
)
|
||||
|
||||
return variables.UserFunctionVariable(
|
||||
inner_fn.__func__, source=AttrSource(source, "__func__")
|
||||
).call_function(tx, [cls_variable, *args], kwargs)
|
||||
fn_vt = VariableTracker.build(
|
||||
tx, inner_fn.__func__, source=AttrSource(source, "__func__")
|
||||
)
|
||||
return fn_vt.call_function(tx, [cls_variable, *args], kwargs)
|
||||
elif isinstance(inner_fn, types.FunctionType):
|
||||
return variables.UserFunctionVariable(
|
||||
inner_fn, source=source
|
||||
).call_function(tx, [self.objvar] + args, kwargs)
|
||||
fn_vt = VariableTracker.build(tx, inner_fn, source=source)
|
||||
return fn_vt.call_function(tx, [self.objvar] + args, kwargs)
|
||||
elif isinstance(inner_fn, types.MethodType):
|
||||
return variables.UserMethodVariable(
|
||||
inner_fn.__func__, self.objvar, source=source
|
||||
|
|
@ -574,10 +574,8 @@ class ComptimeVariable(VariableTracker):
|
|||
from ..comptime import comptime
|
||||
|
||||
# To support the comptime.print_graph convenience accessors
|
||||
from .functions import UserFunctionVariable
|
||||
|
||||
return UserFunctionVariable(
|
||||
getattr(comptime, name), source=AttrSource(self.source, name)
|
||||
return VariableTracker.build(
|
||||
tx, getattr(comptime, name), source=AttrSource(self.source, name)
|
||||
)
|
||||
|
||||
def call_function(
|
||||
|
|
@ -771,9 +769,8 @@ class AutogradFunctionVariable(VariableTracker):
|
|||
sig = inspect.signature(fn)
|
||||
if len(args) - 1 == len(sig._parameters):
|
||||
args = args[1:] # Don't use context
|
||||
return variables.UserFunctionVariable(fn, source=source).call_function(
|
||||
tx, args, kwargs
|
||||
)
|
||||
fn_vt = VariableTracker.build(tx, fn, source=source)
|
||||
return fn_vt.call_function(tx, args, kwargs)
|
||||
elif isinstance(fn, types.MethodType):
|
||||
return variables.UserMethodVariable(
|
||||
fn.__func__,
|
||||
|
|
@ -799,9 +796,8 @@ class AutogradFunctionVariable(VariableTracker):
|
|||
assert isinstance(fn, types.FunctionType)
|
||||
|
||||
fn_source = AttrSource(self.source, "backward")
|
||||
return variables.UserFunctionVariable(fn, source=fn_source).call_function(
|
||||
tx, args, kwargs
|
||||
)
|
||||
fn_vt = VariableTracker.build(tx, fn, source=fn_source)
|
||||
return fn_vt.call_function(tx, args, kwargs)
|
||||
|
||||
def call_function(self, tx: "InstructionTranslator", args, kwargs):
|
||||
return AutogradFunctionVariable(self.fn_cls)
|
||||
|
|
@ -1026,10 +1022,12 @@ class AutogradEngineVariable(UserDefinedObjectVariable):
|
|||
assert tx.one_graph or tx.error_on_graph_break, (
|
||||
"queue_callback() is only supported when Compiled Autograd is enabled with fullgraph=True"
|
||||
)
|
||||
return variables.UserFunctionVariable(
|
||||
fn_vt = VariableTracker.build(
|
||||
tx,
|
||||
torch._dynamo.external_utils.FakeCompiledAutogradEngine.queue_callback,
|
||||
source=self.source,
|
||||
).call_function(
|
||||
)
|
||||
return fn_vt.call_function(
|
||||
tx,
|
||||
(tx.output.side_effects.get_ca_final_callbacks_var(), *args),
|
||||
kwargs,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user