mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Set non-strict export as default mode (#148790)
Summary: - Flip the default value of strict argument in torch.export.export from True to False - Update test infra to cope with the change, some of them made the assumption of strict mode as default - Disabled some tests that fail in non-strict mode Test Plan: Sandcastle Differential Revision: D70228628 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148790 Approved by: https://github.com/angelayi
This commit is contained in:
parent
e3ebf61589
commit
ab45aaca97
|
|
@ -40,6 +40,7 @@ class ExampleTests(TestCase):
|
|||
args_export,
|
||||
kwargs_export,
|
||||
dynamic_shapes=case.dynamic_shapes,
|
||||
strict=True,
|
||||
)
|
||||
exported_program.graph_module.print_readable()
|
||||
|
||||
|
|
@ -72,6 +73,7 @@ class ExampleTests(TestCase):
|
|||
case.example_args,
|
||||
case.example_kwargs,
|
||||
dynamic_shapes=case.dynamic_shapes,
|
||||
strict=True,
|
||||
)
|
||||
|
||||
exportdb_not_supported_rewrite_cases = [
|
||||
|
|
|
|||
|
|
@ -199,20 +199,25 @@ class Inp3:
|
|||
p: torch.Tensor
|
||||
|
||||
|
||||
NON_STRICT_SUFFIX = "_non_strict"
|
||||
RETRACEABILITY_STRICT_SUFFIX = "_retraceability"
|
||||
RETRACEABILITY_NON_STRICT_SUFFIX = "_retraceability_non_strict"
|
||||
SERDES_SUFFIX = "_serdes"
|
||||
SERDES_NON_STRICT_SUFFIX = "_serdes_non_strict"
|
||||
NON_STRICT_SUFFIX = "_nonstrict"
|
||||
STRICT_SUFFIX = "_strict"
|
||||
RETRACEABILITY_STRICT_SUFFIX = "_retraceability_strict"
|
||||
RETRACEABILITY_NON_STRICT_SUFFIX = "_retraceability_nonstrict"
|
||||
SERDES_STRICT_SUFFIX = "_serdes_strict"
|
||||
SERDES_NON_STRICT_SUFFIX = "_serdes_nonstrict"
|
||||
PREDISPATCH_SUFFIX = "_pre_dispatch"
|
||||
TRAINING_IR_DECOMP_STRICT_SUFFIX = "_training_ir_to_decomp"
|
||||
TRAINING_IR_DECOMP_NON_STRICT_SUFFIX = "_training_ir_to_decomp_non_strict"
|
||||
TRAINING_IR_DECOMP_STRICT_SUFFIX = "_training_ir_to_decomp_strict"
|
||||
TRAINING_IR_DECOMP_NON_STRICT_SUFFIX = "_training_ir_to_decomp_nonstrict"
|
||||
LEGACY_EXPORT_STRICT_SUFFIX = "_legacy_export_strict"
|
||||
LEGACY_EXPORT_NONSTRICT_SUFFIX = "_legacy_export_non_strict"
|
||||
LEGACY_EXPORT_NONSTRICT_SUFFIX = "_legacy_export_nonstrict"
|
||||
CPP_RUNTIME_STRICT_SUFFIX = "_cpp_runtime_strict"
|
||||
CPP_RUNTIME_NONSTRICT_SUFFIX = "_cpp_runtime_nonstrict"
|
||||
|
||||
|
||||
# Now default mode is non strict, so original unammended test names
|
||||
# should be treated as non-strict
|
||||
def is_non_strict_test(test_name):
|
||||
return test_name.endswith(NON_STRICT_SUFFIX)
|
||||
return not test_name.endswith(STRICT_SUFFIX)
|
||||
|
||||
|
||||
def is_non_strict_legacy_test(test_name):
|
||||
|
|
@ -226,7 +231,7 @@ def is_retracebility_test(test_name):
|
|||
|
||||
|
||||
def is_serdes_test(test_name):
|
||||
return test_name.endswith(SERDES_SUFFIX) or test_name.endswith(
|
||||
return test_name.endswith(SERDES_STRICT_SUFFIX) or test_name.endswith(
|
||||
SERDES_NON_STRICT_SUFFIX
|
||||
)
|
||||
|
||||
|
|
@ -237,6 +242,12 @@ def is_training_ir_test(test_name):
|
|||
)
|
||||
|
||||
|
||||
def is_cpp_runtime_test(test_name):
|
||||
return test_name.endswith(CPP_RUNTIME_STRICT_SUFFIX) or test_name.endswith(
|
||||
CPP_RUNTIME_NONSTRICT_SUFFIX
|
||||
)
|
||||
|
||||
|
||||
def get_hop_schema(ep: torch.export.ExportedProgram):
|
||||
hop_node = next(
|
||||
node
|
||||
|
|
@ -782,6 +793,8 @@ graph():
|
|||
self.assertEqual(exp_out, ep.module()(*args))
|
||||
|
||||
@requires_gpu
|
||||
@testing.expectedFailureLegacyExportNonStrict # Old export graph contains auto_functionalize not Triton wrapper
|
||||
@testing.expectedFailureLegacyExportStrict # Old export graph contains auto_functionalize not Triton wrapper
|
||||
def test_export_custom_triton_kernel_mutable(self):
|
||||
@triton.jit
|
||||
def add_kernel(
|
||||
|
|
@ -1196,6 +1209,8 @@ graph():
|
|||
self.assertEqual(orig_res, ep_res)
|
||||
|
||||
def test_unbacked_to_cond(self):
|
||||
strict = True
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, a):
|
||||
az = a.nonzero()
|
||||
|
|
@ -1210,9 +1225,11 @@ graph():
|
|||
return r * 2
|
||||
|
||||
M()(torch.randn(7))
|
||||
torch.export.export(M(), (torch.randn(7),))
|
||||
torch.export.export(M(), (torch.randn(7),), strict=strict)
|
||||
|
||||
def test_unbacked_to_cond_passthrough(self):
|
||||
strict = True
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, a):
|
||||
az = a.nonzero()
|
||||
|
|
@ -1227,7 +1244,7 @@ graph():
|
|||
return r * 2
|
||||
|
||||
M()(torch.randn(7))
|
||||
torch.export.export(M(), (torch.randn(7),))
|
||||
torch.export.export(M(), (torch.randn(7),), strict=strict)
|
||||
|
||||
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
||||
def test_cond_contains_unbacked_no_escape(self):
|
||||
|
|
@ -1752,6 +1769,7 @@ graph():
|
|||
# Bug: ep.run_decompositions() doesn't propagate real tensors
|
||||
@testing.expectedFailureTrainingIRToRunDecompNonStrict
|
||||
def test_draft_export_infers_fake_kernel(self):
|
||||
strict = True
|
||||
with torch.library._scoped_library("export", "FRAGMENT") as lib:
|
||||
lib.define("bar(Tensor x) -> Tensor")
|
||||
lib.impl("bar", lambda x: x[0].clone(), "CPU")
|
||||
|
|
@ -1767,7 +1785,7 @@ graph():
|
|||
model = Foo()
|
||||
inputs = (torch.randn(1, 3), torch.randn(2, 1))
|
||||
with torch._functorch.config.patch(fake_tensor_propagate_real_tensors=True):
|
||||
ep = export(model, inputs)
|
||||
ep = export(model, inputs, strict=strict)
|
||||
|
||||
# expecttest only works for the base TestExport class.
|
||||
if self.__class__ != TestExport:
|
||||
|
|
@ -3727,9 +3745,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
if node.op == "placeholder":
|
||||
self.assertEqual(str(tuple(node.meta["val"].shape)), f"({sym},)")
|
||||
|
||||
# retracing doesn't seem to like dataclass registration,
|
||||
# raising a dynamo error in fx_pytree.tree_flatten_spec
|
||||
@testing.expectedFailureRetraceability # T186979579
|
||||
@testing.expectedFailureRetraceability
|
||||
def test_dynamic_shapes_builder_pytree(self):
|
||||
torch.export.register_dataclass(
|
||||
Inp1,
|
||||
|
|
@ -4396,8 +4412,8 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
ep = export(Foo(), inputs, dynamic_shapes=shapes)
|
||||
ep.module()(torch.randn(6, 3), torch.randn(7, 4))
|
||||
|
||||
@testing.expectedFailureRetraceability # T183144629
|
||||
@testing.expectedFailureSerDerNonStrict
|
||||
@testing.expectedFailureRetraceability
|
||||
def test_map(self):
|
||||
class Module(torch.nn.Module):
|
||||
def forward(self, xs, y, z):
|
||||
|
|
@ -4746,11 +4762,9 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
@testing.expectedFailureSerDer # we don't save placeholder metadata
|
||||
@testing.expectedFailureCppSerDes # we don't save placeholder metadata
|
||||
@testing.expectedFailureSerDerNonStrict
|
||||
@testing.expectedFailureNonStrict
|
||||
@testing.expectedFailureTrainingIRToRunDecompNonStrict # source_fn_stack failure
|
||||
@testing.expectedFailureRetraceabilityNonStrict
|
||||
@testing.expectedFailureLegacyExportNonStrict
|
||||
def test_linear_conv(self):
|
||||
strict = True
|
||||
|
||||
class MyLinear(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -4771,7 +4785,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
x_linear = self.linear(x_conv)
|
||||
return x_linear.cos()
|
||||
|
||||
ep = export(Foo(), (torch.randn(20, 16, 50, 100),))
|
||||
ep = export(Foo(), (torch.randn(20, 16, 50, 100),), strict=strict)
|
||||
for node in ep.graph.nodes:
|
||||
if (
|
||||
node.op == "placeholder"
|
||||
|
|
@ -4780,7 +4794,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
):
|
||||
self.assertTrue("source_fn_stack" in node.meta)
|
||||
|
||||
@testing.expectedFailureRetraceability # T186979579
|
||||
@testing.expectedFailureRetraceability
|
||||
def test_dynamic_shapes_dataclass(self):
|
||||
torch.export.register_dataclass(
|
||||
Inp2,
|
||||
|
|
@ -4808,9 +4822,12 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
["torch.Size([s0, 2, 3])", "torch.Size([s0, 3, 4])"],
|
||||
)
|
||||
|
||||
@testing.expectedFailureCppSerDes
|
||||
def test_export_method(self):
|
||||
from torch._export.utils import sync_state, wrap_method
|
||||
|
||||
strict = True
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
@ -4835,6 +4852,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
wrap_method(em.foo),
|
||||
(ex,),
|
||||
dynamic_shapes={"x": (Dim.DYNAMIC,)},
|
||||
strict=strict,
|
||||
).module()
|
||||
|
||||
# ...bar
|
||||
|
|
@ -4842,6 +4860,7 @@ def forward(self, p_linear_weight, p_linear_bias, b_buffer, x):
|
|||
wrap_method(em.bar),
|
||||
(ex,),
|
||||
dynamic_shapes=((Dim.DYNAMIC,),),
|
||||
strict=strict,
|
||||
).module()
|
||||
|
||||
if is_serdes_test(self._testMethodName):
|
||||
|
|
@ -6764,7 +6783,7 @@ def forward(self, b_a_buffer, x):
|
|||
ep = export(m, ())
|
||||
self.assertEqual(ep.graph_signature.lifted_tensor_constants, ["x"])
|
||||
|
||||
@testing.expectedFailureRetraceability # T186979579
|
||||
@testing.expectedFailureRetraceability
|
||||
def test_preserve_shape_dynamism_for_unused_inputs(self):
|
||||
torch.export.register_dataclass(
|
||||
Inp3,
|
||||
|
|
@ -9977,6 +9996,7 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
|||
).run(ep.graph_module.code)
|
||||
|
||||
def test_replace_unbacked_with_very_large_upperbound(self):
|
||||
strict = True
|
||||
# beyond 2^53 where python floats lose precision
|
||||
VERY_LARGE_INT = 1000000007999999992
|
||||
|
||||
|
|
@ -9996,7 +10016,7 @@ def forward(self, p_bar_linear_weight, p_bar_linear_bias, x):
|
|||
"x": (Dim.AUTO, Dim.STATIC),
|
||||
"t": (Dim.STATIC,),
|
||||
}
|
||||
ep = export(Model(), inp, dynamic_shapes=spec)
|
||||
ep = export(Model(), inp, dynamic_shapes=spec, strict=strict)
|
||||
self.assertTrue(torch.allclose(Model()(*inp), ep.module()(*inp)))
|
||||
|
||||
def test_predispatch_cond(self):
|
||||
|
|
@ -12002,13 +12022,11 @@ def forward(self, x):
|
|||
with self.assertRaises(RuntimeError):
|
||||
ep.module()(torch.randn(4, 2))
|
||||
|
||||
@testing.expectedFailureNonStrict
|
||||
@testing.expectedFailureTrainingIRToRunDecompNonStrict # unbacked symint not tracked?
|
||||
@testing.expectedFailureSerDer # T195866111
|
||||
@testing.expectedFailureSerDerNonStrict
|
||||
@testing.expectedFailureRetraceabilityNonStrict
|
||||
@testing.expectedFailureLegacyExportNonStrict
|
||||
def test_hints_wrapper(self):
|
||||
strict = True
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
|
|
@ -12036,7 +12054,7 @@ def forward(self, x):
|
|||
x = torch.randn(2, 4)
|
||||
y = torch.ones(4)
|
||||
|
||||
ep_for_training = torch.export.export_for_training(M(), (x, y))
|
||||
ep_for_training = torch.export.export_for_training(M(), (x, y), strict=strict)
|
||||
self.assertExpectedInline(
|
||||
normalize_gm(
|
||||
ep_for_training.graph_module.print_readable(print_output=False)
|
||||
|
|
@ -12069,7 +12087,7 @@ class GraphModule(torch.nn.Module):
|
|||
""",
|
||||
)
|
||||
|
||||
ep = export(M(), (x, y)).run_decompositions({})
|
||||
ep = export(M(), (x, y), strict=strict).run_decompositions({})
|
||||
export_res = ep.module()(x, y)
|
||||
ref_res = M()(x, y)
|
||||
self.assertEqual(export_res, ref_res)
|
||||
|
|
@ -12600,30 +12618,6 @@ def forward(self, x):
|
|||
return (add, add_1)""",
|
||||
)
|
||||
|
||||
def test_logging_logger(self):
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
logger.log("start")
|
||||
x1 = x + x
|
||||
logger.debug(x1)
|
||||
x2 = x1 * x1
|
||||
logger.info(1, 2, 3)
|
||||
x3 = x2 + x2
|
||||
return (x1, x3)
|
||||
|
||||
gm = export(M(), (torch.randn(3, 3),)).graph_module
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, x):
|
||||
add = torch.ops.aten.add.Tensor(x, x); x = None
|
||||
mul = torch.ops.aten.mul.Tensor(add, add)
|
||||
add_1 = torch.ops.aten.add.Tensor(mul, mul); mul = None
|
||||
return (add, add_1)""",
|
||||
)
|
||||
|
||||
@unittest.skipIf(not TEST_TRANSFORMERS, "No transformers")
|
||||
def test_hf_logging_logger(self):
|
||||
import transformers
|
||||
|
|
@ -12666,6 +12660,31 @@ def forward(self, x):
|
|||
return (add,)""",
|
||||
)
|
||||
|
||||
def test_logging_logger(self):
|
||||
strict = True
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class M(torch.nn.Module):
|
||||
def forward(self, x):
|
||||
logger.log("start")
|
||||
x1 = x + x
|
||||
logger.debug(x1)
|
||||
x2 = x1 * x1
|
||||
logger.info(1, 2, 3)
|
||||
x3 = x2 + x2
|
||||
return (x1, x3)
|
||||
|
||||
gm = export(M(), (torch.randn(3, 3),), strict=strict).graph_module
|
||||
self.assertExpectedInline(
|
||||
gm.code.strip(),
|
||||
"""\
|
||||
def forward(self, x):
|
||||
add = torch.ops.aten.add.Tensor(x, x); x = None
|
||||
mul = torch.ops.aten.mul.Tensor(add, add)
|
||||
add_1 = torch.ops.aten.add.Tensor(mul, mul); mul = None
|
||||
return (add, add_1)""",
|
||||
)
|
||||
|
||||
def test_constant_fqn(self):
|
||||
class Nested(torch.nn.Module):
|
||||
def __init__(self) -> None:
|
||||
|
|
@ -12929,19 +12948,20 @@ class TestExportCustomClass(TorchTestCase):
|
|||
|
||||
x, y = torch.randn(3, 2), torch.randn(3, 2)
|
||||
mod = Mod()
|
||||
# TODO: strict mode doesn't work because dynamo add_mod is treated as a
|
||||
# user defined variable. We might need to add a CustomModule variable to support it.
|
||||
if self._testMethodName == "test_export_script_module":
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported, "UserDefined with non-function"
|
||||
):
|
||||
ep = export(mod, (x, y))
|
||||
else:
|
||||
if is_non_strict_test(self._testMethodName):
|
||||
ep = export(mod, (x, y))
|
||||
self.assertEqual(ep.module()(x, y), mod(x, y))
|
||||
FileCheck().check_count("torch.ops.aten.add.Tensor", 1, exactly=True).run(
|
||||
ep.graph_module.code
|
||||
)
|
||||
return
|
||||
|
||||
# TODO: strict mode doesn't work because dynamo add_mod is treated as a
|
||||
# user defined variable. We might need to add a CustomModule variable to support it.
|
||||
with self.assertRaisesRegex(
|
||||
torch._dynamo.exc.Unsupported, "UserDefined with non-function"
|
||||
):
|
||||
ep = export(mod, (x, y))
|
||||
|
||||
def test_preserve_non_cia_op(self):
|
||||
class M(torch.nn.Module):
|
||||
|
|
|
|||
|
|
@ -12,22 +12,22 @@ from torch.export import export
|
|||
test_classes = {}
|
||||
|
||||
|
||||
def mocked_non_strict_export(*args, **kwargs):
|
||||
# If user already specified strict, don't make it non-strict
|
||||
def mocked_strict_export(*args, **kwargs):
|
||||
# If user already specified strict, don't make it strict
|
||||
if "strict" in kwargs:
|
||||
return export(*args, **kwargs)
|
||||
return export(*args, **kwargs, strict=False)
|
||||
return export(*args, **kwargs, strict=True)
|
||||
|
||||
|
||||
def make_dynamic_cls(cls):
|
||||
cls_prefix = "NonStrictExport"
|
||||
cls_prefix = "StrictExport"
|
||||
|
||||
test_class = testing.make_test_cls_with_mocked_export(
|
||||
cls,
|
||||
cls_prefix,
|
||||
test_export.NON_STRICT_SUFFIX,
|
||||
mocked_non_strict_export,
|
||||
xfail_prop="_expected_failure_non_strict",
|
||||
test_export.STRICT_SUFFIX,
|
||||
mocked_strict_export,
|
||||
xfail_prop="_expected_failure_strict",
|
||||
)
|
||||
|
||||
test_classes[test_class.__name__] = test_class
|
||||
|
|
@ -13,28 +13,29 @@ test_classes = {}
|
|||
|
||||
|
||||
def mocked_retraceability_export_strict(*args, **kwargs):
|
||||
if "strict" in kwargs:
|
||||
ep = export(*args, **kwargs)
|
||||
else:
|
||||
ep = export(*args, **kwargs, strict=True)
|
||||
|
||||
if "dynamic_shapes" in kwargs:
|
||||
if isinstance(kwargs["dynamic_shapes"], dict):
|
||||
kwargs["dynamic_shapes"] = tuple(kwargs["dynamic_shapes"].values())
|
||||
|
||||
if "strict" in kwargs:
|
||||
ep = export(ep.module(), *(args[1:]), **kwargs)
|
||||
else:
|
||||
ep = export(ep.module(), *(args[1:]), **kwargs, strict=True)
|
||||
return ep
|
||||
|
||||
|
||||
def mocked_retraceability_export_non_strict(*args, **kwargs):
|
||||
if "strict" in kwargs:
|
||||
ep = export(*args, **kwargs)
|
||||
else:
|
||||
ep = export(*args, **kwargs, strict=False)
|
||||
if "dynamic_shapes" in kwargs:
|
||||
if isinstance(kwargs["dynamic_shapes"], dict):
|
||||
kwargs["dynamic_shapes"] = tuple(kwargs["dynamic_shapes"].values())
|
||||
|
||||
if "strict" in kwargs:
|
||||
ep = export(ep.module(), *(args[1:]), **kwargs)
|
||||
else:
|
||||
ep = export(ep.module(), *(args[1:]), **kwargs, strict=False)
|
||||
return ep
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -16,7 +16,11 @@ test_classes = {}
|
|||
|
||||
|
||||
def mocked_serder_export_strict(*args, **kwargs):
|
||||
if "strict" not in kwargs:
|
||||
ep = export(*args, **kwargs, strict=True)
|
||||
else:
|
||||
ep = export(*args, **kwargs)
|
||||
|
||||
buffer = io.BytesIO()
|
||||
save(ep, buffer)
|
||||
buffer.seek(0)
|
||||
|
|
@ -25,10 +29,7 @@ def mocked_serder_export_strict(*args, **kwargs):
|
|||
|
||||
|
||||
def mocked_serder_export_non_strict(*args, **kwargs):
|
||||
if "strict" in kwargs:
|
||||
ep = export(*args, **kwargs)
|
||||
else:
|
||||
ep = export(*args, **kwargs, strict=False)
|
||||
buffer = io.BytesIO()
|
||||
save(ep, buffer)
|
||||
buffer.seek(0)
|
||||
|
|
@ -41,7 +42,7 @@ def make_dynamic_cls(cls, strict):
|
|||
test_class = testing.make_test_cls_with_mocked_export(
|
||||
cls,
|
||||
"SerDesExport",
|
||||
test_export.SERDES_SUFFIX,
|
||||
test_export.SERDES_STRICT_SUFFIX,
|
||||
mocked_serder_export_strict,
|
||||
xfail_prop="_expected_failure_serdes",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -257,9 +257,9 @@ def expectedFailureTrainingIRToRunDecompNonStrict(fn):
|
|||
return fn
|
||||
|
||||
|
||||
# Controls tests generated in test/export/test_export_nonstrict.py
|
||||
def expectedFailureNonStrict(fn):
|
||||
fn._expected_failure_non_strict = True
|
||||
# Controls tests generated in test/export/test_export_strict.py
|
||||
def expectedFailureStrict(fn):
|
||||
fn._expected_failure_strict = True
|
||||
return fn
|
||||
|
||||
|
||||
|
|
@ -307,6 +307,11 @@ def expectedFailureCppRuntime(fn):
|
|||
return fn
|
||||
|
||||
|
||||
def expectedFailureCppRuntimeNonStrict(fn):
|
||||
fn._expected_failure_cpp_runtime_non_strict = True
|
||||
return fn
|
||||
|
||||
|
||||
# Controls tests generated in test/export/test_export_legacy.py
|
||||
def expectedFailureLegacyExportStrict(fn):
|
||||
fn._expected_failure_legacy_export = True
|
||||
|
|
|
|||
|
|
@ -258,7 +258,7 @@ def export(
|
|||
kwargs: Optional[dict[str, Any]] = None,
|
||||
*,
|
||||
dynamic_shapes: Optional[Union[dict[str, Any], tuple[Any], list[Any]]] = None,
|
||||
strict: bool = True,
|
||||
strict: bool = False,
|
||||
preserve_module_call_signature: tuple[str, ...] = (),
|
||||
) -> ExportedProgram:
|
||||
"""
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user