mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Set non-strict export as default mode (#148790)
Summary: - Flip the default value of strict argument in torch.export.export from True to False - Update test infra to cope with the change, some of them made the assumption of strict mode as default - Disabled some tests that fail in non-strict mode Test Plan: Sandcastle Differential Revision: D70228628 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148790 Approved by: https://github.com/angelayi
This commit is contained in:
parent
e3ebf61589
commit
ab45aaca97
|
|
@ -40,6 +40,7 @@ class ExampleTests(TestCase):
|
||||||
args_export,
|
args_export,
|
||||||
kwargs_export,
|
kwargs_export,
|
||||||
dynamic_shapes=case.dynamic_shapes,
|
dynamic_shapes=case.dynamic_shapes,
|
||||||
|
strict=True,
|
||||||
)
|
)
|
||||||
exported_program.graph_module.print_readable()
|
exported_program.graph_module.print_readable()
|
||||||
|
|
||||||
|
|
@ -72,6 +73,7 @@ class ExampleTests(TestCase):
|
||||||
case.example_args,
|
case.example_args,
|
||||||
case.example_kwargs,
|
case.example_kwargs,
|
||||||
dynamic_shapes=case.dynamic_shapes,
|
dynamic_shapes=case.dynamic_shapes,
|
||||||
|
strict=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
exportdb_not_supported_rewrite_cases = [
|
exportdb_not_supported_rewrite_cases = [
|
||||||
|
|
|
||||||
|
|
@ -199,20 +199,25 @@ class Inp3:
|
||||||
p: torch.Tensor
|
p: torch.Tensor
|
||||||
|
|
||||||
|
|
||||||
NON_STRICT_SUFFIX = "_non_strict"
|
NON_STRICT_SUFFIX = "_nonstrict"
|
||||||
RETRACEABILITY_STRICT_SUFFIX = "_retraceability"
|
STRICT_SUFFIX = "_strict"
|
||||||
RETRACEABILITY_NON_STRICT_SUFFIX = "_retraceability_non_strict"
|
RETRACEABILITY_STRICT_SUFFIX = "_retraceability_strict"
|
||||||
SERDES_SUFFIX = "_serdes"
|
RETRACEABILITY_NON_STRICT_SUFFIX = "_retraceability_nonstrict"
|
||||||
SERDES_NON_STRICT_SUFFIX = "_serdes_non_strict"
|
SERDES_STRICT_SUFFIX = "_serdes_strict"
|
||||||
|
SERDES_NON_STRICT_SUFFIX = "_serdes_nonstrict"
|
||||||
PREDISPATCH_SUFFIX = "_pre_dispatch"
|
PREDISPATCH_SUFFIX = "_pre_dispatch"
|
||||||
TRAINING_IR_DECOMP_STRICT_SUFFIX = "_training_ir_to_decomp"
|
TRAINING_IR_DECOMP_STRICT_SUFFIX = "_training_ir_to_decomp_strict"
|
||||||
TRAINING_IR_DECOMP_NON_STRICT_SUFFIX = "_training_ir_to_decomp_non_strict"
|
TRAINING_IR_DECOMP_NON_STRICT_SUFFIX = "_training_ir_to_decomp_nonstrict"
|
||||||
LEGACY_EXPORT_STRICT_SUFFIX = "_legacy_export_strict"
|
LEGACY_EXPORT_STRICT_SUFFIX = "_legacy_export_strict"
|
||||||
LEGACY_EXPORT_NONSTRICT_SUFFIX = "_legacy_export_non_strict"
|
LEGACY_EXPORT_NONSTRICT_SUFFIX = "_legacy_export_nonstrict"
|
||||||
|
CPP_RUNTIME_STRICT_SUFFIX = "_cpp_runtime_strict"
|
||||||
|
CPP_RUNTIME_NONSTRICT_SUFFIX = "_cpp_runtime_nonstrict"
|
||||||
|
|
||||||
|
|
||||||
|
# Now default mode is non strict, so original unammended test names
|
||||||
|
# should be treated as non-strict
|
||||||
def is_non_strict_test(test_name):
|
def is_non_strict_test(test_name):
|
||||||
return test_name.endswith(NON_STRICT_SUFFIX)
|
return not test_name.endswith(STRICT_SUFFIX)
|
||||||
|
|
||||||
|
|
||||||
def is_non_strict_legacy_test(test_name):
|
def is_non_strict_legacy_test(test_name):
|
||||||
|
|
@ -226,7 +231,7 @@ def is_retracebility_test(test_name):
|
||||||
|
|
||||||
|
|
||||||
def is_serdes_test(test_name):
|
def is_serdes_test(test_name):
|
||||||
return test_name.endswith(SERDES_SUFFIX) or test_name.endswith(
|
return test_name.endswith(SERDES_STRICT_SUFFIX) or test_name.endswith(
|
||||||
SERDES_NON_STRICT_SUFFIX
|
SERDES_NON_STRICT_SUFFIX
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -237,6 +242,12 @@ def is_training_ir_test(test_name):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def is_cpp_runtime_test(test_name):
|
||||||
|
return test_name.endswith(CPP_RUNTIME_STRICT_SUFFIX) or test_name.endswith(
|
||||||
|
CPP_RUNTIME_NONSTRICT_SUFFIX
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_hop_schema(ep: torch.export.ExportedProgram):
|
def get_hop_schema(ep: torch.export.ExportedProgram):
|
||||||
hop_node = next(
|
hop_node = next(
|
||||||
node
|
node
|
||||||
|
|
@ -782,6 +793,8 @@ graph():
|
||||||
self.assertEqual(exp_out, ep.module()(*args))
|
self.assertEqual(exp_out, ep.module()(*args))
|
||||||
|
|
||||||
@requires_gpu
|
@requires_gpu
|
||||||
|
@testing.expectedFailureLegacyExportNonStrict # Old export graph contains auto_functionalize not Triton wrapper
|
||||||
|
@testing.expectedFailureLegacyExportStrict # Old export graph contains auto_functionalize not Triton wrapper
|
||||||
def test_export_custom_triton_kernel_mutable(self):
|
def test_export_custom_triton_kernel_mutable(self):
|
||||||
@triton.jit
|
@triton.jit
|
||||||
def add_kernel(
|
def add_kernel(
|
||||||
|
|
@ -1196,6 +1209,8 @@ graph():
|
||||||
self.assertEqual(orig_res, ep_res)
|
self.assertEqual(orig_res, ep_res)
|
||||||
|
|
||||||
def test_unbacked_to_cond(self):
|
def test_unbacked_to_cond(self):
|
||||||
|
strict = True
|
||||||
|
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
az = a.nonzero()
|
az = a.nonzero()
|
||||||
|
|
@ -1210,9 +1225,11 @@ graph():
|
||||||
return r * 2
|
return r * 2
|
||||||
|
|
||||||
M()(torch.randn(7))
|
M()(torch.randn(7))
|
||||||
torch.export.export(M(), (torch.randn(7),))
|
torch.export.export(M(), (torch.randn(7),), strict=strict)
|
||||||
|
|
||||||
def test_unbacked_to_cond_passthrough(self):
|
def test_unbacked_to_cond_passthrough(self):
|
||||||
|
strict = True
|
||||||
|
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def forward(self, a):
|
def forward(self, a):
|
||||||
az = a.nonzero()
|
az = a.nonzero()
|
||||||
|
|
@ -1227,7 +1244,7 @@ graph():
|
||||||
return r * 2
|
return r * 2
|
||||||
|
|
||||||
M()(torch.randn(7))
|
M()(torch.randn(7))
|
||||||
torch.export.export(M(), (torch.randn(7),))
|
torch.export.export(M(), (torch.randn(7),), strict=strict)
|
||||||
|
|
||||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||||
def test_cond_contains_unbacked_no_escape(self):
|
def test_cond_contains_unbacked_no_escape(self):
|
||||||
|
|
@ -1752,6 +1769,7 @@ graph():
|
||||||
# Bug: ep.run_decompositions() doesn't propagate real tensors
|
# Bug: ep.run_decompositions() doesn't propagate real tensors
|
||||||
@testing.expectedFailureTrainingIRToRunDecompNonStrict
|
@testing.expectedFailureTrainingIRToRunDecompNonStrict
|
||||||
def test_draft_export_infers_fake_kernel(self):
|
def test_draft_export_infers_fake_kernel(self):
|
||||||
|
strict = True
|
||||||
with torch.library._scoped_library("export", "FRAGMENT") as lib:
|
with torch.library._scoped_library("export", "FRAGMENT") as lib:
|
||||||
lib.define("bar(Tensor x) -> Tensor")
|
lib.define("bar(Tensor x) -> Tensor")
|
||||||
lib.impl("bar", lambda x: x[0].clone(), "CPU")
|
lib.impl("bar", lambda x: x[0].clone(), "CPU")
|
||||||
|
|
@ -1767,7 +1785,7 @@ graph():
|
||||||
model = Foo()
|
model = Foo()
|
||||||
inputs = (torch.randn(1, 3), torch.randn(2, 1))
|
inputs = (torch.randn(1, 3), torch.randn(2, 1))
|
||||||
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
|
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
|
||||||
ep = export(model, inputs)
|
ep = export(model, inputs, strict=strict)
|
||||||
|
|
||||||
# expecttest only works for the base TestExport class.
|
# expecttest only works for the base TestExport class.
|
||||||
if self.__class__ != TestExport:
|
if self.__class__ != TestExport:
|
||||||
|
|
@ -3727,9 +3745,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||||
if node.op == "placeholder":
|
if node.op == "placeholder":
|
||||||
self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)")
|
self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)")
|
||||||
|
|
||||||
# retracing doesn't seem to like dataclass registration,
|
@testing.expectedFailureRetraceability
|
||||||
# raising a dynamo error in fx_pytree.tree_flatten_spec
|
|
||||||
@testing.expectedFailureRetraceability # T186979579
|
|
||||||
def test_dynamic_shapes_builder_pytree(self):
|
def test_dynamic_shapes_builder_pytree(self):
|
||||||
torch.export.register_dataclass(
|
torch.export.register_dataclass(
|
||||||
Inp1,
|
Inp1,
|
||||||
|
|
@ -4396,8 +4412,8 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||||
ep = export(Foo(), inputs, dynamic_shapes=shapes)
|
ep = export(Foo(), inputs, dynamic_shapes=shapes)
|
||||||
ep.module()(torch.randn(6, 3), torch.randn(7, 4))
|
ep.module()(torch.randn(6, 3), torch.randn(7, 4))
|
||||||
|
|
||||||
@testing.expectedFailureRetraceability # T183144629
|
|
||||||
@testing.expectedFailureSerDerNonStrict
|
@testing.expectedFailureSerDerNonStrict
|
||||||
|
@testing.expectedFailureRetraceability
|
||||||
def test_map(self):
|
def test_map(self):
|
||||||
class Module(torch.nn.Module):
|
class Module(torch.nn.Module):
|
||||||
def forward(self, xs, y, z):
|
def forward(self, xs, y, z):
|
||||||
|
|
@ -4746,11 +4762,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||||
@testing.expectedFailureSerDer # we don't save placeholder metadata
|
@testing.expectedFailureSerDer # we don't save placeholder metadata
|
||||||
@testing.expectedFailureCppSerDes # we don't save placeholder metadata
|
@testing.expectedFailureCppSerDes # we don't save placeholder metadata
|
||||||
@testing.expectedFailureSerDerNonStrict
|
@testing.expectedFailureSerDerNonStrict
|
||||||
@testing.expectedFailureNonStrict
|
|
||||||
@testing.expectedFailureTrainingIRToRunDecompNonStrict # source_fn_stack failure
|
|
||||||
@testing.expectedFailureRetraceabilityNonStrict
|
|
||||||
@testing.expectedFailureLegacyExportNonStrict
|
|
||||||
def test_linear_conv(self):
|
def test_linear_conv(self):
|
||||||
|
strict = True
|
||||||
|
|
||||||
class MyLinear(torch.nn.Module):
|
class MyLinear(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -4771,7 +4785,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||||
x_linear = self.linear(x_conv)
|
x_linear = self.linear(x_conv)
|
||||||
return x_linear.cos()
|
return x_linear.cos()
|
||||||
|
|
||||||
ep = export(Foo(), (torch.randn(20, 16, 50, 100),))
|
ep = export(Foo(), (torch.randn(20, 16, 50, 100),), strict=strict)
|
||||||
for node in ep.graph.nodes:
|
for node in ep.graph.nodes:
|
||||||
if (
|
if (
|
||||||
node.op == "placeholder"
|
node.op == "placeholder"
|
||||||
|
|
@ -4780,7 +4794,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||||
):
|
):
|
||||||
self.assertTrue("source_fn_stack" in node.meta)
|
self.assertTrue("source_fn_stack" in node.meta)
|
||||||
|
|
||||||
@testing.expectedFailureRetraceability # T186979579
|
@testing.expectedFailureRetraceability
|
||||||
def test_dynamic_shapes_dataclass(self):
|
def test_dynamic_shapes_dataclass(self):
|
||||||
torch.export.register_dataclass(
|
torch.export.register_dataclass(
|
||||||
Inp2,
|
Inp2,
|
||||||
|
|
@ -4808,9 +4822,12 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||||
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
|
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@testing.expectedFailureCppSerDes
|
||||||
def test_export_method(self):
|
def test_export_method(self):
|
||||||
from torch._export.utils import sync_state, wrap_method
|
from torch._export.utils import sync_state, wrap_method
|
||||||
|
|
||||||
|
strict = True
|
||||||
|
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -4835,6 +4852,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||||
wrap_method(em.foo),
|
wrap_method(em.foo),
|
||||||
(ex,),
|
(ex,),
|
||||||
dynamic_shapes={"x": (Dim.DYNAMIC,)},
|
dynamic_shapes={"x": (Dim.DYNAMIC,)},
|
||||||
|
strict=strict,
|
||||||
).module()
|
).module()
|
||||||
|
|
||||||
# ...bar
|
# ...bar
|
||||||
|
|
@ -4842,6 +4860,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
||||||
wrap_method(em.bar),
|
wrap_method(em.bar),
|
||||||
(ex,),
|
(ex,),
|
||||||
dynamic_shapes=((Dim.DYNAMIC,),),
|
dynamic_shapes=((Dim.DYNAMIC,),),
|
||||||
|
strict=strict,
|
||||||
).module()
|
).module()
|
||||||
|
|
||||||
if is_serdes_test(self._testMethodName):
|
if is_serdes_test(self._testMethodName):
|
||||||
|
|
@ -6764,7 +6783,7 @@ def forward(self, b_a_buffer, x):
|
||||||
ep = export(m, ())
|
ep = export(m, ())
|
||||||
self.assertEqual(ep.graph_signature.lifted_tensor_constants, ["x"])
|
self.assertEqual(ep.graph_signature.lifted_tensor_constants, ["x"])
|
||||||
|
|
||||||
@testing.expectedFailureRetraceability # T186979579
|
@testing.expectedFailureRetraceability
|
||||||
def test_preserve_shape_dynamism_for_unused_inputs(self):
|
def test_preserve_shape_dynamism_for_unused_inputs(self):
|
||||||
torch.export.register_dataclass(
|
torch.export.register_dataclass(
|
||||||
Inp3,
|
Inp3,
|
||||||
|
|
@ -9977,6 +9996,7 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
||||||
).run(ep.graph_module.code)
|
).run(ep.graph_module.code)
|
||||||
|
|
||||||
def test_replace_unbacked_with_very_large_upperbound(self):
|
def test_replace_unbacked_with_very_large_upperbound(self):
|
||||||
|
strict = True
|
||||||
# beyond 2^53 where python floats lose precision
|
# beyond 2^53 where python floats lose precision
|
||||||
VERY_LARGE_INT = 1000000007999999992
|
VERY_LARGE_INT = 1000000007999999992
|
||||||
|
|
||||||
|
|
@ -9996,7 +10016,7 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
||||||
"x": (Dim.AUTO, Dim.STATIC),
|
"x": (Dim.AUTO, Dim.STATIC),
|
||||||
"t": (Dim.STATIC,),
|
"t": (Dim.STATIC,),
|
||||||
}
|
}
|
||||||
ep = export(Model(), inp, dynamic_shapes=spec)
|
ep = export(Model(), inp, dynamic_shapes=spec, strict=strict)
|
||||||
self.assertTrue(torch.allclose(Model()(*inp), ep.module()(*inp)))
|
self.assertTrue(torch.allclose(Model()(*inp), ep.module()(*inp)))
|
||||||
|
|
||||||
def test_predispatch_cond(self):
|
def test_predispatch_cond(self):
|
||||||
|
|
@ -12002,13 +12022,11 @@ def forward(self, x):
|
||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
ep.module()(torch.randn(4, 2))
|
ep.module()(torch.randn(4, 2))
|
||||||
|
|
||||||
@testing.expectedFailureNonStrict
|
|
||||||
@testing.expectedFailureTrainingIRToRunDecompNonStrict # unbacked symint not tracked?
|
|
||||||
@testing.expectedFailureSerDer # T195866111
|
@testing.expectedFailureSerDer # T195866111
|
||||||
@testing.expectedFailureSerDerNonStrict
|
@testing.expectedFailureSerDerNonStrict
|
||||||
@testing.expectedFailureRetraceabilityNonStrict
|
|
||||||
@testing.expectedFailureLegacyExportNonStrict
|
|
||||||
def test_hints_wrapper(self):
|
def test_hints_wrapper(self):
|
||||||
|
strict = True
|
||||||
|
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
@ -12036,7 +12054,7 @@ def forward(self, x):
|
||||||
x = torch.randn(2, 4)
|
x = torch.randn(2, 4)
|
||||||
y = torch.ones(4)
|
y = torch.ones(4)
|
||||||
|
|
||||||
ep_for_training = torch.export.export_for_training(M(), (x, y))
|
ep_for_training = torch.export.export_for_training(M(), (x, y), strict=strict)
|
||||||
self.assertExpectedInline(
|
self.assertExpectedInline(
|
||||||
normalize_gm(
|
normalize_gm(
|
||||||
ep_for_training.graph_module.print_readable(print_output=False)
|
ep_for_training.graph_module.print_readable(print_output=False)
|
||||||
|
|
@ -12069,7 +12087,7 @@ class GraphModule(torch.nn.Module):
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
ep = export(M(), (x, y)).run_decompositions({})
|
ep = export(M(), (x, y), strict=strict).run_decompositions({})
|
||||||
export_res = ep.module()(x, y)
|
export_res = ep.module()(x, y)
|
||||||
ref_res = M()(x, y)
|
ref_res = M()(x, y)
|
||||||
self.assertEqual(export_res, ref_res)
|
self.assertEqual(export_res, ref_res)
|
||||||
|
|
@ -12600,30 +12618,6 @@ def forward(self, x):
|
||||||
return (add, add_1)""",
|
return (add, add_1)""",
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_logging_logger(self):
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
class M(torch.nn.Module):
|
|
||||||
def forward(self, x):
|
|
||||||
logger.log("start")
|
|
||||||
x1 = x + x
|
|
||||||
logger.debug(x1)
|
|
||||||
x2 = x1 * x1
|
|
||||||
logger.info(1, 2, 3)
|
|
||||||
x3 = x2 + x2
|
|
||||||
return (x1, x3)
|
|
||||||
|
|
||||||
gm = export(M(), (torch.randn(3, 3),)).graph_module
|
|
||||||
self.assertExpectedInline(
|
|
||||||
gm.code.strip(),
|
|
||||||
"""\
|
|
||||||
def forward(self, x):
|
|
||||||
add = torch.ops.aten.add.Tensor(x, x); x = None
|
|
||||||
mul = torch.ops.aten.mul.Tensor(add, add)
|
|
||||||
add_1 = torch.ops.aten.add.Tensor(mul, mul); mul = None
|
|
||||||
return (add, add_1)""",
|
|
||||||
)
|
|
||||||
|
|
||||||
@unittest.skipIf(not TEST_TRANSFORMERS, "No transformers")
|
@unittest.skipIf(not TEST_TRANSFORMERS, "No transformers")
|
||||||
def test_hf_logging_logger(self):
|
def test_hf_logging_logger(self):
|
||||||
import transformers
|
import transformers
|
||||||
|
|
@ -12666,6 +12660,31 @@ def forward(self, x):
|
||||||
return (add,)""",
|
return (add,)""",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_logging_logger(self):
|
||||||
|
strict = True
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class M(torch.nn.Module):
|
||||||
|
def forward(self, x):
|
||||||
|
logger.log("start")
|
||||||
|
x1 = x + x
|
||||||
|
logger.debug(x1)
|
||||||
|
x2 = x1 * x1
|
||||||
|
logger.info(1, 2, 3)
|
||||||
|
x3 = x2 + x2
|
||||||
|
return (x1, x3)
|
||||||
|
|
||||||
|
gm = export(M(), (torch.randn(3, 3),), strict=strict).graph_module
|
||||||
|
self.assertExpectedInline(
|
||||||
|
gm.code.strip(),
|
||||||
|
"""\
|
||||||
|
def forward(self, x):
|
||||||
|
add = torch.ops.aten.add.Tensor(x, x); x = None
|
||||||
|
mul = torch.ops.aten.mul.Tensor(add, add)
|
||||||
|
add_1 = torch.ops.aten.add.Tensor(mul, mul); mul = None
|
||||||
|
return (add, add_1)""",
|
||||||
|
)
|
||||||
|
|
||||||
def test_constant_fqn(self):
|
def test_constant_fqn(self):
|
||||||
class Nested(torch.nn.Module):
|
class Nested(torch.nn.Module):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
|
|
@ -12929,19 +12948,20 @@ class TestExportCustomClass(TorchTestCase):
|
||||||
|
|
||||||
x, y = torch.randn(3, 2), torch.randn(3, 2)
|
x, y = torch.randn(3, 2), torch.randn(3, 2)
|
||||||
mod = Mod()
|
mod = Mod()
|
||||||
# TODO: strict mode doesn't work because dynamo add_mod is treated as a
|
if is_non_strict_test(self._testMethodName):
|
||||||
# user defined variable. We might need to add a CustomModule variable to support it.
|
|
||||||
if self._testMethodName == "test_export_script_module":
|
|
||||||
with self.assertRaisesRegex(
|
|
||||||
torch._dynamo.exc.Unsupported, "UserDefined with non-function"
|
|
||||||
):
|
|
||||||
ep = export(mod, (x, y))
|
|
||||||
else:
|
|
||||||
ep = export(mod, (x, y))
|
ep = export(mod, (x, y))
|
||||||
self.assertEqual(ep.module()(x, y), mod(x, y))
|
self.assertEqual(ep.module()(x, y), mod(x, y))
|
||||||
FileCheck().check_count("torch.ops.aten.add.Tensor", 1, exactly=True).run(
|
FileCheck().check_count("torch.ops.aten.add.Tensor", 1, exactly=True).run(
|
||||||
ep.graph_module.code
|
ep.graph_module.code
|
||||||
)
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# TODO: strict mode doesn't work because dynamo add_mod is treated as a
|
||||||
|
# user defined variable. We might need to add a CustomModule variable to support it.
|
||||||
|
with self.assertRaisesRegex(
|
||||||
|
torch._dynamo.exc.Unsupported, "UserDefined with non-function"
|
||||||
|
):
|
||||||
|
ep = export(mod, (x, y))
|
||||||
|
|
||||||
def test_preserve_non_cia_op(self):
|
def test_preserve_non_cia_op(self):
|
||||||
class M(torch.nn.Module):
|
class M(torch.nn.Module):
|
||||||
|
|
|
||||||
|
|
@ -12,22 +12,22 @@ from torch.export import export
|
||||||
test_classes = {}
|
test_classes = {}
|
||||||
|
|
||||||
|
|
||||||
def mocked_non_strict_export(*args, **kwargs):
|
def mocked_strict_export(*args, **kwargs):
|
||||||
# If user already specified strict, don't make it non-strict
|
# If user already specified strict, don't make it strict
|
||||||
if "strict" in kwargs:
|
if "strict" in kwargs:
|
||||||
return export(*args, **kwargs)
|
return export(*args, **kwargs)
|
||||||
return export(*args, **kwargs, strict=False)
|
return export(*args, **kwargs, strict=True)
|
||||||
|
|
||||||
|
|
||||||
def make_dynamic_cls(cls):
|
def make_dynamic_cls(cls):
|
||||||
cls_prefix = "NonStrictExport"
|
cls_prefix = "StrictExport"
|
||||||
|
|
||||||
test_class = testing.make_test_cls_with_mocked_export(
|
test_class = testing.make_test_cls_with_mocked_export(
|
||||||
cls,
|
cls,
|
||||||
cls_prefix,
|
cls_prefix,
|
||||||
test_export.NON_STRICT_SUFFIX,
|
test_export.STRICT_SUFFIX,
|
||||||
mocked_non_strict_export,
|
mocked_strict_export,
|
||||||
xfail_prop="_expected_failure_non_strict",
|
xfail_prop="_expected_failure_strict",
|
||||||
)
|
)
|
||||||
|
|
||||||
test_classes[test_class.__name__] = test_class
|
test_classes[test_class.__name__] = test_class
|
||||||
|
|
@ -13,20 +13,11 @@ test_classes = {}
|
||||||
|
|
||||||
|
|
||||||
def mocked_retraceability_export_strict(*args, **kwargs):
|
def mocked_retraceability_export_strict(*args, **kwargs):
|
||||||
ep = export(*args, **kwargs)
|
|
||||||
if "dynamic_shapes" in kwargs:
|
|
||||||
if isinstance(kwargs["dynamic_shapes"], dict):
|
|
||||||
kwargs["dynamic_shapes"] = tuple(kwargs["dynamic_shapes"].values())
|
|
||||||
|
|
||||||
ep = export(ep.module(), *(args[1:]), **kwargs)
|
|
||||||
return ep
|
|
||||||
|
|
||||||
|
|
||||||
def mocked_retraceability_export_non_strict(*args, **kwargs):
|
|
||||||
if "strict" in kwargs:
|
if "strict" in kwargs:
|
||||||
ep = export(*args, **kwargs)
|
ep = export(*args, **kwargs)
|
||||||
else:
|
else:
|
||||||
ep = export(*args, **kwargs, strict=False)
|
ep = export(*args, **kwargs, strict=True)
|
||||||
|
|
||||||
if "dynamic_shapes" in kwargs:
|
if "dynamic_shapes" in kwargs:
|
||||||
if isinstance(kwargs["dynamic_shapes"], dict):
|
if isinstance(kwargs["dynamic_shapes"], dict):
|
||||||
kwargs["dynamic_shapes"] = tuple(kwargs["dynamic_shapes"].values())
|
kwargs["dynamic_shapes"] = tuple(kwargs["dynamic_shapes"].values())
|
||||||
|
|
@ -34,7 +25,17 @@ def mocked_retraceability_export_non_strict(*args, **kwargs):
|
||||||
if "strict" in kwargs:
|
if "strict" in kwargs:
|
||||||
ep = export(ep.module(), *(args[1:]), **kwargs)
|
ep = export(ep.module(), *(args[1:]), **kwargs)
|
||||||
else:
|
else:
|
||||||
ep = export(ep.module(), *(args[1:]), **kwargs, strict=False)
|
ep = export(ep.module(), *(args[1:]), **kwargs, strict=True)
|
||||||
|
return ep
|
||||||
|
|
||||||
|
|
||||||
|
def mocked_retraceability_export_non_strict(*args, **kwargs):
|
||||||
|
ep = export(*args, **kwargs)
|
||||||
|
if "dynamic_shapes" in kwargs:
|
||||||
|
if isinstance(kwargs["dynamic_shapes"], dict):
|
||||||
|
kwargs["dynamic_shapes"] = tuple(kwargs["dynamic_shapes"].values())
|
||||||
|
|
||||||
|
ep = export(ep.module(), *(args[1:]), **kwargs)
|
||||||
return ep
|
return ep
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -16,7 +16,11 @@ test_classes = {}
|
||||||
|
|
||||||
|
|
||||||
def mocked_serder_export_strict(*args, **kwargs):
|
def mocked_serder_export_strict(*args, **kwargs):
|
||||||
ep = export(*args, **kwargs)
|
if "strict" not in kwargs:
|
||||||
|
ep = export(*args, **kwargs, strict=True)
|
||||||
|
else:
|
||||||
|
ep = export(*args, **kwargs)
|
||||||
|
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
save(ep, buffer)
|
save(ep, buffer)
|
||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
|
|
@ -25,10 +29,7 @@ def mocked_serder_export_strict(*args, **kwargs):
|
||||||
|
|
||||||
|
|
||||||
def mocked_serder_export_non_strict(*args, **kwargs):
|
def mocked_serder_export_non_strict(*args, **kwargs):
|
||||||
if "strict" in kwargs:
|
ep = export(*args, **kwargs)
|
||||||
ep = export(*args, **kwargs)
|
|
||||||
else:
|
|
||||||
ep = export(*args, **kwargs, strict=False)
|
|
||||||
buffer = io.BytesIO()
|
buffer = io.BytesIO()
|
||||||
save(ep, buffer)
|
save(ep, buffer)
|
||||||
buffer.seek(0)
|
buffer.seek(0)
|
||||||
|
|
@ -41,7 +42,7 @@ def make_dynamic_cls(cls, strict):
|
||||||
test_class = testing.make_test_cls_with_mocked_export(
|
test_class = testing.make_test_cls_with_mocked_export(
|
||||||
cls,
|
cls,
|
||||||
"SerDesExport",
|
"SerDesExport",
|
||||||
test_export.SERDES_SUFFIX,
|
test_export.SERDES_STRICT_SUFFIX,
|
||||||
mocked_serder_export_strict,
|
mocked_serder_export_strict,
|
||||||
xfail_prop="_expected_failure_serdes",
|
xfail_prop="_expected_failure_serdes",
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -257,9 +257,9 @@ def expectedFailureTrainingIRToRunDecompNonStrict(fn):
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
# Controls tests generated in test/export/test_export_nonstrict.py
|
# Controls tests generated in test/export/test_export_strict.py
|
||||||
def expectedFailureNonStrict(fn):
|
def expectedFailureStrict(fn):
|
||||||
fn._expected_failure_non_strict = True
|
fn._expected_failure_strict = True
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -307,6 +307,11 @@ def expectedFailureCppRuntime(fn):
|
||||||
return fn
|
return fn
|
||||||
|
|
||||||
|
|
||||||
|
def expectedFailureCppRuntimeNonStrict(fn):
|
||||||
|
fn._expected_failure_cpp_runtime_non_strict = True
|
||||||
|
return fn
|
||||||
|
|
||||||
|
|
||||||
# Controls tests generated in test/export/test_export_legacy.py
|
# Controls tests generated in test/export/test_export_legacy.py
|
||||||
def expectedFailureLegacyExportStrict(fn):
|
def expectedFailureLegacyExportStrict(fn):
|
||||||
fn._expected_failure_legacy_export = True
|
fn._expected_failure_legacy_export = True
|
||||||
|
|
|
||||||
|
|
@ -258,7 +258,7 @@ def export(
|
||||||
kwargs: Optional[dict[str, Any]] = None,
|
kwargs: Optional[dict[str, Any]] = None,
|
||||||
*,
|
*,
|
||||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||||
strict: bool = True,
|
strict: bool = False,
|
||||||
preserve_module_call_signature: tuple[str, ...] = (),
|
preserve_module_call_signature: tuple[str, ...] = (),
|
||||||
) -> ExportedProgram:
|
) -> ExportedProgram:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user