backout D33469839 (#71443)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71443

cogwheel test inline_cvr_infer_canary_pyper_model_publish is timing out.

The convert_fx call takes > 20 mins for local and local_ro sub modules, which used to take ~ 2 mins.

Test Plan:
Fblearn flow run
* the following cmd took 1113 seconds before the diff and 5002 seconds after.
    flow-cli clone-locally 320014219  --run-as-secure-group pytorch_at_scale  --operators pyper_model_publish_workflow.pyper_model_publish_workflow.process_torch_package_model_files.process_non_sparse_parameters[0]

Cogwheel test
* Cogwheel test with packages in B3588 (the last good run) took 4694.48s
* Cogwheel test with packages in B3590 (the first timeout) took 13975.83s
* Cogwheel test with the following packages took 4535.04s
  * all packages in B3588 except the model publish
  * the model publish built with D33469839 (043e84b3d2) reversed (created D33633570)

Reviewed By: albanD, jerryzh168

Differential Revision: D33633570

fbshipit-source-id: dc5e777c48a90c551641a3f79126461f6a60449e
This commit is contained in:
Yan Li 2022-01-18 15:47:27 -08:00 committed by Facebook GitHub Bot
parent 32472884ec
commit 03ab65023a
12 changed files with 50 additions and 280 deletions

View File

@ -50,6 +50,10 @@ class TestCustomOperators(JitTestCase):
output = torch.ops._test.leaky_relu(torch.tensor([-1.0, 1.0]))
self.assertEqual(output, torch.tensor([-0.01, 1]))
def test_only_kwargs(self):
output = torch.ops._test.leaky_relu(self=torch.tensor(-1.0))
self.assertEqual(output, torch.tensor(-0.01))
def test_passing_too_many_args(self):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
@ -74,6 +78,14 @@ class TestCustomOperators(JitTestCase):
):
torch.ops.aten.type_as(torch.ones(5, 5))
def test_passing_an_argument_both_as_positional_and_kwarg(self):
with self.assertRaisesRegexWithHighlight(
RuntimeError,
"Argument 'self' specified both as positional and keyword argument",
""
):
torch.ops._test.leaky_relu(torch.ones(5), self=torch.ones(5))
def test_passing_unknown_kwargs(self):
with self.assertRaisesRegexWithHighlight(
RuntimeError,

View File

@ -724,32 +724,31 @@ class TestOperators(TestCase):
x = torch.randn(2, 3, 4, requires_grad=True)
self.assertONNX(lambda x: torch.cumsum(x, dim=1), x, opset_version=11)
# Github Issue: https://github.com/pytorch/pytorch/issues/71095
# def test_c2_op(self):
# class MyModel(torch.nn.Module):
# def __init__(self):
# super(MyModel, self).__init__()
#
# def forward(self, scores, bbox_deltas, im_info, anchors):
# a, b = torch.ops._caffe2.GenerateProposals(
# (scores), (bbox_deltas), (im_info), (anchors),
# 2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0, True,
# )
# return a, b
#
# model = MyModel()
# A = 4
# H = 10
# W = 8
# img_count = 3
# scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
# bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W,
# dtype=torch.float32)
# bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
# im_info = torch.ones(img_count, 3, dtype=torch.float32)
# anchors = torch.ones(A, 4, dtype=torch.float32)
# inputs = (scores, bbox_deltas, im_info, anchors)
# self.assertONNX(model, inputs, custom_opsets={"org.pytorch._caffe2": 0})
def test_c2_op(self):
class MyModel(torch.nn.Module):
def __init__(self):
super(MyModel, self).__init__()
def forward(self, scores, bbox_deltas, im_info, anchors):
a, b = torch.ops._caffe2.GenerateProposals(
(scores), (bbox_deltas), (im_info), (anchors),
2.0, 6000, 300, 0.7, 16, True, -90, 90, 1.0, True,
)
return a, b
model = MyModel()
A = 4
H = 10
W = 8
img_count = 3
scores = torch.ones(img_count, A, H, W, dtype=torch.float32)
bbox_deltas = torch.linspace(0, 10, steps=img_count * 4 * A * H * W,
dtype=torch.float32)
bbox_deltas = bbox_deltas.view(img_count, 4 * A, H, W)
im_info = torch.ones(img_count, 3, dtype=torch.float32)
anchors = torch.ones(A, 4, dtype=torch.float32)
inputs = (scores, bbox_deltas, im_info, anchors)
self.assertONNX(model, inputs, custom_opsets={"org.pytorch._caffe2": 0})
def test_dict(self):
class MyModel(torch.nn.Module):

View File

@ -1,63 +0,0 @@
import torch
from torch.testing._internal.common_utils import TestCase, run_tests
import copy
class TestPerOverloadAPI(TestCase):
def test_basics_opoverloadpacket(self):
# add is ony used as an example here. It is ok to update the test
# if the semantics of add are modified in the future.
add_packet = torch.ops.aten.add
# class attributes
self.assertEqual(add_packet.op_name, 'add')
self.assertEqual(add_packet.qualified_op_name, 'aten.add')
# callable
self.assertEqual(add_packet(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
# correct module
self.assertEqual(add_packet.__module__, add_packet.op.__module__)
# caching
another_add_packet = torch.ops.aten.add
self.assertEqual(id(add_packet), id(another_add_packet))
# deepcopy is a no-op
self.assertEqual(id(add_packet), id(copy.deepcopy(add_packet)))
# pretty print
self.assertEqual(str(add_packet), "OpOverloadPacket(op='aten.add')")
self.assertRaises(AttributeError, lambda: add_packet.foo)
def test_basics_opoverload(self):
add_packet = torch.ops.aten.add
add_tensoroverload = add_packet.Tensor
# class attributes
self.assertEqual(add_tensoroverload.name, 'aten.add')
self.assertEqual(add_tensoroverload.overload_name, 'Tensor')
self.assertEqual(add_tensoroverload.overload_packet, add_packet)
# deepcopy is a no-op
self.assertEqual(id(add_tensoroverload), id(copy.deepcopy(add_tensoroverload)))
# caching
another_add_tensoroverload = torch.ops.aten.add.Tensor
self.assertEqual(id(add_tensoroverload), id(another_add_tensoroverload))
# pretty print
self.assertEqual(str(add_tensoroverload), "OpOverload(op='aten.add', overload='Tensor')")
# callable
self.assertEqual(add_tensoroverload(torch.tensor(2), torch.tensor(3)), torch.tensor(5))
a = torch.tensor(2)
b = torch.tensor(0)
torch.ops.aten.add.out(a, a, out=b)
self.assertEqual(b, torch.tensor(4))
self.assertRaises(RuntimeError, lambda: add_tensoroverload(a, a, out=b))
if __name__ == '__main__':
run_tests()

View File

@ -197,8 +197,6 @@ def _jit_init() -> _bool: ...
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
def _jit_unflatten(vars: List[Tensor], desc: IODescriptor) -> Any: ...
def _jit_get_operation(op_name: str) -> Callable: ...
def _get_operation_overload(op_name: str, op_overload_name: str) -> Callable: ...
def _get_schema(op_name: str, overload_name: str) -> FunctionSchema: ...
def _jit_pass_optimize_for_mobile(module: 'torch.jit.ScriptModule',
optimization_blocklist: Set[MobileOptimizerType],
preserved_methods: List[AnyStr]) -> 'torch.jit.ScriptModule': ...

View File

@ -25,109 +25,6 @@ def dl_open_guard():
if _SET_GLOBAL_FLAGS:
sys.setdlopenflags(old_flags)
# Each OpOverload object contains pointer to a a specific operator overload, a pointer to the parent `OpOverloadPacket` object.
# You can obtain an OpOverload object through attribute query on OpOverloadPacket.
class OpOverload:
def __init__(self, overloadpacket, op, schema):
self._op = op
self._schema = schema
self._overloadpacket = overloadpacket
# it's a no-op since OpOverload object is immutable and must be unique for a given op overload.
def __deepcopy__(self, memo=None):
return self
def __str__(self):
return "OpOverload(op='{}.{}', overload='{}')".format(*self._schema.name.split("::"), self.overload_name)
def __call__(self, *args, **kwargs):
return self._op(*args, **kwargs or {})
def __getattr__(self, key):
return getattr(self._op, key)
# `my_namespace::my_op`
@property
def name(self):
return "{}.{}".format(*self._schema.name.split("::"))
@property
def overload_name(self):
return self._schema.overload_name
@property
def overload_packet(self):
return self._overloadpacket
@property
def op(self):
return self._op
# TODO: add more methods to expose information about input and output arguments
# OpOverloadPacket class contains pointer to a base unresolved operator that doesn't correspond to a specific operator
# You can obtain an OpOverload object through attribute query.
class OpOverloadPacket:
def __init__(self, qualified_op_name, op_name, op):
# These attributes are accessible on the object through the properties
# defined below but are immutable
self._qualified_op_name = qualified_op_name
self._op_name = op_name
self._op = op
# it's a no-op since OpOverloadPacket object is immutable and must be unique for a given op.
def __deepcopy__(self, memo=None):
return self
def __str__(self):
return "OpOverloadPacket(op='{}.{}')".format(*self._qualified_op_name.split("::"))
@property
def qualified_op_name(self):
return "{}.{}".format(*self._qualified_op_name.split("::"))
@property
def op_name(self):
return self._op_name
@property
def op(self):
return self._op
def __getattr__(self, key):
# It is not a valid op_name when __file__ is passed in
if key == '__file__':
return 'torch.ops'
try:
use_key = '' if key == 'default' else key
# TODO: disallow access to overloads registered by JIT
op_ = torch._C._get_operation_overload(self._qualified_op_name, use_key)
schema = torch._C._get_schema(self._qualified_op_name, use_key)
overload = OpOverload(self, op_, schema)
# cache the overload object
setattr(self, key, overload)
return overload
except RuntimeError:
try:
# This is added to maintain bc in case the user queries an attribute that exists on `self._op`
# which used to be returned before instead of the OpOverloadPacket
out = getattr(self._op, key)
return out
except AttributeError:
raise AttributeError("'{}' object has no attribute '{}'".format(str(self), key)) from None
def __call__(self, *args, **kwargs):
# overloading __call__ to ensure torch.ops.foo.bar() is still callable from JIT
# We save the function ptr as the `op` attribute on OpOverloadPacket to access it here.
return self._op(*args, **kwargs or {})
# Resolution of torch.fn is different from torch.ops.aten.fn
# torch.fn uses the Python argparser, matches with the appropriate schema, and calls into the unboxed version of the method
# torch.ops.aten.fn resolution is done via the mechanism defined in JIT. JIT creates a stack of all the overloads and
# then tries to match the correct one at runtime and always calls into the boxed version of the method
# Autograd codegen creates VariableType, TracerType, inplace or view type and python bindings
# Aten codegen generates tensor methods for the the tensor class
# _OpNamespace is a subclass of ModuleType because the torch script
# allows attribute lookups on modules only. Since we want torch.ops.foo.bar()
@ -162,20 +59,14 @@ class _OpNamespace(types.ModuleType):
return 'torch.ops'
# Get the op `my_namespace::my_op` if available. This will also check
# for overloads and raise an exception if there are more than one.
namespace_name = self.name
qualified_op_name = '{}::{}'.format(namespace_name, op_name)
qualified_op_name = '{}::{}'.format(self.name, op_name)
op = torch._C._jit_get_operation(qualified_op_name)
# let the script frontend know that op is identical to the builtin op
# with qualified_op_name
torch.jit._builtins._register_builtin(op, qualified_op_name)
op.__module__ = self.__module__ + "." + namespace_name
opoverloadpacket = OpOverloadPacket(qualified_op_name, op_name, op)
opoverloadpacket.__module__ = self.__module__ + "." + namespace_name
# cache the opoverloadpacket to ensure that each op corresponds to
# a unique OpOverloadPacket object
setattr(self, op_name, opoverloadpacket)
return opoverloadpacket
setattr(self, op_name, op)
op.__module__ = self.__module__ + "." + self.name
return op
class _Ops(types.ModuleType):
__file__ = '_ops.py'

View File

@ -133,11 +133,11 @@ _INCLUDE_QCONFIG_PROPAGATE_LIST : Set[Callable] = {
# Default mapping from floating point function or torch ops to quantized ops
# TODO: merge with default static mapping
DEFAULT_FLOAT_TO_QUANTIZED_OPERATOR_MAPPINGS : Dict[Union[Callable, str], Callable] = {
F.elu: torch.ops.quantized.elu,
F.hardswish: torch.ops.quantized.hardswish,
F.instance_norm: torch.ops.quantized.instance_norm,
F.layer_norm: torch.ops.quantized.layer_norm,
F.leaky_relu: torch.ops.quantized.leaky_relu,
F.elu: torch._ops.ops.quantized.elu,
F.hardswish: torch._ops.ops.quantized.hardswish,
F.instance_norm: torch._ops.ops.quantized.instance_norm,
F.layer_norm: torch._ops.ops.quantized.layer_norm,
F.leaky_relu: torch._ops.ops.quantized.leaky_relu,
}
# mapping from module to output activation post process class

View File

@ -1293,50 +1293,6 @@ void initJITBindings(PyObject* module) {
})
.def("has_storage", &DeserializationStorageContext::hasStorage);
m.def(
"_get_schema",
[](const std::string& op_name, const std::string& overload_name) {
try {
auto symbol = Symbol::fromQualString(op_name);
auto operations = getAllOperatorsFor(symbol);
for (const auto& op : operations) {
if (op->schema().overload_name() == overload_name) {
return op->schema();
}
}
throw std::runtime_error("Found no matching schema");
} catch (const c10::Error& e) {
auto msg = torch::get_cpp_stacktraces_enabled()
? e.what()
: e.what_without_backtrace();
throw std::runtime_error(msg);
}
});
m.def(
"_get_operation_overload",
[](const std::string& op_name, const std::string& overload_name) {
try {
auto symbol = Symbol::fromQualString(op_name);
auto operations = getAllOperatorsFor(symbol);
for (const auto& op : operations) {
if (op->schema().overload_name() == overload_name) {
auto func =
py::cpp_function([op](py::args args, py::kwargs kwargs) {
return invokeOperatorFromPython({op}, args, kwargs);
});
return func;
}
}
throw std::runtime_error("Found no matching operator overload");
} catch (const c10::Error& e) {
auto msg = torch::get_cpp_stacktraces_enabled()
? e.what()
: e.what_without_backtrace();
throw std::runtime_error(msg);
}
});
m.def(
"_jit_get_operation",
[](const std::string& op_name) {
@ -1412,11 +1368,8 @@ void initJITBindings(PyObject* module) {
py::name(symbol.toUnqualString()),
py::doc(docstring.str().c_str()));
return func;
} catch (const c10::Error& e) {
auto msg = torch::get_cpp_stacktraces_enabled()
? e.what()
: e.what_without_backtrace();
throw std::runtime_error(msg);
} catch (const c10::Error& error) {
throw std::runtime_error(error.what_without_backtrace());
}
},
py::arg("qualified_name"));

View File

@ -1104,13 +1104,6 @@ std::shared_ptr<SugaredValue> toSugaredValue(
}
}
auto opoverloadpacket_type =
py::module::import("torch").attr("_ops").attr("OpOverloadPacket");
py::bool_ is_overloadpacket = py::isinstance(obj, opoverloadpacket_type);
if (is_overloadpacket) {
obj = py::getattr(obj, "op");
}
bool isRpcAvailable = py::cast<bool>(
py::module::import("torch.distributed.rpc").attr("is_available")());

View File

@ -7,7 +7,6 @@ import warnings
from typing import Any, Callable, Dict, List, Optional, Tuple, NamedTuple, cast, TYPE_CHECKING
from torch._jit_internal import boolean_dispatched
from ._compatibility import compatibility
from torch._ops import OpOverloadPacket
if TYPE_CHECKING:
from .node import Argument
@ -138,9 +137,6 @@ def get_signature_for_torch_op(op : Callable, return_schemas : bool = False):
if override:
return (override, None) if return_schemas else None
if isinstance(op, OpOverloadPacket):
op = op._op
aten_fn = torch.jit._builtins._find_builtin(op)
if aten_fn is None:

View File

@ -265,9 +265,6 @@ def infer_concrete_type_builder(nn_module, share_types=True):
# Don't re-add anything we already added
continue
isoverloadpacket = isinstance(value, torch._ops.OpOverloadPacket)
if isoverloadpacket:
value = value.op
# Handle Python function attributes
if inspect.isfunction(value):
try:

View File

@ -1270,6 +1270,7 @@ def script(obj, optimize=None, _frames_up=0, _rcb=None,
return create_script_dict(obj)
if isinstance(obj, list):
return create_script_list(obj)
if inspect.isclass(obj):
qualified_name = _qualified_name(obj)
# If this type is a `nn.Module` subclass, they probably meant to pass

View File

@ -23,7 +23,6 @@ if torch.distributed.rpc.is_available():
from .._jit_internal import RRef, is_rref
from torch._C import RRefType
from torch._ops import OpOverloadPacket
class Module(object):
def __init__(self, name, members):
@ -63,9 +62,6 @@ class EvalEnv(object):
return getattr(builtins, name, None)
def get_signature(fn, rcb, loc, is_method):
if isinstance(fn, OpOverloadPacket):
signature = try_real_annotations(fn.op, loc)
else:
signature = try_real_annotations(fn, loc)
if signature is not None and is_method:
# If this is a method, then the signature will include a type for
@ -110,9 +106,6 @@ def is_vararg(the_callable):
def get_param_names(fn, n_args):
if isinstance(fn, OpOverloadPacket):
fn = fn.op
if not is_function_or_method(fn) and hasattr(fn, '__call__') and is_function_or_method(fn.__call__): # noqa: B004
# De-sugar calls to classes
fn = fn.__call__