mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
backout D33469839 (#71443)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71443
cogwheel test inline_cvr_infer_canary_pyper_model_publish is timing out.
The convert_fx call takes > 20 mins for local and local_ro sub modules, which used to take ~ 2 mins.
Test Plan:
Fblearn flow run
* the following cmd took 1113 seconds before the diff and 5002 seconds after.
flow-cli clone-locally 320014219 --run-as-secure-group pytorch_at_scale --operators pyper_model_publish_workflow.pyper_model_publish_workflow.process_torch_package_model_files.process_non_sparse_parameters[0]
Cogwheel test
* Cogwheel test with packages in B3588 (the last good run) took 4694.48s
* Cogwheel test with packages in B3590 (the first timeout) took 13975.83s
* Cogwheel test with the following packages took 4535.04s
* all packages in B3588 except the model publish
* the model publish built with D33469839 (043e84b3d2) reversed (created D33633570)
Reviewed By: albanD, jerryzh168
Differential Revision: D33633570
fbshipit-source-id: dc5e777c48a90c551641a3f79126461f6a60449e
This commit is contained in:
parent
32472884ec
commit
03ab65023a
|
|
@ -50,6 +50,10 @@ class TestCustomOperators(JitTestCase):
|
|||
output = torch.ops._test.leaky_relu(torch.tensor([-1.0, 1.0]))
|
||||
self.assertEqual(output, torch.tensor([-0.01, 1]))
|
||||
|
||||
def test_only_kwargs(self):
|
||||
output = torch.ops._test.leaky_relu(self=torch.tensor(-1.0))
|
||||
self.assertEqual(output, torch.tensor(-0.01))
|
||||
|
||||
def test_passing_too_many_args(self):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
|
|
@ -74,6 +78,14 @@ class TestCustomOperators(JitTestCase):
|
|||
):
|
||||
torch.ops.aten.type_as(torch.ones(5, 5))
|
||||
|
||||
def test_passing_an_argument_both_as_positional_and_kwarg(self):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
"Argument 'self' specified both as positional and keyword argument",
|
||||
""
|
||||
):
|
||||
torch.ops._test.leaky_relu(torch.ones(5), self=torch.ones(5))
|
||||
|
||||
def test_passing_unknown_kwargs(self):
|
||||
with self.assertRaisesRegexWithHighlight(
|
||||
RuntimeError,
|
||||
|
|
|
|||
|
|
@ -724,32 +724,31 @@ class TestOperators(TestCase):
|
|||
x = torch.randn(2, 3, 4, requires_grad=True)
|
||||
self.assertONNX(lambda x: torch.cumsum(x, dim=1), x, opset_version=11)
|
||||
|
||||
# Github Issue: https://github.com/pytorch/pytorch/issues/71095
|
||||
# def test_c2_op(self):
|
||||
# class MyModel(torch.nn.Module):
|
||||
# def __init__(self):
|
||||
# super(MyModel, self).__init__()
|
||||
#
|
||||
# def forward(self, scores, bbox_deltas, im_info, anchors):
|
||||
# a, b = torch.ops._caffe2.GenerateProposals(
|
||||
# (scores), (bbox_deltas), (im_info), (anchors),
|
||||
# 2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0, True,
|
||||
# )
|
||||
# return a, b
|
||||
#
|
||||
# model = MyModel()
|
||||
# A = 4
|
||||
# H = 10
|
||||
# W = 8
|
||||
# img_count = 3
|
||||
# scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
|
||||
# bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W,
|
||||
# dtype=torch.float32)
|
||||
# bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
|
||||
# im_info = torch.ones(img_count, 3, dtype=torch.float32)
|
||||
# anchors = torch.ones(A, 4, dtype=torch.float32)
|
||||
# inputs = (scores, bbox_deltas, im_info, anchors)
|
||||
# self.assertONNX(model, inputs, custom_opsets={"org.pytorch._caffe2": 0})
|
||||
def test_c2_op(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(MyModel, self).__init__()
|
||||
|
||||
def forward(self, scores, bbox_deltas, im_info, anchors):
|
||||
a, b = torch.ops._caffe2.GenerateProposals(
|
||||
(scores), (bbox_deltas), (im_info), (anchors),
|
||||
2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0, True,
|
||||
)
|
||||
return a, b
|
||||
|
||||
model = MyModel()
|
||||
A = 4
|
||||
H = 10
|
||||
W = 8
|
||||
img_count = 3
|
||||
scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
|
||||
bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W,
|
||||
dtype=torch.float32)
|
||||
bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
|
||||
im_info = torch.ones(img_count, 3, dtype=torch.float32)
|
||||
anchors = torch.ones(A, 4, dtype=torch.float32)
|
||||
inputs = (scores, bbox_deltas, im_info, anchors)
|
||||
self.assertONNX(model, inputs, custom_opsets={"org.pytorch._caffe2": 0})
|
||||
|
||||
def test_dict(self):
|
||||
class MyModel(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -1,63 +0,0 @@
|
|||
import torch
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
import copy
|
||||
|
||||
class TestPerOverloadAPI(TestCase):
|
||||
def test_basics_opoverloadpacket(self):
|
||||
# add is ony used as an example here. It is ok to update the test
|
||||
# if the semantics of add are modified in the future.
|
||||
add_packet = torch.ops.aten.add
|
||||
|
||||
# class attributes
|
||||
self.assertEqual(add_packet.op_name, 'add')
|
||||
self.assertEqual(add_packet.qualified_op_name, 'aten.add')
|
||||
|
||||
# callable
|
||||
self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
|
||||
|
||||
# correct module
|
||||
self.assertEqual(add_packet.__module__, add_packet.op.__module__)
|
||||
|
||||
# caching
|
||||
another_add_packet = torch.ops.aten.add
|
||||
self.assertEqual(id(add_packet), id(another_add_packet))
|
||||
|
||||
# deepcopy is a no-op
|
||||
self.assertEqual(id(add_packet), id(copy.deepcopy(add_packet)))
|
||||
|
||||
# pretty print
|
||||
self.assertEqual(str(add_packet), "OpOverloadPacket(op='aten.add')")
|
||||
|
||||
self.assertRaises(AttributeError, lambda: add_packet.foo)
|
||||
|
||||
def test_basics_opoverload(self):
|
||||
add_packet = torch.ops.aten.add
|
||||
add_tensoroverload = add_packet.Tensor
|
||||
|
||||
# class attributes
|
||||
self.assertEqual(add_tensoroverload.name, 'aten.add')
|
||||
self.assertEqual(add_tensoroverload.overload_name, 'Tensor')
|
||||
self.assertEqual(add_tensoroverload.overload_packet, add_packet)
|
||||
|
||||
# deepcopy is a no-op
|
||||
self.assertEqual(id(add_tensoroverload), id(copy.deepcopy(add_tensoroverload)))
|
||||
|
||||
# caching
|
||||
another_add_tensoroverload = torch.ops.aten.add.Tensor
|
||||
self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload))
|
||||
|
||||
# pretty print
|
||||
self.assertEqual(str(add_tensoroverload), "OpOverload(op='aten.add', overload='Tensor')")
|
||||
|
||||
# callable
|
||||
self.assertEqual(add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
|
||||
|
||||
a = torch.tensor(2)
|
||||
b = torch.tensor(0)
|
||||
torch.ops.aten.add.out(a, a, out=b)
|
||||
self.assertEqual(b, torch.tensor(4))
|
||||
|
||||
self.assertRaises(RuntimeError, lambda: add_tensoroverload(a, a, out=b))
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
@ -197,8 +197,6 @@ def _jit_init() -> _bool: ...
|
|||
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
|
||||
def _jit_unflatten(vars: List[Tensor], desc: IODescriptor) -> Any: ...
|
||||
def _jit_get_operation(op_name: str) -> Callable: ...
|
||||
def _get_operation_overload(op_name: str, op_overload_name: str) -> Callable: ...
|
||||
def _get_schema(op_name: str, overload_name: str) -> FunctionSchema: ...
|
||||
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule',
|
||||
optimization_blocklist: Set[MobileOptimizerType],
|
||||
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...
|
||||
|
|
|
|||
117
torch/_ops.py
117
torch/_ops.py
|
|
@ -25,109 +25,6 @@ def dl_open_guard():
|
|||
if _SET_GLOBAL_FLAGS:
|
||||
sys.setdlopenflags(old_flags)
|
||||
|
||||
# Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
|
||||
# You can obtain an OpOverload object through attribute query on OpOverloadPacket.
|
||||
class OpOverload:
|
||||
def __init__(self, overloadpacket, op, schema):
|
||||
self._op = op
|
||||
self._schema = schema
|
||||
self._overloadpacket = overloadpacket
|
||||
|
||||
# it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
|
||||
def __deepcopy__(self, memo=None):
|
||||
return self
|
||||
|
||||
def __str__(self):
|
||||
return "OpOverload(op='{}.{}', overload='{}')".format(*self._schema.name.split("::"), self.overload_name)
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
return self._op(*args, **kwargs or {})
|
||||
|
||||
def __getattr__(self, key):
|
||||
return getattr(self._op, key)
|
||||
|
||||
# `my_namespace::my_op`
|
||||
@property
|
||||
def name(self):
|
||||
return "{}.{}".format(*self._schema.name.split("::"))
|
||||
|
||||
@property
|
||||
def overload_name(self):
|
||||
return self._schema.overload_name
|
||||
|
||||
@property
|
||||
def overload_packet(self):
|
||||
return self._overloadpacket
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
return self._op
|
||||
|
||||
# TODO: add more methods to expose information about input and output arguments
|
||||
|
||||
# OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
|
||||
# You can obtain an OpOverload object through attribute query.
|
||||
class OpOverloadPacket:
|
||||
def __init__(self, qualified_op_name, op_name, op):
|
||||
# These attributes are accessible on the object through the properties
|
||||
# defined below but are immutable
|
||||
self._qualified_op_name = qualified_op_name
|
||||
self._op_name = op_name
|
||||
self._op = op
|
||||
|
||||
# it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op.
|
||||
def __deepcopy__(self, memo=None):
|
||||
return self
|
||||
|
||||
def __str__(self):
|
||||
return "OpOverloadPacket(op='{}.{}')".format(*self._qualified_op_name.split("::"))
|
||||
|
||||
@property
|
||||
def qualified_op_name(self):
|
||||
return "{}.{}".format(*self._qualified_op_name.split("::"))
|
||||
|
||||
@property
|
||||
def op_name(self):
|
||||
return self._op_name
|
||||
|
||||
@property
|
||||
def op(self):
|
||||
return self._op
|
||||
|
||||
def __getattr__(self, key):
|
||||
# It is not a valid op_name when __file__ is passed in
|
||||
if key == '__file__':
|
||||
return 'torch.ops'
|
||||
|
||||
try:
|
||||
use_key = '' if key == 'default' else key
|
||||
# TODO: disallow access to overloads registered by JIT
|
||||
op_ = torch._C._get_operation_overload(self._qualified_op_name, use_key)
|
||||
schema = torch._C._get_schema(self._qualified_op_name, use_key)
|
||||
overload = OpOverload(self, op_, schema)
|
||||
# cache the overload object
|
||||
setattr(self, key, overload)
|
||||
return overload
|
||||
except RuntimeError:
|
||||
try:
|
||||
# This is added to maintain bc in case the user queries an attribute that exists on `self._op`
|
||||
# which used to be returned before instead of the OpOverloadPacket
|
||||
out = getattr(self._op, key)
|
||||
return out
|
||||
except AttributeError:
|
||||
raise AttributeError("'{}' object has no attribute '{}'".format(str(self), key)) from None
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
# overloading __call__ to ensure torch.ops.foo.bar() is still callable from JIT
|
||||
# We save the function ptr as the `op` attribute on OpOverloadPacket to access it here.
|
||||
return self._op(*args, **kwargs or {})
|
||||
|
||||
# Resolution of torch.fn is different from torch.ops.aten.fn
|
||||
# torch.fn uses the Python argparser, matches with the appropriate schema, and calls into the unboxed version of the method
|
||||
# torch.ops.aten.fn resolution is done via the mechanism defined in JIT. JIT creates a stack of all the overloads and
|
||||
# then tries to match the correct one at runtime and always calls into the boxed version of the method
|
||||
# Autograd codegen creates VariableType, TracerType, inplace or view type and python bindings
|
||||
# Aten codegen generates tensor methods for the the tensor class
|
||||
|
||||
# _OpNamespace is a subclass of ModuleType because the torch script
|
||||
# allows attribute lookups on modules only. Since we want torch.ops.foo.bar()
|
||||
|
|
@ -162,20 +59,14 @@ class _OpNamespace(types.ModuleType):
|
|||
return 'torch.ops'
|
||||
# Get the op `my_namespace::my_op` if available. This will also check
|
||||
# for overloads and raise an exception if there are more than one.
|
||||
namespace_name = self.name
|
||||
qualified_op_name = '{}::{}'.format(namespace_name, op_name)
|
||||
qualified_op_name = '{}::{}'.format(self.name, op_name)
|
||||
op = torch._C._jit_get_operation(qualified_op_name)
|
||||
|
||||
# let the script frontend know that op is identical to the builtin op
|
||||
# with qualified_op_name
|
||||
torch.jit._builtins._register_builtin(op, qualified_op_name)
|
||||
op.__module__ = self.__module__ + "." + namespace_name
|
||||
opoverloadpacket = OpOverloadPacket(qualified_op_name, op_name, op)
|
||||
opoverloadpacket.__module__ = self.__module__ + "." + namespace_name
|
||||
# cache the opoverloadpacket to ensure that each op corresponds to
|
||||
# a unique OpOverloadPacket object
|
||||
setattr(self, op_name, opoverloadpacket)
|
||||
return opoverloadpacket
|
||||
setattr(self, op_name, op)
|
||||
op.__module__ = self.__module__ + "." + self.name
|
||||
return op
|
||||
|
||||
class _Ops(types.ModuleType):
|
||||
__file__ = '_ops.py'
|
||||
|
|
|
|||
|
|
@ -133,11 +133,11 @@ _INCLUDE_QCONFIG_PROPAGATE_LIST : Set[Callable] = {
|
|||
# Default mapping from floating point function or torch ops to quantized ops
|
||||
# TODO: merge with default static mapping
|
||||
DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS : Dict[Union[Callable, str], Callable] = {
|
||||
F.elu: torch.ops.quantized.elu,
|
||||
F.hardswish: torch.ops.quantized.hardswish,
|
||||
F.instance_norm: torch.ops.quantized.instance_norm,
|
||||
F.layer_norm: torch.ops.quantized.layer_norm,
|
||||
F.leaky_relu: torch.ops.quantized.leaky_relu,
|
||||
F.elu: torch._ops.ops.quantized.elu,
|
||||
F.hardswish: torch._ops.ops.quantized.hardswish,
|
||||
F.instance_norm: torch._ops.ops.quantized.instance_norm,
|
||||
F.layer_norm: torch._ops.ops.quantized.layer_norm,
|
||||
F.leaky_relu: torch._ops.ops.quantized.leaky_relu,
|
||||
}
|
||||
|
||||
# mapping from module to output activation post process class
|
||||
|
|
|
|||
|
|
@ -1293,50 +1293,6 @@ void initJITBindings(PyObject* module) {
|
|||
})
|
||||
.def("has_storage", &DeserializationStorageContext::hasStorage);
|
||||
|
||||
m.def(
|
||||
"_get_schema",
|
||||
[](const std::string& op_name, const std::string& overload_name) {
|
||||
try {
|
||||
auto symbol = Symbol::fromQualString(op_name);
|
||||
auto operations = getAllOperatorsFor(symbol);
|
||||
for (const auto& op : operations) {
|
||||
if (op->schema().overload_name() == overload_name) {
|
||||
return op->schema();
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("Found no matching schema");
|
||||
} catch (const c10::Error& e) {
|
||||
auto msg = torch::get_cpp_stacktraces_enabled()
|
||||
? e.what()
|
||||
: e.what_without_backtrace();
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
});
|
||||
|
||||
m.def(
|
||||
"_get_operation_overload",
|
||||
[](const std::string& op_name, const std::string& overload_name) {
|
||||
try {
|
||||
auto symbol = Symbol::fromQualString(op_name);
|
||||
auto operations = getAllOperatorsFor(symbol);
|
||||
for (const auto& op : operations) {
|
||||
if (op->schema().overload_name() == overload_name) {
|
||||
auto func =
|
||||
py::cpp_function([op](py::args args, py::kwargs kwargs) {
|
||||
return invokeOperatorFromPython({op}, args, kwargs);
|
||||
});
|
||||
return func;
|
||||
}
|
||||
}
|
||||
throw std::runtime_error("Found no matching operator overload");
|
||||
} catch (const c10::Error& e) {
|
||||
auto msg = torch::get_cpp_stacktraces_enabled()
|
||||
? e.what()
|
||||
: e.what_without_backtrace();
|
||||
throw std::runtime_error(msg);
|
||||
}
|
||||
});
|
||||
|
||||
m.def(
|
||||
"_jit_get_operation",
|
||||
[](const std::string& op_name) {
|
||||
|
|
@ -1412,11 +1368,8 @@ void initJITBindings(PyObject* module) {
|
|||
py::name(symbol.toUnqualString()),
|
||||
py::doc(docstring.str().c_str()));
|
||||
return func;
|
||||
} catch (const c10::Error& e) {
|
||||
auto msg = torch::get_cpp_stacktraces_enabled()
|
||||
? e.what()
|
||||
: e.what_without_backtrace();
|
||||
throw std::runtime_error(msg);
|
||||
} catch (const c10::Error& error) {
|
||||
throw std::runtime_error(error.what_without_backtrace());
|
||||
}
|
||||
},
|
||||
py::arg("qualified_name"));
|
||||
|
|
|
|||
|
|
@ -1104,13 +1104,6 @@ std::shared_ptr<SugaredValue> toSugaredValue(
|
|||
}
|
||||
}
|
||||
|
||||
auto opoverloadpacket_type =
|
||||
py::module::import("torch").attr("_ops").attr("OpOverloadPacket");
|
||||
py::bool_ is_overloadpacket = py::isinstance(obj, opoverloadpacket_type);
|
||||
if (is_overloadpacket) {
|
||||
obj = py::getattr(obj, "op");
|
||||
}
|
||||
|
||||
bool isRpcAvailable = py::cast<bool>(
|
||||
py::module::import("torch.distributed.rpc").attr("is_available")());
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import warnings
|
|||
from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING
|
||||
from torch._jit_internal import boolean_dispatched
|
||||
from ._compatibility import compatibility
|
||||
from torch._ops import OpOverloadPacket
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .node import Argument
|
||||
|
|
@ -138,9 +137,6 @@ def get_signature_for_torch_op(op : Callable, return_schemas : bool = False):
|
|||
if override:
|
||||
return (override, None) if return_schemas else None
|
||||
|
||||
if isinstance(op, OpOverloadPacket):
|
||||
op = op._op
|
||||
|
||||
aten_fn = torch.jit._builtins._find_builtin(op)
|
||||
|
||||
if aten_fn is None:
|
||||
|
|
|
|||
|
|
@ -265,9 +265,6 @@ def infer_concrete_type_builder(nn_module, share_types=True):
|
|||
# Don't re-add anything we already added
|
||||
continue
|
||||
|
||||
isoverloadpacket = isinstance(value, torch._ops.OpOverloadPacket)
|
||||
if isoverloadpacket:
|
||||
value = value.op
|
||||
# Handle Python function attributes
|
||||
if inspect.isfunction(value):
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -1270,6 +1270,7 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None,
|
|||
return create_script_dict(obj)
|
||||
if isinstance(obj, list):
|
||||
return create_script_list(obj)
|
||||
|
||||
if inspect.isclass(obj):
|
||||
qualified_name = _qualified_name(obj)
|
||||
# If this type is a `nn.Module` subclass, they probably meant to pass
|
||||
|
|
|
|||
|
|
@ -23,7 +23,6 @@ if torch.distributed.rpc.is_available():
|
|||
from .._jit_internal import RRef, is_rref
|
||||
from torch._C import RRefType
|
||||
|
||||
from torch._ops import OpOverloadPacket
|
||||
|
||||
class Module(object):
|
||||
def __init__(self, name, members):
|
||||
|
|
@ -63,9 +62,6 @@ class EvalEnv(object):
|
|||
return getattr(builtins, name, None)
|
||||
|
||||
def get_signature(fn, rcb, loc, is_method):
|
||||
if isinstance(fn, OpOverloadPacket):
|
||||
signature = try_real_annotations(fn.op, loc)
|
||||
else:
|
||||
signature = try_real_annotations(fn, loc)
|
||||
if signature is not None and is_method:
|
||||
# If this is a method, then the signature will include a type for
|
||||
|
|
@ -110,9 +106,6 @@ def is_vararg(the_callable):
|
|||
|
||||
|
||||
def get_param_names(fn, n_args):
|
||||
if isinstance(fn, OpOverloadPacket):
|
||||
fn = fn.op
|
||||
|
||||
if not is_function_or_method(fn) and hasattr(fn, '__call__') and is_function_or_method(fn.__call__): # noqa: B004
|
||||
# De-sugar calls to classes
|
||||
fn = fn.__call__
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user