diff --git a/torch/_functorch/aot_autograd.py b/torch/_functorch/aot_autograd.py index 084352b141e..f2b4915dc10 100644 --- a/torch/_functorch/aot_autograd.py +++ b/torch/_functorch/aot_autograd.py @@ -39,6 +39,21 @@ from . import config from .partitioners import default_partition from torch._guards import TracingContext, DuplicateInputs, Source +original_zip = zip + +def strict_zip(*iterables, strict=True, **kwargs): + if not strict: + return original_zip(*iterables, **kwargs) + + shortest_length = min(len(it) for it in iterables) + for iterable in iterables: + if len(iterable) != shortest_length: + raise ValueError("The iterables have different lengths and strict mode is enabled.") + + return original_zip(*iterables, **kwargs) + +zip = strict_zip + log = logging.getLogger(__name__) aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph") aot_graphs_log = getArtifactLogger(__name__, "aot_graphs") @@ -1862,7 +1877,7 @@ def merge_view_inputs( # For now, I'm banning a bunch of cases. We expect dynamo to properly detect these cases # and error out. We can fix them later. # These checks are transitive, so we don't need to check every pair. - for idx1, idx2 in zip(aliased_input_indices, aliased_input_indices[1:]): + for idx1, idx2 in zip(aliased_input_indices, aliased_input_indices[1:], strict=False): view1 = fwd_inputs[idx1] view2 = fwd_inputs[idx2] # The "inputs that are aliased but have different differentiable bases" case