diff --git a/test/jit/test_custom_operators.py b/test/jit/test_custom_operators.py index feb3b8eb8fb..cdb973590cb 100644 --- a/test/jit/test_custom_operators.py +++ b/test/jit/test_custom_operators.py @@ -50,6 +50,10 @@ class TestCustomOperators(JitTestCase): output = torch.ops._test.leaky_relu(torch.tensor([-1.0, 1.0])) self.assertEqual(output, torch.tensor([-0.01, 1])) + def test_only_kwargs(self): + output = torch.ops._test.leaky_relu(self=torch.tensor(-1.0)) + self.assertEqual(output, torch.tensor(-0.01)) + def test_passing_too_many_args(self): with self.assertRaisesRegexWithHighlight( RuntimeError, @@ -74,6 +78,14 @@ class TestCustomOperators(JitTestCase): ): torch.ops.aten.type_as(torch.ones(5, 5)) + def test_passing_an_argument_both_as_positional_and_kwarg(self): + with self.assertRaisesRegexWithHighlight( + RuntimeError, + "Argument 'self' specified both as positional and keyword argument", + "" + ): + torch.ops._test.leaky_relu(torch.ones(5), self=torch.ones(5)) + def test_passing_unknown_kwargs(self): with self.assertRaisesRegexWithHighlight( RuntimeError, diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index ca69f0fb030..8f2c5c78c7c 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -724,32 +724,31 @@ class TestOperators(TestCase): x = torch.randn(2, 3, 4, requires_grad=True) self.assertONNX(lambda x: torch.cumsum(x, dim=1), x, opset_version=11) -# Github Issue: https://github.com/pytorch/pytorch/issues/71095 -# def test_c2_op(self): -# class MyModel(torch.nn.Module): -# def __init__(self): -# super(MyModel, self).__init__() -# -# def forward(self, scores, bbox_deltas, im_info, anchors): -# a, b = torch.ops._caffe2.GenerateProposals( -# (scores), (bbox_deltas), (im_info), (anchors), -# 2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0, True, -# ) -# return a, b -# -# model = MyModel() -# A = 4 -# H = 10 -# W = 8 -# img_count = 3 -# scores = torch.ones(img_count, A, H, W, dtype=torch.float32) -# bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W, -# dtype=torch.float32) -# bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W) -# im_info = torch.ones(img_count, 3, dtype=torch.float32) -# anchors = torch.ones(A, 4, dtype=torch.float32) -# inputs = (scores, bbox_deltas, im_info, anchors) -# self.assertONNX(model, inputs, custom_opsets={"org.pytorch._caffe2": 0}) + def test_c2_op(self): + class MyModel(torch.nn.Module): + def __init__(self): + super(MyModel, self).__init__() + + def forward(self, scores, bbox_deltas, im_info, anchors): + a, b = torch.ops._caffe2.GenerateProposals( + (scores), (bbox_deltas), (im_info), (anchors), + 2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0, True, + ) + return a, b + + model = MyModel() + A = 4 + H = 10 + W = 8 + img_count = 3 + scores = torch.ones(img_count, A, H, W, dtype=torch.float32) + bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W, + dtype=torch.float32) + bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W) + im_info = torch.ones(img_count, 3, dtype=torch.float32) + anchors = torch.ones(A, 4, dtype=torch.float32) + inputs = (scores, bbox_deltas, im_info, anchors) + self.assertONNX(model, inputs, custom_opsets={"org.pytorch._caffe2": 0}) def test_dict(self): class MyModel(torch.nn.Module): diff --git a/test/test_per_overload_api.py b/test/test_per_overload_api.py deleted file mode 100644 index e237cc915aa..00000000000 --- a/test/test_per_overload_api.py +++ /dev/null @@ -1,63 +0,0 @@ -import torch -from torch.testing._internal.common_utils import TestCase, run_tests -import copy - -class TestPerOverloadAPI(TestCase): - def test_basics_opoverloadpacket(self): - # add is ony used as an example here. It is ok to update the test - # if the semantics of add are modified in the future. - add_packet = torch.ops.aten.add - - # class attributes - self.assertEqual(add_packet.op_name, 'add') - self.assertEqual(add_packet.qualified_op_name, 'aten.add') - - # callable - self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5)) - - # correct module - self.assertEqual(add_packet.__module__, add_packet.op.__module__) - - # caching - another_add_packet = torch.ops.aten.add - self.assertEqual(id(add_packet), id(another_add_packet)) - - # deepcopy is a no-op - self.assertEqual(id(add_packet), id(copy.deepcopy(add_packet))) - - # pretty print - self.assertEqual(str(add_packet), "OpOverloadPacket(op='aten.add')") - - self.assertRaises(AttributeError, lambda: add_packet.foo) - - def test_basics_opoverload(self): - add_packet = torch.ops.aten.add - add_tensoroverload = add_packet.Tensor - - # class attributes - self.assertEqual(add_tensoroverload.name, 'aten.add') - self.assertEqual(add_tensoroverload.overload_name, 'Tensor') - self.assertEqual(add_tensoroverload.overload_packet, add_packet) - - # deepcopy is a no-op - self.assertEqual(id(add_tensoroverload), id(copy.deepcopy(add_tensoroverload))) - - # caching - another_add_tensoroverload = torch.ops.aten.add.Tensor - self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload)) - - # pretty print - self.assertEqual(str(add_tensoroverload), "OpOverload(op='aten.add', overload='Tensor')") - - # callable - self.assertEqual(add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5)) - - a = torch.tensor(2) - b = torch.tensor(0) - torch.ops.aten.add.out(a, a, out=b) - self.assertEqual(b, torch.tensor(4)) - - self.assertRaises(RuntimeError, lambda: add_tensoroverload(a, a, out=b)) - -if __name__ == '__main__': - run_tests() diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in index efc9323f62c..ccf487bd68d 100644 --- a/torch/_C/__init__.pyi.in +++ b/torch/_C/__init__.pyi.in @@ -197,8 +197,6 @@ def _jit_init() -> _bool: ... def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ... def _jit_unflatten(vars: List[Tensor], desc: IODescriptor) -> Any: ... def _jit_get_operation(op_name: str) -> Callable: ... -def _get_operation_overload(op_name: str, op_overload_name: str) -> Callable: ... -def _get_schema(op_name: str, overload_name: str) -> FunctionSchema: ... def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule', optimization_blocklist: Set[MobileOptimizerType], preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ... diff --git a/torch/_ops.py b/torch/_ops.py index a7c44b33f3f..555984d8554 100644 --- a/torch/_ops.py +++ b/torch/_ops.py @@ -25,109 +25,6 @@ def dl_open_guard(): if _SET_GLOBAL_FLAGS: sys.setdlopenflags(old_flags) -# Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object. -# You can obtain an OpOverload object through attribute query on OpOverloadPacket. -class OpOverload: - def __init__(self, overloadpacket, op, schema): - self._op = op - self._schema = schema - self._overloadpacket = overloadpacket - - # it's a no-op since OpOverload object is immutable and must be unique for a given op overload. - def __deepcopy__(self, memo=None): - return self - - def __str__(self): - return "OpOverload(op='{}.{}', overload='{}')".format(*self._schema.name.split("::"), self.overload_name) - - def __call__(self, *args, **kwargs): - return self._op(*args, **kwargs or {}) - - def __getattr__(self, key): - return getattr(self._op, key) - - # `my_namespace::my_op` - @property - def name(self): - return "{}.{}".format(*self._schema.name.split("::")) - - @property - def overload_name(self): - return self._schema.overload_name - - @property - def overload_packet(self): - return self._overloadpacket - - @property - def op(self): - return self._op - - # TODO: add more methods to expose information about input and output arguments - -# OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator -# You can obtain an OpOverload object through attribute query. -class OpOverloadPacket: - def __init__(self, qualified_op_name, op_name, op): - # These attributes are accessible on the object through the properties - # defined below but are immutable - self._qualified_op_name = qualified_op_name - self._op_name = op_name - self._op = op - - # it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op. - def __deepcopy__(self, memo=None): - return self - - def __str__(self): - return "OpOverloadPacket(op='{}.{}')".format(*self._qualified_op_name.split("::")) - - @property - def qualified_op_name(self): - return "{}.{}".format(*self._qualified_op_name.split("::")) - - @property - def op_name(self): - return self._op_name - - @property - def op(self): - return self._op - - def __getattr__(self, key): - # It is not a valid op_name when __file__ is passed in - if key == '__file__': - return 'torch.ops' - - try: - use_key = '' if key == 'default' else key - # TODO: disallow access to overloads registered by JIT - op_ = torch._C._get_operation_overload(self._qualified_op_name, use_key) - schema = torch._C._get_schema(self._qualified_op_name, use_key) - overload = OpOverload(self, op_, schema) - # cache the overload object - setattr(self, key, overload) - return overload - except RuntimeError: - try: - # This is added to maintain bc in case the user queries an attribute that exists on `self._op` - # which used to be returned before instead of the OpOverloadPacket - out = getattr(self._op, key) - return out - except AttributeError: - raise AttributeError("'{}' object has no attribute '{}'".format(str(self), key)) from None - - def __call__(self, *args, **kwargs): - # overloading __call__ to ensure torch.ops.foo.bar() is still callable from JIT - # We save the function ptr as the `op` attribute on OpOverloadPacket to access it here. - return self._op(*args, **kwargs or {}) - -# Resolution of torch.fn is different from torch.ops.aten.fn -# torch.fn uses the Python argparser, matches with the appropriate schema, and calls into the unboxed version of the method -# torch.ops.aten.fn resolution is done via the mechanism defined in JIT. JIT creates a stack of all the overloads and -# then tries to match the correct one at runtime and always calls into the boxed version of the method -# Autograd codegen creates VariableType, TracerType, inplace or view type and python bindings -# Aten codegen generates tensor methods for the the tensor class # _OpNamespace is a subclass of ModuleType because the torch script # allows attribute lookups on modules only. Since we want torch.ops.foo.bar() @@ -162,20 +59,14 @@ class _OpNamespace(types.ModuleType): return 'torch.ops' # Get the op `my_namespace::my_op` if available. This will also check # for overloads and raise an exception if there are more than one. - namespace_name = self.name - qualified_op_name = '{}::{}'.format(namespace_name, op_name) + qualified_op_name = '{}::{}'.format(self.name, op_name) op = torch._C._jit_get_operation(qualified_op_name) - # let the script frontend know that op is identical to the builtin op # with qualified_op_name torch.jit._builtins._register_builtin(op, qualified_op_name) - op.__module__ = self.__module__ + "." + namespace_name - opoverloadpacket = OpOverloadPacket(qualified_op_name, op_name, op) - opoverloadpacket.__module__ = self.__module__ + "." + namespace_name - # cache the opoverloadpacket to ensure that each op corresponds to - # a unique OpOverloadPacket object - setattr(self, op_name, opoverloadpacket) - return opoverloadpacket + setattr(self, op_name, op) + op.__module__ = self.__module__ + "." + self.name + return op class _Ops(types.ModuleType): __file__ = '_ops.py' diff --git a/torch/ao/quantization/quantization_mappings.py b/torch/ao/quantization/quantization_mappings.py index 27d8a20e8a8..8bffe0948bd 100644 --- a/torch/ao/quantization/quantization_mappings.py +++ b/torch/ao/quantization/quantization_mappings.py @@ -133,11 +133,11 @@ _INCLUDE_QCONFIG_PROPAGATE_LIST : Set[Callable] = { # Default mapping from floating point function or torch ops to quantized ops # TODO: merge with default static mapping DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS : Dict[Union[Callable, str], Callable] = { - F.elu: torch.ops.quantized.elu, - F.hardswish: torch.ops.quantized.hardswish, - F.instance_norm: torch.ops.quantized.instance_norm, - F.layer_norm: torch.ops.quantized.layer_norm, - F.leaky_relu: torch.ops.quantized.leaky_relu, + F.elu: torch._ops.ops.quantized.elu, + F.hardswish: torch._ops.ops.quantized.hardswish, + F.instance_norm: torch._ops.ops.quantized.instance_norm, + F.layer_norm: torch._ops.ops.quantized.layer_norm, + F.leaky_relu: torch._ops.ops.quantized.leaky_relu, } # mapping from module to output activation post process class diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp index 9386dab20f8..925ab21c390 100644 --- a/torch/csrc/jit/python/init.cpp +++ b/torch/csrc/jit/python/init.cpp @@ -1293,50 +1293,6 @@ void initJITBindings(PyObject* module) { }) .def("has_storage", &DeserializationStorageContext::hasStorage); - m.def( - "_get_schema", - [](const std::string& op_name, const std::string& overload_name) { - try { - auto symbol = Symbol::fromQualString(op_name); - auto operations = getAllOperatorsFor(symbol); - for (const auto& op : operations) { - if (op->schema().overload_name() == overload_name) { - return op->schema(); - } - } - throw std::runtime_error("Found no matching schema"); - } catch (const c10::Error& e) { - auto msg = torch::get_cpp_stacktraces_enabled() - ? e.what() - : e.what_without_backtrace(); - throw std::runtime_error(msg); - } - }); - - m.def( - "_get_operation_overload", - [](const std::string& op_name, const std::string& overload_name) { - try { - auto symbol = Symbol::fromQualString(op_name); - auto operations = getAllOperatorsFor(symbol); - for (const auto& op : operations) { - if (op->schema().overload_name() == overload_name) { - auto func = - py::cpp_function([op](py::args args, py::kwargs kwargs) { - return invokeOperatorFromPython({op}, args, kwargs); - }); - return func; - } - } - throw std::runtime_error("Found no matching operator overload"); - } catch (const c10::Error& e) { - auto msg = torch::get_cpp_stacktraces_enabled() - ? e.what() - : e.what_without_backtrace(); - throw std::runtime_error(msg); - } - }); - m.def( "_jit_get_operation", [](const std::string& op_name) { @@ -1412,11 +1368,8 @@ void initJITBindings(PyObject* module) { py::name(symbol.toUnqualString()), py::doc(docstring.str().c_str())); return func; - } catch (const c10::Error& e) { - auto msg = torch::get_cpp_stacktraces_enabled() - ? e.what() - : e.what_without_backtrace(); - throw std::runtime_error(msg); + } catch (const c10::Error& error) { + throw std::runtime_error(error.what_without_backtrace()); } }, py::arg("qualified_name")); diff --git a/torch/csrc/jit/python/python_sugared_value.cpp b/torch/csrc/jit/python/python_sugared_value.cpp index 87ab27a5552..1a4ac0370ac 100644 --- a/torch/csrc/jit/python/python_sugared_value.cpp +++ b/torch/csrc/jit/python/python_sugared_value.cpp @@ -1104,13 +1104,6 @@ std::shared_ptr toSugaredValue( } } - auto opoverloadpacket_type = - py::module::import("torch").attr("_ops").attr("OpOverloadPacket"); - py::bool_ is_overloadpacket = py::isinstance(obj, opoverloadpacket_type); - if (is_overloadpacket) { - obj = py::getattr(obj, "op"); - } - bool isRpcAvailable = py::cast( py::module::import("torch.distributed.rpc").attr("is_available")()); diff --git a/torch/fx/operator_schemas.py b/torch/fx/operator_schemas.py index 1e3e02ed7cf..d7ddc3e0360 100644 --- a/torch/fx/operator_schemas.py +++ b/torch/fx/operator_schemas.py @@ -7,7 +7,6 @@ import warnings from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING from torch._jit_internal import boolean_dispatched from ._compatibility import compatibility -from torch._ops import OpOverloadPacket if TYPE_CHECKING: from .node import Argument @@ -138,9 +137,6 @@ def get_signature_for_torch_op(op : Callable, return_schemas : bool = False): if override: return (override, None) if return_schemas else None - if isinstance(op, OpOverloadPacket): - op = op._op - aten_fn = torch.jit._builtins._find_builtin(op) if aten_fn is None: diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index a7efa0832ee..789c10d9dff 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -265,9 +265,6 @@ def infer_concrete_type_builder(nn_module, share_types=True): # Don't re-add anything we already added continue - isoverloadpacket = isinstance(value, torch._ops.OpOverloadPacket) - if isoverloadpacket: - value = value.op # Handle Python function attributes if inspect.isfunction(value): try: diff --git a/torch/jit/_script.py b/torch/jit/_script.py index bc79c416902..9ad8934d55b 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -1270,6 +1270,7 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None, return create_script_dict(obj) if isinstance(obj, list): return create_script_list(obj) + if inspect.isclass(obj): qualified_name = _qualified_name(obj) # If this type is a `nn.Module` subclass, they probably meant to pass diff --git a/torch/jit/annotations.py b/torch/jit/annotations.py index 45d708c6e68..6e6317a0cfe 100644 --- a/torch/jit/annotations.py +++ b/torch/jit/annotations.py @@ -23,7 +23,6 @@ if torch.distributed.rpc.is_available(): from .._jit_internal import RRef, is_rref from torch._C import RRefType -from torch._ops import OpOverloadPacket class Module(object): def __init__(self, name, members): @@ -63,10 +62,7 @@ class EvalEnv(object): return getattr(builtins, name, None) def get_signature(fn, rcb, loc, is_method): - if isinstance(fn, OpOverloadPacket): - signature = try_real_annotations(fn.op, loc) - else: - signature = try_real_annotations(fn, loc) + signature = try_real_annotations(fn, loc) if signature is not None and is_method: # If this is a method, then the signature will include a type for # `self`, but type comments do not contain a `self`. So strip it @@ -110,9 +106,6 @@ def is_vararg(the_callable): def get_param_names(fn, n_args): - if isinstance(fn, OpOverloadPacket): - fn = fn.op - if not is_function_or_method(fn) and hasattr(fn, '__call__') and is_function_or_method(fn.__call__): # noqa: B004 # De-sugar calls to classes fn = fn.__call__