Add torch.ops per overload API (#72206)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72206

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D33955635

Pulled By: anjali411

fbshipit-source-id: 4fbd0c0c4d032bcbd9d9f362b4b6f84eec9ad047
(cherry picked from commit 85076245c9)
This commit is contained in:
anjali411 2022-02-11 09:14:13 -08:00 committed by PyTorch MergeBot
parent 2b0f334443
commit 50770b9e19
5 changed files with 115 additions and 77 deletions

View File

@ -670,6 +670,17 @@ TEST(OperatorRegistrationTest, whenRegisterWithLazyKernelAndCatchAll_AutogradLaz
whenRegisterWithLazyBackendsAndCatchAll_AutogradLazyBackendsIsNotFilled(DispatchKey::Lazy);
}
TEST(OperatorRegistrationTest, whenregisteringwithinvalidoverloadname) {
expectThrows<c10::Error>([] {
auto registrar = c10::RegisterOperators().op("_test::dummy.default", c10::RegisterOperators::options()
.kernel(DispatchKey::CPU, [] (const int64_t&) {}));
}, "default is not a legal overload name for aten operators");
expectThrows<c10::Error>([] {
auto registrar = c10::RegisterOperators().op("_test::dummy.__name__", c10::RegisterOperators::options()
.kernel(DispatchKey::CPU, [] (const int64_t&) {}));
}, "__name__ is not a legal overload name for aten operators");
}
TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringWithMismatchingCppSignatures_thenFails) {
expectThrows<c10::Error>([] {
auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
@ -1243,6 +1254,16 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
"(Dict(str, Dict(int, str)?[])[] a) -> Dict(str, Dict(int, str)?[])[]");
}
TEST(NewOperatorRegistrationTest, erroroutwithinvalidoverloadname) {
auto m = MAKE_TORCH_LIBRARY(_test);
expectThrows<c10::Error>([&] {
m.def("dummy.default(Tensor self) -> Tensor");
}, "default is not a legal overload name for aten operators");
expectThrows<c10::Error>([&] {
m.def("dummy.__name__(Tensor self) -> Tensor");
}, "__name__ is not a legal overload name for aten operators");
}
TEST(NewOperatorRegistrationTest, testBasics) {
auto m = MAKE_TORCH_LIBRARY(_test);
m.def("dummy(Tensor self) -> Tensor");

View File

@ -50,10 +50,6 @@ class TestCustomOperators(JitTestCase):
output = torch.ops._test.leaky_relu(torch.tensor([-1.0, 1.0]))
self.assertEqual(output, torch.tensor([-0.01, 1]))
def test_only_kwargs(self):
output = torch.ops._test.leaky_relu(self=torch.tensor(-1.0))
self.assertEqual(output, torch.tensor(-0.01))
def test_passing_too_many_args(self):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
@ -78,14 +74,6 @@ class TestCustomOperators(JitTestCase):
):
torch.ops.aten.type_as(torch.ones(5, 5))
def test_passing_an_argument_both_as_positional_and_kwarg(self):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
"Argument 'self' specified both as positional and keyword argument",
""
):
torch.ops._test.leaky_relu(torch.ones(5), self=torch.ones(5))
def test_passing_unknown_kwargs(self):
with self.assertRaisesRegexWithHighlight(
RuntimeError,

View File

@ -1,67 +1,64 @@
# Owner(s): ["module: unknown"]
# import torch
# import copy
import torch
import copy
from torch.testing._internal.common_utils import TestCase, run_tests
class TestPerOverloadAPI(TestCase):
# def test_basics_opoverloadpacket(self):
# # add is ony used as an example here. It is ok to update the test
# # if the semantics of add are modified in the future.
# add_packet = torch.ops.aten.add
def test_basics_opoverloadpacket(self):
# add is ony used as an example here. It is ok to update the test
# if the semantics of add are modified in the future.
add_packet = torch.ops.aten.add
# # class attributes
# self.assertEqual(add_packet.op_name, 'add')
# self.assertEqual(add_packet.qualified_op_name, 'aten.add')
# class attributes
self.assertEqual(add_packet.op_name, 'add')
self.assertEqual(add_packet.qualified_op_name, 'aten.add')
# # callable
# self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
# callable
self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
# # correct module
# self.assertEqual(add_packet.__module__, add_packet.op.__module__)
# correct module
self.assertEqual(add_packet.__module__, add_packet.op.__module__)
# # caching
# another_add_packet = torch.ops.aten.add
# self.assertEqual(id(add_packet), id(another_add_packet))
# caching
another_add_packet = torch.ops.aten.add
self.assertEqual(id(add_packet), id(another_add_packet))
# # deepcopy is a no-op
# self.assertEqual(id(add_packet), id(copy.deepcopy(add_packet)))
# deepcopy is a no-op
self.assertEqual(id(add_packet), id(copy.deepcopy(add_packet)))
# # pretty print
# self.assertEqual(str(add_packet), "OpOverloadPacket(op='aten.add')")
# pretty print
self.assertEqual(str(add_packet), "OpOverloadPacket(op='aten.add')")
# self.assertRaises(AttributeError, lambda: add_packet.foo)
self.assertRaises(AttributeError, lambda: add_packet.foo)
# def test_basics_opoverload(self):
# add_packet = torch.ops.aten.add
# add_tensoroverload = add_packet.Tensor
def test_basics_opoverload(self):
add_packet = torch.ops.aten.add
add_tensoroverload = add_packet.Tensor
# # class attributes
# self.assertEqual(add_tensoroverload.name, 'aten.add')
# self.assertEqual(add_tensoroverload.overload_name, 'Tensor')
# self.assertEqual(add_tensoroverload.overload_packet, add_packet)
# class attributes
self.assertEqual(add_tensoroverload.name, 'aten.add')
self.assertEqual(add_tensoroverload.overload_name, 'Tensor')
self.assertEqual(add_tensoroverload.overload_packet, add_packet)
# # deepcopy is a no-op
# self.assertEqual(id(add_tensoroverload), id(copy.deepcopy(add_tensoroverload)))
# deepcopy is a no-op
self.assertEqual(id(add_tensoroverload), id(copy.deepcopy(add_tensoroverload)))
# # caching
# another_add_tensoroverload = torch.ops.aten.add.Tensor
# self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload))
# caching
another_add_tensoroverload = torch.ops.aten.add.Tensor
self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload))
# # pretty print
# self.assertEqual(str(add_tensoroverload), "OpOverload(op='aten.add', overload='Tensor')")
# pretty print
self.assertEqual(str(add_tensoroverload), "OpOverload(op='aten.add', overload='Tensor')")
# # callable
# self.assertEqual(add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
# callable
self.assertEqual(add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
# a = torch.tensor(2)
# b = torch.tensor(0)
# torch.ops.aten.add.out(a, a, out=b)
# self.assertEqual(b, torch.tensor(4))
a = torch.tensor(2)
b = torch.tensor(0)
torch.ops.aten.add.out(a, a, out=b)
self.assertEqual(b, torch.tensor(4))
# self.assertRaises(RuntimeError, lambda: add_tensoroverload(a, a, out=b))
def do_nothing(self):
return
self.assertRaises(RuntimeError, lambda: add_tensoroverload(a, a, out=b))
if __name__ == '__main__':
run_tests()

View File

@ -99,39 +99,63 @@ class OpOverloadPacket:
if key == '__file__':
return 'torch.ops'
# ensure that query for dunder attributes that does not exist on
# opoverloadpacket but instead exists on the self._op object does not unnecessarily call
# `_get_operation_overload` (which is an expensive operation).
# This is done to prevent any potential slowdown. This list can be extended
# if there exists other attributes like `__name__` that only exist on self._op and not on the
# opoverloadpacket.
# This is ok since we are guaranteed that an overload name for an aten op can't start with '__'
try:
if key.startswith('__'):
return getattr(self._op, key)
except AttributeError:
# for consistency because it seems weird to
# throw an attribute error with a message containing
# an object name different from the one the attribute
# query was performed on.
raise AttributeError("'{}' can't have an overload name beginning with '__' and the "
"underlying op {} has no attribute {} either."
.format(str(self), str(self._op), key)) from None
try:
# This is ok since we are guaranteed that an overload name for an aten op can't be 'default'
use_key = '' if key == 'default' else key
# TODO: disallow access to overloads registered by JIT
op_ = torch._C._get_operation_overload(self._qualified_op_name, use_key)
op_ = torch._C._get_operation_overload(
self._qualified_op_name, use_key)
schema = torch._C._get_schema(self._qualified_op_name, use_key)
overload = OpOverload(self, op_, schema)
# cache the overload object
setattr(self, key, overload)
return overload
except RuntimeError:
try:
# This is added to maintain bc in case the user queries an attribute that exists on `self._op`
# which used to be returned before instead of the OpOverloadPacket
out = getattr(self._op, key)
return out
except AttributeError:
raise AttributeError("'{}' object has no attribute '{}'".format(str(self), key)) from None
raise AttributeError(
"The underlying op of '{}' has no overload name '{}'".format(str(self), key)
) from None
def __call__(self, *args, **kwargs):
# overloading __call__ to ensure torch.ops.foo.bar() is still callable from JIT
# We save the function ptr as the `op` attribute on OpOverloadPacket to access it here.
# overloading __call__ to ensure torch.ops.foo.bar()
# is still callable from JIT
# We save the function ptr as the `op` attribute on
# OpOverloadPacket to access it here.
return self._op(*args, **kwargs or {})
# Resolution of torch.fn is different from torch.ops.aten.fn
# torch.fn uses the Python argparser, matches with the appropriate schema, and calls into the unboxed version of the method
# torch.ops.aten.fn resolution is done via the mechanism defined in JIT. JIT creates a stack of all the overloads and
# then tries to match the correct one at runtime and always calls into the boxed version of the method
# Autograd codegen creates VariableType, TracerType, inplace or view type and python bindings
# Aten codegen generates tensor methods for the the tensor class
# torch.fn uses the Python argparser, matches with the
# appropriate schema, and calls into the unboxed version of the method
# torch.ops.aten.fn resolution is done via the mechanism defined in JIT.
# JIT creates a stack of all the overloads and then tries to match the
# correct one at runtime and always calls into the boxed version of the method
# Autograd codegen creates VariableType, TracerType,
# inplace or view type and python bindings.
# Aten codegen generates tensor methods for the the tensor class.
# _OpNamespace is a subclass of ModuleType because the torch script
# allows attribute lookups on modules only. Since we want torch.ops.foo.bar()
# to work from script, we need to ensure ops and foo are modules
class _OpNamespace(types.ModuleType):
"""
An op namespace to dynamically bind Operators into Python.
@ -170,13 +194,13 @@ class _OpNamespace(types.ModuleType):
# with qualified_op_name
torch.jit._builtins._register_builtin(op, qualified_op_name)
op.__module__ = self.__module__ + "." + namespace_name
# opoverloadpacket = OpOverloadPacket(qualified_op_name, op_name, op)
# opoverloadpacket.__module__ = self.__module__ + "." + namespace_name
opoverloadpacket = OpOverloadPacket(qualified_op_name, op_name, op)
opoverloadpacket.__module__ = self.__module__ + "." + namespace_name
# cache the opoverloadpacket to ensure that each op corresponds to
# a unique OpOverloadPacket object
# setattr(self, op_name, opoverloadpacket)
setattr(self, op_name, op)
return op
setattr(self, op_name, opoverloadpacket)
return opoverloadpacket
class _Ops(types.ModuleType):
__file__ = '_ops.py'
@ -220,5 +244,6 @@ class _Ops(types.ModuleType):
ctypes.CDLL(path)
self.loaded_libraries.add(path)
# The ops "namespace"
ops = _Ops()

View File

@ -111,6 +111,13 @@ struct SchemaParser {
if (L.nextIf('.')) {
overload_name = L.expect(TK_IDENT).text();
}
// default is used as an attribute on the `OpOverloadPacket`
// (obtained using `torch.ops.aten.foo`) to get the operator
// overload with overload name as an empty string
// and so shouldn't be used as an overload name
// also disallow dunder attribute names to be overload names
bool is_a_valid_overload_name = !((overload_name == "default") || (overload_name.rfind("__", 0) == 0));
TORCH_CHECK(is_a_valid_overload_name, overload_name, " is not a legal overload name for aten operators");
return {name, overload_name};
}