mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Added strict=True to zip in aot_autograd (#110668)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110668 Approved by: https://github.com/ezyang ghstack dependencies: #110501, #110504, #110591
This commit is contained in:
parent
d279979102
commit
6d23193aab
|
|
@ -39,6 +39,21 @@ from . import config
|
|||
from .partitioners import default_partition
|
||||
from torch._guards import TracingContext, DuplicateInputs, Source
|
||||
|
||||
original_zip = zip
|
||||
|
||||
def strict_zip(*iterables, strict=True, **kwargs):
|
||||
if not strict:
|
||||
return original_zip(*iterables, **kwargs)
|
||||
|
||||
shortest_length = min(len(it) for it in iterables)
|
||||
for iterable in iterables:
|
||||
if len(iterable) != shortest_length:
|
||||
raise ValueError("The iterables have different lengths and strict mode is enabled.")
|
||||
|
||||
return original_zip(*iterables, **kwargs)
|
||||
|
||||
zip = strict_zip
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
aot_joint_log = getArtifactLogger(__name__, "aot_joint_graph")
|
||||
aot_graphs_log = getArtifactLogger(__name__, "aot_graphs")
|
||||
|
|
@ -1862,7 +1877,7 @@ def merge_view_inputs(
|
|||
# For now, I'm banning a bunch of cases. We expect dynamo to properly detect these cases
|
||||
# and error out. We can fix them later.
|
||||
# These checks are transitive, so we don't need to check every pair.
|
||||
for idx1, idx2 in zip(aliased_input_indices, aliased_input_indices[1:]):
|
||||
for idx1, idx2 in zip(aliased_input_indices, aliased_input_indices[1:], strict=False):
|
||||
view1 = fwd_inputs[idx1]
|
||||
view2 = fwd_inputs[idx2]
|
||||
# The "inputs that are aliased but have different differentiable bases" case
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user