mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
This diff tries to fix a limitation in Sigmoid + Triton interaction, where float arguments are not correctly passed. NativeRT passes float arguments as double, while triton kernels were reading as a float, resulting in wrong values. --- ## Limitations in (de)seriazliation In triton, float arguments to a kernel are encoded as "fp32" ([code](https://github.com/triton-lang/triton-cpu/blob/main-merged/python/triton/runtime/jit.py#L310-L326)): ``` elif isinstance(arg, float): return ("fp32", None) ``` But it seems like that torch export serde uses double ([code](d2eff5d454/torch/_export/serde/export_schema.thrift (L149))) because Thrift only has the double type: ``` union Argument { 10: bool as_none; 20: TensorArgument as_tensor; 30: list<TensorArgument> as_tensors; 50: i64 as_int; 70: list<i64> as_ints; 80: double as_float; ===> actually double ... ``` `TritonKernel` constructor loads attributes from a node, where `Constant` represents the variant type. And it only has `double` ([code](d2eff5d454/torch/nativert/graph/Graph.h (L86))): ``` using Constant = std::variant< None, int64_t, std::vector<int64_t>, double, ===> triton float is loaded as double ``` So, NativeRT passes float arguments (originally in Triton) as double to triton kernels. But, all of the triton backends (nvidia, amd and cpu) are reading them as float because the signature still says `fp32`. D84423898 was the current workaround: wrapping float arguments with tensors. ## The Fix Fixing the thrift definition isn't viable because Thrift only supports double type. It's also possible to fix on the triton side: it can downcast from double to float. But I needed to fix all backends. Instead, I think this diff would be the most effective way: when building `TritonKernel`, have downcasted float values, right after loading double arguments. Test Plan: ``` buck test fbcode//mode/opt-amd-gpu fbcode//caffe2/test:test_export -- ``` Differential Revision: D85747160 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166620 Approved by: https://github.com/XueningXu
2338 lines
82 KiB
Python
2338 lines
82 KiB
Python
"""
|
|
PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
|
|
with test_sym_bool)
|
|
"""
|
|
|
|
# Owner(s): ["oncall: export"]
|
|
import copy
|
|
import io
|
|
import math
|
|
import tempfile
|
|
import unittest
|
|
import zipfile
|
|
from collections import namedtuple
|
|
from pathlib import Path
|
|
from typing import NamedTuple
|
|
|
|
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU
|
|
from torch.testing._internal.triton_utils import requires_gpu
|
|
|
|
|
|
if HAS_GPU:
|
|
import triton
|
|
import triton.language as tl
|
|
|
|
from torch.library import wrap_triton
|
|
from torch.utils._triton import has_triton
|
|
|
|
import torch
|
|
import torch._dynamo as torchdynamo
|
|
import torch._export.serde.schema as schema
|
|
import torch.export._trace
|
|
import torch.utils._pytree as pytree
|
|
from torch._export.db.case import ExportCase, SupportLevel
|
|
from torch._export.db.examples import all_examples
|
|
from torch._export.serde.schema import ArgumentKind
|
|
from torch._export.serde.serialize import (
|
|
_dict_to_dataclass,
|
|
_to_json_bytes,
|
|
canonicalize,
|
|
deserialize,
|
|
ExportedProgramDeserializer,
|
|
ExportedProgramSerializer,
|
|
GraphModuleSerializer,
|
|
serialize,
|
|
SerializeError,
|
|
)
|
|
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
|
|
from torch._subclasses.fake_tensor import FakeTensor, FakeTensorMode
|
|
from torch.export import Dim, export, load, save, unflatten
|
|
from torch.export.pt2_archive.constants import ARCHIVE_VERSION_PATH
|
|
from torch.fx.experimental.symbolic_shapes import is_concrete_int, ValueRanges
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
IS_FBCODE,
|
|
IS_MACOS,
|
|
IS_WINDOWS,
|
|
parametrize,
|
|
run_tests,
|
|
TemporaryFileName,
|
|
TestCase,
|
|
)
|
|
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
|
|
|
|
|
|
def get_filtered_export_db_tests():
|
|
return [
|
|
(name, case)
|
|
for name, case in all_examples().items()
|
|
if case.support_level == SupportLevel.SUPPORTED
|
|
]
|
|
|
|
|
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
|
class TestSerialize(TestCase):
|
|
def test_export_with_extension_op_serialization(self):
|
|
class TestModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + x
|
|
|
|
class FooExtensionOp:
|
|
def __hash__(self):
|
|
return 0
|
|
|
|
def __eq__(self, other):
|
|
return type(other) is type(self)
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
return torch.ops.aten.add.Tensor(*args, **kwargs)
|
|
|
|
@property
|
|
def __name__(self):
|
|
return "foo.my_op"
|
|
|
|
class ExtensionVerifier(torch._export.verifier.Verifier):
|
|
dialect = "FOO"
|
|
|
|
def allowed_op_types(self):
|
|
return super().allowed_op_types() + (FooExtensionOp,)
|
|
|
|
class FooExtensionHandler(torch._export.serde.serialize.ExtensionHandler):
|
|
@classmethod
|
|
def namespace(cls):
|
|
return "foo"
|
|
|
|
@classmethod
|
|
def to_op_name(cls, op):
|
|
return "my_op"
|
|
|
|
@classmethod
|
|
def from_op_name(cls, name: str):
|
|
self.assertEqual(name, "my_op")
|
|
return FooExtensionOp()
|
|
|
|
@classmethod
|
|
def op_schema(cls, op):
|
|
return torch.ops.aten.add.Tensor._schema
|
|
|
|
inp = (torch.ones(10),)
|
|
ep = export(TestModule(), inp, strict=True)
|
|
|
|
# Register the custom op handler.
|
|
foo_custom_op = FooExtensionOp()
|
|
torch._export.serde.serialize.register_extension(
|
|
FooExtensionOp, FooExtensionHandler
|
|
)
|
|
|
|
new_gm = copy.deepcopy(ep.graph_module)
|
|
# Inject the custom operator.
|
|
for node in new_gm.graph.nodes:
|
|
if node.name == "add":
|
|
node.target = foo_custom_op
|
|
|
|
new_ep = ep._update(new_gm, ep.graph_signature, verifiers=[ExtensionVerifier])
|
|
serialized = serialize(new_ep)
|
|
deserialized = deserialize(serialized)
|
|
self.assertEqual(
|
|
len(
|
|
deserialized.graph.find_nodes(op="call_function", target=foo_custom_op)
|
|
),
|
|
1,
|
|
)
|
|
|
|
def test_predispatch_export_with_autograd_op(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
with torch.enable_grad():
|
|
return x + x
|
|
|
|
inp = (torch.ones(10),)
|
|
with torch.no_grad():
|
|
from torch.export._trace import _export
|
|
|
|
ep = _export(Foo(), inp, pre_dispatch=True)
|
|
|
|
buffer = io.BytesIO()
|
|
torch.export.save(ep, buffer)
|
|
buffer.seek(0)
|
|
loaded_ep = torch.export.load(buffer)
|
|
|
|
exp_out = ep.module()(*inp)
|
|
actual_out = loaded_ep.module()(*inp)
|
|
self.assertEqual(exp_out, actual_out)
|
|
self.assertEqual(exp_out.requires_grad, actual_out.requires_grad)
|
|
|
|
def test_export_example_inputs_preserved(self):
|
|
class MyModule(torch.nn.Module):
|
|
"""A test module with that has multiple args and uses kwargs"""
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.p = torch.nn.Parameter(torch.ones(2, 3))
|
|
|
|
def forward(self, x, y, use_p=False):
|
|
out = x + y
|
|
if use_p:
|
|
out += self.p
|
|
return out
|
|
|
|
model = MyModule().eval()
|
|
random_inputs = (torch.rand([2, 3]), torch.rand([2, 3]))
|
|
exp_program = export(model, random_inputs, {"use_p": True}, strict=True)
|
|
|
|
output_buffer = io.BytesIO()
|
|
# Tests that example inputs are preserved when saving and loading module.
|
|
torch.export.save(exp_program, output_buffer)
|
|
loaded_model = torch.export.load(output_buffer)
|
|
# Extract the example inputs from before and after saving.
|
|
orig_args, orig_kwargs = exp_program.example_inputs
|
|
loaded_args, loaded_kwargs = loaded_model.example_inputs
|
|
# Run both modules and confirm that outputs match.
|
|
orig_out = exp_program.module()(*orig_args, **orig_kwargs)
|
|
loaded_out = loaded_model.module()(*loaded_args, **loaded_kwargs)
|
|
self.assertEqual(orig_out, loaded_out)
|
|
|
|
def test_metadata_run_decomp_serder(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.sin()
|
|
|
|
exp_program = export(M(), (torch.randn(4, 4),), strict=True)
|
|
|
|
output_buffer = io.BytesIO()
|
|
# Tests that example forward arg names are preserved when saving and loading module.
|
|
torch.export.save(exp_program, output_buffer)
|
|
loaded_model = torch.export.load(output_buffer)
|
|
|
|
ep = loaded_model.run_decompositions({})
|
|
# We should preserve the original module name
|
|
self.assertExpectedInline(
|
|
str(ep.graph_module.code).strip(),
|
|
"""\
|
|
def forward(self, x):
|
|
sin = torch.ops.aten.sin.default(x); x = None
|
|
return (sin,)""",
|
|
)
|
|
|
|
def test_metadata_parsing_with_layer_split(self):
|
|
# Tests that modules with more complicated layer patterns can be serialized
|
|
# and deserialized correctly.
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.layers = torch.nn.Sequential(
|
|
torch.nn.SiLU(),
|
|
torch.nn.SiLU(),
|
|
torch.nn.SiLU(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
# Splitting layers of a sequential stack introduces commas and parens
|
|
# into metadata trace.
|
|
out_start, out_rest = self.layers[0], self.layers[1:]
|
|
h = out_start(x)
|
|
h = out_rest(h)
|
|
return h
|
|
|
|
inp = (torch.ones(10),)
|
|
# Module will only be able to roundtrip if metadata
|
|
# can be correctly parsed.
|
|
ep = export(MyModule(), inp, strict=True)
|
|
buffer = io.BytesIO()
|
|
save(ep, buffer)
|
|
loaded_ep = load(buffer)
|
|
|
|
# Check that both modules run to confirm load was successful.
|
|
exp_out = ep.module()(*inp)
|
|
actual_out = loaded_ep.module()(*inp)
|
|
self.assertEqual(exp_out, actual_out)
|
|
|
|
def test_nested_layer_split(self):
|
|
class Bar(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.layers = torch.nn.Sequential(
|
|
torch.nn.SiLU(),
|
|
torch.nn.SiLU(),
|
|
torch.nn.SiLU(),
|
|
)
|
|
|
|
def forward(self, x):
|
|
out_start, out_rest = self.layers[0], self.layers[1:]
|
|
h = out_start(x)
|
|
h = out_rest(h) + 2
|
|
return h
|
|
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.register_module("a[(1)]", Bar())
|
|
self.register_module("b[(2)]", Bar())
|
|
self.register_buffer("c:[22]", torch.randn(1))
|
|
|
|
def forward(self, x):
|
|
out_a, out_b = getattr(self, "a[(1)]"), getattr(self, "b[(2)]")
|
|
out_c = getattr(self, "c:[22]")
|
|
h = out_a(x)
|
|
h = out_b(h)
|
|
return h + out_c
|
|
|
|
inp = (torch.ones(10),)
|
|
ep = export(Foo(), inp, strict=True)
|
|
buffer = io.BytesIO()
|
|
save(ep, buffer)
|
|
loaded_ep = load(buffer)
|
|
|
|
# Check that both modules run to confirm load was successful.
|
|
exp_out = ep.module()(*inp)
|
|
actual_out = loaded_ep.module()(*inp)
|
|
self.assertEqual(exp_out, actual_out)
|
|
|
|
def test_serialize_param_mutation(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.parameter = torch.nn.Parameter(torch.ones(4, 4))
|
|
|
|
def forward(self, x):
|
|
with torch.no_grad():
|
|
self.parameter.div_(2)
|
|
return x + self.parameter
|
|
|
|
foo = Foo()
|
|
ep = torch.export.export(foo, (torch.rand(4, 4),)).run_decompositions()
|
|
buffer = io.BytesIO()
|
|
save(ep, buffer)
|
|
loaded_ep = load(buffer)
|
|
val = loaded_ep.graph_signature.parameters_to_mutate
|
|
self.assertEqual({"div": "parameter"}, val)
|
|
|
|
def test_serialize_constant_outputs(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
# Along with tensor output, return Nonetype
|
|
# and constant. Although these outputs aren't
|
|
# very useful, they do show up in graphs.
|
|
return x + 1, None, 1024
|
|
|
|
# Check that module can be roundtripped, thereby confirming proper deserialization.
|
|
inp = (torch.ones(10),)
|
|
ep = export(MyModule(), inp, strict=True)
|
|
buffer = io.BytesIO()
|
|
save(ep, buffer)
|
|
loaded_ep = load(buffer)
|
|
|
|
exp_out = ep.module()(*inp)
|
|
actual_out = loaded_ep.module()(*inp)
|
|
self.assertEqual(exp_out, actual_out)
|
|
|
|
def test_serialize_multiple_returns_from_node(self) -> None:
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, w, b):
|
|
return torch.nn.functional.layer_norm(
|
|
x,
|
|
x.size()[1:],
|
|
weight=w,
|
|
bias=b,
|
|
eps=1e-5,
|
|
)
|
|
|
|
exported_module = export(
|
|
MyModule(),
|
|
(
|
|
torch.ones([512, 512], requires_grad=True),
|
|
torch.ones([512]),
|
|
torch.ones([512]),
|
|
),
|
|
strict=True,
|
|
).run_decompositions()
|
|
|
|
serialized = ExportedProgramSerializer().serialize(exported_module)
|
|
node = serialized.exported_program.graph_module.graph.nodes[-1]
|
|
self.assertEqual(node.target, "torch.ops.aten.native_layer_norm.default")
|
|
# aten::native_layer_norm returns 3 tensors
|
|
self.assertEqual(len(node.outputs), 3)
|
|
|
|
# check the names are unique
|
|
seen = set()
|
|
for output in node.outputs:
|
|
name = output.as_tensor.name
|
|
self.assertNotIn(name, seen)
|
|
seen.add(name)
|
|
|
|
def test_serialize_sym_int(self) -> None:
|
|
class DynamicShapeSimpleModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, a, b, c) -> torch.Tensor:
|
|
d = (torch.matmul(a, b) + c) / 2
|
|
d_s0 = d.shape[0]
|
|
d_s1 = d.shape[1]
|
|
d_s3 = d_s0 * d_s1
|
|
e = d.view(d_s3)
|
|
return torch.cat([e, e])
|
|
|
|
inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
|
|
dim0_ac = torch.export.Dim("dim0_ac")
|
|
dim1_bc = torch.export.Dim("dim1_b")
|
|
dynamic_shapes = {
|
|
"a": {0: dim0_ac},
|
|
"b": {1: dim1_bc},
|
|
"c": {0: dim0_ac, 1: dim1_bc},
|
|
}
|
|
exported_module = export(
|
|
DynamicShapeSimpleModel(),
|
|
inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
strict=True,
|
|
).run_decompositions()
|
|
serialized = ExportedProgramSerializer().serialize(exported_module)
|
|
sym_size_nodes = [
|
|
node
|
|
for node in serialized.exported_program.graph_module.graph.nodes
|
|
if node.target == "torch.ops.aten.sym_size.int"
|
|
]
|
|
for node in sym_size_nodes:
|
|
self.assertEqual(node.inputs[0].name, "self")
|
|
self.assertEqual(node.inputs[1].name, "dim")
|
|
|
|
def test_serialize_sym_float(self) -> None:
|
|
# TODO(rec): This doesn't seem to test anything!
|
|
|
|
class DynamicFloatSimpleModel(torch.nn.Module):
|
|
def __init__(self, multiplier: torch.SymFloat):
|
|
super().__init__()
|
|
self.multiplier = multiplier
|
|
|
|
def forward(self, a, b, c) -> torch.Tensor:
|
|
d = (torch.matmul(a, b) + c) / 2
|
|
e = d * self.multiplier
|
|
e_s0 = e.shape[0]
|
|
e_s1 = e.shape[1]
|
|
e_s3 = e_s0 * e_s1
|
|
f = e.view(e_s3)
|
|
return torch.cat([f, f])
|
|
|
|
multiplier_sym = torch.SymFloat("multiplier_sym")
|
|
_model = DynamicFloatSimpleModel(multiplier_sym)
|
|
_inputs = (
|
|
torch.randn(2, 4),
|
|
torch.randn(4, 7),
|
|
torch.randn(2, 7),
|
|
)
|
|
_dim0_ac = Dim("dim0_ac")
|
|
_dim1_bc = Dim("dim1_b")
|
|
|
|
def test_serialize_infinite_sym_int(self) -> None:
|
|
class DynamicShapeSimpleModel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, a, b, c) -> torch.Tensor:
|
|
d = (torch.matmul(a, b) + c) / 2
|
|
d_s0 = d.shape[0]
|
|
d_s1 = d.shape[1]
|
|
d_s3 = d_s0 * d_s1
|
|
e = d.view(d_s3)
|
|
return torch.cat([e, e])
|
|
|
|
inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
|
|
dim0_ac = torch.export.Dim("dim0_ac")
|
|
dim1_bc = torch.export.Dim("dim1_b")
|
|
dynamic_shapes = {
|
|
"a": {0: dim0_ac},
|
|
"b": {1: dim1_bc},
|
|
"c": {0: dim0_ac, 1: dim1_bc},
|
|
}
|
|
exported_module = export(
|
|
DynamicShapeSimpleModel(),
|
|
inputs,
|
|
dynamic_shapes=dynamic_shapes,
|
|
strict=True,
|
|
).run_decompositions()
|
|
serialized = ExportedProgramSerializer().serialize(exported_module)
|
|
for v in serialized.exported_program.range_constraints.values():
|
|
self.assertEqual(v.max_val, None)
|
|
|
|
def test_symint_list(self):
|
|
# This reflects the behavior from inductor's ExternFallbackNode
|
|
shape_env = torch.fx.experimental.symbolic_shapes.ShapeEnv()
|
|
symint = shape_env.create_unbacked_symint()
|
|
serializer = GraphModuleSerializer(None, None) # type: ignore[arg-type]
|
|
res = serializer.serialize_inputs(
|
|
torch.ops.aten.ones.default, ([1, symint, 3],), {}
|
|
)
|
|
self.assertEqual(len(res), 1)
|
|
self.assertEqual(res[0].arg._type, "as_sym_ints")
|
|
|
|
def test_serialize_list_returns(self) -> None:
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return torch.split(x, 2)
|
|
|
|
input = torch.arange(10.0).reshape(5, 2)
|
|
exported_module = export(MyModule(), (input,), strict=True).run_decompositions()
|
|
|
|
serialized = ExportedProgramSerializer().serialize(exported_module)
|
|
node = serialized.exported_program.graph_module.graph.nodes[-1]
|
|
# split.Tensor gets decomposed to split_with_sizes by the core ATen decomposition table
|
|
self.assertEqual(node.target, "torch.ops.aten.split_with_sizes.default")
|
|
self.assertEqual(len(node.outputs), 1)
|
|
# Input looks like:
|
|
# tensor([[0, 1],
|
|
# [2, 3],
|
|
# [4, 5],
|
|
# [6, 7],
|
|
# [8, 9]])
|
|
# Output looks like:
|
|
# (tensor([[0, 1],
|
|
# [2, 3]]),
|
|
# tensor([[4, 5],
|
|
# [6, 7]]),
|
|
# tensor([[8, 9]]))
|
|
self.assertEqual(len(node.outputs[0].as_tensors), 3)
|
|
|
|
# check the names are unique
|
|
seen = set()
|
|
for output in node.outputs[0].as_tensors:
|
|
name = output.name
|
|
self.assertNotIn(name, seen)
|
|
seen.add(name)
|
|
|
|
def test_nonfinite_inputs(self) -> None:
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x):
|
|
x = torch.ops.aten.add.Scalar(x, math.inf)
|
|
x = torch.ops.aten.add.Scalar(x, -math.inf)
|
|
return torch.ops.aten.add.Scalar(x, math.nan)
|
|
|
|
fn = Module()
|
|
ep = torch.export.export(
|
|
fn,
|
|
(torch.randn(3, 2),),
|
|
)
|
|
json_bytes = _to_json_bytes(
|
|
ExportedProgramSerializer().serialize(ep).exported_program
|
|
)
|
|
import json
|
|
|
|
def parse_constant(x):
|
|
raise RuntimeError(f"Invalid JSON float: {x}")
|
|
|
|
json.loads(json_bytes, parse_constant=parse_constant)
|
|
|
|
def test_multi_return_some_unused(self) -> None:
|
|
"""
|
|
Make sure the serialized output matches the op schema, even if some of
|
|
the arguments are never used in the graph.
|
|
"""
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return torch.ops.aten.var_mean.correction(x, [1])[0]
|
|
|
|
exported_module = export(
|
|
MyModule(), (torch.ones([512, 512], requires_grad=True),), strict=True
|
|
).run_decompositions()
|
|
|
|
serialized = ExportedProgramSerializer().serialize(exported_module)
|
|
node = serialized.exported_program.graph_module.graph.nodes[-1]
|
|
self.assertEqual(node.target, "torch.ops.aten.var_mean.correction")
|
|
self.assertEqual(len(node.outputs), 2)
|
|
|
|
# check the names are unique
|
|
seen = set()
|
|
for output in node.outputs:
|
|
name = output.as_tensor.name
|
|
self.assertNotIn(name, seen)
|
|
seen.add(name)
|
|
|
|
def test_rational_ranges(self) -> None:
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + x
|
|
|
|
ep = export(
|
|
M(), (torch.randn(4),), dynamic_shapes=({0: Dim("temp")},), strict=True
|
|
)
|
|
|
|
range_constraints = list(ep.range_constraints.keys())
|
|
assert len(range_constraints) == 1
|
|
symint = range_constraints[0]
|
|
|
|
import sympy
|
|
|
|
upper_range = sympy.Rational(10, 3)
|
|
lower_range = sympy.Rational(10, 6)
|
|
ep.range_constraints[symint] = ValueRanges(lower=lower_range, upper=upper_range)
|
|
|
|
serialized = ExportedProgramSerializer().serialize(ep)
|
|
self.assertEqual(
|
|
serialized.exported_program.range_constraints[symint.name].min_val, 2
|
|
)
|
|
self.assertEqual(
|
|
serialized.exported_program.range_constraints[symint.name].max_val, 3
|
|
)
|
|
|
|
@unittest.skipIf(
|
|
not torch.cuda.is_available() or not has_triton(), "requires cuda and triton"
|
|
)
|
|
def test_triton_hop(self) -> None:
|
|
@triton.jit
|
|
def add_kernel(
|
|
in_ptr0,
|
|
in_ptr1,
|
|
out_ptr,
|
|
n_elements,
|
|
fval,
|
|
ival,
|
|
BLOCK_SIZE: "tl.constexpr",
|
|
):
|
|
pid = tl.program_id(axis=0)
|
|
block_start = pid * BLOCK_SIZE
|
|
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
|
mask = offsets < n_elements
|
|
x = tl.load(in_ptr0 + offsets, mask=mask)
|
|
y = tl.load(in_ptr1 + offsets, mask=mask)
|
|
output = x + y + fval + ival
|
|
tl.store(out_ptr + offsets, output, mask=mask)
|
|
|
|
def custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
output = torch.empty_like(x)
|
|
n_elements = output.numel()
|
|
|
|
def grid(meta):
|
|
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
|
|
wrap_triton(add_kernel)[grid](
|
|
x, y, output, n_elements, 3.14, 42, BLOCK_SIZE=16
|
|
)
|
|
|
|
return output
|
|
|
|
class MyModel(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return custom_add(x, y)
|
|
|
|
def custom_add_autotune(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
output = torch.empty_like(x)
|
|
n_elements = output.numel()
|
|
|
|
def grid(meta):
|
|
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
|
|
|
wrap_triton(add_kernel)[grid](
|
|
x, y, output, n_elements, 3.14, 42, BLOCK_SIZE=16, num_warps=8
|
|
)
|
|
|
|
return output
|
|
|
|
class MyModelAutotune(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return custom_add_autotune(x, y)
|
|
|
|
device = "cuda"
|
|
|
|
for m in [MyModel().to(device), MyModelAutotune().to(device)]:
|
|
args = (torch.randn(3, device=device), torch.randn(3, device=device))
|
|
ep = torch.export.export(m, args=args)
|
|
ep = ep.run_decompositions(decompose_custom_triton_ops=False)
|
|
assert torch.allclose(m(*args), ep.module()(*args))
|
|
|
|
serialized = ExportedProgramSerializer().serialize(ep)
|
|
|
|
for node in serialized.exported_program.graph_module.graph.nodes:
|
|
if (
|
|
node.target
|
|
== "torch.ops.higher_order.triton_kernel_wrapper_functional"
|
|
):
|
|
triton_node = node
|
|
|
|
self.assertIsNotNone(triton_node)
|
|
|
|
args = []
|
|
kwargs = {}
|
|
|
|
for arg in triton_node.inputs:
|
|
if arg.kind == ArgumentKind.POSITIONAL:
|
|
args.append(arg.arg)
|
|
elif arg.kind == ArgumentKind.KEYWORD:
|
|
kwargs[arg.name] = arg.arg
|
|
|
|
self.assertEqual(len(args), 6)
|
|
# Always: name, grid, output_indices and num_warps are
|
|
# Triton version dependent: num_cpu_threads, shared_memory_bytes
|
|
self.assertTrue(len(kwargs) >= 4)
|
|
|
|
for i in range(3):
|
|
self.assertIsNotNone(args[i].as_tensor)
|
|
|
|
self.assertEqual(args[3].as_int, 3)
|
|
self.assertAlmostEqual(args[4].as_float, 3.14, places=2)
|
|
self.assertEqual(args[5].as_int, 42)
|
|
kernel_name = kwargs["name"].as_string
|
|
symbol_name = kernel_name.rpartition("_")[0]
|
|
self.assertEqual(symbol_name, "add_kernel")
|
|
self.assertEqual(kwargs["grid"].as_ints, [1, 1, 1])
|
|
self.assertEqual(kwargs["output_indices"].as_ints, [2])
|
|
self.assertEqual(
|
|
kwargs["num_warps"].as_int, 8 if isinstance(m, MyModelAutotune) else 4
|
|
)
|
|
|
|
if "num_cpu_threads" in kwargs:
|
|
self.assertEqual(kwargs["num_cpu_threads"].as_int, 0)
|
|
if "shared_memory_bytes" in kwargs:
|
|
self.assertEqual(kwargs["shared_memory_bytes"].as_int, 0)
|
|
|
|
self.assertEqual(len(triton_node.outputs), 1)
|
|
self.assertIsNotNone(triton_node.outputs[0].as_tensors)
|
|
self.assertEqual(
|
|
len(triton_node.outputs[0].as_tensors),
|
|
len(kwargs["output_indices"].as_ints),
|
|
)
|
|
self.assertEqual(triton_node.outputs[0].as_tensors[0].name, "getitem")
|
|
|
|
with self.assertRaisesRegex(
|
|
SerializeError,
|
|
"deserialize nyi for torch._higher_order_ops.triton_kernel_wrap.triton_kernel_wrapper_functional",
|
|
):
|
|
ExportedProgramDeserializer().deserialize(
|
|
serialized.exported_program,
|
|
serialized.state_dict,
|
|
serialized.constants,
|
|
serialized.example_inputs,
|
|
)
|
|
|
|
def test_kwargs_default(self) -> None:
|
|
"""
|
|
Tests that the kwargs default values are serialized even if they are not
|
|
specified
|
|
"""
|
|
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
values = torch.randn(3, 2)
|
|
return torch.searchsorted(x, values, side="right", right=True)
|
|
|
|
f = Foo()
|
|
|
|
x, _ = torch.sort(torch.randn(3, 4))
|
|
exported_module = export(f, (x,), strict=True).run_decompositions()
|
|
serialized = ExportedProgramSerializer().serialize(exported_module)
|
|
|
|
node = serialized.exported_program.graph_module.graph.nodes[-1]
|
|
self.assertEqual(node.target, "torch.ops.aten.searchsorted.Tensor")
|
|
self.assertEqual(len(node.inputs), 4)
|
|
self.assertEqual(node.inputs[2].name, "right")
|
|
self.assertEqual(node.inputs[2].arg.as_bool, True)
|
|
self.assertEqual(node.inputs[3].name, "side")
|
|
self.assertEqual(node.inputs[3].arg.as_string, "right")
|
|
|
|
def test_canonicalize(self) -> None:
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
a = y + x
|
|
b = x + y
|
|
return b + a
|
|
|
|
ep = export(Module(), (torch.randn(3, 2), torch.randn(3, 2)), strict=True)
|
|
s = ExportedProgramSerializer().serialize(ep)
|
|
c = canonicalize(s.exported_program)
|
|
g = c.graph_module.graph
|
|
self.assertLess(
|
|
g.nodes[0].inputs[0].arg.as_tensor.name,
|
|
g.nodes[1].inputs[0].arg.as_tensor.name,
|
|
)
|
|
|
|
def test_int_list(self) -> None:
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.ops.aten.sum.dim_IntList(x, [])
|
|
|
|
ep = torch.export.export(M(), (torch.randn(3, 2),), strict=True)
|
|
serialized = ExportedProgramSerializer().serialize(ep)
|
|
for node in serialized.exported_program.graph_module.graph.nodes:
|
|
if "aten.sum.dim_IntList" in node.target:
|
|
self.assertEqual(node.inputs[1].arg.type, "as_ints")
|
|
|
|
def test_empty_constant(self) -> None:
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
m = M()
|
|
sample_inputs = (torch.randn(1, 4),)
|
|
eager_out = m(*sample_inputs)
|
|
ep = torch.export.export(m, sample_inputs)
|
|
buffer = io.BytesIO()
|
|
torch.export.save(ep, buffer)
|
|
buffer.seek(0)
|
|
loaded_ep = torch.export.load(buffer)
|
|
ep_out = loaded_ep.module()(*sample_inputs)
|
|
self.assertTrue(torch.allclose(eager_out, ep_out))
|
|
self.assertEqual(len(loaded_ep.constants), 0)
|
|
|
|
def test_empty_state_dict(self) -> None:
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.const = torch.randn(4, 4)
|
|
|
|
def forward(self, x):
|
|
return x + self.const
|
|
|
|
m = M()
|
|
sample_inputs = (torch.randn(4, 4),)
|
|
eager_out = m(*sample_inputs)
|
|
ep = torch.export.export(m, sample_inputs)
|
|
buffer = io.BytesIO()
|
|
torch.export.save(ep, buffer)
|
|
buffer.seek(0)
|
|
loaded_ep = torch.export.load(buffer)
|
|
ep_out = loaded_ep.module()(*sample_inputs)
|
|
self.assertTrue(torch.allclose(eager_out, ep_out))
|
|
self.assertEqual(len(loaded_ep.state_dict), 0)
|
|
|
|
def test_preserve_aliasing(self) -> None:
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(8, 8)
|
|
self.linear2 = self.linear1 # alias of linear1
|
|
self.register_buffer("buffer1", torch.randn(8, 8))
|
|
self.register_buffer("buffer2", torch.randn(8, 8), persistent=False)
|
|
self.const1 = torch.ones(8, 8)
|
|
self.const2 = self.const1.diagonal() # a partial view of const1
|
|
|
|
def forward(self, x):
|
|
return (
|
|
self.linear1(x)
|
|
+ self.linear2(x)
|
|
+ self.buffer1
|
|
+ self.buffer2
|
|
+ self.const1
|
|
+ self.const2
|
|
)
|
|
|
|
m = M()
|
|
sample_inputs = (torch.randn(1, 8),)
|
|
ep = torch.export.export(m, sample_inputs)
|
|
buffer = io.BytesIO()
|
|
torch.export.save(ep, buffer)
|
|
buffer.seek(0)
|
|
loaded_ep = torch.export.load(buffer)
|
|
eager_out = m(*sample_inputs)
|
|
epm = loaded_ep.module()
|
|
ep_out = epm(*sample_inputs)
|
|
self.assertTrue(torch.allclose(eager_out, ep_out))
|
|
|
|
# loaded_ep should preserve the aliasing info
|
|
self.assertEqual(
|
|
loaded_ep.state_dict["linear1.weight"].untyped_storage(),
|
|
loaded_ep.state_dict["linear2.weight"].untyped_storage(),
|
|
)
|
|
self.assertEqual(
|
|
loaded_ep.state_dict["linear1.bias"].untyped_storage(),
|
|
loaded_ep.state_dict["linear2.bias"].untyped_storage(),
|
|
)
|
|
self.assertEqual(
|
|
loaded_ep.constants["const1"].untyped_storage(),
|
|
loaded_ep.constants["const2"].untyped_storage(),
|
|
)
|
|
# verify const1 and const2 share the same storage
|
|
loaded_ep.constants["const1"][0][0] = 123
|
|
self.assertEqual(loaded_ep.constants["const2"][0], 123)
|
|
loaded_ep.constants["const2"][-1] = 321
|
|
self.assertEqual(loaded_ep.constants["const1"][-1][-1], 321)
|
|
|
|
# unlifted module should also preserve the aliasing info
|
|
epm = loaded_ep.module()
|
|
epm_state_dict = epm.state_dict()
|
|
self.assertEqual(
|
|
epm_state_dict["linear1.weight"].untyped_storage(),
|
|
epm_state_dict["linear2.weight"].untyped_storage(),
|
|
)
|
|
self.assertEqual(
|
|
epm_state_dict["linear1.bias"].untyped_storage(),
|
|
epm_state_dict["linear2.bias"].untyped_storage(),
|
|
)
|
|
self.assertEqual(
|
|
epm.const1.untyped_storage(),
|
|
epm.const2.untyped_storage(),
|
|
)
|
|
# verify const1 and const2 share the same storage
|
|
epm.const1[0][0] = 123
|
|
self.assertEqual(epm.const2[0], 123)
|
|
epm.const2[-1] = 321
|
|
self.assertEqual(epm.const1[-1][-1], 321)
|
|
|
|
def test_storage_offset(self) -> None:
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.const = torch.arange(8)[:4]
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x) + self.const
|
|
|
|
m = M()
|
|
sample_inputs = (torch.randn(1, 4),)
|
|
ep = torch.export.export(m, sample_inputs)
|
|
buffer = io.BytesIO()
|
|
save(ep, buffer)
|
|
buffer.seek(0)
|
|
loaded_ep = load(buffer)
|
|
self.assertEqual(m(*sample_inputs), loaded_ep.module()(*sample_inputs))
|
|
|
|
def test_1D_tensor_slicing(self) -> None:
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.const = torch.arange(8)[::2]
|
|
|
|
def forward(self, x):
|
|
return x + self.const
|
|
|
|
m = M()
|
|
sample_inputs = (torch.randn(4),)
|
|
ep = torch.export.export(m, sample_inputs)
|
|
buffer = io.BytesIO()
|
|
save(ep, buffer)
|
|
buffer.seek(0)
|
|
loaded_ep = load(buffer)
|
|
self.assertEqual(m(*sample_inputs), loaded_ep.module()(*sample_inputs))
|
|
|
|
def test_2D_tensor_slicing(self) -> None:
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.const = torch.randn(4, 4)[:2, :2]
|
|
|
|
def forward(self, x):
|
|
return x + self.const
|
|
|
|
m = M()
|
|
sample_inputs = (torch.randn(2, 2),)
|
|
ep = torch.export.export(m, sample_inputs)
|
|
buffer = io.BytesIO()
|
|
save(ep, buffer)
|
|
buffer.seek(0)
|
|
loaded_ep = load(buffer)
|
|
self.assertEqual(m(*sample_inputs), loaded_ep.module()(*sample_inputs))
|
|
|
|
def test_non_float_weight(self) -> None:
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.p = torch.nn.Parameter(
|
|
torch.ones(2, 2, dtype=torch.int8), requires_grad=False
|
|
)
|
|
|
|
def forward(self, x):
|
|
return x + self.p
|
|
|
|
m = M()
|
|
sample_inputs = (torch.randn(2, 2),)
|
|
ep = torch.export.export(m, sample_inputs)
|
|
buffer = io.BytesIO()
|
|
save(ep, buffer)
|
|
buffer.seek(0)
|
|
loaded_ep = load(buffer)
|
|
self.assertEqual(m(*sample_inputs), loaded_ep.module()(*sample_inputs))
|
|
|
|
@requires_gpu
|
|
def test_weight_sharing_gpu(self) -> None:
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.c2 = torch.ones(2, 4, device=GPU_TYPE)
|
|
self.c1 = self.c2[0, :]
|
|
self.linear = torch.nn.Linear(4, 4)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x) + self.c1 + self.c2
|
|
|
|
m = M().to(GPU_TYPE)
|
|
sample_inputs = (torch.randn(2, 4, device=GPU_TYPE),)
|
|
ep = torch.export.export(m, sample_inputs)
|
|
# Check that c1 and c2 share the same storage
|
|
self.assertEqual(
|
|
ep.constants["c1"].untyped_storage(), ep.constants["c2"].untyped_storage()
|
|
)
|
|
buffer = io.BytesIO()
|
|
save(ep, buffer)
|
|
buffer.seek(0)
|
|
loaded_ep = load(buffer)
|
|
# Check that c1 and c2 share the same storage after serdes
|
|
self.assertEqual(
|
|
loaded_ep.constants["c1"].untyped_storage(),
|
|
loaded_ep.constants["c2"].untyped_storage(),
|
|
)
|
|
self.assertEqual(m(*sample_inputs), loaded_ep.module()(*sample_inputs))
|
|
|
|
def test_complex_constant(self) -> None:
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
s = torch.sin(x)
|
|
y = (1 + 1j) * s
|
|
z = 1j * s
|
|
return y, z
|
|
|
|
m = M()
|
|
sample_inputs = (torch.randn(2, 2),)
|
|
ep = torch.export.export(m, sample_inputs)
|
|
buffer = io.BytesIO()
|
|
save(ep, buffer)
|
|
buffer.seek(0)
|
|
loaded_ep = load(buffer)
|
|
self.assertEqual(m(*sample_inputs), loaded_ep.module()(*sample_inputs))
|
|
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Windows not supported for this test")
|
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
|
class TestDeserialize(TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
init_torchbind_implementations()
|
|
|
|
def _check_graph_nodes(self, gm1, gm2, _check_meta=True):
|
|
# TODO: The _check_meta flag bypasses checking for
|
|
# source_fn/nn_module_stack as there is an issue with
|
|
# roundtripping the source_fn value on torch.ops.map nodes
|
|
# original source_fn: <functorch.experimental._map.MapWrapper object at 0x7f80a0549930>
|
|
# deserialized source_fn: 'functorch.experimental._map.map'
|
|
|
|
self.assertEqual(len(gm1.graph.nodes), len(gm2.graph.nodes))
|
|
|
|
for node1, node2 in zip(gm1.graph.nodes, gm2.graph.nodes):
|
|
self.assertEqual(node1.op, node2.op)
|
|
if node1.op == "call_function":
|
|
# Check "val" metadata
|
|
val1 = node1.meta.get("val", None)
|
|
val2 = node2.meta.get("val", None)
|
|
self.assertEqual(len(node1.args), len(node2.args))
|
|
self.assertEqual(set(node1.kwargs.keys()), set(node2.kwargs.keys()))
|
|
if val1 is None or val2 is None:
|
|
# Either both are None
|
|
self.assertEqual(val1, val2)
|
|
elif isinstance(val1, FakeTensor) and isinstance(val2, FakeTensor):
|
|
# Or both are fake tensors with the same shape/dtype
|
|
self.assertEqual(len(val1.shape), len(val2.shape))
|
|
for s1, s2 in zip(val1.shape, val2.shape):
|
|
if is_concrete_int(s1) and is_concrete_int(s2):
|
|
self.assertEqual(s1, s2)
|
|
else:
|
|
self.assertEqual(str(s1), str(s2))
|
|
self.assertEqual(val1.dtype, val2.dtype)
|
|
elif isinstance(val1, (list, tuple)) and isinstance(
|
|
val2, (list, tuple)
|
|
):
|
|
# Or both are fake tensors lists with one element and with the
|
|
# same shape/dtype
|
|
for v1, v2 in zip(
|
|
pytree.tree_leaves(val1), pytree.tree_leaves(val2)
|
|
):
|
|
if isinstance(v1, FakeTensor):
|
|
self.assertEqual(v1.shape, v2.shape)
|
|
self.assertEqual(v1.dtype, v2.dtype)
|
|
else:
|
|
# For expressions like 's0 < 10' can only compare through string
|
|
self.assertEqual(str(val1), str(val2))
|
|
|
|
# Check "stack_trace" metadata
|
|
self.assertEqual(
|
|
node1.meta.get("stack_trace", None),
|
|
node2.meta.get("stack_trace", None),
|
|
)
|
|
|
|
if node1.target == torch.ops.higher_order.cond:
|
|
true_graph1 = getattr(gm1, node1.args[1].target)
|
|
true_graph2 = getattr(gm2, node2.args[1].target)
|
|
self._check_graph_nodes(true_graph1, true_graph2)
|
|
|
|
false_graph1 = getattr(gm1, node1.args[2].target)
|
|
false_graph2 = getattr(gm2, node2.args[2].target)
|
|
self._check_graph_nodes(false_graph1, false_graph2)
|
|
elif node1.target == torch.ops.higher_order.map_impl:
|
|
map_graph1 = getattr(gm1, node1.args[0].target)
|
|
map_graph2 = getattr(gm2, node2.args[0].target)
|
|
self._check_graph_nodes(map_graph1, map_graph2, False)
|
|
|
|
if _check_meta and node1.op not in ("get_attr", "placeholder", "output"):
|
|
# Check "nn_module_stack" metadata
|
|
self.assertEqual(
|
|
node1.meta.get("nn_module_stack", None),
|
|
node2.meta.get("nn_module_stack", None),
|
|
)
|
|
# Check "source_fn_stack" metadata
|
|
self.assertEqual(
|
|
node1.meta.get("source_fn_stack", None),
|
|
node2.meta.get("source_fn_stack", None),
|
|
)
|
|
|
|
def check_graph(
|
|
self,
|
|
fn,
|
|
inputs,
|
|
dynamic_shapes=None,
|
|
_check_meta=True,
|
|
use_pre_dispatch=True,
|
|
strict=True,
|
|
) -> None:
|
|
"""Export a graph, serialize it, deserialize it, and compare the results."""
|
|
|
|
def _deepcopy_inputs(inputs):
|
|
# copy.deepcopy(deepcopy) can fail if tensor inputs have attribute (i.e. __dict__).
|
|
# we remove __dict__ when deepcopying.
|
|
dict_mapping = dict()
|
|
inputs_clone = ()
|
|
for idx, i in enumerate(inputs):
|
|
if isinstance(i, torch.Tensor) and hasattr(inputs[0], "__dict__"):
|
|
dict_mapping[idx] = i.__dict__
|
|
i.__dict__ = {}
|
|
inputs_clone += (copy.deepcopy(i),)
|
|
|
|
# Add __dict__ back.
|
|
for k, v in dict_mapping.items():
|
|
inputs[k].__dict__ = v
|
|
inputs_clone[k].__dict__ = v
|
|
return inputs_clone
|
|
|
|
def _check_graph(pre_dispatch):
|
|
if pre_dispatch:
|
|
ep = torch.export.export(
|
|
fn,
|
|
_deepcopy_inputs(inputs),
|
|
{},
|
|
dynamic_shapes=dynamic_shapes,
|
|
strict=strict,
|
|
)
|
|
else:
|
|
# We should have this branch because
|
|
# PT2 Inference goes through this private
|
|
# export API.
|
|
ep = torch.export._trace._export(
|
|
fn,
|
|
_deepcopy_inputs(inputs),
|
|
{},
|
|
dynamic_shapes=dynamic_shapes,
|
|
strict=strict,
|
|
pre_dispatch=False,
|
|
)
|
|
ep.graph.eliminate_dead_code()
|
|
|
|
serialized_artifact = serialize(ep, opset_version={"aten": 0})
|
|
deserialized_ep = deserialize(
|
|
serialized_artifact, expected_opset_version={"aten": 0}
|
|
)
|
|
deserialized_ep.graph.eliminate_dead_code()
|
|
|
|
orig_outputs = ep.module()(*_deepcopy_inputs(inputs))
|
|
loaded_outputs = deserialized_ep.module()(*_deepcopy_inputs(inputs))
|
|
|
|
flat_orig_outputs = pytree.tree_leaves(orig_outputs)
|
|
flat_loaded_outputs = pytree.tree_leaves(loaded_outputs)
|
|
|
|
for orig, loaded in zip(flat_orig_outputs, flat_loaded_outputs):
|
|
self.assertEqual(type(orig), type(loaded))
|
|
# torch.allclose doesn't work for float8
|
|
if isinstance(orig, torch.Tensor) and orig.dtype not in [
|
|
torch.float8_e4m3fn,
|
|
torch.float8_e5m2,
|
|
]:
|
|
if orig.is_meta:
|
|
self.assertEqual(orig, loaded)
|
|
else:
|
|
self.assertTrue(torch.allclose(orig, loaded))
|
|
else:
|
|
self.assertEqual(orig, loaded)
|
|
self._check_graph_nodes(
|
|
ep.graph_module, deserialized_ep.graph_module, _check_meta
|
|
)
|
|
|
|
if use_pre_dispatch:
|
|
_check_graph(pre_dispatch=True)
|
|
_check_graph(pre_dispatch=False)
|
|
else:
|
|
_check_graph(pre_dispatch=False)
|
|
|
|
def test_optional_tuple(self):
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch.library.define(
|
|
"mylib::foo",
|
|
"(Tensor a, Tensor b, Tensor? c) -> (Tensor, Tensor?)",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
|
|
@torch.library.impl("mylib::foo", "cpu", lib=lib)
|
|
@torch.library.register_fake("mylib::foo")
|
|
def foo_impl(a, b, c):
|
|
res2 = None
|
|
if c is not None:
|
|
res2 = c + a + b
|
|
return a + b, res2
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, a, b, c):
|
|
return torch.ops.mylib.foo(a, b, c)
|
|
|
|
self.check_graph(M(), (torch.randn(3), torch.randn(3), torch.randn(3)))
|
|
|
|
def test_unbacked_bindings_serialize(self):
|
|
from torch._export.utils import _get_shape_env_from_gm
|
|
from torch.utils._sympy.symbol import prefix_str, symbol_is_type, SymT
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
x += 2
|
|
n = x.item()
|
|
n = n * 2 + y.item()
|
|
return n + 2
|
|
|
|
inps = (
|
|
torch.tensor(4),
|
|
torch.tensor(5),
|
|
)
|
|
for _strict in [True, False]:
|
|
ep = torch.export.export(M(), inps, strict=_strict).run_decompositions()
|
|
|
|
# check bindings after deserialization
|
|
buffer = io.BytesIO()
|
|
save(ep, buffer)
|
|
buffer.seek(0)
|
|
loaded_ep = load(buffer)
|
|
bound = set()
|
|
for old_node, new_node in zip(ep.graph.nodes, loaded_ep.graph.nodes):
|
|
self.assertEqual(
|
|
"unbacked_bindings" in old_node.meta,
|
|
"unbacked_bindings" in new_node.meta,
|
|
)
|
|
bound.update(new_node.meta.get("unbacked_bindings", {}))
|
|
|
|
# check ShapeEnv counters
|
|
shape_env = _get_shape_env_from_gm(loaded_ep.graph_module)
|
|
next_index = shape_env.unbacked_symint_counter
|
|
shape_env.unbacked_symint_counter += 1
|
|
for symbol in bound:
|
|
self.assertTrue(symbol_is_type(symbol, SymT.UNBACKED_INT))
|
|
self.assertTrue(
|
|
int(str(symbol)[len(prefix_str[SymT.UNBACKED_INT]) :]) < next_index
|
|
)
|
|
|
|
def test_sym_bool_dynamic_shapes(self) -> None:
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y):
|
|
z = x[:, -y.shape[0] :, :]
|
|
return z
|
|
|
|
inputs = (torch.ones(4, 5, 10), torch.ones(3))
|
|
dynamic_shapes = {"x": {}, "y": {0: Dim("seqlen", max=4)}}
|
|
# Compile with dynamic_shapes set to get operator.neg involved
|
|
self.check_graph(MyModule(), inputs, dynamic_shapes=dynamic_shapes)
|
|
|
|
def test_auto_functionalize(self):
|
|
with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
|
|
torch.library.define(
|
|
"mylib::foo1",
|
|
"(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> Tensor",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
torch.library.define(
|
|
"mylib::foo2",
|
|
"(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> (Tensor, Tensor)",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
torch.library.define(
|
|
"mylib::foo3",
|
|
"(Tensor(a!) x, Tensor[] y, Tensor(b!) z, SymInt w, Tensor n) -> ()",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
lib=lib,
|
|
)
|
|
|
|
@torch.library.impl("mylib::foo1", "cpu", lib=lib)
|
|
@torch.library.register_fake("mylib::foo1")
|
|
def foo1_impl(x, y, z, w, n):
|
|
x.add_(y[0] + w)
|
|
z.add_(y[1] + n)
|
|
return n + n
|
|
|
|
@torch.library.impl("mylib::foo2", "cpu", lib=lib)
|
|
@torch.library.register_fake("mylib::foo2")
|
|
def foo2_impl(x, y, z, w, n):
|
|
x.add_(y[0] + w)
|
|
z.add_(y[1] + n)
|
|
return (n + n, n * n)
|
|
|
|
@torch.library.impl("mylib::foo3", "cpu", lib=lib)
|
|
@torch.library.register_fake("mylib::foo3")
|
|
def foo3_impl(x, y, z, w, n):
|
|
x.add_(y[0] + w)
|
|
z.add_(y[1] + n)
|
|
return
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y, z, n):
|
|
n = torch.ops.mylib.foo1(x, y, z, 2, n)
|
|
torch.ops.mylib.foo3(x, y, z, 2, n)
|
|
return torch.ops.mylib.foo2(x, y, z, 2, n)
|
|
|
|
x = torch.randn(3)
|
|
y = (torch.randn(3), torch.randn(3))
|
|
z = torch.randn(3)
|
|
n = torch.randn(3)
|
|
orig_args = (x, y, z, n)
|
|
|
|
# TODO Auto_functionalize is not supported on pre_dispatch IR
|
|
self.check_graph(M(), orig_args, use_pre_dispatch=False)
|
|
|
|
def test_hoo_symint_input(self):
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, a, b, c):
|
|
num = c.item()
|
|
return torch.cond(
|
|
pred=torch.tensor([True]),
|
|
true_fn=lambda a, b: a + b + num,
|
|
false_fn=lambda a, b: a - b - num,
|
|
operands=(a, b),
|
|
)
|
|
|
|
inp = (torch.ones(3, 3), torch.ones(3, 3), torch.tensor(2))
|
|
self.check_graph(Mod(), inp, use_pre_dispatch=False)
|
|
|
|
def test_none_input(self):
|
|
"""
|
|
Testing a backwards-compatibility breakage where old models do not have
|
|
an input spec with the node name.
|
|
"""
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
return x + z
|
|
|
|
ep = torch.export.export(M(), (torch.ones(3, 3), None, torch.ones(3, 3)))
|
|
|
|
serialized_program = ExportedProgramSerializer(None, 2).serialize(ep)
|
|
serialized_program.exported_program.graph_module.signature.input_specs[1] = (
|
|
schema.InputSpec.create(
|
|
user_input=schema.UserInputSpec(
|
|
arg=schema.Argument.create(as_none=True)
|
|
)
|
|
)
|
|
)
|
|
ep = ExportedProgramDeserializer(None).deserialize(
|
|
serialized_program.exported_program, {}, {}, {}
|
|
)
|
|
ep.graph_module.recompile()
|
|
unflattened = torch.export.unflatten(ep)
|
|
inp = (torch.rand(3, 3), None, torch.rand(3, 3))
|
|
self.assertEqual(unflattened(*inp), M()(*inp))
|
|
|
|
def test_multi_return(self) -> None:
|
|
"""
|
|
Test multiple return from a single node (ex. layer_norm has 2 outputs)
|
|
"""
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, w, b):
|
|
return torch.nn.functional.layer_norm(
|
|
x,
|
|
x.size()[1:],
|
|
weight=w,
|
|
bias=b,
|
|
eps=1e-5,
|
|
)
|
|
|
|
inputs = (
|
|
torch.ones([512, 512], requires_grad=True),
|
|
torch.ones([512]),
|
|
torch.ones([512]),
|
|
)
|
|
self.check_graph(MyModule(), inputs)
|
|
|
|
def test_basic(self) -> None:
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
x = x + x
|
|
x = x * x
|
|
x = x / x
|
|
return x, x.clone()
|
|
|
|
inputs = (torch.ones([512], requires_grad=True),)
|
|
self.check_graph(MyModule(), inputs)
|
|
|
|
def test_dynamic(self) -> None:
|
|
class DynamicShapeSimpleModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, a, b, c) -> torch.Tensor:
|
|
d = (torch.matmul(a, b) + c) / 2
|
|
d_s0 = d.shape[0]
|
|
d_s1 = d.shape[1]
|
|
d_s3 = d_s0 * d_s1
|
|
e = d.view(d_s3)
|
|
return torch.cat([e, e])
|
|
|
|
inputs = (torch.randn(2, 4), torch.randn(4, 7), torch.randn(2, 7))
|
|
dim0_ac = torch.export.Dim("dim0_ac")
|
|
dynamic_shapes = {"a": {0: dim0_ac}, "b": None, "c": {0: dim0_ac}}
|
|
self.check_graph(DynamicShapeSimpleModel(), inputs, dynamic_shapes)
|
|
|
|
def test_sym_bool(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
assert x.size(0) in y
|
|
return x + y
|
|
|
|
f = Module()
|
|
self.check_graph(f, (torch.ones(1), torch.ones(3)))
|
|
|
|
def test_sym_bool_torch_check_equal(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = x.nonzero()
|
|
z = y.size(0)
|
|
torch._check(z == 2)
|
|
return y
|
|
|
|
self.check_graph(Module(), (torch.Tensor([1, 0, 1, 0]),))
|
|
|
|
def test_sym_int_torch_check_equal(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x):
|
|
y = x.nonzero()
|
|
z = y.size(0)
|
|
torch._check(z % 3 == 0)
|
|
torch._check(z == 3)
|
|
return y
|
|
|
|
self.check_graph(Module(), (torch.Tensor([1, 0, 1, 0, 1, 0]),))
|
|
|
|
def test_shape(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
z, y = x.size()
|
|
return z + y + x[0], z
|
|
|
|
inputs = (torch.ones(2, 3),)
|
|
dim0_x, dim1_x = torch.export.dims("dim0_x", "dim1_x")
|
|
dynamic_shapes = {"x": (dim0_x, dim1_x)}
|
|
self.check_graph(Foo(), inputs, dynamic_shapes)
|
|
|
|
def test_module(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear1 = torch.nn.Linear(3, 3)
|
|
self.relu = torch.nn.ReLU()
|
|
self.linear2 = torch.nn.Linear(3, 5)
|
|
|
|
def forward(self, x):
|
|
x = self.linear1(x)
|
|
x = self.linear1(x)
|
|
x = torch.nn.functional.relu(x)
|
|
x = self.linear2(x)
|
|
return x
|
|
|
|
inputs = (torch.randn(3, 3),)
|
|
self.check_graph(M(), inputs)
|
|
|
|
def test_module_meta(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.p = torch.nn.Parameter(torch.ones(3, 3))
|
|
|
|
def forward(self, x):
|
|
return self.p + x
|
|
|
|
with torch.device("meta"):
|
|
mod = M()
|
|
|
|
inputs = (torch.randn(3, 3, device="meta"),)
|
|
self.check_graph(mod, inputs)
|
|
|
|
def test_pytree_namedtuple(self):
|
|
N1 = namedtuple("N1", ["a", "b"])
|
|
|
|
class N2(NamedTuple):
|
|
a: torch.Tensor
|
|
b: torch.Tensor
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return N2(x.a + y.a, x.b * y.b)
|
|
|
|
pytree._register_namedtuple(
|
|
N1,
|
|
serialized_type_name="test.export.test_serialize.test_pytree_namedtuple.N1",
|
|
)
|
|
pytree._register_namedtuple(
|
|
N2,
|
|
serialized_type_name="test.export.test_serialize.test_pytree_namedtuple.N2",
|
|
)
|
|
|
|
inp = (N1(torch.randn(3), torch.randn(3)), N1(torch.randn(3), torch.randn(3)))
|
|
ep = torch.export.export(M(), inp)
|
|
ep.example_inputs = None # Can't pickle the input since the namedtuple class is not at a global namespace
|
|
serialized = ExportedProgramSerializer().serialize(ep)
|
|
self.assertEqual(
|
|
len(serialized.exported_program.graph_module.treespec_namedtuple_fields), 2
|
|
)
|
|
deserialized = ExportedProgramDeserializer().deserialize(
|
|
serialized.exported_program,
|
|
serialized.state_dict,
|
|
serialized.constants,
|
|
)
|
|
self.assertTrue("treespec_namedtuple_fields" in deserialized.graph_module.meta)
|
|
self.assertEqual(
|
|
deserialized.graph_module.meta["treespec_namedtuple_fields"],
|
|
{
|
|
"test.export.test_serialize.test_pytree_namedtuple.N1": ["a", "b"],
|
|
"test.export.test_serialize.test_pytree_namedtuple.N2": ["a", "b"],
|
|
},
|
|
)
|
|
|
|
unlifted = deserialized.module()
|
|
self.assertTrue("treespec_namedtuple_fields" in unlifted.meta)
|
|
self.assertEqual(len(unlifted.meta["treespec_namedtuple_fields"]), 2)
|
|
|
|
unflattened = unflatten(deserialized)
|
|
self.assertTrue("treespec_namedtuple_fields" in unflattened.meta)
|
|
self.assertEqual(len(unflattened.meta["treespec_namedtuple_fields"]), 2)
|
|
|
|
def test_cond(self):
|
|
from functorch.experimental.control_flow import cond
|
|
|
|
inputs = torch.ones(4, 3), torch.zeros(4, 3)
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
def t(x, y):
|
|
return x + y
|
|
|
|
def f(x, y):
|
|
return x - y
|
|
|
|
return cond(x[0][0] > 4, t, f, [x, y])
|
|
|
|
self.check_graph(M(), inputs)
|
|
|
|
def test_sym_float(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
b = x.item()
|
|
return b * 0.1
|
|
|
|
self.check_graph(M(), (torch.tensor(1.0),))
|
|
|
|
def test_arg_from(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.register_buffer("compress_weight", torch.ones((10, 10)))
|
|
self.register_buffer("compress_bias", torch.ones(10))
|
|
|
|
def forward(self) -> None:
|
|
if self.compress_weight is None or self.compress_bias is None:
|
|
return
|
|
torch.nn.init.kaiming_uniform_(self.compress_weight, a=math.sqrt(5))
|
|
fan_in, _ = torch.nn.init._calculate_fan_in_and_fan_out(
|
|
self.compress_weight
|
|
)
|
|
bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0
|
|
torch.nn.init.uniform_(self.compress_bias, -bound, bound)
|
|
|
|
with torch.no_grad():
|
|
self.check_graph(M(), ())
|
|
|
|
def test_map(self):
|
|
from functorch.experimental import control_flow
|
|
|
|
def f(x, y):
|
|
return x + y
|
|
|
|
class Module(torch.nn.Module):
|
|
def forward(self, xs, y):
|
|
return control_flow.map(f, xs, y)
|
|
|
|
g = Module()
|
|
inputs = (torch.ones(3, 2, 2), torch.ones(2))
|
|
self.check_graph(g, inputs, _check_meta=False)
|
|
|
|
def test_positional_argument_with_default_value(self):
|
|
class MyLinear(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.weight = torch.randn(10, 10)
|
|
self.bias = torch.randn(10)
|
|
|
|
def forward(self, x):
|
|
# bias has an default value here but it should be preserved
|
|
# as a positional argument.
|
|
return torch.ops.aten.linear.default(x, self.weight, self.bias)
|
|
|
|
self.check_graph(MyLinear(), (torch.randn(10, 10),))
|
|
|
|
def test_tensor_tensor_list(self):
|
|
with torch.library._scoped_library("_export", "FRAGMENT") as lib:
|
|
lib.define(
|
|
"_test_tensor_tensor_list_output(Tensor x, Tensor y) -> (Tensor, Tensor[])",
|
|
tags=torch.Tag.pt2_compliant_tag,
|
|
)
|
|
|
|
def _test_tensor_tensor_list_output(x, y):
|
|
return y, [x]
|
|
|
|
lib.impl(
|
|
"_test_tensor_tensor_list_output",
|
|
_test_tensor_tensor_list_output,
|
|
"CPU",
|
|
)
|
|
lib.impl(
|
|
"_test_tensor_tensor_list_output",
|
|
_test_tensor_tensor_list_output,
|
|
"Meta",
|
|
)
|
|
|
|
class M(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
a, b = torch.ops._export._test_tensor_tensor_list_output.default(
|
|
x, y
|
|
)
|
|
return a + b[0]
|
|
|
|
self.check_graph(M(), (torch.rand(3, 2), torch.rand(3, 2)))
|
|
|
|
def test_list_of_optional_tensors(self) -> None:
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, x, y, z):
|
|
indices = [None, None, torch.tensor([1, 3, 5, 7])]
|
|
indexed = torch.ops.aten.index.Tensor(x + y, indices)
|
|
return indexed + z
|
|
|
|
inputs = (torch.rand(8, 8, 8), torch.rand(8, 8, 8), torch.rand(8, 8, 4))
|
|
self.check_graph(MyModule(), inputs)
|
|
|
|
def test_sym_ite(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
b = x.shape[0] == 5
|
|
ret = torch.sym_ite(b, x.shape[0], x.shape[1])
|
|
return ret
|
|
|
|
dynamic_shapes = {"x": {0: Dim("dim0"), 1: Dim("dim1")}}
|
|
self.check_graph(Foo(), (torch.ones(4, 5),), dynamic_shapes=dynamic_shapes)
|
|
|
|
def test_multiple_getitem(self):
|
|
class M(torch.nn.Module):
|
|
def forward(self, x):
|
|
a, b = torch.topk(x, 2)
|
|
a = a * 2
|
|
return a, b
|
|
|
|
ep = torch.export.export(M(), (torch.ones(3),), strict=True)
|
|
|
|
# insert another getitem node
|
|
for node in ep.graph.nodes:
|
|
if node.op == "call_function" and node.target == torch.ops.aten.mul.Tensor:
|
|
getitem_0 = node.args[0]
|
|
with ep.graph.inserting_before(getitem_0):
|
|
getitem_copy = ep.graph.node_copy(getitem_0)
|
|
mul_node = ep.graph.call_function(
|
|
torch.ops.aten.mul.Tensor, (getitem_copy, 2)
|
|
)
|
|
mul_node.meta = copy.copy(getitem_copy.meta)
|
|
node.args = (getitem_0, mul_node)
|
|
|
|
deserialized_ep = deserialize(serialize(ep))
|
|
|
|
inp = (torch.randn(3),)
|
|
orig_res = ep.module()(*inp)
|
|
res = deserialized_ep.module()(*inp)
|
|
self.assertTrue(torch.allclose(orig_res[0], res[0]))
|
|
self.assertTrue(torch.allclose(orig_res[1], res[1]))
|
|
|
|
# The deserialized graph should have deduped getitem calls
|
|
self.assertExpectedInline(
|
|
deserialized_ep.graph_module.code.strip("\n"),
|
|
"""\
|
|
def forward(self, x):
|
|
topk_default = torch.ops.aten.topk.default(x, 2); x = None
|
|
getitem = topk_default[0]
|
|
getitem_1 = topk_default[1]; topk_default = None
|
|
mul_tensor = torch.ops.aten.mul.Tensor(getitem, 2)
|
|
mul = torch.ops.aten.mul.Tensor(getitem, mul_tensor); getitem = mul_tensor = None
|
|
return (mul, getitem_1)
|
|
""",
|
|
)
|
|
|
|
@parametrize(
|
|
"name,case",
|
|
get_filtered_export_db_tests(),
|
|
name_fn=lambda name, case: f"case_{name}",
|
|
)
|
|
def test_exportdb_supported(self, name: str, case: ExportCase) -> None:
|
|
model = case.model
|
|
_check_meta = "map" not in name
|
|
with torch._export.config.patch(use_new_tracer_experimental=True):
|
|
self.check_graph(model, case.example_args, _check_meta=_check_meta)
|
|
|
|
def test_constraints(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
n = x.item()
|
|
torch._check(n >= 0)
|
|
return y.sum() + torch.ones(n, 5).sum()
|
|
|
|
f = Module()
|
|
self.check_graph(f, (torch.tensor(3), torch.randn(4, 5)))
|
|
|
|
def test_get_attr(self) -> None:
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + torch.tensor(3)
|
|
|
|
f = Module()
|
|
self.check_graph(f, (torch.tensor(3),))
|
|
|
|
def test_get_attr_list(self) -> None:
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.cat([x, torch.tensor([1, 1])])
|
|
|
|
f = Module()
|
|
self.check_graph(f, (torch.tensor([1, 1]),))
|
|
|
|
@unittest.skipIf(not torch.cuda.is_available(), "Requires cuda")
|
|
def test_device(self) -> None:
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = torch.nn.Conv2d(3, 16, 3, stride=1, bias=True)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
conv = self.conv(x)
|
|
relu = self.relu(conv)
|
|
mul = relu * 0.5
|
|
return mul
|
|
|
|
inp = torch.randn((1, 3, 224, 224), dtype=torch.float).to("cuda")
|
|
model = MyModule().eval().cuda()
|
|
self.check_graph(model, (inp,))
|
|
|
|
def test_custom_obj_tuple_out(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
a = torch.ops._TorchScriptTesting.takes_foo_tuple_return(self.attr, x)
|
|
y = a[0] + a[1]
|
|
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
|
|
return x + b
|
|
|
|
m = MyModule()
|
|
inputs = (torch.ones(2, 3),)
|
|
self.check_graph(m, inputs, strict=False)
|
|
|
|
def test_custom_obj(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
a = torch.ops._TorchScriptTesting.takes_foo(self.attr, x)
|
|
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, a)
|
|
return x + b
|
|
|
|
m = MyModule()
|
|
inputs = (torch.ones(2, 3),)
|
|
self.check_graph(m, inputs, strict=False)
|
|
|
|
def test_custom_obj_list_out(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
a = torch.ops._TorchScriptTesting.takes_foo_list_return(self.attr, x)
|
|
y = a[0] + a[1] + a[2]
|
|
b = torch.ops._TorchScriptTesting.takes_foo(self.attr, y)
|
|
return x + b
|
|
|
|
m = MyModule()
|
|
inputs = (torch.ones(2, 3),)
|
|
self.check_graph(m, inputs, strict=False)
|
|
|
|
def test_export_no_inputs(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.p = torch.ones(3, 3)
|
|
|
|
def forward(self):
|
|
return self.p * self.p
|
|
|
|
ep = torch.export.export(M(), (), strict=True)
|
|
ep._example_inputs = None
|
|
roundtrip_ep = deserialize(serialize(ep))
|
|
self.assertTrue(torch.allclose(ep.module()(), roundtrip_ep.module()()))
|
|
|
|
def test_serialize_float8(self):
|
|
for dtype in [torch.float8_e5m2, torch.float8_e4m3fn]:
|
|
|
|
class MyModule(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x.to(dtype)
|
|
|
|
m = MyModule()
|
|
inputs = (torch.ones(2, 3),)
|
|
self.check_graph(m, inputs, strict=False)
|
|
|
|
def test_forward_compatibility(self):
|
|
self.assertEqual(
|
|
schema.TensorArgument(
|
|
name="x",
|
|
),
|
|
_dict_to_dataclass(
|
|
schema.TensorArgument,
|
|
{
|
|
"shiny_new_field": "hello world",
|
|
"name": "x",
|
|
},
|
|
),
|
|
)
|
|
|
|
|
|
instantiate_parametrized_tests(TestDeserialize)
|
|
|
|
|
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
|
class TestSchemaVersioning(TestCase):
|
|
def test_error(self):
|
|
class Module(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + x
|
|
|
|
f = Module()
|
|
ep = export(f, (torch.randn(1, 3),), strict=True)
|
|
|
|
serialized_program = ExportedProgramSerializer().serialize(ep)
|
|
serialized_program.exported_program.schema_version.major = -1
|
|
with self.assertRaisesRegex(
|
|
SerializeError, r"Serialized schema version .* does not match our current"
|
|
):
|
|
ExportedProgramDeserializer().deserialize(
|
|
serialized_program.exported_program,
|
|
serialized_program.state_dict,
|
|
serialized_program.constants,
|
|
serialized_program.example_inputs,
|
|
)
|
|
|
|
|
|
# We didn't set up kwargs input yet
|
|
unittest.expectedFailure(TestDeserialize.test_exportdb_supported_case_fn_with_kwargs)
|
|
|
|
|
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
|
class TestSaveLoad(TestCase):
|
|
def test_save_buffer(self):
|
|
inp = (torch.tensor([0.1, 0.1]),)
|
|
|
|
class Module(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
|
|
def forward(self, x):
|
|
x = x + 1
|
|
y = x.t()
|
|
y = y.relu()
|
|
y = self.linear(y)
|
|
return y
|
|
|
|
ep = export(Module(), inp, strict=True)
|
|
|
|
buffer = io.BytesIO()
|
|
save(ep, buffer)
|
|
buffer.seek(0)
|
|
loaded_ep = load(buffer)
|
|
|
|
self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp)))
|
|
|
|
@unittest.skipIf(IS_WINDOWS, "Cannot modify file in windows")
|
|
def test_save_file(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x * x
|
|
|
|
f = Foo()
|
|
|
|
inp = (torch.randn(2, 2),)
|
|
ep = export(f, inp, strict=True)
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
|
save(ep, f.name)
|
|
f.seek(0)
|
|
loaded_ep = load(f.name)
|
|
|
|
self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp)))
|
|
|
|
def test_save_path(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
return x + y
|
|
|
|
f = Foo()
|
|
|
|
inp = (torch.tensor([6]), torch.tensor([7]))
|
|
ep = export(f, inp, strict=True)
|
|
|
|
with TemporaryFileName(suffix=".pt2") as fname:
|
|
path = Path(fname)
|
|
save(ep, path)
|
|
loaded_ep = load(path)
|
|
|
|
self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp)))
|
|
|
|
def test_save_extra(self):
|
|
inp = (torch.tensor([0.1, 0.1]),)
|
|
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x * x + x
|
|
|
|
f = Foo()
|
|
|
|
ep = export(f, inp, strict=True)
|
|
|
|
buffer = io.BytesIO()
|
|
save(ep, buffer, extra_files={"extra.txt": "moo"})
|
|
buffer.seek(0)
|
|
extra_files = {"extra.txt": ""}
|
|
loaded_ep = load(buffer, extra_files=extra_files)
|
|
|
|
self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp)))
|
|
self.assertEqual(extra_files["extra.txt"], "moo")
|
|
|
|
@unittest.skipIf(
|
|
IS_FBCODE or IS_MACOS or IS_WINDOWS, "The file path is different in fbcode CI"
|
|
)
|
|
def test_version_error(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + x
|
|
|
|
f = Foo()
|
|
|
|
ep = export(f, (torch.randn(1, 3),), strict=True)
|
|
|
|
with self.assertRaisesRegex(
|
|
ValueError, r"Saved archive version -1 does not match our current"
|
|
):
|
|
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
|
save(ep, f.name)
|
|
f.seek(0)
|
|
file_prefix = f.name.split("/")[2].split(".")[0]
|
|
|
|
# Create a new file and copy things over, but modify the
|
|
# archive version
|
|
with tempfile.NamedTemporaryFile(suffix=".pt2") as fnew:
|
|
with zipfile.ZipFile(f, "r") as zin:
|
|
with zipfile.ZipFile(fnew, "w") as zout:
|
|
for item in zin.infolist():
|
|
if (
|
|
item.filename
|
|
!= f"{file_prefix}/{ARCHIVE_VERSION_PATH}"
|
|
):
|
|
zout.writestr(item, zin.read(item.filename))
|
|
zout.writestr(f"{file_prefix}/{ARCHIVE_VERSION_PATH}", "-1")
|
|
|
|
f.seek(0)
|
|
load(fnew.name)
|
|
|
|
def test_save_constants(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.a = torch.tensor(3)
|
|
|
|
def forward(self, x):
|
|
list_tensor = [torch.tensor(3), torch.tensor(4)]
|
|
return x + self.a + list_tensor[0] + list_tensor[1]
|
|
|
|
ep = export(Foo(), (torch.tensor(1),), strict=True)
|
|
buffer = io.BytesIO()
|
|
save(ep, buffer)
|
|
buffer.seek(0)
|
|
loaded_ep = load(buffer)
|
|
|
|
inp = (torch.tensor(1),)
|
|
self.assertTrue(torch.allclose(ep.module()(*inp), loaded_ep.module()(*inp)))
|
|
|
|
def test_save_load_with_multiple_empty_tensors(self) -> None:
|
|
# Test scenario where models have multiple empty tensors
|
|
# but with different data types.
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.register_buffer(
|
|
"int_buffer",
|
|
torch.zeros([0], dtype=torch.uint8),
|
|
)
|
|
self.register_buffer(
|
|
"int_buffer2",
|
|
torch.zeros([0], dtype=torch.uint8),
|
|
)
|
|
self.register_buffer(
|
|
"float_buffer",
|
|
torch.zeros([0], dtype=torch.float32),
|
|
)
|
|
|
|
def forward(self, t: torch.Tensor) -> torch.Tensor:
|
|
return t + self.int_buffer + self.float_buffer + self.int_buffer2
|
|
|
|
m = M()
|
|
inp = torch.rand([0])
|
|
|
|
ep = torch.export.export(m, (inp,))
|
|
|
|
buffer = io.BytesIO()
|
|
torch.export.save(ep, buffer)
|
|
model_bytes = buffer.getvalue()
|
|
|
|
# First two buffers are duplicates, but not the third one.
|
|
# So in the serialized model, there will be two physical tensors.
|
|
self.assertTrue(b"weight_0" in model_bytes)
|
|
self.assertTrue(b"weight_1" in model_bytes)
|
|
self.assertFalse(b"weight_2" in model_bytes)
|
|
|
|
buffer = io.BytesIO(model_bytes)
|
|
buffer.seek(0)
|
|
dep = torch.export.load(buffer)
|
|
unf = torch.export.unflatten(dep)
|
|
self.assertEqual(unf.int_buffer.dtype, torch.uint8)
|
|
self.assertEqual(unf.int_buffer2.dtype, torch.uint8)
|
|
self.assertEqual(unf.float_buffer.dtype, torch.float32)
|
|
|
|
|
|
@unittest.skipIf(not torchdynamo.is_dynamo_supported(), "dynamo doesn't support")
|
|
class TestSerializeCustomClass(TestCase):
|
|
def setUp(self):
|
|
super().setUp()
|
|
init_torchbind_implementations()
|
|
|
|
def test_custom_class(self):
|
|
custom_obj = torch.classes._TorchScriptTesting._PickleTester([3, 4])
|
|
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + x
|
|
|
|
f = Foo()
|
|
|
|
inputs = (torch.zeros(4, 4),)
|
|
ep = export(f, inputs, strict=True)
|
|
|
|
# Replace one of the values with an instance of our custom class
|
|
for node in ep.graph.nodes:
|
|
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
|
|
with ep.graph.inserting_before(node):
|
|
custom_node = ep.graph.call_function(
|
|
torch.ops._TorchScriptTesting.take_an_instance.default,
|
|
(custom_obj,),
|
|
)
|
|
custom_node.meta["val"] = torch.ones(4, 4)
|
|
custom_node.meta["torch_fn"] = (
|
|
"take_an_instance",
|
|
"take_an_instance",
|
|
)
|
|
arg0, _ = node.args
|
|
node.args = (arg0, custom_node)
|
|
|
|
serialized_vals = serialize(ep)
|
|
|
|
ep_str = serialized_vals.exported_program.decode("utf-8")
|
|
assert "class_fqn" in ep_str
|
|
assert custom_obj._type().qualified_name() in ep_str
|
|
|
|
deserialized_ep = deserialize(serialized_vals)
|
|
|
|
for node in deserialized_ep.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target
|
|
== torch.ops._TorchScriptTesting.take_an_instance.default
|
|
):
|
|
arg = node.args[0]
|
|
self.assertTrue(isinstance(arg, torch._C.ScriptObject))
|
|
self.assertEqual(arg._type(), custom_obj._type())
|
|
self.assertEqual(arg.__getstate__(), custom_obj.__getstate__())
|
|
self.assertEqual(arg.top(), 7)
|
|
|
|
def test_custom_class_containing_fake_tensor(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.custom_obj = torch.classes._TorchScriptTesting._ContainsTensor(
|
|
torch.rand(2, 3)
|
|
)
|
|
|
|
def forward(self, x):
|
|
return x + self.custom_obj.get()
|
|
|
|
with FakeTensorMode():
|
|
f = Foo()
|
|
|
|
inputs = (torch.zeros(2, 3),)
|
|
with enable_torchbind_tracing():
|
|
ep = export(f, inputs, strict=False)
|
|
|
|
serialized_vals = serialize(ep)
|
|
ep = deserialize(serialized_vals)
|
|
self.assertTrue(isinstance(ep.constants["custom_obj"].get(), FakeTensor))
|
|
|
|
def test_custom_class_input_to_function(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
|
|
|
|
def forward(self, x):
|
|
return x + torch.ops._TorchScriptTesting.takes_foo(self.attr, x)
|
|
|
|
with FakeTensorMode():
|
|
f = Foo()
|
|
|
|
inputs = (torch.zeros(2, 3),)
|
|
with enable_torchbind_tracing():
|
|
ep = export(f, inputs, strict=False)
|
|
|
|
serialized_vals = serialize(ep)
|
|
ep = deserialize(serialized_vals)
|
|
self.assertExpectedInline(
|
|
str(ep.graph_module.code).strip(),
|
|
"""\
|
|
def forward(self, obj_attr, x):
|
|
takes_foo = torch.ops._TorchScriptTesting.takes_foo.default(obj_attr, x); obj_attr = None
|
|
add = torch.ops.aten.add.Tensor(x, takes_foo); x = takes_foo = None
|
|
return (add,)""",
|
|
)
|
|
self.assertTrue(isinstance(ep.constants["attr"], torch.ScriptObject))
|
|
gm = ep.module()
|
|
self.assertExpectedInline(
|
|
str(gm.code).strip(),
|
|
"""\
|
|
def forward(self, x):
|
|
x, = fx_pytree.tree_flatten_spec(([x], {}), self._in_spec)
|
|
attr = self.attr
|
|
_guards_fn = self._guards_fn(x); _guards_fn = None
|
|
takes_foo = torch.ops._TorchScriptTesting.takes_foo.default(attr, x); attr = None
|
|
add = torch.ops.aten.add.Tensor(x, takes_foo); x = takes_foo = None
|
|
return pytree.tree_unflatten((add,), self._out_spec)""",
|
|
)
|
|
self.assertTrue(isinstance(gm.attr, torch.ScriptObject))
|
|
|
|
def test_custom_tag_metadata_serialization(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + x
|
|
|
|
f = Foo()
|
|
|
|
inputs = (torch.zeros(4, 4),)
|
|
ep = export(f, inputs, strict=True)
|
|
|
|
new_gm = copy.deepcopy(ep.graph_module)
|
|
new_gm.meta["custom"] = {}
|
|
new_gm.meta["custom"]["f"] = "bar"
|
|
|
|
for node in new_gm.graph.nodes:
|
|
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
|
|
node.meta["custom"] = {}
|
|
node.meta["custom"]["quantization_tag"] = "foo"
|
|
|
|
new_ep = ep._update(new_gm, ep.graph_signature)
|
|
serialized_vals = serialize(new_ep)
|
|
new_ep = deserialize(serialized_vals)
|
|
|
|
self.assertEqual(new_ep.graph_module.meta["custom"]["f"], "bar")
|
|
counter = 0
|
|
for node in new_ep.graph.nodes:
|
|
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
|
|
counter += 1
|
|
self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo")
|
|
self.assertEqual(counter, 1)
|
|
|
|
def test_custom_tag_metadata_decomp(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
f = Foo()
|
|
|
|
inputs = (torch.ones(2, 2),)
|
|
ep = export(f, inputs, strict=True)
|
|
|
|
new_gm = copy.deepcopy(ep.graph_module)
|
|
new_gm.meta["custom"] = {}
|
|
new_gm.meta["custom"]["f"] = "bar"
|
|
|
|
counter = 0
|
|
for node in new_gm.graph.nodes:
|
|
if (
|
|
node.op == "call_function"
|
|
and node.target == torch.ops.aten.linear.default
|
|
):
|
|
counter += 1
|
|
node.meta["custom"] = {}
|
|
node.meta["custom"]["quantization_tag"] = "foo"
|
|
self.assertEqual(counter, 1)
|
|
|
|
new_ep = ep._update(new_gm, ep.graph_signature)
|
|
new_ep = new_ep.run_decompositions()
|
|
|
|
self.assertEqual(new_ep.graph_module.meta["custom"]["f"], "bar")
|
|
counter = 0
|
|
for node in new_ep.graph.nodes:
|
|
if node.op == "call_function":
|
|
counter += 1
|
|
self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo")
|
|
self.assertTrue(counter > 1)
|
|
|
|
def test_custom_tag_metadata_copy(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x):
|
|
return x + x
|
|
|
|
f = Foo()
|
|
|
|
inputs = (torch.zeros(4, 4),)
|
|
ep = export(f, inputs, strict=True)
|
|
|
|
new_gm = copy.deepcopy(ep.graph_module)
|
|
new_gm.meta["custom"] = {}
|
|
new_gm.meta["custom"]["f"] = "bar"
|
|
|
|
for node in new_gm.graph.nodes:
|
|
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
|
|
node.meta["custom"] = {}
|
|
node.meta["custom"]["quantization_tag"] = "foo"
|
|
|
|
new_gm = copy.deepcopy(new_gm)
|
|
|
|
self.assertEqual(new_gm.meta["custom"]["f"], "bar")
|
|
counter = 0
|
|
for node in new_gm.graph.nodes:
|
|
if node.op == "call_function" and node.target == torch.ops.aten.add.Tensor:
|
|
counter += 1
|
|
self.assertTrue(node.meta["custom"]["quantization_tag"] == "foo")
|
|
self.assertEqual(counter, 1)
|
|
|
|
def test_unbacked_range_serdes(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x, y):
|
|
n = x.item()
|
|
torch._check(n >= 0)
|
|
torch._check(n < y.size(0))
|
|
return torch.empty(n), y[n]
|
|
|
|
ep = torch.export.export(
|
|
Foo(),
|
|
(torch.tensor([5]), torch.randn(10)),
|
|
dynamic_shapes={
|
|
"x": None,
|
|
"y": (Dim.DYNAMIC,),
|
|
},
|
|
)
|
|
buffer = io.BytesIO()
|
|
save(ep, buffer)
|
|
buffer.seek(0)
|
|
loaded_ep = load(buffer)
|
|
|
|
# pre-serialize ep
|
|
pre_shape_env = torch._guards.detect_fake_mode(
|
|
[node.meta.get("val") for node in ep.graph.nodes]
|
|
).shape_env
|
|
post_shape_env = torch._guards.detect_fake_mode(
|
|
[node.meta.get("val") for node in loaded_ep.graph.nodes]
|
|
).shape_env
|
|
self.assertEqual(pre_shape_env.var_to_range, post_shape_env.var_to_range)
|
|
|
|
def test_backed_size_oblivious_serdes(self):
|
|
class Foo(torch.nn.Module):
|
|
def forward(self, x, y, z):
|
|
return x + y + z.item()
|
|
|
|
with torch.fx.experimental._config.patch(backed_size_oblivious=True):
|
|
ep = torch.export.export(
|
|
Foo(),
|
|
(torch.randn(1), torch.randn(1), torch.tensor([5])),
|
|
dynamic_shapes={
|
|
"x": (Dim.DYNAMIC,),
|
|
"y": (Dim.DYNAMIC,),
|
|
"z": None,
|
|
},
|
|
)
|
|
buffer = io.BytesIO()
|
|
save(ep, buffer)
|
|
buffer.seek(0)
|
|
loaded_ep = load(buffer)
|
|
shape_env = torch._guards.detect_fake_mode(
|
|
[node.meta.get("val") for node in loaded_ep.graph.nodes]
|
|
).shape_env
|
|
s0 = next(iter(ep.graph.nodes)).meta["val"].size(0)
|
|
self.assertEqual(shape_env.var_to_range[s0.node.expr].lower, 0)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|