pytorch/torch/_dynamo/variables/torch.py
Peter Bell a3a2486be8 [dynamo] Avoid eager imports of classes with custom VariableTrackers (#112319)
Currently custom VariableTrackers exist for classes that live outside of pytorch.
For these cases dynamo currently eagerly imports the module to get the class
object to compare against.

This instead uses `sys.modules.get("module.path")` such that the module is never
imported by dynamo itself, but if the user has imported the module then we will
still access the module and grab the type we need to compare against.

I noticed this issue because importing `KeyedJaggedTensor` fails half-way
through if `fbgemm_gpu` has been built with an incompatible PyTorch version, in
which case it retries the import again each time!

Pull Request resolved: https://github.com/pytorch/pytorch/pull/112319
Approved by: https://github.com/lezcano, https://github.com/ezyang
2023-11-07 22:45:54 +00:00

811 lines
32 KiB
Python

import collections
import inspect
import logging
import math
import re
import types
from typing import Dict, List
from torch._streambase import _StreamBase
from ..guards import install_guard
try:
import numpy as np
except ModuleNotFoundError:
np = None
import torch._C
import torch._refs
import torch.fx
import torch.nn
import torch.onnx.operators
from torch._dynamo.variables import UserFunctionVariable
from .. import config, variables
from ..allowed_functions import torch_get_name
from ..device_interface import device_interfaces
from ..exc import unimplemented
from ..guards import GuardBuilder
from ..utils import (
check_constant_args,
check_unspec_python_args,
has_torch_function,
is_rng_state_getter_or_setter,
istype,
product,
proxy_args_kwargs,
tensortype_to_dtype,
)
from .base import VariableTracker
from .ctx_manager import (
AutocastModeVariable,
NullContextVariable,
TorchFunctionDisableVariable,
)
from .distributed import is_constant_pg_functions, is_from_local, ProcessGroupVariable
from .higher_order_ops import TorchHigherOrderOperatorVariable
from .lists import ListVariable, TupleVariable
from .torch_function import can_dispatch_torch_function, dispatch_torch_function
log = logging.getLogger(__name__)
# TODO(voz): Maybe rename these later
tensor_dunder_fns = [
torch.Tensor.__rmatmul__,
torch.Tensor.__rmod__,
torch.Tensor.__rpow__,
torch.Tensor.__rsub__,
torch.Tensor.__rdiv__,
torch._C.TensorBase.__radd__,
torch._C.TensorBase.__rmul__,
torch._C.TensorBase.__ror__,
torch._C.TensorBase.__rxor__,
torch._C.TensorBase.__rand__,
]
torch_special_class_types = (torch._C.Generator,)
REWRITE_OPS_TO_TENSOR_SIZE_METHOD = [
torch.onnx.operators.shape_as_tensor,
torch._shape_as_tensor,
]
constant_fold_functions = [
torch._assert,
torch._utils._get_device_index,
torch.cuda.is_available,
torch.device,
torch.distributed.is_available,
torch.finfo,
torch.get_autocast_gpu_dtype,
torch.get_default_dtype,
torch.iinfo,
torch.is_autocast_cache_enabled,
torch.is_autocast_cpu_enabled,
torch.is_autocast_enabled,
torch.is_complex,
torch.is_floating_point,
torch.nn.functional._Reduction.get_enum,
torch.promote_types,
torch._C._get_privateuse1_backend_name,
]
if torch.distributed.is_available():
constant_fold_functions.extend(
[
torch.distributed.is_initialized,
torch.distributed.get_rank,
torch.distributed.get_world_size,
]
)
# TODO(voz): perhaps a decorator? This is rather readable for now tho, and not a public API.
def remap_as_fn___radd__(*args):
return torch._C.TensorBase.__radd__(*args)
def remap_as_fn___rmul__(*args):
return torch._C.TensorBase.__rmul__(*args)
def remap_as_fn___ror__(*args):
return torch._C.TensorBase.__ror__(*args)
def remap_as_fn___rxor__(*args):
return torch._C.TensorBase.__rxor__(*args)
def remap_as_fn___rand__(*args):
return torch._C.TensorBase.__rand__(*args)
tensor_dunder_fns_remap = {
torch._C.TensorBase.__radd__: remap_as_fn___radd__,
torch._C.TensorBase.__rmul__: remap_as_fn___rmul__,
torch._C.TensorBase.__ror__: remap_as_fn___ror__,
torch._C.TensorBase.__rxor__: remap_as_fn___rxor__,
torch._C.TensorBase.__rand__: remap_as_fn___rand__,
}
def torch_reconstruct(codegen, value):
name = torch_get_name(value, f"allowed_fn_{id(value)}")
unique_var_name = "__" + re.sub(r"[^a-zA-Z0-9_]+", "_", name)
return codegen.setup_globally_cached(unique_var_name, value, False)
class TorchCtxManagerClassVariable(VariableTracker):
"""Points to a context manager class in torch.* that dynamo has implementations"""
@classmethod
def create_with_source(cls, value, source):
install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH))
return TorchCtxManagerClassVariable(
value,
source=source,
)
def __init__(self, value, **kwargs):
super().__init__(**kwargs)
self.value = value
def reconstruct(self, codegen):
return torch_reconstruct(codegen, self.value)
def python_type(self):
return type(self.value)
def as_python_constant(self):
return self.value
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from . import GradModeVariable, InferenceModeVariable, StreamVariable
if self.value is torch.no_grad:
if len(args) == 1 and isinstance(
args[0], variables.functions.BaseUserFunctionVariable
):
ctx = GradModeVariable.create(tx, False, initialized=False)
return ctx.call_function(tx, args, kwargs)
else:
return GradModeVariable.create(tx, False)
elif self.value is torch.enable_grad:
if len(args) == 1 and isinstance(
args[0], variables.functions.BaseUserFunctionVariable
):
ctx = GradModeVariable.create(tx, True, initialized=False)
return ctx.call_function(tx, args, kwargs)
return GradModeVariable.create(tx, True)
elif self.value is torch.set_grad_enabled and len(args) == 1:
return GradModeVariable.create(tx, args[0].as_python_constant())
elif self.value is torch.inference_mode:
return InferenceModeVariable.create(tx, args[0].as_python_constant())
elif inspect.isclass(self.value) and issubclass(self.value, _StreamBase):
from torch._dynamo.variables.builder import wrap_fx_proxy_cls
return wrap_fx_proxy_cls(
StreamVariable,
tx,
tx.output.create_proxy(
"call_function",
self.value,
(),
{},
),
)
elif self.value in [
torch.amp.autocast_mode.autocast,
torch.cuda.amp.autocast,
torch.cpu.amp.autocast,
]:
return AutocastModeVariable.create(self.value, args, kwargs)
elif self.value in (
torch.profiler.profile,
torch.profiler.record_function,
torch.autograd.profiler.profile,
torch.autograd.profiler.record_function,
):
log.warning("Profiler function %s will be ignored", self.value)
return NullContextVariable()
elif self.value is torch._C.DisableTorchFunctionSubclass:
assert not (args or kwargs)
return TorchFunctionDisableVariable.create(tx)
class TorchVariable(VariableTracker):
"""Points to a module or method in torch.*"""
def __init__(self, value, **kwargs):
super().__init__(**kwargs)
if (
isinstance(value, collections.abc.Hashable)
and value in tensor_dunder_fns_remap
):
value = tensor_dunder_fns_remap[value]
self.value = value
assert not isinstance(
value, (torch.dtype, torch.device)
), "should use ConstantVariable"
# the remainder of this is just optional debug checks
try:
self_should_be_none = getattr(self.value, "__self__", None)
except RuntimeError as e:
assert "No such operator" in str(e), str(e)
self_should_be_none = None
except AssertionError as e:
assert "Unknown attribute" in str(e), str(e)
self_should_be_none = None
# assert "_ntuple.<locals>.parse" not in str(value)
if self_should_be_none is None:
pass
elif isinstance(self_should_be_none, types.ModuleType):
# weird ones like torch.nn.functional.avg_pool2d have __self__
name = self_should_be_none.__name__
assert re.match(r"^(torch|math)([.]|$)", name), f"__self__ set to {name}"
elif isinstance(
self_should_be_none, type(torch._C._get_tracing_state.__self__)
):
# some _C functions have __self__ as a null capsule
pass
elif isinstance(self_should_be_none, torch_special_class_types):
pass
else:
raise AssertionError(f"{value} found with __self__ set")
def __repr__(self):
return f"TorchVariable({self.value})"
def call_hasattr(self, tx, name):
result = hasattr(self.value, name)
return variables.ConstantVariable.create(result)
def reconstruct(self, codegen):
return torch_reconstruct(codegen, self.value)
def as_proxy(self):
return self.value
def python_type(self):
if isinstance(self.value, (torch.Tensor, torch.nn.Module, torch.device)):
return type(self.value)
if isinstance(self.value, type):
return type
return super().python_type()
def as_python_constant(self):
return self.value
def can_constant_fold_through(self):
if self.value in constant_fold_functions:
return True
return getattr(self.value, "__module__", None) == "math"
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from . import (
ConstantVariable,
DeterministicAlgorithmsVariable,
DisabledSavedTensorsHooksVariable,
GradModeVariable,
StreamContextVariable,
SymNodeVariable,
TensorVariable,
UserDefinedObjectVariable,
)
from .builder import wrap_fx_proxy, wrap_fx_proxy_cls
constant_args = check_constant_args(args, kwargs)
unspec_python_args = check_unspec_python_args(args, kwargs)
if self.value is torch._functorch.vmap.vmap_impl:
return TorchHigherOrderOperatorVariable.make(
self.value,
source=self.source,
).call_function(tx, args, kwargs)
if self.value is torch.overrides.get_default_nowrap_functions:
# [Note: __torch_function__] we return empty here because we restrict
# the set of functions that we trace __torch_function__ on to
# functions outside of the actual set. Implementing this properly will require implementing
# some variable types to track and compare tensor getset descriptors
from .builder import SourcelessBuilder
return SourcelessBuilder()(
tx, torch.overrides.get_default_nowrap_functions()
)
elif self.value in config.constant_functions:
assert not args and not kwargs
# See: https://github.com/pytorch/pytorch/issues/110765
if self.value in [
torch._utils.is_compiling,
torch._dynamo.external_utils.is_compiling,
]:
tx.mark_inconsistent_side_effects()
return ConstantVariable.create(config.constant_functions[self.value])
elif self.value is torch._functorch.eager_transforms.grad_impl:
op = TorchHigherOrderOperatorVariable.make(
self.value,
source=self.source,
).call_function(tx, args, kwargs)
return op
elif self.can_constant_fold_through() and (constant_args or unspec_python_args):
# constant fold
return ConstantVariable.create(
self.as_python_constant()(
*[x.as_python_constant() for x in args],
**{k: v.as_python_constant() for k, v in kwargs.items()},
),
)
elif istype(self.value, type) and issubclass(self.value, torch.nn.Module):
if self.value is torch.nn.CrossEntropyLoss:
return self._call_cross_entropy_loss(tx, args, kwargs)
else:
return variables.UserDefinedClassVariable(
self.value, source=self.source
).call_function(tx, args, kwargs)
elif self.value in (torch.is_tensor, torch.overrides.is_tensor_like):
assert len(args) == 1
if isinstance(args[0], TensorVariable) or (
self.value is torch.overrides.is_tensor_like
and isinstance(args[0], UserDefinedObjectVariable)
and hasattr(args[0].value, "__torch_function__")
):
return ConstantVariable.create(True)
else:
return ConstantVariable.create(False)
elif self.value in (
torch.is_floating_point,
torch.is_complex,
):
input_arg = None
if args:
input_arg = args[0]
else:
assert "input" in kwargs
input_arg = kwargs["input"]
if isinstance(input_arg, TensorVariable) and input_arg.dtype is not None:
if self.value is torch.is_floating_point:
return ConstantVariable.create(input_arg.dtype.is_floating_point)
elif self.value is torch.is_complex:
return ConstantVariable.create(input_arg.dtype.is_complex)
else:
raise AssertionError(f"calling {self.value}")
elif (
self.value is torch.numel
and isinstance(args[0], TensorVariable)
and args[0].size is not None
):
return ConstantVariable.create(product(args[0].size))
elif self.value in REWRITE_OPS_TO_TENSOR_SIZE_METHOD:
assert len(args) == 1
assert isinstance(args[0], TensorVariable)
return args[0].call_method(tx, "size", [], {})
elif self.value in (
torch.nn.modules.utils._single,
torch.nn.modules.utils._pair,
torch.nn.modules.utils._triple,
torch.nn.modules.utils._quadruple,
torch.nn.modules.utils._ntuple,
):
return self._call_ntuple(tx, args, kwargs)
elif self.value is torch.is_grad_enabled:
assert not (args or kwargs)
install_guard(GradModeVariable._guards_singleton)
return ConstantVariable.create(torch.is_grad_enabled())
elif self.value is torch.use_deterministic_algorithms and len(args) == 1:
return DeterministicAlgorithmsVariable.create(
tx, args[0].as_python_constant()
)
elif self.value is torch.are_deterministic_algorithms_enabled:
assert not (args or kwargs)
install_guard(DeterministicAlgorithmsVariable._guards_singleton)
return ConstantVariable.create(torch.are_deterministic_algorithms_enabled())
elif self.value is torch.autograd.graph.disable_saved_tensors_hooks:
assert len(args) == 1
return DisabledSavedTensorsHooksVariable.create(
tx, args[0].as_python_constant()
)
elif self.value is torch._C._is_torch_function_enabled:
assert not (args or kwargs)
install_guard(TorchFunctionDisableVariable._guards_singleton)
return ConstantVariable.create(tx.output.torch_function_enabled)
elif self.value in (
torch.overrides.has_torch_function_variadic,
torch.overrides.has_torch_function_unary,
):
assert not kwargs
return ConstantVariable.create(
any(has_torch_function(a) for a in args),
)
elif any(
self.value is method
for method in [
interface_elem.stream for interface_elem in device_interfaces.values()
]
):
assert len(args) == 1
return StreamContextVariable.create(tx, args[0])
elif self.value is torch.from_numpy:
if not config.trace_numpy:
unimplemented("torch.from_numpy. config.trace_numpy is False")
if not np:
unimplemented("torch.from_numpy. NumPy is not available")
assert len(args) == 1, f"Got arguments {args}"
assert not kwargs
t = args[0]
from .tensor import NumpyNdarrayVariable
if isinstance(t, NumpyNdarrayVariable):
# TODO: mark the tensor as non-resizable
return wrap_fx_proxy_cls(
target_cls=TensorVariable,
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
torch.detach,
*proxy_args_kwargs(args, {}),
),
example_value=None,
)
else:
unimplemented(f"torch.from_numpy(<{type(t)}>)")
elif can_dispatch_torch_function(tx, args, kwargs):
return dispatch_torch_function(tx, self, args, kwargs)
elif self.value is torch.autograd._profiler_enabled:
unimplemented("torch.autograd._profiler_enabled not supported yet")
elif self.value is torch.jit.annotate:
assert len(args) == 2
return args[1]
elif self.value is torch.backends.cudnn.is_acceptable:
# is_acceptable(tensor) returns true if
# (a) tensor dtype/device are supported by cudnn
# (b) cudnn is available
# (c) some initialization has completed
# technically, it depends on some global state from (c) (torch.backends.cudnn.__cudnn_version)
assert (
len(args) == 1 or "tensor" in kwargs
), "Expect 1 input to cudnn.is_acceptable"
tensor_variable = args[0] if len(args) > 0 else kwargs["tensor"]
assert isinstance(
tensor_variable, TensorVariable
), "Expect input to cudnn.is_acceptable to be a tensor"
tensor_inp = torch.tensor(
0, dtype=tensor_variable.dtype, device=tensor_variable.device
)
return ConstantVariable.create(
torch.backends.cudnn.is_acceptable(tensor_inp)
)
elif self.value is torch.nn.Parameter:
# https://github.com/pytorch/pytorch/issues/99569
unimplemented("torch.nn.Parameter not supported")
elif is_rng_state_getter_or_setter(self.value):
# We graph break on RNG state setters or getters like
# `torch.get_rng_state` or `torch.set_rng_state`. These functions
# are not aten operations and therefore they are completely ignored
# by the AOT dispatcher. As a result, the AOT graph does not have
# these setter or getter functions, producing an incorrect graph
# when it comes to rng states.
unimplemented(f"RNG state getter/setter function - {self.value}")
elif self.value is torch.manual_seed:
# https://github.com/pytorch/pytorch/issues/107187
unimplemented("torch.manual_seed not supported")
elif (
self.value == torch.numel
and len(args) == 1
and isinstance(args[0], TensorVariable)
and len(kwargs) == 0
):
# TODO(voz): This is rewritten as a call_method because
# torch.numel(x) w/ sym shapes raises a RuntimeError and x.numel() does not
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_method",
"numel",
*proxy_args_kwargs(args, kwargs),
),
)
# TODO: These special cases shouldn't be necessary; we should
# generically support torch.ops that return int
elif (
self.value in [torch.ops.aten.sym_size, torch.ops.aten.sym_size.int]
and len(args) == 2
and len(kwargs) == 0
and isinstance(args[0], TensorVariable)
):
# we see this when retracing already traced code
return args[0].call_method(tx, "size", [args[1]], {})
elif (
self.value is [torch.ops.aten.sym_stride, torch.ops.aten.sym_stride.int]
and len(args) == 2
and len(kwargs) == 0
and isinstance(args[0], TensorVariable)
):
return args[0].call_method(tx, "stride", [args[1]], {})
elif (
self.value == torch.addcdiv
and len(args) == 3
and "value" in kwargs
and len(kwargs) == 1
):
# decompose addcdiv into constituent ops, prevents a graph break due to converting
# value to a scalar
result = TorchVariable(torch.div).call_function(tx, args[1:], {})
result = TorchVariable(torch.mul).call_function(
tx, [result, kwargs["value"]], {}
)
return TorchVariable(torch.add).call_function(tx, [args[0], result], {})
elif is_constant_pg_functions(self.value):
# becuase the input is a "ProcessGroupVariable", we'll be guarding on its
# ID_MATCH based on how it was constructed.
# We desugar it at trace-time into ranks by directly calling util
# bake the result into the trace
assert len(args) == 1, "Expected one arg (pg)"
assert isinstance(args[0], ProcessGroupVariable)
invocation_result = self.value(args[0].as_python_constant())
# Note - while we *could* cook up sources around invocations, like a FunctionSource
# the space of invoking functions in the middle of the guard chain is very iffy. As such,
# guard propagation via options is the best we can do.
from .builder import SourcelessBuilder
return SourcelessBuilder()(tx, invocation_result)
elif is_from_local(self.value):
# rewrite non-primitive args/kwargs to be included in the on-the-fly prim function
# and rewrite args to have only proxyable args, then insert call_function
args_as_value = [x.as_python_constant() for x in args[1:]]
kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()}
def fn_with_prim_types(x):
return self.value(x, *args_as_value, **kwargs_as_value)
# attach the same function name for better debugging
fn_with_prim_types.__name__ = "prim " + self.value.__name__
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
fn_with_prim_types,
*proxy_args_kwargs([args[0]], {}),
),
)
elif self.value == torch.nn.init._calculate_correct_fan:
return UserFunctionVariable(
torch.nn.init._calculate_correct_fan
).call_function(tx, args, {})
elif (
self.value is torch.nested.nested_tensor
and kwargs.get("layout", torch.strided) == torch.strided
) or self.value in (
torch._nested_tensor_from_mask,
torch._nested_from_padded,
):
raise unimplemented("torch.compile does not support strided NestedTensor")
elif self.value is torch.nn.utils.rnn.pack_padded_sequence:
unimplemented("workaround https://github.com/pytorch/pytorch/issues/93501")
elif isinstance(self.value, types.ModuleType):
unimplemented("TypeError(\"'module' object is not callable\")")
else:
any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args)
all_ints_or_floats = all(
isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable))
for x in args
)
bin_ops = {"add", "sub", "mul", "div", "sqrt"}
if (
getattr(self.value, "__module__", "") == "torch"
and self.value.__name__ in bin_ops
and any_symints_or_symfloats
and all_ints_or_floats
):
msg = f"""\
Calling {str(self.value)} on only torch.SymInt arguments is not yet supported.
To support this behavior, we need to allow const-propping tensors that store symint data.
For now, dynamo will explicitly graph break when it encounters user code with this behavior.
"""
log.warning(msg)
raise unimplemented(msg)
# torch.LongTensor cannot accept a list of FakeTensors.
# So we stack the list of FakeTensors instead.
if (
np
and self.value in tensortype_to_dtype
and len(args) == 1
and isinstance(args[0], ListVariable)
and len(args[0].items) > 1
and all(isinstance(x, variables.TensorVariable) for x in args[0].items)
):
# Stack FakeTensor
stacked = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
torch.stack,
*proxy_args_kwargs(args, kwargs),
),
)
args = [stacked]
# TODO(voz): Replace w/ dynamic shape rewrite table.
# Ideally, we would be able to do this at ctor time, but alas we need a combination
# of value + args to determine this.
fn_ = self.value
if any(isinstance(x, SymNodeVariable) for x in args):
if self.value == math.sqrt:
from torch.fx.experimental.sym_node import sym_sqrt
fn_ = sym_sqrt
if fn_ is torch.tensor:
def check_any_unspec(x):
# NB: This includes UnspecializedPythonVariable
if isinstance(x, (TensorVariable, SymNodeVariable)):
return True
elif isinstance(x, ListVariable):
return any(check_any_unspec(y) for y in x.items)
# TODO: there maybe other recursive structures you need to
# check
else:
return False
data_arg = None
if args:
data_arg = args[0]
elif "data" in kwargs:
data_arg = kwargs["data"]
# NB: OK to pass torch.tensor(tensor), this will trace fine
if not isinstance(data_arg, TensorVariable) and check_any_unspec(
data_arg
):
# This is slower and less canonical, so only use it if we
# have to
fn_ = torch._refs.tensor
tensor_variable = wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
fn_,
*proxy_args_kwargs(args, kwargs),
),
)
if "out" in kwargs and not (
isinstance(kwargs["out"], variables.ConstantVariable)
and kwargs["out"].as_python_constant() is None
):
# out variants of torch operators like torch.sort and
# torch.sigmoid mutate the tensors in the out field. Track such
# tensors and rewrite the symbolic locals.
if isinstance(tensor_variable, TupleVariable):
assert isinstance(kwargs["out"], (TupleVariable, ListVariable))
output_tensor_names = [
tx.find_symbolic_locals_name(x) for x in kwargs["out"].items
]
for idx, name in enumerate(output_tensor_names):
if name in tx.symbolic_locals:
tx.symbolic_locals[name] = tensor_variable.items[idx]
elif isinstance(tensor_variable, TensorVariable):
assert isinstance(kwargs["out"], TensorVariable)
if (
kwargs["out"].source
and kwargs["out"] in tx.output.graphargs
and kwargs["out"].size != tensor_variable.size
):
# It's hard to get out variants with resizing on graph inputs work
# properly across dynamo/aot/inductor, just fall back.
unimplemented("out variants with resizing on graph inputs")
name = tx.find_symbolic_locals_name(kwargs["out"])
if name in tx.symbolic_locals:
tx.symbolic_locals[name] = tensor_variable
else:
unimplemented(f"out variant of {type(kwargs['out'])}")
return tensor_variable
def _call_cross_entropy_loss(self, tx, args, kwargs):
"""
functional: input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean',
label_smoothing=0.0
non functional ctor: weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean',
label_smoothing=0.0
non functional loss call: input, target, optional_output
"""
from . import ConstantVariable
def normalize_args(
weight=ConstantVariable.create(None),
size_average=ConstantVariable.create(None),
ignore_index=ConstantVariable.create(-100),
reduce=ConstantVariable.create(None),
reduction=ConstantVariable.create("mean"),
label_smoothing=ConstantVariable.create(0.0),
):
return (
weight,
size_average,
ignore_index,
reduce,
reduction,
label_smoothing,
)
(
weight,
size_average,
ignore_index,
reduce_arg,
reduction,
label_smoothing,
) = normalize_args(*args, **kwargs)
def fake_cross_entropy_loss(input, target):
from .builder import wrap_fx_proxy
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
torch.nn.functional.cross_entropy,
*proxy_args_kwargs(
[
input,
target,
weight,
size_average,
ignore_index,
reduce_arg,
reduction,
label_smoothing,
],
{},
),
),
)
return variables.LambdaVariable(fake_cross_entropy_loss)
def _call_ntuple(self, tx, args, kwargs):
"""inline behavior of torch.nn.modules.utils._ntuple"""
if self.value is torch.nn.modules.utils._ntuple:
count = args[0].as_python_constant()
else:
count = self.value.__closure__[0].cell_contents
assert isinstance(count, int)
assert not kwargs
def handle_ntuple(value):
if value.has_unpack_var_sequence(tx):
return variables.TupleVariable(
list(value.unpack_var_sequence(tx)),
)
elif value.is_python_constant():
# constant prop through it
return variables.ConstantVariable.create(
torch.nn.modules.utils._ntuple(count)(value.as_python_constant()),
)
else:
unimplemented(f"torch.nn.modules.utils._ntuple({value})")
if self.value is torch.nn.modules.utils._ntuple:
return variables.LambdaVariable(handle_ntuple)
else:
return handle_ntuple(args[0])