mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
[Easy] Fix argument name collision in dispatched functions (#129562)
Use positional-only argument to avoid naming collision with aten ops arguments that are named "self".
```python
In [1]: def foo(self, *args, **kwargs):
...: print(self, args, kwargs)
...:
In [2]: def bar(self, /, *args, **kwargs):
...: print(self, args, kwargs)
...:
In [3]: foo(1, 2, self=3)
TypeError: foo() got multiple values for argument 'self'
In [4]: bar(1, 2, self=3)
1
(2,)
{'self': 3}
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/129562
Approved by: https://github.com/zou3519, https://github.com/fegin
This commit is contained in:
parent
c0ed38e644
commit
b29b23137c
|
|
@ -4219,18 +4219,18 @@ def meta_max_pool2d_with_indices(
|
||||||
|
|
||||||
|
|
||||||
@register_meta(aten.fractional_max_pool2d.default)
|
@register_meta(aten.fractional_max_pool2d.default)
|
||||||
def meta_fractional_max_pool2d(self_, kernel_size, output_size, random_samples):
|
def meta_fractional_max_pool2d(self, kernel_size, output_size, random_samples):
|
||||||
torch._check(
|
torch._check(
|
||||||
self_.ndim in (3, 4),
|
self.ndim in (3, 4),
|
||||||
lambda: f"fractional_max_pool2d: Expected 3D or 4D tensor, but got: {self_.ndim}",
|
lambda: f"fractional_max_pool2d: Expected 3D or 4D tensor, but got: {self.ndim}",
|
||||||
)
|
)
|
||||||
ndim = self_.ndim
|
ndim = self.ndim
|
||||||
|
|
||||||
for d in range(ndim - 3, ndim):
|
for d in range(ndim - 3, ndim):
|
||||||
torch._check(
|
torch._check(
|
||||||
self_.size(d) > 0,
|
self.size(d) > 0,
|
||||||
f"fractional_max_pool2d: Expected input to have non-zero "
|
f"fractional_max_pool2d: Expected input to have non-zero "
|
||||||
f" size for non-batch dimenions, but got {self_.size()} with dimension {d} empty",
|
f" size for non-batch dimenions, but got {self.size()} with dimension {d} empty",
|
||||||
)
|
)
|
||||||
|
|
||||||
# the check and message are out of sync, but this matches the structured meta
|
# the check and message are out of sync, but this matches the structured meta
|
||||||
|
|
@ -4245,16 +4245,16 @@ def meta_fractional_max_pool2d(self_, kernel_size, output_size, random_samples):
|
||||||
"either be a single int or tuple of Ints",
|
"either be a single int or tuple of Ints",
|
||||||
)
|
)
|
||||||
|
|
||||||
input_channels = self_.size(-3)
|
input_channels = self.size(-3)
|
||||||
input_height = self_.size(-2)
|
input_height = self.size(-2)
|
||||||
input_width = self_.size(-1)
|
input_width = self.size(-1)
|
||||||
if ndim == 4:
|
if ndim == 4:
|
||||||
input_batch = self_.size(0)
|
input_batch = self.size(0)
|
||||||
else:
|
else:
|
||||||
input_batch = 1
|
input_batch = 1
|
||||||
|
|
||||||
torch._check(
|
torch._check(
|
||||||
self_.dtype == random_samples.dtype,
|
self.dtype == random_samples.dtype,
|
||||||
lambda: "Expect _random_samples to have the same dtype as input",
|
lambda: "Expect _random_samples to have the same dtype as input",
|
||||||
)
|
)
|
||||||
torch._check(
|
torch._check(
|
||||||
|
|
@ -4284,7 +4284,7 @@ def meta_fractional_max_pool2d(self_, kernel_size, output_size, random_samples):
|
||||||
lambda: f"fractional_max_pool2d: kernel width {kernel_size[1]} is too large relative to input width {input_width}",
|
lambda: f"fractional_max_pool2d: kernel width {kernel_size[1]} is too large relative to input width {input_width}",
|
||||||
)
|
)
|
||||||
|
|
||||||
if self_.dim() == 4:
|
if self.dim() == 4:
|
||||||
size = [input_batch, input_channels, output_size[0], output_size[1]]
|
size = [input_batch, input_channels, output_size[0], output_size[1]]
|
||||||
else:
|
else:
|
||||||
size = [input_channels, output_size[0], output_size[1]]
|
size = [input_channels, output_size[0], output_size[1]]
|
||||||
|
|
@ -4292,20 +4292,20 @@ def meta_fractional_max_pool2d(self_, kernel_size, output_size, random_samples):
|
||||||
return (
|
return (
|
||||||
torch.empty(
|
torch.empty(
|
||||||
size,
|
size,
|
||||||
dtype=self_.dtype,
|
dtype=self.dtype,
|
||||||
device=self_.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
torch.empty(
|
torch.empty(
|
||||||
size,
|
size,
|
||||||
dtype=torch.int64,
|
dtype=torch.int64,
|
||||||
device=self_.device,
|
device=self.device,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@register_meta(aten.max_unpool2d)
|
@register_meta(aten.max_unpool2d)
|
||||||
@out_wrapper()
|
@out_wrapper()
|
||||||
def meta_max_unpool2d(self_, indices, output_size):
|
def meta_max_unpool2d(self, indices, output_size):
|
||||||
utils.alert_not_deterministic("max_unpooling2d_forward_out")
|
utils.alert_not_deterministic("max_unpooling2d_forward_out")
|
||||||
|
|
||||||
torch._check(
|
torch._check(
|
||||||
|
|
@ -4323,33 +4323,33 @@ def meta_max_unpool2d(self_, indices, output_size):
|
||||||
oheight, owidth = output_size
|
oheight, owidth = output_size
|
||||||
|
|
||||||
torch._check(
|
torch._check(
|
||||||
self_.ndim in (3, 4),
|
self.ndim in (3, 4),
|
||||||
lambda: (
|
lambda: (
|
||||||
f"Input to max_unpooling2d should be a 3d or 4d Tensor, "
|
f"Input to max_unpooling2d should be a 3d or 4d Tensor, "
|
||||||
f"but got a tensor with {self_.ndim} dimensions."
|
f"but got a tensor with {self.ndim} dimensions."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
torch._check(
|
torch._check(
|
||||||
self_.shape == indices.shape,
|
self.shape == indices.shape,
|
||||||
lambda: (
|
lambda: (
|
||||||
f"Expected shape of indices to be same as that of the input tensor ({self_.shape}) "
|
f"Expected shape of indices to be same as that of the input tensor ({self.shape}) "
|
||||||
f"but got indices tensor with shape: {indices.shape}"
|
f"but got indices tensor with shape: {indices.shape}"
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
for i in range(1, self_.ndim):
|
for i in range(1, self.ndim):
|
||||||
torch._check(
|
torch._check(
|
||||||
self_.size(i) > 0,
|
self.size(i) > 0,
|
||||||
lambda: (
|
lambda: (
|
||||||
f"max_unpooling2d(): "
|
f"max_unpooling2d(): "
|
||||||
f"Expected input to have non-zero size for non-batch dimensions, "
|
f"Expected input to have non-zero size for non-batch dimensions, "
|
||||||
f"but got {self_.shape} with dimension {i} being empty."
|
f"but got {self.shape} with dimension {i} being empty."
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
self = self_.contiguous()
|
self = self.contiguous()
|
||||||
|
|
||||||
if self_.ndim == 3:
|
if self.ndim == 3:
|
||||||
nchannels = self.size(0)
|
nchannels = self.size(0)
|
||||||
result = self.new_empty((nchannels, oheight, owidth))
|
result = self.new_empty((nchannels, oheight, owidth))
|
||||||
else:
|
else:
|
||||||
|
|
@ -4409,18 +4409,18 @@ def _max_unpooling3d_shape_check(input, indices, output_size, stride, padding, f
|
||||||
|
|
||||||
@register_meta(aten.max_unpool3d)
|
@register_meta(aten.max_unpool3d)
|
||||||
@out_wrapper()
|
@out_wrapper()
|
||||||
def meta_max_unpool3d(self_, indices, output_size, stride, padding):
|
def meta_max_unpool3d(self, indices, output_size, stride, padding):
|
||||||
utils.alert_not_deterministic("max_unpooling3d_forward_out")
|
utils.alert_not_deterministic("max_unpooling3d_forward_out")
|
||||||
|
|
||||||
_max_unpooling3d_shape_check(
|
_max_unpooling3d_shape_check(
|
||||||
self_, indices, output_size, stride, padding, "max_unpooling3d()"
|
self, indices, output_size, stride, padding, "max_unpooling3d()"
|
||||||
)
|
)
|
||||||
|
|
||||||
self = self_.contiguous()
|
self = self.contiguous()
|
||||||
|
|
||||||
odepth, oheight, owidth = output_size
|
odepth, oheight, owidth = output_size
|
||||||
|
|
||||||
if self_.ndim == 4:
|
if self.ndim == 4:
|
||||||
nchannels = self.size(0)
|
nchannels = self.size(0)
|
||||||
result = self.new_empty((nchannels, odepth, oheight, owidth))
|
result = self.new_empty((nchannels, odepth, oheight, owidth))
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -7,14 +7,14 @@ import sys
|
||||||
import types
|
import types
|
||||||
from typing import Any, Callable, Dict, List, Set, Type, Union
|
from typing import Any, Callable, Dict, List, Set, Type, Union
|
||||||
|
|
||||||
import torch._C
|
import torch
|
||||||
import torch.utils._pytree as pytree
|
import torch.utils._pytree as pytree
|
||||||
from torch import _utils_internal
|
from torch import _utils_internal
|
||||||
from torch._functorch.pyfunctorch import dispatch_functorch
|
from torch._functorch.pyfunctorch import dispatch_functorch
|
||||||
from torch.utils._python_dispatch import TorchDispatchMode
|
from torch.utils._python_dispatch import TorchDispatchMode
|
||||||
|
|
||||||
# Query `hasattr` only once.
|
|
||||||
|
|
||||||
|
# Query `hasattr` only once.
|
||||||
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
|
_SET_GLOBAL_FLAGS = hasattr(sys, "getdlopenflags") and hasattr(sys, "setdlopenflags")
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -418,7 +418,6 @@ class HigherOrderOperator(OperatorBase):
|
||||||
def __call__(self_, *args, **kwargs): # noqa: B902
|
def __call__(self_, *args, **kwargs): # noqa: B902
|
||||||
# Dynamo already traces the body of HigherOrderOp beforehand when it
|
# Dynamo already traces the body of HigherOrderOp beforehand when it
|
||||||
# so no need to trace into it.
|
# so no need to trace into it.
|
||||||
import torch._dynamo
|
|
||||||
from torch._dynamo import disable
|
from torch._dynamo import disable
|
||||||
|
|
||||||
@disable
|
@disable
|
||||||
|
|
@ -721,15 +720,15 @@ class OpOverload(OperatorBase):
|
||||||
*self._schema.name.split("::"), self._overloadname
|
*self._schema.name.split("::"), self._overloadname
|
||||||
)
|
)
|
||||||
|
|
||||||
def __call__(self_, *args, **kwargs): # noqa: B902
|
# Use positional-only argument to avoid naming collision with aten ops arguments
|
||||||
# use `self_` to avoid naming collide with aten ops arguments that
|
# that are named "self". This way, all the aten ops can be called by kwargs.
|
||||||
# are named "self". This way, all the aten ops can be called by kwargs.
|
def __call__(self, /, *args, **kwargs):
|
||||||
return self_._op(*args, **kwargs)
|
return self._op(*args, **kwargs)
|
||||||
|
|
||||||
def redispatch(self_, keyset, *args, **kwargs): # noqa: B902
|
# Use positional-only argument to avoid naming collision with aten ops arguments
|
||||||
# use `self_` to avoid naming collide with aten ops arguments that
|
# that are named "self". This way, all the aten ops can be called by kwargs.
|
||||||
# are named "self". This way, all the aten ops can be called by kwargs.
|
def redispatch(self, /, keyset, *args, **kwargs):
|
||||||
return self_._handle.redispatch_boxed(keyset, *args, **kwargs)
|
return self._handle.redispatch_boxed(keyset, *args, **kwargs)
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash(self._op)
|
return hash(self._op)
|
||||||
|
|
@ -945,9 +944,9 @@ class TorchBindOpOverload(OpOverload):
|
||||||
if self in SIDE_EFFECTS:
|
if self in SIDE_EFFECTS:
|
||||||
del SIDE_EFFECTS[self]
|
del SIDE_EFFECTS[self]
|
||||||
|
|
||||||
# use `self_` to avoid naming collide with arguments that
|
# Use positional-only argument to avoid naming collision with aten ops arguments
|
||||||
# are named "self". This way, they can be called by kwargs.
|
# that are named "self". This way, all the aten ops can be called by kwargs.
|
||||||
def __call__(self_, *args, **kwargs): # noqa: B902
|
def __call__(self, /, *args, **kwargs):
|
||||||
if _must_dispatch_in_python(args, kwargs):
|
if _must_dispatch_in_python(args, kwargs):
|
||||||
# When any inputs are FakeScriptObject, we need to
|
# When any inputs are FakeScriptObject, we need to
|
||||||
# skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher
|
# skip c++ dispatcher and dispatch in python through _get_dispatch of python_dispatcher
|
||||||
|
|
@ -959,11 +958,9 @@ class TorchBindOpOverload(OpOverload):
|
||||||
# of the eagerly executing the op might change after tracing.
|
# of the eagerly executing the op might change after tracing.
|
||||||
# 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might
|
# 2. We don't want to register the op as effectful for all torchbind ops in ctor because this might
|
||||||
# cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction.
|
# cause unexpected behavior for some autograd.profiler ops e.g. profiler._record_function_exit._RecordFunction.
|
||||||
with self_._register_as_effectful_op_temporarily():
|
with self._register_as_effectful_op_temporarily():
|
||||||
return self_._dispatch_in_python(
|
return self._dispatch_in_python(args, kwargs, self._fallthrough_keys())
|
||||||
args, kwargs, self_._fallthrough_keys()
|
return self._op(*args, **kwargs)
|
||||||
)
|
|
||||||
return self_._op(*args, **kwargs)
|
|
||||||
|
|
||||||
def _dispatch_in_python(self, args, kwargs, fallthrough_keys):
|
def _dispatch_in_python(self, args, kwargs, fallthrough_keys):
|
||||||
non_fallthrough_keys = torch._C._dispatch_keyset_full()
|
non_fallthrough_keys = torch._C._dispatch_keyset_full()
|
||||||
|
|
@ -1107,10 +1104,9 @@ class OpOverloadPacket:
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(self._dir)
|
return iter(self._dir)
|
||||||
|
|
||||||
def __call__(self_, *args, **kwargs): # noqa: B902
|
# Use positional-only argument to avoid naming collision with aten ops arguments
|
||||||
# use `self_` to avoid naming collide with aten ops arguments that
|
# that are named "self". This way, all the aten ops can be called by kwargs.
|
||||||
# named "self". This way, all the aten ops can be called by kwargs.
|
def __call__(self, /, *args, **kwargs):
|
||||||
|
|
||||||
# overloading __call__ to ensure torch.ops.foo.bar()
|
# overloading __call__ to ensure torch.ops.foo.bar()
|
||||||
# is still callable from JIT
|
# is still callable from JIT
|
||||||
# We save the function ptr as the `op` attribute on
|
# We save the function ptr as the `op` attribute on
|
||||||
|
|
@ -1119,9 +1115,9 @@ class OpOverloadPacket:
|
||||||
# Directly calling OverloadPacket goes into C++, which will check
|
# Directly calling OverloadPacket goes into C++, which will check
|
||||||
# the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we
|
# the schema and cause an error for torchbind op when inputs consist of FakeScriptObject so we
|
||||||
# intercept it here and call TorchBindOpverload instead.
|
# intercept it here and call TorchBindOpverload instead.
|
||||||
if self_._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
|
if self._has_torchbind_op_overload and _must_dispatch_in_python(args, kwargs):
|
||||||
return _call_overload_packet_from_python(self_, args, kwargs)
|
return _call_overload_packet_from_python(self, args, kwargs)
|
||||||
return self_._op(*args, **(kwargs or {}))
|
return self._op(*args, **(kwargs or {}))
|
||||||
|
|
||||||
# TODO: use this to make a __dir__
|
# TODO: use this to make a __dir__
|
||||||
def overloads(self):
|
def overloads(self):
|
||||||
|
|
|
||||||
|
|
@ -7866,16 +7866,16 @@ class DistributedTest:
|
||||||
}
|
}
|
||||||
|
|
||||||
class ToyModel(torch.nn.Module):
|
class ToyModel(torch.nn.Module):
|
||||||
def __init__(_self): # noqa: B902
|
def __init__(self_): # noqa: B902
|
||||||
super().__init__()
|
super().__init__()
|
||||||
_self.lin = nn.Linear(10, 10, bias=False)
|
self_.lin = nn.Linear(10, 10, bias=False)
|
||||||
|
|
||||||
def forward(_self, x, expected_type): # noqa: B902
|
def forward(self_, x, expected_type): # noqa: B902
|
||||||
# Similar to scatter, the recursive to in the single-device
|
# Similar to scatter, the recursive to in the single-device
|
||||||
# case does not move tensors if they are in a custom type.
|
# case does not move tensors if they are in a custom type.
|
||||||
self.assertTrue(isinstance(x, expected_type))
|
self.assertTrue(isinstance(x, expected_type))
|
||||||
fwd_tensor = validators[expected_type](x)
|
fwd_tensor = validators[expected_type](x)
|
||||||
return _self.lin(fwd_tensor)
|
return self_.lin(fwd_tensor)
|
||||||
|
|
||||||
model = torch.nn.parallel.DistributedDataParallel(
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
ToyModel().to(self.rank), device_ids=[self.rank]
|
ToyModel().to(self.rank), device_ids=[self.rank]
|
||||||
|
|
@ -7929,11 +7929,11 @@ class DistributedTest:
|
||||||
b = torch.rand(batch, dim, device=self.rank)
|
b = torch.rand(batch, dim, device=self.rank)
|
||||||
|
|
||||||
class NamedTupleModule(torch.nn.Module):
|
class NamedTupleModule(torch.nn.Module):
|
||||||
def __init__(_self): # noqa: B902
|
def __init__(self_): # noqa: B902
|
||||||
super().__init__()
|
super().__init__()
|
||||||
_self.lin = nn.Linear(10, 1)
|
self_.lin = nn.Linear(10, 1)
|
||||||
|
|
||||||
def forward(_self, input, expected_type): # noqa: B902
|
def forward(self_, input, expected_type): # noqa: B902
|
||||||
# Without NamedTuple support, this would be of type tuple.
|
# Without NamedTuple support, this would be of type tuple.
|
||||||
self.assertTrue(
|
self.assertTrue(
|
||||||
isinstance(input, expected_type),
|
isinstance(input, expected_type),
|
||||||
|
|
@ -7942,7 +7942,7 @@ class DistributedTest:
|
||||||
self.assertEqual(input._fields, EXPECTED_FIELDS)
|
self.assertEqual(input._fields, EXPECTED_FIELDS)
|
||||||
self.assertEqual(a, input.a)
|
self.assertEqual(a, input.a)
|
||||||
self.assertEqual(b, input.b)
|
self.assertEqual(b, input.b)
|
||||||
return _self.lin(torch.mul(input.a, input.b))
|
return self_.lin(torch.mul(input.a, input.b))
|
||||||
|
|
||||||
model = torch.nn.parallel.DistributedDataParallel(
|
model = torch.nn.parallel.DistributedDataParallel(
|
||||||
NamedTupleModule().cuda(self.rank), device_ids=[self.rank]
|
NamedTupleModule().cuda(self.rank), device_ids=[self.rank]
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user