Add torch.ops.out_dtype (#103333)

https://docs.google.com/document/d/10DYFG2sU3TSvguFP5kYwYLlo45KHFg3BhBOkUk0NKsU/edit#bookmark=id.hgfzmhlzkamk

Renamed mixed_dtype --> out_dtype because "mixed_dtype is not very descriptive in the context of regular pytorch where we support type promotion on most ops"

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103333
Approved by: https://github.com/zou3519
This commit is contained in:
angelayi 2023-07-18 16:25:42 +00:00 committed by PyTorch MergeBot
parent 1b78f23a1a
commit 133c5ec997
4 changed files with 394 additions and 0 deletions

171
test/test_out_dtype_op.py Normal file
View File

@ -0,0 +1,171 @@
# Owner(s): ["module: functorch"]
import unittest
import torch
import torch._dynamo
import torch._export
from torch._higher_order_ops.out_dtype import out_dtype
from torch.fx.experimental.proxy_tensor import make_fx
from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing import FileCheck
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support")
class TestOutDtypeOp(TestCase):
def test_out_dtype_make_fx(self):
class M(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = weight
def forward(self, x):
return out_dtype(
torch.ops.aten.mm.default, torch.int32, x, self.weight
)
weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
m = M(weight)
x = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
gm = make_fx(m)(x)
self.assertTrue(torch.allclose(m(x), gm(x)))
gm = make_fx(torch.func.functionalize(M(weight)))(x)
self.assertTrue(torch.allclose(m(x), gm(x)))
FileCheck().check("torch.ops.higher_order.out_dtype").check("aten.mm.default").run(gm.code)
self.assertTrue(torch.allclose(m(x), gm(x)))
for node in gm.graph.nodes:
if node.op == "call_function" and node.target is out_dtype:
# Result of this node should be int32
self.assertTrue(node.meta["val"].dtype, torch.int32)
# Argument of this node should be int8
self.assertTrue(node.args[2].meta["val"].dtype, torch.int8)
def test_out_dtype_op_functional(self):
class M(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = weight
def forward(self, x):
return out_dtype(
torch.ops.aten.mm.default, torch.int32, x, self.weight
)
weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
m = M(weight)
x = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
ep = torch._export.export(
m,
(x,),
_add_runtime_assertions=False,
)
FileCheck().check("torch.ops.higher_order.out_dtype").check("aten.mm.default").run(ep.graph_module.code)
self.assertTrue(torch.allclose(m(x), ep(x)))
for node in ep.graph.nodes:
if node.op == "call_function" and node.target is out_dtype:
# Result of this node should be int32
self.assertTrue(node.meta["val"].dtype, torch.int32)
# Argument of this node should be int8
self.assertTrue(node.args[2].meta["val"].dtype, torch.int8)
def test_out_dtype_mm_numerical(self):
class M(torch.nn.Module):
def __init__(self, weight):
super().__init__()
self.weight = weight
def forward(self, x):
return out_dtype(
torch.ops.aten.mm.default, torch.int32, x, self.weight
)
weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
m = M(weight)
x = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
gm = make_fx(m)(x)
x_casted = x.to(torch.int32)
weight_casted = weight.to(torch.int32)
numerical_res = torch.ops.aten.mm.default(x_casted, weight_casted)
self.assertTrue(torch.allclose(numerical_res, gm(x)))
def test_out_dtype_dynamo(self):
def f(x, y):
return out_dtype(
torch.ops.aten.mul.Scalar, torch.int32, x, y
)
inp = (torch.randint(-128, 127, (5, 5), dtype=torch.int8), 3.0)
compiled = torch.compile(f, backend="eager", fullgraph=True)
self.assertTrue(torch.allclose(f(*inp), compiled(*inp)))
def test_out_dtype_mul_scalar_numerical(self):
def f(x, y):
return out_dtype(
torch.ops.aten.mul.Scalar, torch.int32, x, y
)
inp = (torch.randint(-128, 127, (5, 5), dtype=torch.int8), 3.0)
gm = make_fx(f)(*inp)
numerical_res = torch.ops.aten.mul.Scalar(inp[0].to(dtype=torch.int32), 3)
self.assertTrue(torch.allclose(numerical_res, gm(*inp)))
def test_out_dtype_non_functional(self):
def f(x, y):
return out_dtype(
torch.ops.aten.add_.Tensor, torch.int32, x, y
)
with self.assertRaisesRegex(ValueError, "out_dtype's first argument needs to be a functional operator"):
_ = torch._export.export(
f, (torch.randint(-128, 127, (5, 5), dtype=torch.int8), torch.randint(-128, 127, (5, 5), dtype=torch.int8)),
)
def test_out_dtype_non_op_overload(self):
def f(x, y):
return out_dtype(
torch.add, torch.int32, x, y
)
with self.assertRaisesRegex(ValueError, "out_dtype's first argument must be an OpOverload"):
f(torch.randint(-128, 127, (5, 5), dtype=torch.int8), torch.randint(-128, 127, (5, 5), dtype=torch.int8))
def test_out_dtype_no_autograd(self):
def f(x, y):
return out_dtype(
torch.ops.aten.mm.default, torch.int32, x, y
)
inp = (torch.randn(5, 5, requires_grad=True), torch.randn(5, 5, requires_grad=True))
with self.assertRaisesRegex(RuntimeError, "Autograd is not supported for out_dtype"):
f(*inp)
with torch.no_grad():
f(*inp)
def test_out_dtype_wrong_output(self) -> None:
def multiple_out(x):
return out_dtype(
torch.ops.aten.topk.default, torch.int32, x, 5
)
inp = (torch.randn(10),)
with self.assertRaisesRegex(ValueError, "out_dtype's can only apply to ops that return a single tensor"):
multiple_out(*inp)
def singleton_list_out(x):
return out_dtype(
torch.ops.aten.split_copy.Tensor, torch.int32, x, 10
)
with self.assertRaisesRegex(ValueError, "out_dtype's can only apply to ops that return a single tensor"):
singleton_list_out(*inp)
if __name__ == '__main__':
run_tests()

View File

@ -235,6 +235,8 @@ class TorchHigherOrderOperatorVariable(VariableTracker):
return MapHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "executorch_call_delegate":
return ExecutorchCallDelegateHigherOrderVariable(value, source, **kwargs)
elif value.__name__ == "out_dtype":
return OutDtypeHigherOrderVariable(value, source, **kwargs)
elif value is torch._functorch.eager_transforms.grad_impl:
return FunctorchGradHigherOrderVariable(value, source, **kwargs)
elif value.__name__ in (
@ -852,6 +854,38 @@ class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
)
class OutDtypeHigherOrderVariable(TorchHigherOrderOperatorVariable):
def call_function(
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
) -> "VariableTracker":
from .builder import wrap_fx_proxy
if len(kwargs) != 0:
unimplemented("out_dtype does not handle kwargs")
p_args = tuple(arg.as_proxy() for arg in args)
op = p_args[0]
output_dtype = p_args[1]
fake_sub_args = pytree.tree_map_only(
torch.fx.Proxy, lambda a: get_fake_value(a.node, tx), p_args[2:]
)
# This is a simplified implementation of this operator just for tracing.
# Actual implementation may also first promote the arguments
example_value = op(*fake_sub_args).to(dtype=output_dtype)
# Store the invocation as a call
return wrap_fx_proxy(
tx=tx,
proxy=tx.output.create_proxy(
"call_function",
self.value,
args=tuple(p_args),
kwargs={},
),
example_value=example_value,
)
class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
def call_function(
self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]

View File

@ -0,0 +1,186 @@
import torch
import torch.utils._pytree as pytree
from torch.fx.experimental.proxy_tensor import (
disable_proxy_modes_tracing,
ProxyTorchDispatchMode,
track_tensor_tree,
)
from torch.utils._python_dispatch import (
_get_current_dispatch_mode,
_pop_mode_temporarily,
)
from torch._C import DispatchKey, _ExcludeDispatchKeyGuard, DispatchKeySet
from torch._functorch.eager_transforms import (
_unwrap_all_tensors_from_functional,
_wrap_all_tensors_to_functional,
)
from torch._ops import HigherOrderOperator
from torch._subclasses.fake_tensor import FakeTensorMode
from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
# TODO to figure out a more generic approach
ALLOWABLE_OPS = [
torch.ops.aten.mm.default,
torch.ops.aten.conv2d.default,
torch.ops.aten.mul.Scalar,
]
class OutDtypeOperator(HigherOrderOperator):
"""
The out_dtype operator takes an existing ATen functional operator, an
`out_dtype` argument, and arguments to the original operator, and executes
the original operator and returns a Tensor with the `out_dtype` precision.
This operator does not mandate a compute precision so it allows the
representation to not be opinionated about the exact implementation.
The general implementation for all operators will be the following:
1. Promote inputs dtypes based on default PyTorch dtype promotion rules,
using the dtypes of all input Tensors/Scalars and the `out_dtype`
arugument.
2. Execute the operator
3. Cast the output to `out_dtype`
"""
def __init__(self):
super().__init__("out_dtype")
# TODO(ydwu4): Subclassing HigherOrderOperator causes __module__ to
# become different (torch._higher_order_ops.out_dtype) which will result
# in torch.fx to record the op incorrectly in the graph.
self.__module__ = "torch.ops.higher_order"
def __call__(self, op, output_dtype, *args):
if not isinstance(op, torch._ops.OpOverload):
raise ValueError("out_dtype's first argument must be an OpOverload")
if op._schema.is_mutable:
raise ValueError("out_dtype's first argument needs to be a functional operator")
if not(
len(op._schema.returns) == 1 and
isinstance(op._schema.returns[0].type, torch.TensorType)
):
raise ValueError(
"out_dtype's can only apply to ops that return a single tensor"
f"Instead got {[r.type for r in op._schema.returns]}"
)
if op not in ALLOWABLE_OPS:
raise ValueError(
f"out_dtype only allows the following operators: {ALLOWABLE_OPS}."
)
res = super().__call__(op, output_dtype, *args)
return res
out_dtype = OutDtypeOperator()
out_dtype.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
out_dtype.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
out_dtype.fallthrough(DispatchKey.ADInplaceOrView) # type: ignore[attr-defined]
out_dtype.fallthrough(DispatchKey.BackendSelect) # type: ignore[attr-defined]
out_dtype.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined]
def trace_out_dtype(proxy_mode, func_overload, op, output_dtype, *args):
with disable_proxy_modes_tracing():
# This is a simplified implementation of this operator just for tracing.
# Actual implementation may also first promote the arguments
out = op(*args).to(dtype=output_dtype)
node_args = (op, output_dtype, *args)
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
out_proxy = proxy_mode.tracer.create_proxy(
"call_function", func_overload, proxy_args, {}, name="out_dtype"
)
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
@out_dtype.py_impl(DispatchKey.CompositeExplicitAutograd)
def out_dtype_dense(
op: torch._ops.OpOverload,
output_dtype: torch.dtype,
*args
):
flat_inputs = pytree.tree_flatten(args)[0] + [torch.ones(1, dtype=output_dtype)]
promote_dtype: torch.dtype = elementwise_dtypes(
*flat_inputs,
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
)[0]
casted_args = pytree.tree_map_only(
torch.Tensor, lambda arg: arg.to(dtype=promote_dtype), args
)
res = op(*casted_args).to(dtype=output_dtype)
return res
@out_dtype.py_impl(DispatchKey.Autograd)
def out_dtype_autograd(
op: torch._ops.OpOverload,
output_dtype: torch.dtype,
*args
):
# TODO: maybe support autograd
flat_operands, _ = pytree.tree_flatten(args)
if torch.is_grad_enabled() and any(
f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
):
raise RuntimeError("Autograd is not supported for out_dtype")
with torch._C._AutoDispatchBelowAutograd():
return out_dtype(op, output_dtype, *args)
@out_dtype.py_impl(ProxyTorchDispatchMode)
def out_dtype_proxy(
op: torch._ops.OpOverload,
output_dtype: torch.dtype,
*args
):
mode = _get_current_dispatch_mode()
assert (mode is not None), "Mode should always be enabled for python fallback key"
with _pop_mode_temporarily() as mode:
if mode.enable_tracing:
return trace_out_dtype(mode, out_dtype, op, output_dtype, *args)
else:
return out_dtype(op, output_dtype, *args)
@out_dtype.py_impl(FakeTensorMode)
def out_dtype_fake_tensor_mode(
op: torch._ops.OpOverload,
output_dtype: torch.dtype,
*args
):
return out_dtype_dense(op, output_dtype, *args)
@out_dtype.py_impl(DispatchKey.Functionalize)
def out_dtype_func1(op, output_dtype, *args):
reapply_views = torch._C._functionalization_reapply_views_tls()
# At this point, we will see functionalized tensors, so need to unwrap them first
unwrapped_args = tuple(
_unwrap_all_tensors_from_functional(arg, reapply_views=reapply_views)
for arg in args
)
# pyre-ignore
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
res = out_dtype(op, output_dtype, *unwrapped_args)
return _wrap_all_tensors_to_functional(res, level=0)
@out_dtype.py_impl(torch._C._functorch.TransformType.Functionalize)
def out_dtype_func2(interpreter, op, output_dtype, *args):
reapply_views = interpreter.functionalize_add_back_views()
# At this point, we will see functionalized tensors, so need to unwrap them first
unwrapped_args = tuple(
_unwrap_all_tensors_from_functional(arg, reapply_views=reapply_views)
for arg in args
)
with interpreter.lower():
res = out_dtype(op, output_dtype, *unwrapped_args)
return _wrap_all_tensors_to_functional(res, level=interpreter.level())

View File

@ -258,6 +258,9 @@ class TracerBase:
elif isinstance(a, range):
return range(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
elif isinstance(a, torch._ops.OpOverload):
return a
if isinstance(a, Proxy):
# base case: we unwrap the Proxy object
return a.node