mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Easy] Fix argument name collision in dispatched functions (#129562)
Use positional-only argument to avoid naming collision with aten ops arguments that are named "self".
```python
In [1]: def foo(self, *args, **kwargs):
...: print(self, args, kwargs)
...:
In [2]: def bar(self, /, *args, **kwargs):
...: print(self, args, kwargs)
...:
In [3]: foo(1, 2, self=3)
TypeError: foo() got multiple values for argument 'self'
In [4]: bar(1, 2, self=3)
1
(2,)
{'self': 3}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129562
Approved by: https://github.com/zou3519, https://github.com/fegin
This commit is contained in:
parent
c0ed38e644
commit
b29b23137c
|
|
@ -4219,18 +4219,18 @@ def meta_max_pool2d_with_indices(
|
|||
|
||||
|
||||
@register_meta(aten.fractional_max_pool2d.default)
|
||||
def meta_fractional_max_pool2d(self_, kernel_size, output_size, random_samples):
|
||||
def meta_fractional_max_pool2d(self, kernel_size, output_size, random_samples):
|
||||
torch._check(
|
||||
self_.ndim in (3, 4),
|
||||
lambda: f"fractional_max_pool2d: Expected 3D or 4D tensor, but got: {self_.ndim}",
|
||||
self.ndim in (3, 4),
|
||||
lambda: f"fractional_max_pool2d: Expected 3D or 4D tensor, but got: {self.ndim}",
|
||||
)
|
||||
ndim = self_.ndim
|
||||
ndim = self.ndim
|
||||
|
||||
for d in range(ndim - 3, ndim):
|
||||
torch._check(
|
||||
self_.size(d) > 0,
|
||||
self.size(d) > 0,
|
||||
f"fractional_max_pool2d: Expected input to have non-zero "
|
||||
f" size for non-batch dimenions, but got {self_.size()} with dimension {d} empty",
|
||||
f" size for non-batch dimenions, but got {self.size()} with dimension {d} empty",
|
||||
)
|
||||
|
||||
# the check and message are out of sync, but this matches the structured meta
|
||||
|
|
@ -4245,16 +4245,16 @@ def meta_fractional_max_pool2d(self_, kernel_size, output_size, random_samples):
|
|||
"either be a single int or tuple of Ints",
|
||||
)
|
||||
|
||||
input_channels = self_.size(-3)
|
||||
input_height = self_.size(-2)
|
||||
input_width = self_.size(-1)
|
||||
input_channels = self.size(-3)
|
||||
input_height = self.size(-2)
|
||||
input_width = self.size(-1)
|
||||
if ndim == 4:
|
||||
input_batch = self_.size(0)
|
||||
input_batch = self.size(0)
|
||||
else:
|
||||
input_batch = 1
|
||||
|
||||
torch._check(
|
||||
self_.dtype == random_samples.dtype,
|
||||
self.dtype == random_samples.dtype,
|
||||
lambda: "Expect _random_samples to have the same dtype as input",
|
||||
)
|
||||
torch._check(
|
||||
|
|
@ -4284,7 +4284,7 @@ def meta_fractional_max_pool2d(self_, kernel_size, output_size, random_samples):
|
|||
lambda: f"fractional_max_pool2d: kernel width {kernel_size[1]} is too large relative to input width {input_width}",
|
||||
)
|
||||
|
||||
if self_.dim() == 4:
|
||||
if self.dim() == 4:
|
||||
size = [input_batch, input_channels, output_size[0], output_size[1]]
|
||||
else:
|
||||
size = [input_channels, output_size[0], output_size[1]]
|
||||
|
|
@ -4292,20 +4292,20 @@ def meta_fractional_max_pool2d(self_, kernel_size, output_size, random_samples):
|
|||
return (
|
||||
torch.empty(
|
||||
size,
|
||||
dtype=self_.dtype,
|
||||
device=self_.device,
|
||||
dtype=self.dtype,
|
||||
device=self.device,
|
||||
),
|
||||
torch.empty(
|
||||
size,
|
||||
dtype=torch.int64,
|
||||
device=self_.device,
|
||||
device=self.device,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@register_meta(aten.max_unpool2d)
|
||||
@out_wrapper()
|
||||
def meta_max_unpool2d(self_, indices, output_size):
|
||||
def meta_max_unpool2d(self, indices, output_size):
|
||||
utils.alert_not_deterministic("max_unpooling2d_forward_out")
|
||||
|
||||
torch._check(
|
||||
|
|
@ -4323,33 +4323,33 @@ def meta_max_unpool2d(self_, indices, output_size):
|
|||
oheight, owidth = output_size
|
||||
|
||||
torch._check(
|
||||
self_.ndim in (3, 4),
|
||||
self.ndim in (3, 4),
|
||||
lambda: (
|
||||
f"Input to max_unpooling2d should be a 3d or 4d Tensor, "
|
||||
f"but got a tensor with {self_.ndim} dimensions."
|
||||
f"but got a tensor with {self.ndim} dimensions."
|
||||
),
|
||||
)
|
||||
torch._check(
|
||||
self_.shape == indices.shape,
|
||||
self.shape == indices.shape,
|
||||
lambda: (
|
||||
f"Expected shape of indices to be same as that of the input tensor ({self_.shape}) "
|
||||
f"Expected shape of indices to be same as that of the input tensor ({self.shape}) "
|
||||
f"but got indices tensor with shape: {indices.shape}"
|
||||
),
|
||||
)
|
||||
|
||||
for i in range(1, self_.ndim):
|
||||
for i in range(1, self.ndim):
|
||||
torch._check(
|
||||
self_.size(i) > 0,
|
||||
self.size(i) > 0,
|
||||
lambda: (
|
||||
f"max_unpooling2d(): "
|
||||
f"Expected input to have non-zero size for non-batch dimensions, "
|
||||
f"but got {self_.shape} with dimension {i} being empty."
|
||||
f"but got {self.shape} with dimension {i} being empty."
|
||||
),
|
||||
)
|
||||
|
||||
self = self_.contiguous()
|
||||
self = self.contiguous()
|
||||
|
||||
if self_.ndim == 3:
|
||||
if self.ndim == 3:
|
||||
nchannels = self.size(0)
|
||||
result = self.new_empty((nchannels, oheight, owidth))
|
||||
else:
|
||||
|
|
@ -4409,18 +4409,18 @@ def _max_unpooling3d_shape_check(input, indices, output_size, stride, padding, f
|
|||
|
||||
@register_meta(aten.max_unpool3d)
|
||||
@out_wrapper()
|
||||
def meta_max_unpool3d(self_, indices, output_size, stride, padding):
|
||||
def meta_max_unpool3d(self, indices, output_size, stride, padding):
|
||||
utils.alert_not_deterministic("max_unpooling3d_forward_out")
|
||||
|
||||
_max_unpooling3d_shape_check(
|
||||
self_, indices, output_size, stride, padding, "max_unpooling3d()"
|
||||
self, indices, output_size, stride, padding, "max_unpooling3d()"
|
||||
)
|
||||
|
||||
self = self_.contiguous()
|
||||
self = self.contiguous()
|
||||
|
||||
odepth, oheight, owidth = output_size
|
||||
|
||||
if self_.ndim == 4:
|
||||
if self.ndim == 4:
|
||||
nchannels = self.size(0)
|
||||
result = self.new_empty((nchannels, odepth, oheight, owidth))
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -7,14 +7,14 @@ import sys
|
|||
import types
|
||||
from typing import Any, Callable, Dict, List, Set, Type, Union
|
||||
|
||||
import torch._C
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch import _utils_internal
|
||||
from torch._functorch.pyfunctorch import dispatch_functorch
|
||||
from torch.utils._python_dispatch import TorchDispatchMode
|
||||
|
||||
# Query `hasattr` only once.
|
||||
|
||||
# Query `hasattr` only once.
|
||||
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
|
||||
|
||||
|
||||
|
|
@ -418,7 +418,6 @@ class HigherOrderOperator(OperatorBase):
|
|||
def __call__(self_, *args, **kwargs): # noqa: B902
|
||||
# Dynamo already traces the body of HigherOrderOp beforehand when it
|
||||
# so no need to trace into it.
|
||||
import torch._dynamo
|
||||
from torch._dynamo import disable
|
||||
|
||||
@disable
|
||||
|
|
@ -721,15 +720,15 @@ class OpOverload(OperatorBase):
|
|||
*self._schema.name.split("::"), self._overloadname
|
||||
)
|
||||
|
||||
def __call__(self_, *args, **kwargs): # noqa: B902
|
||||
# use `self_` to avoid naming collide with aten ops arguments that
|
||||
# are named "self". This way, all the aten ops can be called by kwargs.
|
||||
return self_._op(*args, **kwargs)
|
||||
# Use positional-only argument to avoid naming collision with aten ops arguments
|
||||
# that are named "self". This way, all the aten ops can be called by kwargs.
|
||||
def __call__(self, /, *args, **kwargs):
|
||||
return self._op(*args, **kwargs)
|
||||
|
||||
def redispatch(self_, keyset, *args, **kwargs): # noqa: B902
|
||||
# use `self_` to avoid naming collide with aten ops arguments that
|
||||
# are named "self". This way, all the aten ops can be called by kwargs.
|
||||
return self_._handle.redispatch_boxed(keyset, *args, **kwargs)
|
||||
# Use positional-only argument to avoid naming collision with aten ops arguments
|
||||
# that are named "self". This way, all the aten ops can be called by kwargs.
|
||||
def redispatch(self, /, keyset, *args, **kwargs):
|
||||
return self._handle.redispatch_boxed(keyset, *args, **kwargs)
|
||||
|
||||
def __hash__(self):
|
||||
return hash(self._op)
|
||||
|
|
@ -945,9 +944,9 @@ class TorchBindOpOverload(OpOverload):
|
|||
if self in SIDE_EFFECTS:
|
||||
del SIDE_EFFECTS[self]
|
||||
|
||||
# use `self_` to avoid naming collide with arguments that
|
||||
# are named "self". This way, they can be called by kwargs.
|
||||
def __call__(self_, *args, **kwargs): # noqa: B902
|
||||
# Use positional-only argument to avoid naming collision with aten ops arguments
|
||||
# that are named "self". This way, all the aten ops can be called by kwargs.
|
||||
def __call__(self, /, *args, **kwargs):
|
||||
if _must_dispatch_in_python(args, kwargs):
|
||||
# When any inputs are FakeScriptObject, we need to
|
||||
# skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher
|
||||
|
|
@ -959,11 +958,9 @@ class TorchBindOpOverload(OpOverload):
|
|||
# of the eagerly executing the op might change after tracing.
|
||||
# 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might
|
||||
# cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction.
|
||||
with self_._register_as_effectful_op_temporarily():
|
||||
return self_._dispatch_in_python(
|
||||
args, kwargs, self_._fallthrough_keys()
|
||||
)
|
||||
return self_._op(*args, **kwargs)
|
||||
with self._register_as_effectful_op_temporarily():
|
||||
return self._dispatch_in_python(args, kwargs, self._fallthrough_keys())
|
||||
return self._op(*args, **kwargs)
|
||||
|
||||
def _dispatch_in_python(self, args, kwargs, fallthrough_keys):
|
||||
non_fallthrough_keys = torch._C._dispatch_keyset_full()
|
||||
|
|
@ -1107,10 +1104,9 @@ class OpOverloadPacket:
|
|||
def __iter__(self):
|
||||
return iter(self._dir)
|
||||
|
||||
def __call__(self_, *args, **kwargs): # noqa: B902
|
||||
# use `self_` to avoid naming collide with aten ops arguments that
|
||||
# named "self". This way, all the aten ops can be called by kwargs.
|
||||
|
||||
# Use positional-only argument to avoid naming collision with aten ops arguments
|
||||
# that are named "self". This way, all the aten ops can be called by kwargs.
|
||||
def __call__(self, /, *args, **kwargs):
|
||||
# overloading __call__ to ensure torch.ops.foo.bar()
|
||||
# is still callable from JIT
|
||||
# We save the function ptr as the `op` attribute on
|
||||
|
|
@ -1119,9 +1115,9 @@ class OpOverloadPacket:
|
|||
# Directly calling OverloadPacket goes into C++, which will check
|
||||
# the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we
|
||||
# intercept it here and call TorchBindOpverload instead.
|
||||
if self_._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
|
||||
return _call_overload_packet_from_python(self_, args, kwargs)
|
||||
return self_._op(*args, **(kwargs or {}))
|
||||
if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
|
||||
return _call_overload_packet_from_python(self, args, kwargs)
|
||||
return self._op(*args, **(kwargs or {}))
|
||||
|
||||
# TODO: use this to make a __dir__
|
||||
def overloads(self):
|
||||
|
|
|
|||
|
|
@ -7866,16 +7866,16 @@ class DistributedTest:
|
|||
}
|
||||
|
||||
class ToyModel(torch.nn.Module):
|
||||
def __init__(_self): # noqa: B902
|
||||
def __init__(self_): # noqa: B902
|
||||
super().__init__()
|
||||
_self.lin = nn.Linear(10, 10, bias=False)
|
||||
self_.lin = nn.Linear(10, 10, bias=False)
|
||||
|
||||
def forward(_self, x, expected_type): # noqa: B902
|
||||
def forward(self_, x, expected_type): # noqa: B902
|
||||
# Similar to scatter, the recursive to in the single-device
|
||||
# case does not move tensors if they are in a custom type.
|
||||
self.assertTrue(isinstance(x, expected_type))
|
||||
fwd_tensor = validators[expected_type](x)
|
||||
return _self.lin(fwd_tensor)
|
||||
return self_.lin(fwd_tensor)
|
||||
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
ToyModel().to(self.rank), device_ids=[self.rank]
|
||||
|
|
@ -7929,11 +7929,11 @@ class DistributedTest:
|
|||
b = torch.rand(batch, dim, device=self.rank)
|
||||
|
||||
class NamedTupleModule(torch.nn.Module):
|
||||
def __init__(_self): # noqa: B902
|
||||
def __init__(self_): # noqa: B902
|
||||
super().__init__()
|
||||
_self.lin = nn.Linear(10, 1)
|
||||
self_.lin = nn.Linear(10, 1)
|
||||
|
||||
def forward(_self, input, expected_type): # noqa: B902
|
||||
def forward(self_, input, expected_type): # noqa: B902
|
||||
# Without NamedTuple support, this would be of type tuple.
|
||||
self.assertTrue(
|
||||
isinstance(input, expected_type),
|
||||
|
|
@ -7942,7 +7942,7 @@ class DistributedTest:
|
|||
self.assertEqual(input._fields, EXPECTED_FIELDS)
|
||||
self.assertEqual(a, input.a)
|
||||
self.assertEqual(b, input.b)
|
||||
return _self.lin(torch.mul(input.a, input.b))
|
||||
return self_.lin(torch.mul(input.a, input.b))
|
||||
|
||||
model = torch.nn.parallel.DistributedDataParallel(
|
||||
NamedTupleModule().cuda(self.rank), device_ids=[self.rank]
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user