[Easy] Fix argument name collision in dispatched functions (#129562)

Use positional-only argument to avoid naming collision with aten ops arguments that are named "self".

```python
In [1]: def foo(self, *args, **kwargs):
   ...:     print(self, args, kwargs)
   ...:

In [2]: def bar(self, /, *args, **kwargs):
   ...:     print(self, args, kwargs)
   ...:

In [3]: foo(1, 2, self=3)
TypeError: foo() got multiple values for argument 'self'

In [4]: bar(1, 2, self=3)
1
(2,)
{'self': 3}
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129562
Approved by: https://github.com/zou3519, https://github.com/fegin
This commit is contained in:
Xuehai Pan 2024-07-17 14:39:54 +00:00 committed by PyTorch MergeBot
parent c0ed38e644
commit b29b23137c
3 changed files with 59 additions and 63 deletions

View File

@ -4219,18 +4219,18 @@ def meta_max_pool2d_with_indices(
@register_meta(aten.fractional_max_pool2d.default)
def meta_fractional_max_pool2d(self_, kernel_size, output_size, random_samples):
def meta_fractional_max_pool2d(self, kernel_size, output_size, random_samples):
torch._check(
self_.ndim in (3, 4),
lambda: f"fractional_max_pool2d: Expected 3D or 4D tensor, but got: {self_.ndim}",
self.ndim in (3, 4),
lambda: f"fractional_max_pool2d: Expected 3D or 4D tensor, but got: {self.ndim}",
)
ndim = self_.ndim
ndim = self.ndim
for d in range(ndim - 3, ndim):
torch._check(
self_.size(d) > 0,
self.size(d) > 0,
f"fractional_max_pool2d: Expected input to have non-zero "
f" size for non-batch dimenions, but got {self_.size()} with dimension {d} empty",
f" size for non-batch dimenions, but got {self.size()} with dimension {d} empty",
)
# the check and message are out of sync, but this matches the structured meta
@ -4245,16 +4245,16 @@ def meta_fractional_max_pool2d(self_, kernel_size, output_size, random_samples):
"either be a single int or tuple of Ints",
)
input_channels = self_.size(-3)
input_height = self_.size(-2)
input_width = self_.size(-1)
input_channels = self.size(-3)
input_height = self.size(-2)
input_width = self.size(-1)
if ndim == 4:
input_batch = self_.size(0)
input_batch = self.size(0)
else:
input_batch = 1
torch._check(
self_.dtype == random_samples.dtype,
self.dtype == random_samples.dtype,
lambda: "Expect _random_samples to have the same dtype as input",
)
torch._check(
@ -4284,7 +4284,7 @@ def meta_fractional_max_pool2d(self_, kernel_size, output_size, random_samples):
lambda: f"fractional_max_pool2d: kernel width {kernel_size[1]} is too large relative to input width {input_width}",
)
if self_.dim() == 4:
if self.dim() == 4:
size = [input_batch, input_channels, output_size[0], output_size[1]]
else:
size = [input_channels, output_size[0], output_size[1]]
@ -4292,20 +4292,20 @@ def meta_fractional_max_pool2d(self_, kernel_size, output_size, random_samples):
return (
torch.empty(
size,
dtype=self_.dtype,
device=self_.device,
dtype=self.dtype,
device=self.device,
),
torch.empty(
size,
dtype=torch.int64,
device=self_.device,
device=self.device,
),
)
@register_meta(aten.max_unpool2d)
@out_wrapper()
def meta_max_unpool2d(self_, indices, output_size):
def meta_max_unpool2d(self, indices, output_size):
utils.alert_not_deterministic("max_unpooling2d_forward_out")
torch._check(
@ -4323,33 +4323,33 @@ def meta_max_unpool2d(self_, indices, output_size):
oheight, owidth = output_size
torch._check(
self_.ndim in (3, 4),
self.ndim in (3, 4),
lambda: (
f"Input to max_unpooling2d should be a 3d or 4d Tensor, "
f"but got a tensor with {self_.ndim} dimensions."
f"but got a tensor with {self.ndim} dimensions."
),
)
torch._check(
self_.shape == indices.shape,
self.shape == indices.shape,
lambda: (
f"Expected shape of indices to be same as that of the input tensor ({self_.shape}) "
f"Expected shape of indices to be same as that of the input tensor ({self.shape}) "
f"but got indices tensor with shape: {indices.shape}"
),
)
for i in range(1, self_.ndim):
for i in range(1, self.ndim):
torch._check(
self_.size(i) > 0,
self.size(i) > 0,
lambda: (
f"max_unpooling2d(): "
f"Expected input to have non-zero size for non-batch dimensions, "
f"but got {self_.shape} with dimension {i} being empty."
f"but got {self.shape} with dimension {i} being empty."
),
)
self = self_.contiguous()
self = self.contiguous()
if self_.ndim == 3:
if self.ndim == 3:
nchannels = self.size(0)
result = self.new_empty((nchannels, oheight, owidth))
else:
@ -4409,18 +4409,18 @@ def _max_unpooling3d_shape_check(input, indices, output_size, stride, padding, f
@register_meta(aten.max_unpool3d)
@out_wrapper()
def meta_max_unpool3d(self_, indices, output_size, stride, padding):
def meta_max_unpool3d(self, indices, output_size, stride, padding):
utils.alert_not_deterministic("max_unpooling3d_forward_out")
_max_unpooling3d_shape_check(
self_, indices, output_size, stride, padding, "max_unpooling3d()"
self, indices, output_size, stride, padding, "max_unpooling3d()"
)
self = self_.contiguous()
self = self.contiguous()
odepth, oheight, owidth = output_size
if self_.ndim == 4:
if self.ndim == 4:
nchannels = self.size(0)
result = self.new_empty((nchannels, odepth, oheight, owidth))
else:

View File

@ -7,14 +7,14 @@ import sys
import types
from typing import Any, Callable, Dict, List, Set, Type, Union
import torch._C
import torch
import torch.utils._pytree as pytree
from torch import _utils_internal
from torch._functorch.pyfunctorch import dispatch_functorch
from torch.utils._python_dispatch import TorchDispatchMode
# Query `hasattr` only once.
# Query `hasattr` only once.
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
@ -418,7 +418,6 @@ class HigherOrderOperator(OperatorBase):
def __call__(self_, *args, **kwargs): # noqa: B902
# Dynamo already traces the body of HigherOrderOp beforehand when it
# so no need to trace into it.
import torch._dynamo
from torch._dynamo import disable
@disable
@ -721,15 +720,15 @@ class OpOverload(OperatorBase):
*self._schema.name.split("::"), self._overloadname
)
def __call__(self_, *args, **kwargs): # noqa: B902
# use `self_` to avoid naming collide with aten ops arguments that
# are named "self". This way, all the aten ops can be called by kwargs.
return self_._op(*args, **kwargs)
# Use positional-only argument to avoid naming collision with aten ops arguments
# that are named "self". This way, all the aten ops can be called by kwargs.
def __call__(self, /, *args, **kwargs):
return self._op(*args, **kwargs)
def redispatch(self_, keyset, *args, **kwargs): # noqa: B902
# use `self_` to avoid naming collide with aten ops arguments that
# are named "self". This way, all the aten ops can be called by kwargs.
return self_._handle.redispatch_boxed(keyset, *args, **kwargs)
# Use positional-only argument to avoid naming collision with aten ops arguments
# that are named "self". This way, all the aten ops can be called by kwargs.
def redispatch(self, /, keyset, *args, **kwargs):
return self._handle.redispatch_boxed(keyset, *args, **kwargs)
def __hash__(self):
return hash(self._op)
@ -945,9 +944,9 @@ class TorchBindOpOverload(OpOverload):
if self in SIDE_EFFECTS:
del SIDE_EFFECTS[self]
# use `self_` to avoid naming collide with arguments that
# are named "self". This way, they can be called by kwargs.
def __call__(self_, *args, **kwargs): # noqa: B902
# Use positional-only argument to avoid naming collision with aten ops arguments
# that are named "self". This way, all the aten ops can be called by kwargs.
def __call__(self, /, *args, **kwargs):
if _must_dispatch_in_python(args, kwargs):
# When any inputs are FakeScriptObject, we need to
# skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher
@ -959,11 +958,9 @@ class TorchBindOpOverload(OpOverload):
# of the eagerly executing the op might change after tracing.
# 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might
# cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction.
with self_._register_as_effectful_op_temporarily():
return self_._dispatch_in_python(
args, kwargs, self_._fallthrough_keys()
)
return self_._op(*args, **kwargs)
with self._register_as_effectful_op_temporarily():
return self._dispatch_in_python(args, kwargs, self._fallthrough_keys())
return self._op(*args, **kwargs)
def _dispatch_in_python(self, args, kwargs, fallthrough_keys):
non_fallthrough_keys = torch._C._dispatch_keyset_full()
@ -1107,10 +1104,9 @@ class OpOverloadPacket:
def __iter__(self):
return iter(self._dir)
def __call__(self_, *args, **kwargs): # noqa: B902
# use `self_` to avoid naming collide with aten ops arguments that
# named "self". This way, all the aten ops can be called by kwargs.
# Use positional-only argument to avoid naming collision with aten ops arguments
# that are named "self". This way, all the aten ops can be called by kwargs.
def __call__(self, /, *args, **kwargs):
# overloading __call__ to ensure torch.ops.foo.bar()
# is still callable from JIT
# We save the function ptr as the `op` attribute on
@ -1119,9 +1115,9 @@ class OpOverloadPacket:
# Directly calling OverloadPacket goes into C++, which will check
# the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we
# intercept it here and call TorchBindOpverload instead.
if self_._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
return _call_overload_packet_from_python(self_, args, kwargs)
return self_._op(*args, **(kwargs or {}))
if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
return _call_overload_packet_from_python(self, args, kwargs)
return self._op(*args, **(kwargs or {}))
# TODO: use this to make a __dir__
def overloads(self):

View File

@ -7866,16 +7866,16 @@ class DistributedTest:
}
class ToyModel(torch.nn.Module):
def __init__(_self): # noqa: B902
def __init__(self_): # noqa: B902
super().__init__()
_self.lin = nn.Linear(10, 10, bias=False)
self_.lin = nn.Linear(10, 10, bias=False)
def forward(_self, x, expected_type): # noqa: B902
def forward(self_, x, expected_type): # noqa: B902
# Similar to scatter, the recursive to in the single-device
# case does not move tensors if they are in a custom type.
self.assertTrue(isinstance(x, expected_type))
fwd_tensor = validators[expected_type](x)
return _self.lin(fwd_tensor)
return self_.lin(fwd_tensor)
model = torch.nn.parallel.DistributedDataParallel(
ToyModel().to(self.rank), device_ids=[self.rank]
@ -7929,11 +7929,11 @@ class DistributedTest:
b = torch.rand(batch, dim, device=self.rank)
class NamedTupleModule(torch.nn.Module):
def __init__(_self): # noqa: B902
def __init__(self_): # noqa: B902
super().__init__()
_self.lin = nn.Linear(10, 1)
self_.lin = nn.Linear(10, 1)
def forward(_self, input, expected_type): # noqa: B902
def forward(self_, input, expected_type): # noqa: B902
# Without NamedTuple support, this would be of type tuple.
self.assertTrue(
isinstance(input, expected_type),
@ -7942,7 +7942,7 @@ class DistributedTest:
self.assertEqual(input._fields, EXPECTED_FIELDS)
self.assertEqual(a, input.a)
self.assertEqual(b, input.b)
return _self.lin(torch.mul(input.a, input.b))
return self_.lin(torch.mul(input.a, input.b))
model = torch.nn.parallel.DistributedDataParallel(
NamedTupleModule().cuda(self.rank), device_ids=[self.rank]