mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Add torch.ops per overload API (#72206)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/72206
Test Plan: Imported from OSS
Reviewed By: albanD
Differential Revision: D33955635
Pulled By: anjali411
fbshipit-source-id: 4fbd0c0c4d032bcbd9d9f362b4b6f84eec9ad047
(cherry picked from commit 85076245c9)
This commit is contained in:
parent
2b0f334443
commit
50770b9e19
|
|
@ -670,6 +670,17 @@ TEST(OperatorRegistrationTest, whenRegisterWithLazyKernelAndCatchAll_AutogradLaz
|
|||
whenRegisterWithLazyBackendsAndCatchAll_AutogradLazyBackendsIsNotFilled(DispatchKey::Lazy);
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest, whenregisteringwithinvalidoverloadname) {
|
||||
expectThrows<c10::Error>([] {
|
||||
auto registrar = c10::RegisterOperators().op("_test::dummy.default", c10::RegisterOperators::options()
|
||||
.kernel(DispatchKey::CPU, [] (const int64_t&) {}));
|
||||
}, "default is not a legal overload name for aten operators");
|
||||
expectThrows<c10::Error>([] {
|
||||
auto registrar = c10::RegisterOperators().op("_test::dummy.__name__", c10::RegisterOperators::options()
|
||||
.kernel(DispatchKey::CPU, [] (const int64_t&) {}));
|
||||
}, "__name__ is not a legal overload name for aten operators");
|
||||
}
|
||||
|
||||
TEST(OperatorRegistrationTest, givenLambdaKernel_whenRegisteringWithMismatchingCppSignatures_thenFails) {
|
||||
expectThrows<c10::Error>([] {
|
||||
auto registrar = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options()
|
||||
|
|
@ -1243,6 +1254,16 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) {
|
|||
"(Dict(str, Dict(int, str)?[])[] a) -> Dict(str, Dict(int, str)?[])[]");
|
||||
}
|
||||
|
||||
TEST(NewOperatorRegistrationTest, erroroutwithinvalidoverloadname) {
|
||||
auto m = MAKE_TORCH_LIBRARY(_test);
|
||||
expectThrows<c10::Error>([&] {
|
||||
m.def("dummy.default(Tensor self) -> Tensor");
|
||||
}, "default is not a legal overload name for aten operators");
|
||||
expectThrows<c10::Error>([&] {
|
||||
m.def("dummy.__name__(Tensor self) -> Tensor");
|
||||
}, "__name__ is not a legal overload name for aten operators");
|
||||
}
|
||||
|
||||
TEST(NewOperatorRegistrationTest, testBasics) {
|
||||
auto m = MAKE_TORCH_LIBRARY(_test);
|
||||
m.def("dummy(Tensor self) -> Tensor");
|
||||
|
|
|
|||
|
|
@ -50,10 +50,6 @@ class TestCustomOperators(JitTestCase):
|
|||
output = torch.ops._test.leaky_relu(torch.tensor([-1.0, 1.0]))
|
||||
self.assertEqual(output, torch.tensor([-0.01, 1]))
|
||||
|
||||
def test_only_kwargs(self):
|
||||
output = torch.ops._test.leaky_relu(self=torch.tensor(-1.0))
|
||||
self.assertEqual(output, torch.tensor(-0.01))
|
||||
|
||||
def test_passing_too_many_args(self):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
|
|
@ -78,14 +74,6 @@ class TestCustomOperators(JitTestCase):
|
|||
):
|
||||
torch.ops.aten.type_as(torch.ones(5, 5))
|
||||
|
||||
def test_passing_an_argument_both_as_positional_and_kwarg(self):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
"Argument 'self' specified both as positional and keyword argument",
|
||||
""
|
||||
):
|
||||
torch.ops._test.leaky_relu(torch.ones(5), self=torch.ones(5))
|
||||
|
||||
def test_passing_unknown_kwargs(self):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
|
|
|
|||
|
|
@ -1,67 +1,64 @@
|
|||
# Owner(s): ["module: unknown"]
|
||||
# import torch
|
||||
# import copy
|
||||
import torch
|
||||
import copy
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
|
||||
class TestPerOverloadAPI(TestCase):
|
||||
# def test_basics_opoverloadpacket(self):
|
||||
# # add is ony used as an example here. It is ok to update the test
|
||||
# # if the semantics of add are modified in the future.
|
||||
# add_packet = torch.ops.aten.add
|
||||
def test_basics_opoverloadpacket(self):
|
||||
# add is ony used as an example here. It is ok to update the test
|
||||
# if the semantics of add are modified in the future.
|
||||
add_packet = torch.ops.aten.add
|
||||
|
||||
# # class attributes
|
||||
# self.assertEqual(add_packet.op_name, 'add')
|
||||
# self.assertEqual(add_packet.qualified_op_name, 'aten.add')
|
||||
# class attributes
|
||||
self.assertEqual(add_packet.op_name, 'add')
|
||||
self.assertEqual(add_packet.qualified_op_name, 'aten.add')
|
||||
|
||||
# # callable
|
||||
# self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
|
||||
# callable
|
||||
self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
|
||||
|
||||
# # correct module
|
||||
# self.assertEqual(add_packet.__module__, add_packet.op.__module__)
|
||||
# correct module
|
||||
self.assertEqual(add_packet.__module__, add_packet.op.__module__)
|
||||
|
||||
# # caching
|
||||
# another_add_packet = torch.ops.aten.add
|
||||
# self.assertEqual(id(add_packet), id(another_add_packet))
|
||||
# caching
|
||||
another_add_packet = torch.ops.aten.add
|
||||
self.assertEqual(id(add_packet), id(another_add_packet))
|
||||
|
||||
# # deepcopy is a no-op
|
||||
# self.assertEqual(id(add_packet), id(copy.deepcopy(add_packet)))
|
||||
# deepcopy is a no-op
|
||||
self.assertEqual(id(add_packet), id(copy.deepcopy(add_packet)))
|
||||
|
||||
# # pretty print
|
||||
# self.assertEqual(str(add_packet), "OpOverloadPacket(op='aten.add')")
|
||||
# pretty print
|
||||
self.assertEqual(str(add_packet), "OpOverloadPacket(op='aten.add')")
|
||||
|
||||
# self.assertRaises(AttributeError, lambda: add_packet.foo)
|
||||
self.assertRaises(AttributeError, lambda: add_packet.foo)
|
||||
|
||||
# def test_basics_opoverload(self):
|
||||
# add_packet = torch.ops.aten.add
|
||||
# add_tensoroverload = add_packet.Tensor
|
||||
def test_basics_opoverload(self):
|
||||
add_packet = torch.ops.aten.add
|
||||
add_tensoroverload = add_packet.Tensor
|
||||
|
||||
# # class attributes
|
||||
# self.assertEqual(add_tensoroverload.name, 'aten.add')
|
||||
# self.assertEqual(add_tensoroverload.overload_name, 'Tensor')
|
||||
# self.assertEqual(add_tensoroverload.overload_packet, add_packet)
|
||||
# class attributes
|
||||
self.assertEqual(add_tensoroverload.name, 'aten.add')
|
||||
self.assertEqual(add_tensoroverload.overload_name, 'Tensor')
|
||||
self.assertEqual(add_tensoroverload.overload_packet, add_packet)
|
||||
|
||||
# # deepcopy is a no-op
|
||||
# self.assertEqual(id(add_tensoroverload), id(copy.deepcopy(add_tensoroverload)))
|
||||
# deepcopy is a no-op
|
||||
self.assertEqual(id(add_tensoroverload), id(copy.deepcopy(add_tensoroverload)))
|
||||
|
||||
# # caching
|
||||
# another_add_tensoroverload = torch.ops.aten.add.Tensor
|
||||
# self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload))
|
||||
# caching
|
||||
another_add_tensoroverload = torch.ops.aten.add.Tensor
|
||||
self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload))
|
||||
|
||||
# # pretty print
|
||||
# self.assertEqual(str(add_tensoroverload), "OpOverload(op='aten.add', overload='Tensor')")
|
||||
# pretty print
|
||||
self.assertEqual(str(add_tensoroverload), "OpOverload(op='aten.add', overload='Tensor')")
|
||||
|
||||
# # callable
|
||||
# self.assertEqual(add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
|
||||
# callable
|
||||
self.assertEqual(add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
|
||||
|
||||
# a = torch.tensor(2)
|
||||
# b = torch.tensor(0)
|
||||
# torch.ops.aten.add.out(a, a, out=b)
|
||||
# self.assertEqual(b, torch.tensor(4))
|
||||
a = torch.tensor(2)
|
||||
b = torch.tensor(0)
|
||||
torch.ops.aten.add.out(a, a, out=b)
|
||||
self.assertEqual(b, torch.tensor(4))
|
||||
|
||||
# self.assertRaises(RuntimeError, lambda: add_tensoroverload(a, a, out=b))
|
||||
|
||||
def do_nothing(self):
|
||||
return
|
||||
self.assertRaises(RuntimeError, lambda: add_tensoroverload(a, a, out=b))
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -99,39 +99,63 @@ class OpOverloadPacket:
|
|||
if key == '__file__':
|
||||
return 'torch.ops'
|
||||
|
||||
# ensure that query for dunder attributes that does not exist on
|
||||
# opoverloadpacket but instead exists on the self._op object does not unnecessarily call
|
||||
# `_get_operation_overload` (which is an expensive operation).
|
||||
# This is done to prevent any potential slowdown. This list can be extended
|
||||
# if there exists other attributes like `__name__` that only exist on self._op and not on the
|
||||
# opoverloadpacket.
|
||||
# This is ok since we are guaranteed that an overload name for an aten op can't start with '__'
|
||||
try:
|
||||
if key.startswith('__'):
|
||||
return getattr(self._op, key)
|
||||
except AttributeError:
|
||||
# for consistency because it seems weird to
|
||||
# throw an attribute error with a message containing
|
||||
# an object name different from the one the attribute
|
||||
# query was performed on.
|
||||
raise AttributeError("'{}' can't have an overload name beginning with '__' and the "
|
||||
"underlying op {} has no attribute {} either."
|
||||
.format(str(self), str(self._op), key)) from None
|
||||
|
||||
try:
|
||||
# This is ok since we are guaranteed that an overload name for an aten op can't be 'default'
|
||||
use_key = '' if key == 'default' else key
|
||||
# TODO: disallow access to overloads registered by JIT
|
||||
op_ = torch._C._get_operation_overload(self._qualified_op_name, use_key)
|
||||
op_ = torch._C._get_operation_overload(
|
||||
self._qualified_op_name, use_key)
|
||||
schema = torch._C._get_schema(self._qualified_op_name, use_key)
|
||||
overload = OpOverload(self, op_, schema)
|
||||
# cache the overload object
|
||||
setattr(self, key, overload)
|
||||
return overload
|
||||
except RuntimeError:
|
||||
try:
|
||||
# This is added to maintain bc in case the user queries an attribute that exists on `self._op`
|
||||
# which used to be returned before instead of the OpOverloadPacket
|
||||
out = getattr(self._op, key)
|
||||
return out
|
||||
except AttributeError:
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(str(self), key)) from None
|
||||
raise AttributeError(
|
||||
"The underlying op of '{}' has no overload name '{}'".format(str(self), key)
|
||||
) from None
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# overloading __call__ to ensure torch.ops.foo.bar() is still callable from JIT
|
||||
# We save the function ptr as the `op` attribute on OpOverloadPacket to access it here.
|
||||
# overloading __call__ to ensure torch.ops.foo.bar()
|
||||
# is still callable from JIT
|
||||
# We save the function ptr as the `op` attribute on
|
||||
# OpOverloadPacket to access it here.
|
||||
return self._op(*args, **kwargs or {})
|
||||
|
||||
# Resolution of torch.fn is different from torch.ops.aten.fn
|
||||
# torch.fn uses the Python argparser, matches with the appropriate schema, and calls into the unboxed version of the method
|
||||
# torch.ops.aten.fn resolution is done via the mechanism defined in JIT. JIT creates a stack of all the overloads and
|
||||
# then tries to match the correct one at runtime and always calls into the boxed version of the method
|
||||
# Autograd codegen creates VariableType, TracerType, inplace or view type and python bindings
|
||||
# Aten codegen generates tensor methods for the the tensor class
|
||||
# torch.fn uses the Python argparser, matches with the
|
||||
# appropriate schema, and calls into the unboxed version of the method
|
||||
# torch.ops.aten.fn resolution is done via the mechanism defined in JIT.
|
||||
# JIT creates a stack of all the overloads and then tries to match the
|
||||
# correct one at runtime and always calls into the boxed version of the method
|
||||
# Autograd codegen creates VariableType, TracerType,
|
||||
# inplace or view type and python bindings.
|
||||
# Aten codegen generates tensor methods for the the tensor class.
|
||||
|
||||
# _OpNamespace is a subclass of ModuleType because the torch script
|
||||
# allows attribute lookups on modules only. Since we want torch.ops.foo.bar()
|
||||
# to work from script, we need to ensure ops and foo are modules
|
||||
|
||||
|
||||
class _OpNamespace(types.ModuleType):
|
||||
"""
|
||||
An op namespace to dynamically bind Operators into Python.
|
||||
|
|
@ -170,13 +194,13 @@ class _OpNamespace(types.ModuleType):
|
|||
# with qualified_op_name
|
||||
torch.jit._builtins._register_builtin(op, qualified_op_name)
|
||||
op.__module__ = self.__module__ + "." + namespace_name
|
||||
# opoverloadpacket = OpOverloadPacket(qualified_op_name, op_name, op)
|
||||
# opoverloadpacket.__module__ = self.__module__ + "." + namespace_name
|
||||
opoverloadpacket = OpOverloadPacket(qualified_op_name, op_name, op)
|
||||
opoverloadpacket.__module__ = self.__module__ + "." + namespace_name
|
||||
# cache the opoverloadpacket to ensure that each op corresponds to
|
||||
# a unique OpOverloadPacket object
|
||||
# setattr(self, op_name, opoverloadpacket)
|
||||
setattr(self, op_name, op)
|
||||
return op
|
||||
setattr(self, op_name, opoverloadpacket)
|
||||
return opoverloadpacket
|
||||
|
||||
|
||||
class _Ops(types.ModuleType):
|
||||
__file__ = '_ops.py'
|
||||
|
|
@ -220,5 +244,6 @@ class _Ops(types.ModuleType):
|
|||
ctypes.CDLL(path)
|
||||
self.loaded_libraries.add(path)
|
||||
|
||||
|
||||
# The ops "namespace"
|
||||
ops = _Ops()
|
||||
|
|
|
|||
|
|
@ -111,6 +111,13 @@ struct SchemaParser {
|
|||
if (L.nextIf('.')) {
|
||||
overload_name = L.expect(TK_IDENT).text();
|
||||
}
|
||||
// default is used as an attribute on the `OpOverloadPacket`
|
||||
// (obtained using `torch.ops.aten.foo`) to get the operator
|
||||
// overload with overload name as an empty string
|
||||
// and so shouldn't be used as an overload name
|
||||
// also disallow dunder attribute names to be overload names
|
||||
bool is_a_valid_overload_name = !((overload_name == "default") || (overload_name.rfind("__", 0) == 0));
|
||||
TORCH_CHECK(is_a_valid_overload_name, overload_name, " is not a legal overload name for aten operators");
|
||||
return {name, overload_name};
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user