mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add torch.ops.out_dtype (#103333)
https://docs.google.com/document/d/10DYFG2sU3TSvguFP5kYwYLlo45KHFg3BhBOkUk0NKsU/edit#bookmark=id.hgfzmhlzkamk Renamed mixed_dtype --> out_dtype because "mixed_dtype is not very descriptive in the context of regular pytorch where we support type promotion on most ops" Pull Request resolved: https://github.com/pytorch/pytorch/pull/103333 Approved by: https://github.com/zou3519
This commit is contained in:
parent
1b78f23a1a
commit
133c5ec997
171
test/test_out_dtype_op.py
Normal file
171
test/test_out_dtype_op.py
Normal file
|
|
@ -0,0 +1,171 @@
|
|||
# Owner(s): ["module: functorch"]
|
||||
import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo
|
||||
import torch._export
|
||||
from torch._higher_order_ops.out_dtype import out_dtype
|
||||
from torch.fx.experimental.proxy_tensor import make_fx
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
from torch.testing import FileCheck
|
||||
|
||||
|
||||
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support")
|
||||
class TestOutDtypeOp(TestCase):
|
||||
def test_out_dtype_make_fx(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, weight):
|
||||
super().__init__()
|
||||
self.weight = weight
|
||||
|
||||
def forward(self, x):
|
||||
return out_dtype(
|
||||
torch.ops.aten.mm.default, torch.int32, x, self.weight
|
||||
)
|
||||
|
||||
weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
|
||||
m = M(weight)
|
||||
x = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
|
||||
|
||||
gm = make_fx(m)(x)
|
||||
self.assertTrue(torch.allclose(m(x), gm(x)))
|
||||
|
||||
gm = make_fx(torch.func.functionalize(M(weight)))(x)
|
||||
self.assertTrue(torch.allclose(m(x), gm(x)))
|
||||
|
||||
FileCheck().check("torch.ops.higher_order.out_dtype").check("aten.mm.default").run(gm.code)
|
||||
self.assertTrue(torch.allclose(m(x), gm(x)))
|
||||
for node in gm.graph.nodes:
|
||||
if node.op == "call_function" and node.target is out_dtype:
|
||||
# Result of this node should be int32
|
||||
self.assertTrue(node.meta["val"].dtype, torch.int32)
|
||||
# Argument of this node should be int8
|
||||
self.assertTrue(node.args[2].meta["val"].dtype, torch.int8)
|
||||
|
||||
def test_out_dtype_op_functional(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, weight):
|
||||
super().__init__()
|
||||
self.weight = weight
|
||||
|
||||
def forward(self, x):
|
||||
return out_dtype(
|
||||
torch.ops.aten.mm.default, torch.int32, x, self.weight
|
||||
)
|
||||
|
||||
weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
|
||||
m = M(weight)
|
||||
x = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
|
||||
ep = torch._export.export(
|
||||
m,
|
||||
(x,),
|
||||
_add_runtime_assertions=False,
|
||||
)
|
||||
FileCheck().check("torch.ops.higher_order.out_dtype").check("aten.mm.default").run(ep.graph_module.code)
|
||||
self.assertTrue(torch.allclose(m(x), ep(x)))
|
||||
for node in ep.graph.nodes:
|
||||
if node.op == "call_function" and node.target is out_dtype:
|
||||
# Result of this node should be int32
|
||||
self.assertTrue(node.meta["val"].dtype, torch.int32)
|
||||
# Argument of this node should be int8
|
||||
self.assertTrue(node.args[2].meta["val"].dtype, torch.int8)
|
||||
|
||||
def test_out_dtype_mm_numerical(self):
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self, weight):
|
||||
super().__init__()
|
||||
self.weight = weight
|
||||
|
||||
def forward(self, x):
|
||||
return out_dtype(
|
||||
torch.ops.aten.mm.default, torch.int32, x, self.weight
|
||||
)
|
||||
|
||||
weight = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
|
||||
m = M(weight)
|
||||
x = torch.randint(-128, 127, (5, 5), dtype=torch.int8)
|
||||
|
||||
gm = make_fx(m)(x)
|
||||
|
||||
x_casted = x.to(torch.int32)
|
||||
weight_casted = weight.to(torch.int32)
|
||||
numerical_res = torch.ops.aten.mm.default(x_casted, weight_casted)
|
||||
self.assertTrue(torch.allclose(numerical_res, gm(x)))
|
||||
|
||||
def test_out_dtype_dynamo(self):
|
||||
def f(x, y):
|
||||
return out_dtype(
|
||||
torch.ops.aten.mul.Scalar, torch.int32, x, y
|
||||
)
|
||||
|
||||
inp = (torch.randint(-128, 127, (5, 5), dtype=torch.int8), 3.0)
|
||||
|
||||
compiled = torch.compile(f, backend="eager", fullgraph=True)
|
||||
self.assertTrue(torch.allclose(f(*inp), compiled(*inp)))
|
||||
|
||||
def test_out_dtype_mul_scalar_numerical(self):
|
||||
def f(x, y):
|
||||
return out_dtype(
|
||||
torch.ops.aten.mul.Scalar, torch.int32, x, y
|
||||
)
|
||||
|
||||
inp = (torch.randint(-128, 127, (5, 5), dtype=torch.int8), 3.0)
|
||||
|
||||
gm = make_fx(f)(*inp)
|
||||
numerical_res = torch.ops.aten.mul.Scalar(inp[0].to(dtype=torch.int32), 3)
|
||||
self.assertTrue(torch.allclose(numerical_res, gm(*inp)))
|
||||
|
||||
def test_out_dtype_non_functional(self):
|
||||
def f(x, y):
|
||||
return out_dtype(
|
||||
torch.ops.aten.add_.Tensor, torch.int32, x, y
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "out_dtype's first argument needs to be a functional operator"):
|
||||
_ = torch._export.export(
|
||||
f, (torch.randint(-128, 127, (5, 5), dtype=torch.int8), torch.randint(-128, 127, (5, 5), dtype=torch.int8)),
|
||||
)
|
||||
|
||||
def test_out_dtype_non_op_overload(self):
|
||||
def f(x, y):
|
||||
return out_dtype(
|
||||
torch.add, torch.int32, x, y
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "out_dtype's first argument must be an OpOverload"):
|
||||
f(torch.randint(-128, 127, (5, 5), dtype=torch.int8), torch.randint(-128, 127, (5, 5), dtype=torch.int8))
|
||||
|
||||
def test_out_dtype_no_autograd(self):
|
||||
def f(x, y):
|
||||
return out_dtype(
|
||||
torch.ops.aten.mm.default, torch.int32, x, y
|
||||
)
|
||||
|
||||
inp = (torch.randn(5, 5, requires_grad=True), torch.randn(5, 5, requires_grad=True))
|
||||
with self.assertRaisesRegex(RuntimeError, "Autograd is not supported for out_dtype"):
|
||||
f(*inp)
|
||||
|
||||
with torch.no_grad():
|
||||
f(*inp)
|
||||
|
||||
def test_out_dtype_wrong_output(self) -> None:
|
||||
def multiple_out(x):
|
||||
return out_dtype(
|
||||
torch.ops.aten.topk.default, torch.int32, x, 5
|
||||
)
|
||||
|
||||
inp = (torch.randn(10),)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "out_dtype's can only apply to ops that return a single tensor"):
|
||||
multiple_out(*inp)
|
||||
|
||||
def singleton_list_out(x):
|
||||
return out_dtype(
|
||||
torch.ops.aten.split_copy.Tensor, torch.int32, x, 10
|
||||
)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "out_dtype's can only apply to ops that return a single tensor"):
|
||||
singleton_list_out(*inp)
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
@ -235,6 +235,8 @@ class TorchHigherOrderOperatorVariable(VariableTracker):
|
|||
return MapHigherOrderVariable(value, source, **kwargs)
|
||||
elif value.__name__ == "executorch_call_delegate":
|
||||
return ExecutorchCallDelegateHigherOrderVariable(value, source, **kwargs)
|
||||
elif value.__name__ == "out_dtype":
|
||||
return OutDtypeHigherOrderVariable(value, source, **kwargs)
|
||||
elif value is torch._functorch.eager_transforms.grad_impl:
|
||||
return FunctorchGradHigherOrderVariable(value, source, **kwargs)
|
||||
elif value.__name__ in (
|
||||
|
|
@ -852,6 +854,38 @@ class WrapHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
|||
)
|
||||
|
||||
|
||||
class OutDtypeHigherOrderVariable(TorchHigherOrderOperatorVariable):
|
||||
def call_function(
|
||||
self, tx, args: "List[VariableTracker]", kwargs: "Dict[str, VariableTracker]"
|
||||
) -> "VariableTracker":
|
||||
from .builder import wrap_fx_proxy
|
||||
|
||||
if len(kwargs) != 0:
|
||||
unimplemented("out_dtype does not handle kwargs")
|
||||
|
||||
p_args = tuple(arg.as_proxy() for arg in args)
|
||||
op = p_args[0]
|
||||
output_dtype = p_args[1]
|
||||
fake_sub_args = pytree.tree_map_only(
|
||||
torch.fx.Proxy, lambda a: get_fake_value(a.node, tx), p_args[2:]
|
||||
)
|
||||
# This is a simplified implementation of this operator just for tracing.
|
||||
# Actual implementation may also first promote the arguments
|
||||
example_value = op(*fake_sub_args).to(dtype=output_dtype)
|
||||
|
||||
# Store the invocation as a call
|
||||
return wrap_fx_proxy(
|
||||
tx=tx,
|
||||
proxy=tx.output.create_proxy(
|
||||
"call_function",
|
||||
self.value,
|
||||
args=tuple(p_args),
|
||||
kwargs={},
|
||||
),
|
||||
example_value=example_value,
|
||||
)
|
||||
|
||||
|
||||
class CheckpointHigherOrderVariable(WrapHigherOrderVariable):
|
||||
def call_function(
|
||||
self, tx, args: List[VariableTracker], kwargs: Dict[str, VariableTracker]
|
||||
|
|
|
|||
186
torch/_higher_order_ops/out_dtype.py
Normal file
186
torch/_higher_order_ops/out_dtype.py
Normal file
|
|
@ -0,0 +1,186 @@
|
|||
|
||||
import torch
|
||||
import torch.utils._pytree as pytree
|
||||
from torch.fx.experimental.proxy_tensor import (
|
||||
disable_proxy_modes_tracing,
|
||||
ProxyTorchDispatchMode,
|
||||
track_tensor_tree,
|
||||
)
|
||||
from torch.utils._python_dispatch import (
|
||||
_get_current_dispatch_mode,
|
||||
_pop_mode_temporarily,
|
||||
)
|
||||
from torch._C import DispatchKey, _ExcludeDispatchKeyGuard, DispatchKeySet
|
||||
from torch._functorch.eager_transforms import (
|
||||
_unwrap_all_tensors_from_functional,
|
||||
_wrap_all_tensors_to_functional,
|
||||
)
|
||||
from torch._ops import HigherOrderOperator
|
||||
from torch._subclasses.fake_tensor import FakeTensorMode
|
||||
from torch._prims_common import elementwise_dtypes, ELEMENTWISE_TYPE_PROMOTION_KIND
|
||||
|
||||
|
||||
# TODO to figure out a more generic approach
|
||||
ALLOWABLE_OPS = [
|
||||
torch.ops.aten.mm.default,
|
||||
torch.ops.aten.conv2d.default,
|
||||
torch.ops.aten.mul.Scalar,
|
||||
]
|
||||
|
||||
|
||||
class OutDtypeOperator(HigherOrderOperator):
|
||||
"""
|
||||
The out_dtype operator takes an existing ATen functional operator, an
|
||||
`out_dtype` argument, and arguments to the original operator, and executes
|
||||
the original operator and returns a Tensor with the `out_dtype` precision.
|
||||
This operator does not mandate a compute precision so it allows the
|
||||
representation to not be opinionated about the exact implementation.
|
||||
|
||||
The general implementation for all operators will be the following:
|
||||
1. Promote inputs dtypes based on default PyTorch dtype promotion rules,
|
||||
using the dtypes of all input Tensors/Scalars and the `out_dtype`
|
||||
arugument.
|
||||
2. Execute the operator
|
||||
3. Cast the output to `out_dtype`
|
||||
"""
|
||||
|
||||
|
||||
def __init__(self):
|
||||
super().__init__("out_dtype")
|
||||
# TODO(ydwu4): Subclassing HigherOrderOperator causes __module__ to
|
||||
# become different (torch._higher_order_ops.out_dtype) which will result
|
||||
# in torch.fx to record the op incorrectly in the graph.
|
||||
self.__module__ = "torch.ops.higher_order"
|
||||
|
||||
def __call__(self, op, output_dtype, *args):
|
||||
if not isinstance(op, torch._ops.OpOverload):
|
||||
raise ValueError("out_dtype's first argument must be an OpOverload")
|
||||
if op._schema.is_mutable:
|
||||
raise ValueError("out_dtype's first argument needs to be a functional operator")
|
||||
if not(
|
||||
len(op._schema.returns) == 1 and
|
||||
isinstance(op._schema.returns[0].type, torch.TensorType)
|
||||
):
|
||||
raise ValueError(
|
||||
"out_dtype's can only apply to ops that return a single tensor"
|
||||
f"Instead got {[r.type for r in op._schema.returns]}"
|
||||
)
|
||||
|
||||
if op not in ALLOWABLE_OPS:
|
||||
raise ValueError(
|
||||
f"out_dtype only allows the following operators: {ALLOWABLE_OPS}."
|
||||
)
|
||||
|
||||
res = super().__call__(op, output_dtype, *args)
|
||||
|
||||
return res
|
||||
|
||||
|
||||
out_dtype = OutDtypeOperator()
|
||||
out_dtype.fallthrough(DispatchKey.PythonDispatcher) # type: ignore[attr-defined]
|
||||
out_dtype.fallthrough(DispatchKey.PythonTLSSnapshot) # type: ignore[attr-defined]
|
||||
out_dtype.fallthrough(DispatchKey.ADInplaceOrView) # type: ignore[attr-defined]
|
||||
out_dtype.fallthrough(DispatchKey.BackendSelect) # type: ignore[attr-defined]
|
||||
out_dtype.fallthrough(DispatchKey.AutocastCPU) # type: ignore[attr-defined]
|
||||
|
||||
|
||||
def trace_out_dtype(proxy_mode, func_overload, op, output_dtype, *args):
|
||||
with disable_proxy_modes_tracing():
|
||||
# This is a simplified implementation of this operator just for tracing.
|
||||
# Actual implementation may also first promote the arguments
|
||||
out = op(*args).to(dtype=output_dtype)
|
||||
|
||||
node_args = (op, output_dtype, *args)
|
||||
proxy_args = pytree.tree_map(proxy_mode.tracer.unwrap_proxy, node_args)
|
||||
out_proxy = proxy_mode.tracer.create_proxy(
|
||||
"call_function", func_overload, proxy_args, {}, name="out_dtype"
|
||||
)
|
||||
return track_tensor_tree(out, out_proxy, constant=None, tracer=proxy_mode.tracer)
|
||||
|
||||
|
||||
@out_dtype.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||
def out_dtype_dense(
|
||||
op: torch._ops.OpOverload,
|
||||
output_dtype: torch.dtype,
|
||||
*args
|
||||
):
|
||||
flat_inputs = pytree.tree_flatten(args)[0] + [torch.ones(1, dtype=output_dtype)]
|
||||
promote_dtype: torch.dtype = elementwise_dtypes(
|
||||
*flat_inputs,
|
||||
type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
|
||||
)[0]
|
||||
|
||||
casted_args = pytree.tree_map_only(
|
||||
torch.Tensor, lambda arg: arg.to(dtype=promote_dtype), args
|
||||
)
|
||||
res = op(*casted_args).to(dtype=output_dtype)
|
||||
return res
|
||||
|
||||
|
||||
@out_dtype.py_impl(DispatchKey.Autograd)
|
||||
def out_dtype_autograd(
|
||||
op: torch._ops.OpOverload,
|
||||
output_dtype: torch.dtype,
|
||||
*args
|
||||
):
|
||||
# TODO: maybe support autograd
|
||||
flat_operands, _ = pytree.tree_flatten(args)
|
||||
if torch.is_grad_enabled() and any(
|
||||
f.requires_grad for f in flat_operands if isinstance(f, torch.Tensor)
|
||||
):
|
||||
raise RuntimeError("Autograd is not supported for out_dtype")
|
||||
|
||||
with torch._C._AutoDispatchBelowAutograd():
|
||||
return out_dtype(op, output_dtype, *args)
|
||||
|
||||
|
||||
@out_dtype.py_impl(ProxyTorchDispatchMode)
|
||||
def out_dtype_proxy(
|
||||
op: torch._ops.OpOverload,
|
||||
output_dtype: torch.dtype,
|
||||
*args
|
||||
):
|
||||
mode = _get_current_dispatch_mode()
|
||||
assert (mode is not None), "Mode should always be enabled for python fallback key"
|
||||
with _pop_mode_temporarily() as mode:
|
||||
if mode.enable_tracing:
|
||||
return trace_out_dtype(mode, out_dtype, op, output_dtype, *args)
|
||||
else:
|
||||
return out_dtype(op, output_dtype, *args)
|
||||
|
||||
|
||||
@out_dtype.py_impl(FakeTensorMode)
|
||||
def out_dtype_fake_tensor_mode(
|
||||
op: torch._ops.OpOverload,
|
||||
output_dtype: torch.dtype,
|
||||
*args
|
||||
):
|
||||
return out_dtype_dense(op, output_dtype, *args)
|
||||
|
||||
|
||||
@out_dtype.py_impl(DispatchKey.Functionalize)
|
||||
def out_dtype_func1(op, output_dtype, *args):
|
||||
reapply_views = torch._C._functionalization_reapply_views_tls()
|
||||
# At this point, we will see functionalized tensors, so need to unwrap them first
|
||||
unwrapped_args = tuple(
|
||||
_unwrap_all_tensors_from_functional(arg, reapply_views=reapply_views)
|
||||
for arg in args
|
||||
)
|
||||
# pyre-ignore
|
||||
with _ExcludeDispatchKeyGuard(DispatchKeySet(DispatchKey.Functionalize)):
|
||||
res = out_dtype(op, output_dtype, *unwrapped_args)
|
||||
return _wrap_all_tensors_to_functional(res, level=0)
|
||||
|
||||
|
||||
@out_dtype.py_impl(torch._C._functorch.TransformType.Functionalize)
|
||||
def out_dtype_func2(interpreter, op, output_dtype, *args):
|
||||
reapply_views = interpreter.functionalize_add_back_views()
|
||||
# At this point, we will see functionalized tensors, so need to unwrap them first
|
||||
unwrapped_args = tuple(
|
||||
_unwrap_all_tensors_from_functional(arg, reapply_views=reapply_views)
|
||||
for arg in args
|
||||
)
|
||||
|
||||
with interpreter.lower():
|
||||
res = out_dtype(op, output_dtype, *unwrapped_args)
|
||||
return _wrap_all_tensors_to_functional(res, level=interpreter.level())
|
||||
|
|
@ -258,6 +258,9 @@ class TracerBase:
|
|||
elif isinstance(a, range):
|
||||
return range(self.create_arg(a.start), self.create_arg(a.stop), self.create_arg(a.step))
|
||||
|
||||
elif isinstance(a, torch._ops.OpOverload):
|
||||
return a
|
||||
|
||||
if isinstance(a, Proxy):
|
||||
# base case: we unwrap the Proxy object
|
||||
return a.node
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user