Set non-strict export as default mode (#148790)

Summary:
- Flip the default value of strict argument in torch.export.export from True to False
- Update test infra to cope with the change, some of them made the assumption of strict mode as default
- Disabled some tests that fail in non-strict mode

Test Plan: Sandcastle

Differential Revision: D70228628

Pull Request resolved: https://github.com/pytorch/pytorch/pull/148790
Approved by: https://github.com/angelayi
This commit is contained in:
Yanan Cao (PyTorch) 2025-03-12 21:10:54 +00:00 committed by PyTorch MergeBot
parent e3ebf61589
commit ab45aaca97
7 changed files with 121 additions and 92 deletions

View File

@ -40,6 +40,7 @@ class ExampleTests(TestCase):
args_export,
kwargs_export,
dynamic_shapes=case.dynamic_shapes,
strict=True,
)
exported_program.graph_module.print_readable()
@ -72,6 +73,7 @@ class ExampleTests(TestCase):
case.example_args,
case.example_kwargs,
dynamic_shapes=case.dynamic_shapes,
strict=True,
)
exportdb_not_supported_rewrite_cases = [

View File

@ -199,20 +199,25 @@ class Inp3:
p: torch.Tensor
NON_STRICT_SUFFIX = "_non_strict"
RETRACEABILITY_STRICT_SUFFIX = "_retraceability"
RETRACEABILITY_NON_STRICT_SUFFIX = "_retraceability_non_strict"
SERDES_SUFFIX = "_serdes"
SERDES_NON_STRICT_SUFFIX = "_serdes_non_strict"
NON_STRICT_SUFFIX = "_nonstrict"
STRICT_SUFFIX = "_strict"
RETRACEABILITY_STRICT_SUFFIX = "_retraceability_strict"
RETRACEABILITY_NON_STRICT_SUFFIX = "_retraceability_nonstrict"
SERDES_STRICT_SUFFIX = "_serdes_strict"
SERDES_NON_STRICT_SUFFIX = "_serdes_nonstrict"
PREDISPATCH_SUFFIX = "_pre_dispatch"
TRAINING_IR_DECOMP_STRICT_SUFFIX = "_training_ir_to_decomp"
TRAINING_IR_DECOMP_NON_STRICT_SUFFIX = "_training_ir_to_decomp_non_strict"
TRAINING_IR_DECOMP_STRICT_SUFFIX = "_training_ir_to_decomp_strict"
TRAINING_IR_DECOMP_NON_STRICT_SUFFIX = "_training_ir_to_decomp_nonstrict"
LEGACY_EXPORT_STRICT_SUFFIX = "_legacy_export_strict"
LEGACY_EXPORT_NONSTRICT_SUFFIX = "_legacy_export_non_strict"
LEGACY_EXPORT_NONSTRICT_SUFFIX = "_legacy_export_nonstrict"
CPP_RUNTIME_STRICT_SUFFIX = "_cpp_runtime_strict"
CPP_RUNTIME_NONSTRICT_SUFFIX = "_cpp_runtime_nonstrict"
# Now default mode is non strict, so original unammended test names
# should be treated as non-strict
def is_non_strict_test(test_name):
return test_name.endswith(NON_STRICT_SUFFIX)
return not test_name.endswith(STRICT_SUFFIX)
def is_non_strict_legacy_test(test_name):
@ -226,7 +231,7 @@ def is_retracebility_test(test_name):
def is_serdes_test(test_name):
return test_name.endswith(SERDES_SUFFIX) or test_name.endswith(
return test_name.endswith(SERDES_STRICT_SUFFIX) or test_name.endswith(
SERDES_NON_STRICT_SUFFIX
)
@ -237,6 +242,12 @@ def is_training_ir_test(test_name):
)
def is_cpp_runtime_test(test_name):
return test_name.endswith(CPP_RUNTIME_STRICT_SUFFIX) or test_name.endswith(
CPP_RUNTIME_NONSTRICT_SUFFIX
)
def get_hop_schema(ep: torch.export.ExportedProgram):
hop_node = next(
node
@ -782,6 +793,8 @@ graph():
self.assertEqual(exp_out, ep.module()(*args))
@requires_gpu
@testing.expectedFailureLegacyExportNonStrict # Old export graph contains auto_functionalize not Triton wrapper
@testing.expectedFailureLegacyExportStrict # Old export graph contains auto_functionalize not Triton wrapper
def test_export_custom_triton_kernel_mutable(self):
@triton.jit
def add_kernel(
@ -1196,6 +1209,8 @@ graph():
self.assertEqual(orig_res, ep_res)
def test_unbacked_to_cond(self):
strict = True
class M(torch.nn.Module):
def forward(self, a):
az = a.nonzero()
@ -1210,9 +1225,11 @@ graph():
return r * 2
M()(torch.randn(7))
torch.export.export(M(), (torch.randn(7),))
torch.export.export(M(), (torch.randn(7),), strict=strict)
def test_unbacked_to_cond_passthrough(self):
strict = True
class M(torch.nn.Module):
def forward(self, a):
az = a.nonzero()
@ -1227,7 +1244,7 @@ graph():
return r * 2
M()(torch.randn(7))
torch.export.export(M(), (torch.randn(7),))
torch.export.export(M(), (torch.randn(7),), strict=strict)
@torch._dynamo.config.patch(capture_scalar_outputs=True)
def test_cond_contains_unbacked_no_escape(self):
@ -1752,6 +1769,7 @@ graph():
# Bug: ep.run_decompositions() doesn't propagate real tensors
@testing.expectedFailureTrainingIRToRunDecompNonStrict
def test_draft_export_infers_fake_kernel(self):
strict = True
with torch.library._scoped_library("export", "FRAGMENT") as lib:
lib.define("bar(Tensor x) -> Tensor")
lib.impl("bar", lambda x: x[0].clone(), "CPU")
@ -1767,7 +1785,7 @@ graph():
model = Foo()
inputs = (torch.randn(1, 3), torch.randn(2, 1))
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
ep = export(model, inputs)
ep = export(model, inputs, strict=strict)
# expecttest only works for the base TestExport class.
if self.__class__ != TestExport:
@ -3727,9 +3745,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
if node.op == "placeholder":
self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)")
# retracing doesn't seem to like dataclass registration,
# raising a dynamo error in fx_pytree.tree_flatten_spec
@testing.expectedFailureRetraceability # T186979579
@testing.expectedFailureRetraceability
def test_dynamic_shapes_builder_pytree(self):
torch.export.register_dataclass(
Inp1,
@ -4396,8 +4412,8 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
ep = export(Foo(), inputs, dynamic_shapes=shapes)
ep.module()(torch.randn(6, 3), torch.randn(7, 4))
@testing.expectedFailureRetraceability # T183144629
@testing.expectedFailureSerDerNonStrict
@testing.expectedFailureRetraceability
def test_map(self):
class Module(torch.nn.Module):
def forward(self, xs, y, z):
@ -4746,11 +4762,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
@testing.expectedFailureSerDer # we don't save placeholder metadata
@testing.expectedFailureCppSerDes # we don't save placeholder metadata
@testing.expectedFailureSerDerNonStrict
@testing.expectedFailureNonStrict
@testing.expectedFailureTrainingIRToRunDecompNonStrict # source_fn_stack failure
@testing.expectedFailureRetraceabilityNonStrict
@testing.expectedFailureLegacyExportNonStrict
def test_linear_conv(self):
strict = True
class MyLinear(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@ -4771,7 +4785,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
x_linear = self.linear(x_conv)
return x_linear.cos()
ep = export(Foo(), (torch.randn(20, 16, 50, 100),))
ep = export(Foo(), (torch.randn(20, 16, 50, 100),), strict=strict)
for node in ep.graph.nodes:
if (
node.op == "placeholder"
@ -4780,7 +4794,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
):
self.assertTrue("source_fn_stack" in node.meta)
@testing.expectedFailureRetraceability # T186979579
@testing.expectedFailureRetraceability
def test_dynamic_shapes_dataclass(self):
torch.export.register_dataclass(
Inp2,
@ -4808,9 +4822,12 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
)
@testing.expectedFailureCppSerDes
def test_export_method(self):
from torch._export.utils import sync_state, wrap_method
strict = True
class M(torch.nn.Module):
def __init__(self):
super().__init__()
@ -4835,6 +4852,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
wrap_method(em.foo),
(ex,),
dynamic_shapes={"x": (Dim.DYNAMIC,)},
strict=strict,
).module()
# ...bar
@ -4842,6 +4860,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
wrap_method(em.bar),
(ex,),
dynamic_shapes=((Dim.DYNAMIC,),),
strict=strict,
).module()
if is_serdes_test(self._testMethodName):
@ -6764,7 +6783,7 @@ def forward(self, b_a_buffer, x):
ep = export(m, ())
self.assertEqual(ep.graph_signature.lifted_tensor_constants, ["x"])
@testing.expectedFailureRetraceability # T186979579
@testing.expectedFailureRetraceability
def test_preserve_shape_dynamism_for_unused_inputs(self):
torch.export.register_dataclass(
Inp3,
@ -9977,6 +9996,7 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
).run(ep.graph_module.code)
def test_replace_unbacked_with_very_large_upperbound(self):
strict = True
# beyond 2^53 where python floats lose precision
VERY_LARGE_INT = 1000000007999999992
@ -9996,7 +10016,7 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
"x": (Dim.AUTO, Dim.STATIC),
"t": (Dim.STATIC,),
}
ep = export(Model(), inp, dynamic_shapes=spec)
ep = export(Model(), inp, dynamic_shapes=spec, strict=strict)
self.assertTrue(torch.allclose(Model()(*inp), ep.module()(*inp)))
def test_predispatch_cond(self):
@ -12002,13 +12022,11 @@ def forward(self, x):
with self.assertRaises(RuntimeError):
ep.module()(torch.randn(4, 2))
@testing.expectedFailureNonStrict
@testing.expectedFailureTrainingIRToRunDecompNonStrict # unbacked symint not tracked?
@testing.expectedFailureSerDer # T195866111
@testing.expectedFailureSerDerNonStrict
@testing.expectedFailureRetraceabilityNonStrict
@testing.expectedFailureLegacyExportNonStrict
def test_hints_wrapper(self):
strict = True
class M(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
@ -12036,7 +12054,7 @@ def forward(self, x):
x = torch.randn(2, 4)
y = torch.ones(4)
ep_for_training = torch.export.export_for_training(M(), (x, y))
ep_for_training = torch.export.export_for_training(M(), (x, y), strict=strict)
self.assertExpectedInline(
normalize_gm(
ep_for_training.graph_module.print_readable(print_output=False)
@ -12069,7 +12087,7 @@ class GraphModule(torch.nn.Module):
""",
)
ep = export(M(), (x, y)).run_decompositions({})
ep = export(M(), (x, y), strict=strict).run_decompositions({})
export_res = ep.module()(x, y)
ref_res = M()(x, y)
self.assertEqual(export_res, ref_res)
@ -12600,30 +12618,6 @@ def forward(self, x):
return (add, add_1)""",
)
def test_logging_logger(self):
logger = logging.getLogger(__name__)
class M(torch.nn.Module):
def forward(self, x):
logger.log("start")
x1 = x + x
logger.debug(x1)
x2 = x1 * x1
logger.info(1, 2, 3)
x3 = x2 + x2
return (x1, x3)
gm = export(M(), (torch.randn(3, 3),)).graph_module
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x):
add = torch.ops.aten.add.Tensor(x, x); x = None
mul = torch.ops.aten.mul.Tensor(add, add)
add_1 = torch.ops.aten.add.Tensor(mul, mul); mul = None
return (add, add_1)""",
)
@unittest.skipIf(not TEST_TRANSFORMERS, "No transformers")
def test_hf_logging_logger(self):
import transformers
@ -12666,6 +12660,31 @@ def forward(self, x):
return (add,)""",
)
def test_logging_logger(self):
strict = True
logger = logging.getLogger(__name__)
class M(torch.nn.Module):
def forward(self, x):
logger.log("start")
x1 = x + x
logger.debug(x1)
x2 = x1 * x1
logger.info(1, 2, 3)
x3 = x2 + x2
return (x1, x3)
gm = export(M(), (torch.randn(3, 3),), strict=strict).graph_module
self.assertExpectedInline(
gm.code.strip(),
"""\
def forward(self, x):
add = torch.ops.aten.add.Tensor(x, x); x = None
mul = torch.ops.aten.mul.Tensor(add, add)
add_1 = torch.ops.aten.add.Tensor(mul, mul); mul = None
return (add, add_1)""",
)
def test_constant_fqn(self):
class Nested(torch.nn.Module):
def __init__(self) -> None:
@ -12929,19 +12948,20 @@ class TestExportCustomClass(TorchTestCase):
x, y = torch.randn(3, 2), torch.randn(3, 2)
mod = Mod()
# TODO: strict mode doesn't work because dynamo add_mod is treated as a
# user defined variable. We might need to add a CustomModule variable to support it.
if self._testMethodName == "test_export_script_module":
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported, "UserDefined with non-function"
):
ep = export(mod, (x, y))
else:
if is_non_strict_test(self._testMethodName):
ep = export(mod, (x, y))
self.assertEqual(ep.module()(x, y), mod(x, y))
FileCheck().check_count("torch.ops.aten.add.Tensor", 1, exactly=True).run(
ep.graph_module.code
)
return
# TODO: strict mode doesn't work because dynamo add_mod is treated as a
# user defined variable. We might need to add a CustomModule variable to support it.
with self.assertRaisesRegex(
torch._dynamo.exc.Unsupported, "UserDefined with non-function"
):
ep = export(mod, (x, y))
def test_preserve_non_cia_op(self):
class M(torch.nn.Module):

View File

@ -12,22 +12,22 @@ from torch.export import export
test_classes = {}
def mocked_non_strict_export(*args, **kwargs):
# If user already specified strict, don't make it non-strict
def mocked_strict_export(*args, **kwargs):
# If user already specified strict, don't make it strict
if "strict" in kwargs:
return export(*args, **kwargs)
return export(*args, **kwargs, strict=False)
return export(*args, **kwargs, strict=True)
def make_dynamic_cls(cls):
cls_prefix = "NonStrictExport"
cls_prefix = "StrictExport"
test_class = testing.make_test_cls_with_mocked_export(
cls,
cls_prefix,
test_export.NON_STRICT_SUFFIX,
mocked_non_strict_export,
xfail_prop="_expected_failure_non_strict",
test_export.STRICT_SUFFIX,
mocked_strict_export,
xfail_prop="_expected_failure_strict",
)
test_classes[test_class.__name__] = test_class

View File

@ -13,20 +13,11 @@ test_classes = {}
def mocked_retraceability_export_strict(*args, **kwargs):
ep = export(*args, **kwargs)
if "dynamic_shapes" in kwargs:
if isinstance(kwargs["dynamic_shapes"], dict):
kwargs["dynamic_shapes"] = tuple(kwargs["dynamic_shapes"].values())
ep = export(ep.module(), *(args[1:]), **kwargs)
return ep
def mocked_retraceability_export_non_strict(*args, **kwargs):
if "strict" in kwargs:
ep = export(*args, **kwargs)
else:
ep = export(*args, **kwargs, strict=False)
ep = export(*args, **kwargs, strict=True)
if "dynamic_shapes" in kwargs:
if isinstance(kwargs["dynamic_shapes"], dict):
kwargs["dynamic_shapes"] = tuple(kwargs["dynamic_shapes"].values())
@ -34,7 +25,17 @@ def mocked_retraceability_export_non_strict(*args, **kwargs):
if "strict" in kwargs:
ep = export(ep.module(), *(args[1:]), **kwargs)
else:
ep = export(ep.module(), *(args[1:]), **kwargs, strict=False)
ep = export(ep.module(), *(args[1:]), **kwargs, strict=True)
return ep
def mocked_retraceability_export_non_strict(*args, **kwargs):
ep = export(*args, **kwargs)
if "dynamic_shapes" in kwargs:
if isinstance(kwargs["dynamic_shapes"], dict):
kwargs["dynamic_shapes"] = tuple(kwargs["dynamic_shapes"].values())
ep = export(ep.module(), *(args[1:]), **kwargs)
return ep

View File

@ -16,7 +16,11 @@ test_classes = {}
def mocked_serder_export_strict(*args, **kwargs):
ep = export(*args, **kwargs)
if "strict" not in kwargs:
ep = export(*args, **kwargs, strict=True)
else:
ep = export(*args, **kwargs)
buffer = io.BytesIO()
save(ep, buffer)
buffer.seek(0)
@ -25,10 +29,7 @@ def mocked_serder_export_strict(*args, **kwargs):
def mocked_serder_export_non_strict(*args, **kwargs):
if "strict" in kwargs:
ep = export(*args, **kwargs)
else:
ep = export(*args, **kwargs, strict=False)
ep = export(*args, **kwargs)
buffer = io.BytesIO()
save(ep, buffer)
buffer.seek(0)
@ -41,7 +42,7 @@ def make_dynamic_cls(cls, strict):
test_class = testing.make_test_cls_with_mocked_export(
cls,
"SerDesExport",
test_export.SERDES_SUFFIX,
test_export.SERDES_STRICT_SUFFIX,
mocked_serder_export_strict,
xfail_prop="_expected_failure_serdes",
)

View File

@ -257,9 +257,9 @@ def expectedFailureTrainingIRToRunDecompNonStrict(fn):
return fn
# Controls tests generated in test/export/test_export_nonstrict.py
def expectedFailureNonStrict(fn):
fn._expected_failure_non_strict = True
# Controls tests generated in test/export/test_export_strict.py
def expectedFailureStrict(fn):
fn._expected_failure_strict = True
return fn
@ -307,6 +307,11 @@ def expectedFailureCppRuntime(fn):
return fn
def expectedFailureCppRuntimeNonStrict(fn):
fn._expected_failure_cpp_runtime_non_strict = True
return fn
# Controls tests generated in test/export/test_export_legacy.py
def expectedFailureLegacyExportStrict(fn):
fn._expected_failure_legacy_export = True

View File

@ -258,7 +258,7 @@ def export(
kwargs: Optional[dict[str, Any]] = None,
*,
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
strict: bool = True,
strict: bool = False,
preserve_module_call_signature: tuple[str, ...] = (),
) -> ExportedProgram:
"""