Added strict=True to zip in aot_autograd (#110668)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/110668
Approved by: https://github.com/ezyang
ghstack dependencies: #110501, #110504, #110591
This commit is contained in:
chilli 2023-10-05 17:28:44 -07:00 committed by PyTorch MergeBot
parent d279979102
commit 6d23193aab

View File

@ -39,6 +39,21 @@ from . import config
from .partitioners import default_partition
from torch._guards import TracingContext, DuplicateInputs, Source
original_zip = zip
def strict_zip(*iterables, strict=True, **kwargs):
if not strict:
return original_zip(*iterables, **kwargs)
shortest_length = min(len(it) for it in iterables)
for iterable in iterables:
if len(iterable) != shortest_length:
raise ValueError("The iterables have different lengths and strict mode is enabled.")
return original_zip(*iterables, **kwargs)
zip = strict_zip
log = logging.getLogger(__name__)
aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph")
aot_graphs_log = getArtifactLogger(__name__, "aot_graphs")
@ -1862,7 +1877,7 @@ def merge_view_inputs(
# For now, I'm banning a bunch of cases. We expect dynamo to properly detect these cases
# and error out. We can fix them later.
# These checks are transitive, so we don't need to check every pair.
for idx1, idx2 in zip(aliased_input_indices, aliased_input_indices[1:]):
for idx1, idx2 in zip(aliased_input_indices, aliased_input_indices[1:], strict=False):
view1 = fwd_inputs[idx1]
view2 = fwd_inputs[idx2]
# The "inputs that are aliased but have different differentiable bases" case