import collections import inspect import logging import math import re import types from typing import Dict, List from torch._streambase import _StreamBase from ..guards import install_guard try: import numpy as np except ModuleNotFoundError: np = None import torch._C import torch._refs import torch.fx import torch.nn import torch.onnx.operators from torch._dynamo.variables import UserFunctionVariable from .. import config, polyfill, variables from ..allowed_functions import torch_get_name from ..device_interface import get_registered_device_interfaces from ..exc import unimplemented from ..guards import GuardBuilder from ..utils import ( check_constant_args, check_unspec_python_args, has_torch_function, is_rng_state_getter_or_setter, istype, product, proxy_args_kwargs, tensortype_to_dtype, ) from .base import VariableTracker from .ctx_manager import ( AutocastModeVariable, NullContextVariable, TorchFunctionDisableVariable, ) from .distributed import is_constant_pg_functions, is_from_local, ProcessGroupVariable from .higher_order_ops import TorchHigherOrderOperatorVariable from .lists import ListVariable, TupleVariable from .torch_function import can_dispatch_torch_function, dispatch_torch_function log = logging.getLogger(__name__) # TODO(voz): Maybe rename these later tensor_dunder_fns = [ torch.Tensor.__rmatmul__, torch.Tensor.__rmod__, torch.Tensor.__rpow__, torch.Tensor.__rsub__, torch.Tensor.__rdiv__, torch._C.TensorBase.__radd__, torch._C.TensorBase.__rmul__, torch._C.TensorBase.__ror__, torch._C.TensorBase.__rxor__, torch._C.TensorBase.__rand__, ] torch_special_class_types = (torch._C.Generator,) REWRITE_OPS_TO_TENSOR_SIZE_METHOD = [ torch.onnx.operators.shape_as_tensor, torch._shape_as_tensor, ] constant_fold_functions = [ torch._assert, torch._utils._get_device_index, torch.cuda.is_available, torch.device, torch.distributed.is_available, torch.finfo, torch.get_autocast_gpu_dtype, torch.get_default_dtype, torch.iinfo, torch.is_autocast_cache_enabled, torch.is_autocast_cpu_enabled, torch.is_autocast_enabled, torch.is_complex, torch.is_floating_point, torch.nn.functional._Reduction.get_enum, torch.promote_types, torch._C._get_privateuse1_backend_name, ] if torch.distributed.is_available(): constant_fold_functions.extend( [ torch.distributed.is_initialized, torch.distributed.get_rank, torch.distributed.get_world_size, ] ) # TODO(voz): perhaps a decorator? This is rather readable for now tho, and not a public API. def remap_as_fn___radd__(*args): return torch._C.TensorBase.__radd__(*args) def remap_as_fn___rmul__(*args): return torch._C.TensorBase.__rmul__(*args) def remap_as_fn___ror__(*args): return torch._C.TensorBase.__ror__(*args) def remap_as_fn___rxor__(*args): return torch._C.TensorBase.__rxor__(*args) def remap_as_fn___rand__(*args): return torch._C.TensorBase.__rand__(*args) tensor_dunder_fns_remap = { torch._C.TensorBase.__radd__: remap_as_fn___radd__, torch._C.TensorBase.__rmul__: remap_as_fn___rmul__, torch._C.TensorBase.__ror__: remap_as_fn___ror__, torch._C.TensorBase.__rxor__: remap_as_fn___rxor__, torch._C.TensorBase.__rand__: remap_as_fn___rand__, } class BaseTorchVariable(VariableTracker): """Points to a context manager class in torch.* that dynamo has implementations""" @classmethod def create_with_source(cls, value, source): install_guard(source.make_guard(GuardBuilder.FUNCTION_MATCH)) return cls( value, source=source, ) def __init__(self, value, **kwargs): super().__init__(**kwargs) self.value = value def reconstruct(self, codegen): name = torch_get_name(value, f"allowed_fn_{id(value)}") unique_var_name = "__" + re.sub(r"[^a-zA-Z0-9_]+", "_", name) return codegen.setup_globally_cached(unique_var_name, value, False) def as_proxy(self): return self.value def python_type(self): return type(self.value) def as_python_constant(self): return self.value def call_hasattr(self, tx, name): result = hasattr(self.value, name) return variables.ConstantVariable.create(result) def can_constant_fold_through(self): if self.value in constant_fold_functions: return True return getattr(self.value, "__module__", None) == "math" class TorchCtxManagerClassVariable(BaseTorchVariable): """Points to a context manager class in torch.* that dynamo has implementations""" def __repr__(self): return f"TorchCtxManagerClassVariable({self.value})" def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": from . import GradModeVariable, InferenceModeVariable, StreamVariable if self.value is torch.no_grad: if len(args) == 1 and isinstance( args[0], variables.functions.BaseUserFunctionVariable ): ctx = GradModeVariable.create(tx, False) return ctx.call_function(tx, args, kwargs) else: return GradModeVariable.create(tx, False) elif self.value is torch.enable_grad: if len(args) == 1 and isinstance( args[0], variables.functions.BaseUserFunctionVariable ): ctx = GradModeVariable.create(tx, True) return ctx.call_function(tx, args, kwargs) return GradModeVariable.create(tx, True) elif self.value is torch.set_grad_enabled and len(args) == 1: return GradModeVariable.create( tx, args[0].as_python_constant(), initialized=True ) elif self.value is torch.inference_mode: return InferenceModeVariable.create(tx, args[0].as_python_constant()) elif inspect.isclass(self.value) and issubclass(self.value, _StreamBase): from torch._dynamo.variables.builder import wrap_fx_proxy_cls return wrap_fx_proxy_cls( StreamVariable, tx, tx.output.create_proxy( "call_function", self.value, (), {}, ), ) elif self.value in [ torch.amp.autocast_mode.autocast, torch.cuda.amp.autocast, torch.cpu.amp.autocast, ]: return AutocastModeVariable.create(self.value, args, kwargs) elif self.value in ( torch.profiler.profile, torch.profiler.record_function, torch.autograd.profiler.profile, torch.autograd.profiler.record_function, ): log.warning("Profiler function %s will be ignored", self.value) return NullContextVariable() elif self.value is torch._C.DisableTorchFunctionSubclass: assert not (args or kwargs) return TorchFunctionDisableVariable.create(tx) class TorchInGraphFunctionVariable(BaseTorchVariable): """Points to a torch function/method that should be put in FX graph""" def __repr__(self): return f"TorchInGraphFunctionVariable({self.value})" def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": from . import ( ConstantVariable, DeterministicAlgorithmsVariable, DisabledSavedTensorsHooksVariable, GradModeVariable, StreamContextVariable, SymNodeVariable, TensorVariable, UserDefinedObjectVariable, ) from .builder import wrap_fx_proxy, wrap_fx_proxy_cls constant_args = check_constant_args(args, kwargs) unspec_python_args = check_unspec_python_args(args, kwargs) if self.value is torch._functorch.vmap.vmap_impl: return TorchHigherOrderOperatorVariable.make( self.value, source=self.source, ).call_function(tx, args, kwargs) if self.value is torch.overrides.get_default_nowrap_functions: # [Note: __torch_function__] we return empty here because we restrict # the set of functions that we trace __torch_function__ on to # functions outside of the actual set. Implementing this properly will require implementing # some variable types to track and compare tensor getset descriptors from .builder import SourcelessBuilder return SourcelessBuilder()( tx, torch.overrides.get_default_nowrap_functions() ) elif self.value in config.constant_functions: assert not args and not kwargs # See: https://github.com/pytorch/pytorch/issues/110765 if self.value in [ torch._utils.is_compiling, torch._dynamo.external_utils.is_compiling, ]: tx.mark_inconsistent_side_effects() return ConstantVariable.create(config.constant_functions[self.value]) elif self.value is torch._functorch.eager_transforms.grad_impl: op = TorchHigherOrderOperatorVariable.make( self.value, source=self.source, ).call_function(tx, args, kwargs) return op elif self.can_constant_fold_through() and (constant_args or unspec_python_args): # constant fold return ConstantVariable.create( self.as_python_constant()( *[x.as_python_constant() for x in args], **{k: v.as_python_constant() for k, v in kwargs.items()}, ), ) elif self.value == math.radians and not (constant_args or unspec_python_args): # Use polyfill to convert math.radians(x) into math.pi * x / 180.0 from .builder import SourcelessBuilder return tx.inline_user_function_return( SourcelessBuilder()(tx, polyfill.radians), args, kwargs ) elif self.value in (torch.is_tensor, torch.overrides.is_tensor_like): assert len(args) == 1 if isinstance(args[0], TensorVariable) or ( self.value is torch.overrides.is_tensor_like and isinstance(args[0], UserDefinedObjectVariable) and hasattr(args[0].value, "__torch_function__") ): return ConstantVariable.create(True) else: return ConstantVariable.create(False) elif self.value in ( torch.is_floating_point, torch.is_complex, ): input_arg = None if args: input_arg = args[0] else: assert "input" in kwargs input_arg = kwargs["input"] if isinstance(input_arg, TensorVariable) and input_arg.dtype is not None: if self.value is torch.is_floating_point: return ConstantVariable.create(input_arg.dtype.is_floating_point) elif self.value is torch.is_complex: return ConstantVariable.create(input_arg.dtype.is_complex) else: raise AssertionError(f"calling {self.value}") elif ( self.value is torch.numel and isinstance(args[0], TensorVariable) and args[0].size is not None ): return ConstantVariable.create(product(args[0].size)) elif self.value in REWRITE_OPS_TO_TENSOR_SIZE_METHOD: assert len(args) == 1 assert isinstance(args[0], TensorVariable) return args[0].call_method(tx, "size", [], {}) elif self.value in ( torch.nn.modules.utils._single, torch.nn.modules.utils._pair, torch.nn.modules.utils._triple, torch.nn.modules.utils._quadruple, torch.nn.modules.utils._ntuple, ): return self._call_ntuple(tx, args, kwargs) elif self.value is torch.is_grad_enabled: assert not (args or kwargs) install_guard(GradModeVariable._guards_singleton) return ConstantVariable.create(torch.is_grad_enabled()) elif self.value is torch.use_deterministic_algorithms and len(args) == 1: return DeterministicAlgorithmsVariable.create( tx, args[0].as_python_constant() ) elif self.value is torch.are_deterministic_algorithms_enabled: assert not (args or kwargs) install_guard(DeterministicAlgorithmsVariable._guards_singleton) return ConstantVariable.create(torch.are_deterministic_algorithms_enabled()) elif self.value is torch.autograd.graph.disable_saved_tensors_hooks: assert len(args) == 1 return DisabledSavedTensorsHooksVariable.create( tx, args[0].as_python_constant() ) elif self.value is torch._C._is_torch_function_enabled: assert not (args or kwargs) install_guard(TorchFunctionDisableVariable._guards_singleton) return ConstantVariable.create(tx.output.torch_function_enabled) elif self.value in ( torch.overrides.has_torch_function_variadic, torch.overrides.has_torch_function_unary, ): assert not kwargs return ConstantVariable.create( any(has_torch_function(a) for a in args), ) elif any( self.value is method for method in [ device_interface.stream for _, device_interface in get_registered_device_interfaces() ] ): assert len(args) == 1 return StreamContextVariable.create(tx, args[0]) elif self.value is torch.from_numpy: if not config.trace_numpy: unimplemented("torch.from_numpy. config.trace_numpy is False") if not np: unimplemented("torch.from_numpy. NumPy is not available") return wrap_fx_proxy_cls( target_cls=TensorVariable, tx=tx, proxy=tx.output.create_proxy( "call_function", torch.as_tensor, *proxy_args_kwargs(args, {}), ), example_value=None, ) elif can_dispatch_torch_function(tx, args, kwargs): return dispatch_torch_function(tx, self, args, kwargs) elif self.value is torch.autograd._profiler_enabled: unimplemented("torch.autograd._profiler_enabled not supported yet") elif self.value is torch.jit.annotate: assert len(args) == 2 return args[1] elif self.value is torch.backends.cudnn.is_acceptable: # is_acceptable(tensor) returns true if # (a) tensor dtype/device are supported by cudnn # (b) cudnn is available # (c) some initialization has completed # technically, it depends on some global state from (c) (torch.backends.cudnn.__cudnn_version) assert ( len(args) == 1 or "tensor" in kwargs ), "Expect 1 input to cudnn.is_acceptable" tensor_variable = args[0] if len(args) > 0 else kwargs["tensor"] assert isinstance( tensor_variable, TensorVariable ), "Expect input to cudnn.is_acceptable to be a tensor" tensor_inp = torch.tensor( 0, dtype=tensor_variable.dtype, device=tensor_variable.device ) return ConstantVariable.create( torch.backends.cudnn.is_acceptable(tensor_inp) ) elif is_rng_state_getter_or_setter(self.value): # We graph break on RNG state setters or getters like # `torch.get_rng_state` or `torch.set_rng_state`. These functions # are not aten operations and therefore they are completely ignored # by the AOT dispatcher. As a result, the AOT graph does not have # these setter or getter functions, producing an incorrect graph # when it comes to rng states. unimplemented(f"RNG state getter/setter function - {self.value}") elif self.value is torch.manual_seed: # https://github.com/pytorch/pytorch/issues/107187 unimplemented("torch.manual_seed not supported") elif ( self.value == torch.numel and len(args) == 1 and isinstance(args[0], TensorVariable) and len(kwargs) == 0 ): # TODO(voz): This is rewritten as a call_method because # torch.numel(x) w/ sym shapes raises a RuntimeError and x.numel() does not return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_method", "numel", *proxy_args_kwargs(args, kwargs), ), ) # TODO: These special cases shouldn't be necessary; we should # generically support torch.ops that return int elif ( self.value in [torch.ops.aten.sym_size, torch.ops.aten.sym_size.int] and len(args) == 2 and len(kwargs) == 0 and isinstance(args[0], TensorVariable) ): # we see this when retracing already traced code return args[0].call_method(tx, "size", [args[1]], {}) elif ( self.value is [torch.ops.aten.sym_stride, torch.ops.aten.sym_stride.int] and len(args) == 2 and len(kwargs) == 0 and isinstance(args[0], TensorVariable) ): return args[0].call_method(tx, "stride", [args[1]], {}) elif ( self.value == torch.addcdiv and len(args) == 3 and "value" in kwargs and len(kwargs) == 1 ): # decompose addcdiv into constituent ops, prevents a graph break due to converting # value to a scalar result = TorchVariable(torch.div).call_function(tx, args[1:], {}) result = TorchVariable(torch.mul).call_function( tx, [result, kwargs["value"]], {} ) return TorchVariable(torch.add).call_function(tx, [args[0], result], {}) elif ( self.value is torch._assert and len(args) >= 1 and ( (args[0].is_python_constant() and args[0].as_python_constant()) or ( isinstance(args[0], variables.SymNodeVariable) and args[0].evaluate_expr() ) ) ): return ConstantVariable(None) elif is_constant_pg_functions(self.value): # becuase the input is a "ProcessGroupVariable", we'll be guarding on its # ID_MATCH based on how it was constructed. # We desugar it at trace-time into ranks by directly calling util # bake the result into the trace assert len(args) == 1, "Expected one arg (pg)" assert isinstance(args[0], ProcessGroupVariable) invocation_result = self.value(args[0].as_python_constant()) # Note - while we *could* cook up sources around invocations, like a FunctionSource # the space of invoking functions in the middle of the guard chain is very iffy. As such, # guard propagation via options is the best we can do. from .builder import SourcelessBuilder return SourcelessBuilder()(tx, invocation_result) elif is_from_local(self.value): # rewrite non-primitive args/kwargs to be included in the on-the-fly prim function # and rewrite args to have only proxyable args, then insert call_function args_as_value = [x.as_python_constant() for x in args[1:]] kwargs_as_value = {k: v.as_python_constant() for k, v in kwargs.items()} def fn_with_prim_types(x): return self.value(x, *args_as_value, **kwargs_as_value) # attach the same function name for better debugging fn_with_prim_types.__name__ = "prim " + self.value.__name__ return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", fn_with_prim_types, *proxy_args_kwargs([args[0]], {}), ), ) elif self.value == torch.nn.init._calculate_correct_fan: return UserFunctionVariable( torch.nn.init._calculate_correct_fan ).call_function(tx, args, {}) elif ( self.value is torch.nested.nested_tensor and kwargs.get("layout", torch.strided) == torch.strided ) or self.value in ( torch._nested_tensor_from_mask, torch._nested_from_padded, ): raise unimplemented("torch.compile does not support strided NestedTensor") elif self.value is torch.nn.utils.rnn.pack_padded_sequence: unimplemented("workaround https://github.com/pytorch/pytorch/issues/93501") else: any_symints_or_symfloats = any(isinstance(x, SymNodeVariable) for x in args) all_ints_or_floats = all( isinstance(x, (variables.ConstantVariable, variables.SymNodeVariable)) for x in args ) bin_ops = {"add", "sub", "mul", "div", "sqrt"} if ( getattr(self.value, "__module__", "") == "torch" and self.value.__name__ in bin_ops and any_symints_or_symfloats and all_ints_or_floats ): msg = f"""\ Calling {str(self.value)} on only torch.SymInt arguments is not yet supported. To support this behavior, we need to allow const-propping tensors that store symint data. For now, dynamo will explicitly graph break when it encounters user code with this behavior. """ log.warning(msg) raise unimplemented(msg) # TODO(voz): Replace w/ dynamic shape rewrite table. # Ideally, we would be able to do this at ctor time, but alas we need a combination # of value + args to determine this. fn_ = self.value if any(isinstance(x, SymNodeVariable) for x in args): if self.value == math.sqrt: from torch.fx.experimental.sym_node import sym_sqrt fn_ = sym_sqrt if fn_ is torch.tensor: def check_any_unspec(x): # NB: This includes UnspecializedPythonVariable if isinstance(x, (TensorVariable, SymNodeVariable)): return True elif isinstance(x, ListVariable): return any(check_any_unspec(y) for y in x.items) # TODO: there maybe other recursive structures you need to # check else: return False data_arg = None if args: data_arg = args[0] elif "data" in kwargs: data_arg = kwargs["data"] # NB: OK to pass torch.tensor(tensor), this will trace fine if not isinstance(data_arg, TensorVariable) and check_any_unspec( data_arg ): # This is slower and less canonical, so only use it if we # have to fn_ = torch._refs.tensor tensor_variable = wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", fn_, *proxy_args_kwargs(args, kwargs), ), ) if ( isinstance(tensor_variable, TensorVariable) and "requires_grad" in kwargs and kwargs["requires_grad"].as_python_constant() ): unimplemented( """factory functions that return tensors that require grad are not supported. Either create the tensor outside the compiled region, or do not set the tensor to require_grad""" ) if "out" in kwargs and not ( isinstance(kwargs["out"], variables.ConstantVariable) and kwargs["out"].as_python_constant() is None ): # out variants of torch operators like torch.sort and # torch.sigmoid mutate the tensors in the out field. Track such # tensors and rewrite the symbolic locals. if isinstance(tensor_variable, TupleVariable): assert isinstance(kwargs["out"], (TupleVariable, ListVariable)) output_tensor_names = [ tx.find_symbolic_locals_name(x) for x in kwargs["out"].items ] for idx, name in enumerate(output_tensor_names): if name in tx.symbolic_locals: tx.symbolic_locals[name] = tensor_variable.items[idx] elif isinstance(tensor_variable, TensorVariable): assert isinstance(kwargs["out"], TensorVariable) if ( kwargs["out"].source and kwargs["out"] in tx.output.graphargs and kwargs["out"].size != tensor_variable.size ): # It's hard to get out variants with resizing on graph inputs work # properly across dynamo/aot/inductor, just fall back. unimplemented("out variants with resizing on graph inputs") assert "example_value" in kwargs["out"].proxy.node.meta if not torch._prims_common.is_contiguous( kwargs["out"].proxy.node.meta["example_value"] ): # It's difficult to handle strides correctly in functionalization # when calling an out= op with a non-contiguous out argument unimplemented( "out= op was called where output tensor was non-contiguous" ) name = tx.find_symbolic_locals_name(kwargs["out"]) if name in tx.symbolic_locals: tx.symbolic_locals[name] = tensor_variable else: unimplemented(f"out variant of {type(kwargs['out'])}") return tensor_variable def _call_ntuple(self, tx, args, kwargs): """inline behavior of torch.nn.modules.utils._ntuple""" if self.value is torch.nn.modules.utils._ntuple: count = args[0].as_python_constant() else: count = self.value.__closure__[0].cell_contents assert isinstance(count, int) assert not kwargs def handle_ntuple(value): if value.has_unpack_var_sequence(tx): return variables.TupleVariable( list(value.unpack_var_sequence(tx)), ) elif value.is_python_constant(): # constant prop through it return variables.ConstantVariable.create( torch.nn.modules.utils._ntuple(count)(value.as_python_constant()), ) else: unimplemented(f"torch.nn.modules.utils._ntuple({value})") if self.value is torch.nn.modules.utils._ntuple: return variables.LambdaVariable(handle_ntuple) else: return handle_ntuple(args[0]) class TorchVariable(BaseTorchVariable): """Points to a module, classes or functions in torch.*""" def __init__(self, value, **kwargs): # TODO: Remove tensor_dunder_fns_remap since it's not used anymore. if ( isinstance(value, collections.abc.Hashable) and value in tensor_dunder_fns_remap ): value = tensor_dunder_fns_remap[value] assert not isinstance( value, (torch.dtype, torch.device) ), "should use ConstantVariable" super().__init__(value, **kwargs) # the remainder of this is just optional debug checks try: self_should_be_none = getattr(self.value, "__self__", None) except RuntimeError as e: assert "No such operator" in str(e), str(e) self_should_be_none = None except AssertionError as e: assert "Unknown attribute" in str(e), str(e) self_should_be_none = None # assert "_ntuple..parse" not in str(value) if self_should_be_none is None: pass elif isinstance(self_should_be_none, types.ModuleType): # weird ones like torch.nn.functional.avg_pool2d have __self__ name = self_should_be_none.__name__ assert re.match(r"^(torch|math)([.]|$)", name), f"__self__ set to {name}" elif isinstance( self_should_be_none, type(torch._C._get_tracing_state.__self__) ): # some _C functions have __self__ as a null capsule pass elif isinstance(self_should_be_none, torch_special_class_types): pass else: raise AssertionError(f"{value} found with __self__ set") def __repr__(self): return f"TorchVariable({self.value})" def python_type(self): if isinstance(self.value, (torch.Tensor, torch.nn.Module, torch.device)): return type(self.value) if isinstance(self.value, type): return type return super().python_type() def call_function( self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]" ) -> "VariableTracker": from . import ConstantVariable from .builder import wrap_fx_proxy constant_args = check_constant_args(args, kwargs) unspec_python_args = check_unspec_python_args(args, kwargs) if self.can_constant_fold_through() and (constant_args or unspec_python_args): # constant fold return ConstantVariable.create( self.as_python_constant()( *[x.as_python_constant() for x in args], **{k: v.as_python_constant() for k, v in kwargs.items()}, ), ) elif istype(self.value, type) and issubclass(self.value, torch.nn.Module): if self.value is torch.nn.CrossEntropyLoss: return self._call_cross_entropy_loss(tx, args, kwargs) else: return variables.UserDefinedClassVariable( self.value, source=self.source ).call_function(tx, args, kwargs) elif can_dispatch_torch_function(tx, args, kwargs): return dispatch_torch_function(tx, self, args, kwargs) elif self.value is torch.nn.Parameter: # https://github.com/pytorch/pytorch/issues/99569 unimplemented("torch.nn.Parameter not supported") elif isinstance(self.value, types.ModuleType): unimplemented("TypeError(\"'module' object is not callable\")") else: # torch.LongTensor cannot accept a list of FakeTensors. # So we stack the list of FakeTensors instead. if ( np and self.value in tensortype_to_dtype and len(args) == 1 and isinstance(args[0], ListVariable) and len(args[0].items) > 1 and all(isinstance(x, variables.TensorVariable) for x in args[0].items) ): # Stack FakeTensor stacked = wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", torch.stack, *proxy_args_kwargs(args, kwargs), ), ) args = [stacked] tensor_variable = wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", self.value, *proxy_args_kwargs(args, kwargs), ), ) return tensor_variable def _call_cross_entropy_loss(self, tx, args, kwargs): """ functional: input, target, weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', label_smoothing=0.0 non functional ctor: weight=None, size_average=None, ignore_index=- 100, reduce=None, reduction='mean', label_smoothing=0.0 non functional loss call: input, target, optional_output """ from . import ConstantVariable def normalize_args( weight=ConstantVariable.create(None), size_average=ConstantVariable.create(None), ignore_index=ConstantVariable.create(-100), reduce=ConstantVariable.create(None), reduction=ConstantVariable.create("mean"), label_smoothing=ConstantVariable.create(0.0), ): return ( weight, size_average, ignore_index, reduce, reduction, label_smoothing, ) ( weight, size_average, ignore_index, reduce_arg, reduction, label_smoothing, ) = normalize_args(*args, **kwargs) def fake_cross_entropy_loss(input, target): from .builder import wrap_fx_proxy return wrap_fx_proxy( tx=tx, proxy=tx.output.create_proxy( "call_function", torch.nn.functional.cross_entropy, *proxy_args_kwargs( [ input, target, weight, size_average, ignore_index, reduce_arg, reduction, label_smoothing, ], {}, ), ), ) return variables.LambdaVariable(fake_cross_entropy_loss)