mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
backout D33469839 (#71443)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71443
cogwheel test inline_cvr_infer_canary_pyper_model_publish is timing out.
The convert_fx call takes > 20 mins for local and local_ro sub modules, which used to take ~ 2 mins.
Test Plan:
Fblearn flow run
* the following cmd took 1113 seconds before the diff and 5002 seconds after.
flow-cli clone-locally 320014219 --run-as-secure-group pytorch_at_scale --operators pyper_model_publish_workflow.pyper_model_publish_workflow.process_torch_package_model_files.process_non_sparse_parameters[0]
Cogwheel test
* Cogwheel test with packages in B3588 (the last good run) took 4694.48s
* Cogwheel test with packages in B3590 (the first timeout) took 13975.83s
* Cogwheel test with the following packages took 4535.04s
* all packages in B3588 except the model publish
* the model publish built with D33469839 (043e84b3d2) reversed (created D33633570)
Reviewed By: albanD, jerryzh168
Differential Revision: D33633570
fbshipit-source-id: dc5e777c48a90c551641a3f79126461f6a60449e
This commit is contained in:
parent
32472884ec
commit
03ab65023a
|
|
@ -50,6 +50,10 @@ class TestCustomOperators(JitTestCase):
|
||||||
output = torch.ops._test.leaky_relu(torch.tensor([-1.0, 1.0]))
|
output = torch.ops._test.leaky_relu(torch.tensor([-1.0, 1.0]))
|
||||||
self.assertEqual(output, torch.tensor([-0.01, 1]))
|
self.assertEqual(output, torch.tensor([-0.01, 1]))
|
||||||
|
|
||||||
|
def test_only_kwargs(self):
|
||||||
|
output = torch.ops._test.leaky_relu(self=torch.tensor(-1.0))
|
||||||
|
self.assertEqual(output, torch.tensor(-0.01))
|
||||||
|
|
||||||
def test_passing_too_many_args(self):
|
def test_passing_too_many_args(self):
|
||||||
with self.assertRaisesRegexWithHighlight(
|
with self.assertRaisesRegexWithHighlight(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
|
|
@ -74,6 +78,14 @@ class TestCustomOperators(JitTestCase):
|
||||||
):
|
):
|
||||||
torch.ops.aten.type_as(torch.ones(5, 5))
|
torch.ops.aten.type_as(torch.ones(5, 5))
|
||||||
|
|
||||||
|
def test_passing_an_argument_both_as_positional_and_kwarg(self):
|
||||||
|
with self.assertRaisesRegexWithHighlight(
|
||||||
|
RuntimeError,
|
||||||
|
"Argument 'self' specified both as positional and keyword argument",
|
||||||
|
""
|
||||||
|
):
|
||||||
|
torch.ops._test.leaky_relu(torch.ones(5), self=torch.ones(5))
|
||||||
|
|
||||||
def test_passing_unknown_kwargs(self):
|
def test_passing_unknown_kwargs(self):
|
||||||
with self.assertRaisesRegexWithHighlight(
|
with self.assertRaisesRegexWithHighlight(
|
||||||
RuntimeError,
|
RuntimeError,
|
||||||
|
|
|
||||||
|
|
@ -724,32 +724,31 @@ class TestOperators(TestCase):
|
||||||
x = torch.randn(2, 3, 4, requires_grad=True)
|
x = torch.randn(2, 3, 4, requires_grad=True)
|
||||||
self.assertONNX(lambda x: torch.cumsum(x, dim=1), x, opset_version=11)
|
self.assertONNX(lambda x: torch.cumsum(x, dim=1), x, opset_version=11)
|
||||||
|
|
||||||
# Github Issue: https://github.com/pytorch/pytorch/issues/71095
|
def test_c2_op(self):
|
||||||
# def test_c2_op(self):
|
class MyModel(torch.nn.Module):
|
||||||
# class MyModel(torch.nn.Module):
|
def __init__(self):
|
||||||
# def __init__(self):
|
super(MyModel, self).__init__()
|
||||||
# super(MyModel, self).__init__()
|
|
||||||
#
|
def forward(self, scores, bbox_deltas, im_info, anchors):
|
||||||
# def forward(self, scores, bbox_deltas, im_info, anchors):
|
a, b = torch.ops._caffe2.GenerateProposals(
|
||||||
# a, b = torch.ops._caffe2.GenerateProposals(
|
(scores), (bbox_deltas), (im_info), (anchors),
|
||||||
# (scores), (bbox_deltas), (im_info), (anchors),
|
2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0, True,
|
||||||
# 2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0, True,
|
)
|
||||||
# )
|
return a, b
|
||||||
# return a, b
|
|
||||||
#
|
model = MyModel()
|
||||||
# model = MyModel()
|
A = 4
|
||||||
# A = 4
|
H = 10
|
||||||
# H = 10
|
W = 8
|
||||||
# W = 8
|
img_count = 3
|
||||||
# img_count = 3
|
scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
|
||||||
# scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
|
bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W,
|
||||||
# bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W,
|
dtype=torch.float32)
|
||||||
# dtype=torch.float32)
|
bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
|
||||||
# bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
|
im_info = torch.ones(img_count, 3, dtype=torch.float32)
|
||||||
# im_info = torch.ones(img_count, 3, dtype=torch.float32)
|
anchors = torch.ones(A, 4, dtype=torch.float32)
|
||||||
# anchors = torch.ones(A, 4, dtype=torch.float32)
|
inputs = (scores, bbox_deltas, im_info, anchors)
|
||||||
# inputs = (scores, bbox_deltas, im_info, anchors)
|
self.assertONNX(model, inputs, custom_opsets={"org.pytorch._caffe2": 0})
|
||||||
# self.assertONNX(model, inputs, custom_opsets={"org.pytorch._caffe2": 0})
|
|
||||||
|
|
||||||
def test_dict(self):
|
def test_dict(self):
|
||||||
class MyModel(torch.nn.Module):
|
class MyModel(torch.nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -1,63 +0,0 @@
|
||||||
import torch
|
|
||||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
|
||||||
import copy
|
|
||||||
|
|
||||||
class TestPerOverloadAPI(TestCase):
|
|
||||||
def test_basics_opoverloadpacket(self):
|
|
||||||
# add is ony used as an example here. It is ok to update the test
|
|
||||||
# if the semantics of add are modified in the future.
|
|
||||||
add_packet = torch.ops.aten.add
|
|
||||||
|
|
||||||
# class attributes
|
|
||||||
self.assertEqual(add_packet.op_name, 'add')
|
|
||||||
self.assertEqual(add_packet.qualified_op_name, 'aten.add')
|
|
||||||
|
|
||||||
# callable
|
|
||||||
self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
|
|
||||||
|
|
||||||
# correct module
|
|
||||||
self.assertEqual(add_packet.__module__, add_packet.op.__module__)
|
|
||||||
|
|
||||||
# caching
|
|
||||||
another_add_packet = torch.ops.aten.add
|
|
||||||
self.assertEqual(id(add_packet), id(another_add_packet))
|
|
||||||
|
|
||||||
# deepcopy is a no-op
|
|
||||||
self.assertEqual(id(add_packet), id(copy.deepcopy(add_packet)))
|
|
||||||
|
|
||||||
# pretty print
|
|
||||||
self.assertEqual(str(add_packet), "OpOverloadPacket(op='aten.add')")
|
|
||||||
|
|
||||||
self.assertRaises(AttributeError, lambda: add_packet.foo)
|
|
||||||
|
|
||||||
def test_basics_opoverload(self):
|
|
||||||
add_packet = torch.ops.aten.add
|
|
||||||
add_tensoroverload = add_packet.Tensor
|
|
||||||
|
|
||||||
# class attributes
|
|
||||||
self.assertEqual(add_tensoroverload.name, 'aten.add')
|
|
||||||
self.assertEqual(add_tensoroverload.overload_name, 'Tensor')
|
|
||||||
self.assertEqual(add_tensoroverload.overload_packet, add_packet)
|
|
||||||
|
|
||||||
# deepcopy is a no-op
|
|
||||||
self.assertEqual(id(add_tensoroverload), id(copy.deepcopy(add_tensoroverload)))
|
|
||||||
|
|
||||||
# caching
|
|
||||||
another_add_tensoroverload = torch.ops.aten.add.Tensor
|
|
||||||
self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload))
|
|
||||||
|
|
||||||
# pretty print
|
|
||||||
self.assertEqual(str(add_tensoroverload), "OpOverload(op='aten.add', overload='Tensor')")
|
|
||||||
|
|
||||||
# callable
|
|
||||||
self.assertEqual(add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
|
|
||||||
|
|
||||||
a = torch.tensor(2)
|
|
||||||
b = torch.tensor(0)
|
|
||||||
torch.ops.aten.add.out(a, a, out=b)
|
|
||||||
self.assertEqual(b, torch.tensor(4))
|
|
||||||
|
|
||||||
self.assertRaises(RuntimeError, lambda: add_tensoroverload(a, a, out=b))
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
run_tests()
|
|
||||||
|
|
@ -197,8 +197,6 @@ def _jit_init() -> _bool: ...
|
||||||
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
|
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
|
||||||
def _jit_unflatten(vars: List[Tensor], desc: IODescriptor) -> Any: ...
|
def _jit_unflatten(vars: List[Tensor], desc: IODescriptor) -> Any: ...
|
||||||
def _jit_get_operation(op_name: str) -> Callable: ...
|
def _jit_get_operation(op_name: str) -> Callable: ...
|
||||||
def _get_operation_overload(op_name: str, op_overload_name: str) -> Callable: ...
|
|
||||||
def _get_schema(op_name: str, overload_name: str) -> FunctionSchema: ...
|
|
||||||
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule',
|
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule',
|
||||||
optimization_blocklist: Set[MobileOptimizerType],
|
optimization_blocklist: Set[MobileOptimizerType],
|
||||||
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
|
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
|
||||||
|
|
|
||||||
117
torch/_ops.py
117
torch/_ops.py
|
|
@ -25,109 +25,6 @@ def dl_open_guard():
|
||||||
if _SET_GLOBAL_FLAGS:
|
if _SET_GLOBAL_FLAGS:
|
||||||
sys.setdlopenflags(old_flags)
|
sys.setdlopenflags(old_flags)
|
||||||
|
|
||||||
# Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
|
|
||||||
# You can obtain an OpOverload object through attribute query on OpOverloadPacket.
|
|
||||||
class OpOverload:
|
|
||||||
def __init__(self, overloadpacket, op, schema):
|
|
||||||
self._op = op
|
|
||||||
self._schema = schema
|
|
||||||
self._overloadpacket = overloadpacket
|
|
||||||
|
|
||||||
# it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
|
|
||||||
def __deepcopy__(self, memo=None):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return "OpOverload(op='{}.{}', overload='{}')".format(*self._schema.name.split("::"), self.overload_name)
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
return self._op(*args, **kwargs or {})
|
|
||||||
|
|
||||||
def __getattr__(self, key):
|
|
||||||
return getattr(self._op, key)
|
|
||||||
|
|
||||||
# `my_namespace::my_op`
|
|
||||||
@property
|
|
||||||
def name(self):
|
|
||||||
return "{}.{}".format(*self._schema.name.split("::"))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def overload_name(self):
|
|
||||||
return self._schema.overload_name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def overload_packet(self):
|
|
||||||
return self._overloadpacket
|
|
||||||
|
|
||||||
@property
|
|
||||||
def op(self):
|
|
||||||
return self._op
|
|
||||||
|
|
||||||
# TODO: add more methods to expose information about input and output arguments
|
|
||||||
|
|
||||||
# OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
|
|
||||||
# You can obtain an OpOverload object through attribute query.
|
|
||||||
class OpOverloadPacket:
|
|
||||||
def __init__(self, qualified_op_name, op_name, op):
|
|
||||||
# These attributes are accessible on the object through the properties
|
|
||||||
# defined below but are immutable
|
|
||||||
self._qualified_op_name = qualified_op_name
|
|
||||||
self._op_name = op_name
|
|
||||||
self._op = op
|
|
||||||
|
|
||||||
# it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op.
|
|
||||||
def __deepcopy__(self, memo=None):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __str__(self):
|
|
||||||
return "OpOverloadPacket(op='{}.{}')".format(*self._qualified_op_name.split("::"))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def qualified_op_name(self):
|
|
||||||
return "{}.{}".format(*self._qualified_op_name.split("::"))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def op_name(self):
|
|
||||||
return self._op_name
|
|
||||||
|
|
||||||
@property
|
|
||||||
def op(self):
|
|
||||||
return self._op
|
|
||||||
|
|
||||||
def __getattr__(self, key):
|
|
||||||
# It is not a valid op_name when __file__ is passed in
|
|
||||||
if key == '__file__':
|
|
||||||
return 'torch.ops'
|
|
||||||
|
|
||||||
try:
|
|
||||||
use_key = '' if key == 'default' else key
|
|
||||||
# TODO: disallow access to overloads registered by JIT
|
|
||||||
op_ = torch._C._get_operation_overload(self._qualified_op_name, use_key)
|
|
||||||
schema = torch._C._get_schema(self._qualified_op_name, use_key)
|
|
||||||
overload = OpOverload(self, op_, schema)
|
|
||||||
# cache the overload object
|
|
||||||
setattr(self, key, overload)
|
|
||||||
return overload
|
|
||||||
except RuntimeError:
|
|
||||||
try:
|
|
||||||
# This is added to maintain bc in case the user queries an attribute that exists on `self._op`
|
|
||||||
# which used to be returned before instead of the OpOverloadPacket
|
|
||||||
out = getattr(self._op, key)
|
|
||||||
return out
|
|
||||||
except AttributeError:
|
|
||||||
raise AttributeError("'{}' object has no attribute '{}'".format(str(self), key)) from None
|
|
||||||
|
|
||||||
def __call__(self, *args, **kwargs):
|
|
||||||
# overloading __call__ to ensure torch.ops.foo.bar() is still callable from JIT
|
|
||||||
# We save the function ptr as the `op` attribute on OpOverloadPacket to access it here.
|
|
||||||
return self._op(*args, **kwargs or {})
|
|
||||||
|
|
||||||
# Resolution of torch.fn is different from torch.ops.aten.fn
|
|
||||||
# torch.fn uses the Python argparser, matches with the appropriate schema, and calls into the unboxed version of the method
|
|
||||||
# torch.ops.aten.fn resolution is done via the mechanism defined in JIT. JIT creates a stack of all the overloads and
|
|
||||||
# then tries to match the correct one at runtime and always calls into the boxed version of the method
|
|
||||||
# Autograd codegen creates VariableType, TracerType, inplace or view type and python bindings
|
|
||||||
# Aten codegen generates tensor methods for the the tensor class
|
|
||||||
|
|
||||||
# _OpNamespace is a subclass of ModuleType because the torch script
|
# _OpNamespace is a subclass of ModuleType because the torch script
|
||||||
# allows attribute lookups on modules only. Since we want torch.ops.foo.bar()
|
# allows attribute lookups on modules only. Since we want torch.ops.foo.bar()
|
||||||
|
|
@ -162,20 +59,14 @@ class _OpNamespace(types.ModuleType):
|
||||||
return 'torch.ops'
|
return 'torch.ops'
|
||||||
# Get the op `my_namespace::my_op` if available. This will also check
|
# Get the op `my_namespace::my_op` if available. This will also check
|
||||||
# for overloads and raise an exception if there are more than one.
|
# for overloads and raise an exception if there are more than one.
|
||||||
namespace_name = self.name
|
qualified_op_name = '{}::{}'.format(self.name, op_name)
|
||||||
qualified_op_name = '{}::{}'.format(namespace_name, op_name)
|
|
||||||
op = torch._C._jit_get_operation(qualified_op_name)
|
op = torch._C._jit_get_operation(qualified_op_name)
|
||||||
|
|
||||||
# let the script frontend know that op is identical to the builtin op
|
# let the script frontend know that op is identical to the builtin op
|
||||||
# with qualified_op_name
|
# with qualified_op_name
|
||||||
torch.jit._builtins._register_builtin(op, qualified_op_name)
|
torch.jit._builtins._register_builtin(op, qualified_op_name)
|
||||||
op.__module__ = self.__module__ + "." + namespace_name
|
setattr(self, op_name, op)
|
||||||
opoverloadpacket = OpOverloadPacket(qualified_op_name, op_name, op)
|
op.__module__ = self.__module__ + "." + self.name
|
||||||
opoverloadpacket.__module__ = self.__module__ + "." + namespace_name
|
return op
|
||||||
# cache the opoverloadpacket to ensure that each op corresponds to
|
|
||||||
# a unique OpOverloadPacket object
|
|
||||||
setattr(self, op_name, opoverloadpacket)
|
|
||||||
return opoverloadpacket
|
|
||||||
|
|
||||||
class _Ops(types.ModuleType):
|
class _Ops(types.ModuleType):
|
||||||
__file__ = '_ops.py'
|
__file__ = '_ops.py'
|
||||||
|
|
|
||||||
|
|
@ -133,11 +133,11 @@ _INCLUDE_QCONFIG_PROPAGATE_LIST : Set[Callable] = {
|
||||||
# Default mapping from floating point function or torch ops to quantized ops
|
# Default mapping from floating point function or torch ops to quantized ops
|
||||||
# TODO: merge with default static mapping
|
# TODO: merge with default static mapping
|
||||||
DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS : Dict[Union[Callable, str], Callable] = {
|
DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS : Dict[Union[Callable, str], Callable] = {
|
||||||
F.elu: torch.ops.quantized.elu,
|
F.elu: torch._ops.ops.quantized.elu,
|
||||||
F.hardswish: torch.ops.quantized.hardswish,
|
F.hardswish: torch._ops.ops.quantized.hardswish,
|
||||||
F.instance_norm: torch.ops.quantized.instance_norm,
|
F.instance_norm: torch._ops.ops.quantized.instance_norm,
|
||||||
F.layer_norm: torch.ops.quantized.layer_norm,
|
F.layer_norm: torch._ops.ops.quantized.layer_norm,
|
||||||
F.leaky_relu: torch.ops.quantized.leaky_relu,
|
F.leaky_relu: torch._ops.ops.quantized.leaky_relu,
|
||||||
}
|
}
|
||||||
|
|
||||||
# mapping from module to output activation post process class
|
# mapping from module to output activation post process class
|
||||||
|
|
|
||||||
|
|
@ -1293,50 +1293,6 @@ void initJITBindings(PyObject* module) {
|
||||||
})
|
})
|
||||||
.def("has_storage", &DeserializationStorageContext::hasStorage);
|
.def("has_storage", &DeserializationStorageContext::hasStorage);
|
||||||
|
|
||||||
m.def(
|
|
||||||
"_get_schema",
|
|
||||||
[](const std::string& op_name, const std::string& overload_name) {
|
|
||||||
try {
|
|
||||||
auto symbol = Symbol::fromQualString(op_name);
|
|
||||||
auto operations = getAllOperatorsFor(symbol);
|
|
||||||
for (const auto& op : operations) {
|
|
||||||
if (op->schema().overload_name() == overload_name) {
|
|
||||||
return op->schema();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
throw std::runtime_error("Found no matching schema");
|
|
||||||
} catch (const c10::Error& e) {
|
|
||||||
auto msg = torch::get_cpp_stacktraces_enabled()
|
|
||||||
? e.what()
|
|
||||||
: e.what_without_backtrace();
|
|
||||||
throw std::runtime_error(msg);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
m.def(
|
|
||||||
"_get_operation_overload",
|
|
||||||
[](const std::string& op_name, const std::string& overload_name) {
|
|
||||||
try {
|
|
||||||
auto symbol = Symbol::fromQualString(op_name);
|
|
||||||
auto operations = getAllOperatorsFor(symbol);
|
|
||||||
for (const auto& op : operations) {
|
|
||||||
if (op->schema().overload_name() == overload_name) {
|
|
||||||
auto func =
|
|
||||||
py::cpp_function([op](py::args args, py::kwargs kwargs) {
|
|
||||||
return invokeOperatorFromPython({op}, args, kwargs);
|
|
||||||
});
|
|
||||||
return func;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
throw std::runtime_error("Found no matching operator overload");
|
|
||||||
} catch (const c10::Error& e) {
|
|
||||||
auto msg = torch::get_cpp_stacktraces_enabled()
|
|
||||||
? e.what()
|
|
||||||
: e.what_without_backtrace();
|
|
||||||
throw std::runtime_error(msg);
|
|
||||||
}
|
|
||||||
});
|
|
||||||
|
|
||||||
m.def(
|
m.def(
|
||||||
"_jit_get_operation",
|
"_jit_get_operation",
|
||||||
[](const std::string& op_name) {
|
[](const std::string& op_name) {
|
||||||
|
|
@ -1412,11 +1368,8 @@ void initJITBindings(PyObject* module) {
|
||||||
py::name(symbol.toUnqualString()),
|
py::name(symbol.toUnqualString()),
|
||||||
py::doc(docstring.str().c_str()));
|
py::doc(docstring.str().c_str()));
|
||||||
return func;
|
return func;
|
||||||
} catch (const c10::Error& e) {
|
} catch (const c10::Error& error) {
|
||||||
auto msg = torch::get_cpp_stacktraces_enabled()
|
throw std::runtime_error(error.what_without_backtrace());
|
||||||
? e.what()
|
|
||||||
: e.what_without_backtrace();
|
|
||||||
throw std::runtime_error(msg);
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
py::arg("qualified_name"));
|
py::arg("qualified_name"));
|
||||||
|
|
|
||||||
|
|
@ -1104,13 +1104,6 @@ std::shared_ptr<SugaredValue> toSugaredValue(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
auto opoverloadpacket_type =
|
|
||||||
py::module::import("torch").attr("_ops").attr("OpOverloadPacket");
|
|
||||||
py::bool_ is_overloadpacket = py::isinstance(obj, opoverloadpacket_type);
|
|
||||||
if (is_overloadpacket) {
|
|
||||||
obj = py::getattr(obj, "op");
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isRpcAvailable = py::cast<bool>(
|
bool isRpcAvailable = py::cast<bool>(
|
||||||
py::module::import("torch.distributed.rpc").attr("is_available")());
|
py::module::import("torch.distributed.rpc").attr("is_available")());
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,6 @@ import warnings
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING
|
from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING
|
||||||
from torch._jit_internal import boolean_dispatched
|
from torch._jit_internal import boolean_dispatched
|
||||||
from ._compatibility import compatibility
|
from ._compatibility import compatibility
|
||||||
from torch._ops import OpOverloadPacket
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .node import Argument
|
from .node import Argument
|
||||||
|
|
@ -138,9 +137,6 @@ def get_signature_for_torch_op(op : Callable, return_schemas : bool = False):
|
||||||
if override:
|
if override:
|
||||||
return (override, None) if return_schemas else None
|
return (override, None) if return_schemas else None
|
||||||
|
|
||||||
if isinstance(op, OpOverloadPacket):
|
|
||||||
op = op._op
|
|
||||||
|
|
||||||
aten_fn = torch.jit._builtins._find_builtin(op)
|
aten_fn = torch.jit._builtins._find_builtin(op)
|
||||||
|
|
||||||
if aten_fn is None:
|
if aten_fn is None:
|
||||||
|
|
|
||||||
|
|
@ -265,9 +265,6 @@ def infer_concrete_type_builder(nn_module, share_types=True):
|
||||||
# Don't re-add anything we already added
|
# Don't re-add anything we already added
|
||||||
continue
|
continue
|
||||||
|
|
||||||
isoverloadpacket = isinstance(value, torch._ops.OpOverloadPacket)
|
|
||||||
if isoverloadpacket:
|
|
||||||
value = value.op
|
|
||||||
# Handle Python function attributes
|
# Handle Python function attributes
|
||||||
if inspect.isfunction(value):
|
if inspect.isfunction(value):
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -1270,6 +1270,7 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None,
|
||||||
return create_script_dict(obj)
|
return create_script_dict(obj)
|
||||||
if isinstance(obj, list):
|
if isinstance(obj, list):
|
||||||
return create_script_list(obj)
|
return create_script_list(obj)
|
||||||
|
|
||||||
if inspect.isclass(obj):
|
if inspect.isclass(obj):
|
||||||
qualified_name = _qualified_name(obj)
|
qualified_name = _qualified_name(obj)
|
||||||
# If this type is a `nn.Module` subclass, they probably meant to pass
|
# If this type is a `nn.Module` subclass, they probably meant to pass
|
||||||
|
|
|
||||||
|
|
@ -23,7 +23,6 @@ if torch.distributed.rpc.is_available():
|
||||||
from .._jit_internal import RRef, is_rref
|
from .._jit_internal import RRef, is_rref
|
||||||
from torch._C import RRefType
|
from torch._C import RRefType
|
||||||
|
|
||||||
from torch._ops import OpOverloadPacket
|
|
||||||
|
|
||||||
class Module(object):
|
class Module(object):
|
||||||
def __init__(self, name, members):
|
def __init__(self, name, members):
|
||||||
|
|
@ -63,10 +62,7 @@ class EvalEnv(object):
|
||||||
return getattr(builtins, name, None)
|
return getattr(builtins, name, None)
|
||||||
|
|
||||||
def get_signature(fn, rcb, loc, is_method):
|
def get_signature(fn, rcb, loc, is_method):
|
||||||
if isinstance(fn, OpOverloadPacket):
|
signature = try_real_annotations(fn, loc)
|
||||||
signature = try_real_annotations(fn.op, loc)
|
|
||||||
else:
|
|
||||||
signature = try_real_annotations(fn, loc)
|
|
||||||
if signature is not None and is_method:
|
if signature is not None and is_method:
|
||||||
# If this is a method, then the signature will include a type for
|
# If this is a method, then the signature will include a type for
|
||||||
# `self`, but type comments do not contain a `self`. So strip it
|
# `self`, but type comments do not contain a `self`. So strip it
|
||||||
|
|
@ -110,9 +106,6 @@ def is_vararg(the_callable):
|
||||||
|
|
||||||
|
|
||||||
def get_param_names(fn, n_args):
|
def get_param_names(fn, n_args):
|
||||||
if isinstance(fn, OpOverloadPacket):
|
|
||||||
fn = fn.op
|
|
||||||
|
|
||||||
if not is_function_or_method(fn) and hasattr(fn, '__call__') and is_function_or_method(fn.__call__): # noqa: B004
|
if not is_function_or_method(fn) and hasattr(fn, '__call__') and is_function_or_method(fn.__call__): # noqa: B004
|
||||||
# De-sugar calls to classes
|
# De-sugar calls to classes
|
||||||
fn = fn.__call__
|
fn = fn.__call__
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user