[2/N] Add strict parameter to Python zip calls (#166257)

This PR adds `strict=True/False` to zip calls in test utils. strict=True is passed when possible.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166257
Approved by: https://github.com/janeyx99
This commit is contained in:
Yuanyuan Chen 2025-10-30 08:10:10 +00:00 committed by PyTorch MergeBot
parent 2829d48bd1
commit 39e5cdddf7
53 changed files with 211 additions and 131 deletions

View File

@ -509,7 +509,9 @@ def merge_chunks(
values_to_cat = []
chunk_start_idx = 0
assert len(partial_values) == len(meta_chunks)
for partial_value, meta_chunk in zip(partial_values, meta_chunks):
for partial_value, meta_chunk in zip(
partial_values, meta_chunks, strict=True
):
chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim)
slice_indices = [slice(None, None, None)] * partial_value.ndim

View File

@ -738,7 +738,7 @@ class BlockMask:
(slice(i + n, i + n + 1) if -n <= i < 0 else slice(i, i + 1))
if isinstance(i, int)
else i
for i, n in zip(padded, sizes)
for i, n in zip(padded, sizes, strict=True)
)
new_kv_num_blocks = self.kv_num_blocks[index]
new_kv_indices = self.kv_indices[index]

View File

@ -3304,7 +3304,8 @@ def gaussian_nll_loss(
# or input.size = (4, 3, 32, 32), var.size = (4, 1, 32, 32)
elif (
input.ndim == var.ndim
and sum(y for x, y in zip(input.size(), var.size()) if x != y) == 1
and sum(y for x, y in zip(input.size(), var.size(), strict=True) if x != y)
== 1
): # Heteroscedastic case
pass

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
import itertools
from collections import namedtuple
from collections.abc import Sequence
@ -273,7 +274,7 @@ class AdaptiveLogSoftmaxWithLoss(Module):
out[:, : self.shortlist_size] = head_logprob[:, : self.shortlist_size]
for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])):
for i, (start_idx, stop_idx) in enumerate(itertools.pairwise(self.cutoffs)):
cluster_output = self.tail[i](input)
cluster_logprob = F.log_softmax(cluster_output, dim=1)
output_logprob = cluster_logprob + head_logprob[

View File

@ -150,7 +150,9 @@ class Sequential(Module):
delattr(self, key)
# To preserve numbering
str_indices = [str(i) for i in range(len(self._modules))]
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
self._modules = OrderedDict(
zip(str_indices, self._modules.values(), strict=True)
)
@_copy_to_script_wrapper
def __len__(self) -> int:
@ -395,7 +397,9 @@ class ModuleList(Module):
delattr(self, self._get_abs_string_index(idx))
# To preserve numbering, self._modules is being reconstructed with modules after deletion
str_indices = [str(i) for i in range(len(self._modules))]
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
self._modules = OrderedDict(
zip(str_indices, self._modules.values(), strict=True)
)
@_copy_to_script_wrapper
def __len__(self) -> int:
@ -432,7 +436,9 @@ class ModuleList(Module):
lines = []
main_str = self._get_name() + "("
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
for (start_id, end_id), b in zip(
start_end_indices, repeated_blocks, strict=True
):
local_repr = f"({start_id}): {b}" # default repr
if start_id != end_id:

View File

@ -142,7 +142,10 @@ class _ConvNd(Module):
self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size)
if padding == "same":
for d, k, i in zip(
dilation, kernel_size, range(len(kernel_size) - 1, -1, -1)
dilation,
kernel_size,
range(len(kernel_size) - 1, -1, -1),
strict=False,
):
total_padding = d * (k - 1)
left_pad = total_padding // 2

View File

@ -196,7 +196,7 @@ class RNNBase(Module):
param_names += ["weight_hr_l{}{}"]
param_names = [x.format(layer, suffix) for x in param_names]
for name, param in zip(param_names, layer_params):
for name, param in zip(param_names, layer_params, strict=True):
setattr(self, name, param)
self._flat_weights_names.extend(param_names)
self._all_weights.append(param_names)
@ -352,7 +352,9 @@ class RNNBase(Module):
# Returns True if the weight tensors have changed since the last forward pass.
# This is the case when used with torch.func.functional_call(), for example.
weights_changed = False
for ref, name in zip(self._flat_weight_refs, self._flat_weights_names):
for ref, name in zip(
self._flat_weight_refs, self._flat_weights_names, strict=True
):
weight = getattr(self, name) if hasattr(self, name) else None
if weight is not None and ref is not None and ref() is not weight:
weights_changed = True

View File

@ -41,7 +41,8 @@ def _list_with_default(out_size: list[int], defaults: list[int]) -> list[int]:
if len(defaults) <= len(out_size):
raise ValueError(f"Input dimension should be at least {len(out_size) + 1}")
return [
v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size) :])
v if v is not None else d
for v, d in zip(out_size, defaults[-len(out_size) :], strict=False)
]

View File

@ -141,18 +141,18 @@ def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760):
output = []
ref_order = []
# process sparse ones first since they may have different sizes on different gpus
for tensor_at_gpus in zip(*inputs):
for tensor_at_gpus in zip(*inputs, strict=True):
if all(t.is_sparse for t in tensor_at_gpus):
result = reduce_add(tensor_at_gpus, destination) # this will be sparse too
output.append(result)
ref_order.append(tensor_at_gpus[0])
else:
for coll, t in zip(dense_tensors, tensor_at_gpus):
for coll, t in zip(dense_tensors, tensor_at_gpus, strict=True):
coll.append(t.to_dense() if t.is_sparse else t)
ref_order.append(dense_tensors[0][-1])
itrs = [_take_tensors(tensors, buffer_size) for tensors in dense_tensors]
# now the dense ones, which have consistent sizes
for chunks in zip(*itrs):
for chunks in zip(*itrs, strict=True):
flat_tensors = [
_flatten_dense_tensors(chunk) for chunk in chunks
] # (num_gpus,)

View File

@ -115,7 +115,7 @@ def parallel_apply(
target=_worker, args=(i, module, input, kwargs, device, stream)
)
for i, (module, input, kwargs, device, stream) in enumerate(
zip(modules, inputs, kwargs_tup, devices, streams)
zip(modules, inputs, kwargs_tup, devices, streams, strict=True)
)
]

View File

@ -57,16 +57,20 @@ def scatter(inputs, target_gpus, dim=0):
return Scatter.apply(target_gpus, None, dim, obj)
if _is_namedtuple(obj):
# pyrefly: ignore [no-matching-overload]
return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
return [
type(obj)(*args) for args in zip(*map(scatter_map, obj), strict=False)
]
if isinstance(obj, tuple) and len(obj) > 0:
# pyrefly: ignore [no-matching-overload]
return list(zip(*map(scatter_map, obj)))
return list(zip(*map(scatter_map, obj), strict=False))
if isinstance(obj, list) and len(obj) > 0:
# pyrefly: ignore [no-matching-overload]
return [list(i) for i in zip(*map(scatter_map, obj))]
return [list(i) for i in zip(*map(scatter_map, obj), strict=False)]
if isinstance(obj, dict) and len(obj) > 0:
# pyrefly: ignore [no-matching-overload]
return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
return [
type(obj)(i) for i in zip(*map(scatter_map, obj.items()), strict=False)
]
return [obj for _ in target_gpus]
# After scatter_map is called, a scatter_map cell will exist. This cell
@ -131,9 +135,9 @@ def gather(outputs: Any, target_device: Union[int, torch.device], dim: int = 0)
return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
if _is_namedtuple(out):
# pyrefly: ignore [no-matching-overload]
return type(out)._make(map(gather_map, zip(*outputs)))
return type(out)._make(map(gather_map, zip(*outputs, strict=True)))
# pyrefly: ignore [no-matching-overload]
return type(out)(map(gather_map, zip(*outputs)))
return type(out)(map(gather_map, zip(*outputs, strict=True)))
# Recursive function calls like this create reference cycles.
# Setting the function to None clears the refcycle.

View File

@ -28,7 +28,7 @@ def conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs):
kwargs = expanded_args_and_kwargs[
len(expanded_args_and_kwargs) - len(kwarg_names) :
]
kwargs = dict(zip(kwarg_names, kwargs))
kwargs = dict(zip(kwarg_names, kwargs, strict=True))
return conv_normalizer(*args, **kwargs)

View File

@ -32,7 +32,7 @@ def standard_kwargs(kwarg_names, expanded_args):
expanded_args_without_kwargs = expanded_args[
: len(expanded_args) - len(kwarg_names)
]
expanded_kwargs = dict(zip(kwarg_names, kwarg_values))
expanded_kwargs = dict(zip(kwarg_names, kwarg_values, strict=True))
return expanded_args_without_kwargs, expanded_kwargs

View File

@ -250,7 +250,7 @@ class NamedMemberAccessor:
values = list(values)
assert len(names) == len(values), "names and values must have the same length"
for name, value in zip(names, values):
for name, value in zip(names, values, strict=True):
self.set_tensor(name, value)
def set_tensors_dict(self, named_tensors: dict[str, torch.Tensor]) -> None:
@ -298,7 +298,7 @@ class NamedMemberAccessor:
return [
self.swap_tensor(name, value, allow_missing=allow_missing)
for name, value in zip(names, values)
for name, value in zip(names, values, strict=True)
]
def swap_tensors_dict(

View File

@ -528,7 +528,7 @@ def unpad_sequence(
max_length = padded_sequences.shape[1]
idx = torch.arange(max_length, device=lengths.device)
for seq, length in zip(padded_sequences, lengths):
for seq, length in zip(padded_sequences, lengths, strict=True):
mask = idx < length
unpacked_seq = seq[mask]
unpadded_sequences.append(unpacked_seq)

View File

@ -532,7 +532,7 @@ def _multi_tensor_adafactor(
alphas = [
max(eps2, p.norm(2).item() / (p.numel() ** 0.5)) * r
for p, r in zip(device_params, rho_ts)
for p, r in zip(device_params, rho_ts, strict=True)
]
# Perform stepweight decay
@ -566,7 +566,9 @@ def _multi_tensor_adafactor(
var_estimates = [
row_var @ col_var
for row_var, col_var in zip(device_row_vars, device_col_vars)
for row_var, col_var in zip(
device_row_vars, device_col_vars, strict=True
)
]
row_var_means = [
row_var.mean(dim=-2, keepdim=True) for row_var in device_row_vars
@ -594,7 +596,7 @@ def _multi_tensor_adafactor(
alphas = [
-a / (max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * d)))
for a, update in zip(alphas, updates)
for a, update in zip(alphas, updates, strict=True)
]
torch._foreach_mul_(updates, alphas)
torch._foreach_add_(device_params, updates)

View File

@ -266,7 +266,7 @@ def _single_tensor_adadelta(
if not all(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
for p, step in zip(params, state_steps, strict=True)
):
raise AssertionError(
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
@ -276,7 +276,7 @@ def _single_tensor_adadelta(
lr = _to_scalar(lr)
for param, grad, square_avg, acc_delta, step in zip(
params, grads, square_avgs, acc_deltas, state_steps
params, grads, square_avgs, acc_deltas, state_steps, strict=True
):
step += 1
grad = grad if not maximize else -grad
@ -329,7 +329,7 @@ def _multi_tensor_adadelta(
if not all(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
for p, step in zip(params, state_steps, strict=True)
):
raise AssertionError(
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."

View File

@ -343,7 +343,9 @@ def _single_tensor_adagrad(
if not torch.jit.is_scripting():
lr = _to_scalar(lr)
for param, grad, state_sum, step_t in zip(params, grads, state_sums, state_steps):
for param, grad, state_sum, step_t in zip(
params, grads, state_sums, state_steps, strict=True
):
# update step
step_t += 1
step = _get_value(step_t)

View File

@ -608,7 +608,7 @@ def _multi_tensor_adam(
if not all(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
for p, step in zip(params, state_steps, strict=True)
):
raise AssertionError(
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."

View File

@ -334,7 +334,7 @@ def _multi_tensor_adamax(
if not all(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
for p, step in zip(params, state_steps, strict=True)
):
raise AssertionError(
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."

View File

@ -307,7 +307,7 @@ def _multi_tensor_asgd(
if not all(
p.device.type == mu.device.type == eta.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, mu, eta, step in zip(params, mus, etas, state_steps)
for p, mu, eta, step in zip(params, mus, etas, state_steps, strict=True)
):
raise AssertionError(
f"If capturable=True, params, mus, etas, and state_steps must be on "

View File

@ -314,7 +314,7 @@ class LBFGS(Optimizer):
return [p.clone(memory_format=torch.contiguous_format) for p in self._params]
def _set_param(self, params_data):
for p, pdata in zip(self._params, params_data):
for p, pdata in zip(self._params, params_data, strict=True):
p.copy_(pdata)
def _directional_evaluate(self, closure, x, t, d):

View File

@ -302,7 +302,7 @@ class LRScheduler:
else:
values = self.get_lr()
for param_group, lr in zip(self.optimizer.param_groups, values):
for param_group, lr in zip(self.optimizer.param_groups, values, strict=True):
_update_param_group_val(param_group, "lr", lr)
self._last_lr: list[float | Tensor] = _param_groups_val_list(
@ -472,7 +472,7 @@ class LambdaLR(LRScheduler):
return [
base_lr * lmbda(self.last_epoch)
for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)
for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs, strict=True)
]
@ -592,7 +592,9 @@ class MultiplicativeLR(LRScheduler):
if not self._is_initial:
return [
group["lr"] * lmbda(self.last_epoch)
for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)
for lmbda, group in zip(
self.lr_lambdas, self.optimizer.param_groups, strict=True
)
]
else:
return _param_groups_val_list(self.optimizer, "lr")
@ -1441,13 +1443,17 @@ class CosineAnnealingLR(LRScheduler):
+ (base_lr - self.eta_min)
* (1 + math.cos((self.last_epoch) * math.pi / self.T_max))
/ 2
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
for base_lr, group in zip(
self.base_lrs, self.optimizer.param_groups, strict=True
)
]
elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
return [
group["lr"]
+ (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
for base_lr, group in zip(
self.base_lrs, self.optimizer.param_groups, strict=True
)
]
return [
(1 + math.cos(math.pi * self.last_epoch / self.T_max))
@ -1906,7 +1912,7 @@ class CyclicLR(LRScheduler):
base_lrs = _format_param("base_lr", optimizer, base_lr)
if last_epoch == -1:
for lr, group in zip(base_lrs, optimizer.param_groups):
for lr, group in zip(base_lrs, optimizer.param_groups, strict=True):
_update_param_group_val(group, "lr", lr)
self.max_lrs = _format_param("max_lr", optimizer, max_lr)
@ -1949,7 +1955,10 @@ class CyclicLR(LRScheduler):
self.max_momentums = _format_param("max_momentum", optimizer, max_momentum)
if last_epoch == -1:
for m_momentum, b_momentum, group in zip(
self.max_momentums, self.base_momentums, optimizer.param_groups
self.max_momentums,
self.base_momentums,
optimizer.param_groups,
strict=True,
):
if self.use_beta1:
group["betas"] = (m_momentum, *group["betas"][1:])
@ -2033,7 +2042,7 @@ class CyclicLR(LRScheduler):
scale_factor = (x - 1) / (self.step_ratio - 1)
lrs = []
for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
for base_lr, max_lr in zip(self.base_lrs, self.max_lrs, strict=True):
base_height = (max_lr - base_lr) * scale_factor
if self.scale_mode == "cycle":
lr = base_lr + base_height * self.scale_fn(cycle)
@ -2044,7 +2053,7 @@ class CyclicLR(LRScheduler):
if self.cycle_momentum:
momentums = []
for base_momentum, max_momentum in zip(
self.base_momentums, self.max_momentums
self.base_momentums, self.max_momentums, strict=True
):
base_height = (max_momentum - base_momentum) * scale_factor
if self.scale_mode == "cycle":
@ -2054,7 +2063,9 @@ class CyclicLR(LRScheduler):
self.last_epoch
)
momentums.append(momentum)
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
for param_group, momentum in zip(
self.optimizer.param_groups, momentums, strict=True
):
if self.use_beta1:
param_group["betas"] = (momentum, *param_group["betas"][1:])
else:
@ -2260,7 +2271,9 @@ class CosineAnnealingWarmRestarts(LRScheduler):
self.last_epoch = math.floor(epoch)
with _enable_get_lr_call(self):
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
for param_group, lr in zip(
self.optimizer.param_groups, self.get_lr(), strict=True
):
_update_param_group_val(param_group, "lr", lr)
self._last_lr = _param_groups_val_list(self.optimizer, "lr")
@ -2500,7 +2513,7 @@ class OneCycleLR(LRScheduler):
base_momentums = _format_param("base_momentum", optimizer, base_momentum)
if last_epoch == -1:
for m_momentum, b_momentum, group in zip(
max_momentums, base_momentums, optimizer.param_groups
max_momentums, base_momentums, optimizer.param_groups, strict=True
):
if self.use_beta1:
group["betas"] = (m_momentum, *group["betas"][1:])

View File

@ -412,7 +412,7 @@ def _multi_tensor_nadam(
if not all(
p.device.type == mp.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, mp, step in zip(params, mu_products, state_steps)
for p, mp, step in zip(params, mu_products, state_steps, strict=True)
):
raise AssertionError(
"If capturable=True, "
@ -570,7 +570,7 @@ def _multi_tensor_nadam(
step_size_grads = _stack_if_compiling(
[
(_get_value(lr) * (1.0 - mu) / (1.0 - _get_value(mu_product))) * -1
for mu_product, mu in zip(grouped_mu_products, mus)
for mu_product, mu in zip(grouped_mu_products, mus, strict=True)
]
)
step_size_expavg = _stack_if_compiling(
@ -581,7 +581,9 @@ def _multi_tensor_nadam(
/ (1.0 - _get_value(mu_product) * mu_next)
)
* -1
for mu_product, mu_next in zip(grouped_mu_products, mu_nexts)
for mu_product, mu_next in zip(
grouped_mu_products, mu_nexts, strict=True
)
]
)

View File

@ -941,7 +941,9 @@ class Optimizer:
)
param_lens = (len(g["params"]) for g in groups)
saved_lens = (len(g["params"]) for g in saved_groups)
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
if any(
p_len != s_len for p_len, s_len in zip(param_lens, saved_lens, strict=True)
):
raise ValueError(
"loaded state dict contains a parameter group "
"that doesn't match the size of optimizer's group"
@ -952,6 +954,7 @@ class Optimizer:
zip(
chain.from_iterable(g["params"] for g in saved_groups),
chain.from_iterable(g["params"] for g in groups),
strict=True,
)
)
@ -1005,7 +1008,9 @@ class Optimizer:
new_group["param_names"] = group["param_names"]
return new_group
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
param_groups = [
update_group(g, ng) for g, ng in zip(groups, saved_groups, strict=True)
]
self.__setstate__({"state": state, "param_groups": param_groups})
for post_hook in self._optimizer_load_state_dict_post_hooks.values():

View File

@ -392,7 +392,7 @@ def _multi_tensor_radam(
if not all(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
for p, step in zip(params, state_steps, strict=True)
):
raise AssertionError(
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
@ -501,7 +501,8 @@ def _multi_tensor_radam(
# TODO(mlazos): we should try and get a foreach_where op https://github.com/pytorch/pytorch/issues/117884
rect = [
torch.where(rho_t > 5.0, n, 0.0) for n, rho_t in zip(num, rho_t_list)
torch.where(rho_t > 5.0, n, 0.0)
for n, rho_t in zip(num, rho_t_list, strict=True)
]
del num
del rho_t_list
@ -544,11 +545,14 @@ def _multi_tensor_radam(
1 - beta1 ** _get_value(step) for step in grouped_state_steps
]
unrect_step_size = [
(lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)
(lr * rect / bc) * -1
for rect, bc in zip(unrectified, bias_correction1, strict=True)
]
bias_correction2 = [
((1 - beta2 ** _get_value(step)) ** 0.5) * (lr * rect / bc) * -1
for step, rect, bc in zip(grouped_state_steps, rect, bias_correction1)
for step, rect, bc in zip(
grouped_state_steps, rect, bias_correction1, strict=True
)
]
buffer = torch._foreach_sqrt(grouped_exp_avg_sqs)

View File

@ -370,7 +370,7 @@ def _multi_tensor_rmsprop(
if not all(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
for p, step in zip(params, state_steps, strict=True)
):
raise AssertionError(
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."

View File

@ -319,7 +319,7 @@ def _multi_tensor_rprop(
if not all(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
for p, step in zip(params, state_steps, strict=True)
):
raise AssertionError(
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."

View File

@ -143,7 +143,9 @@ class SGD(Optimizer): # noqa: D101
if group["momentum"] != 0:
# update momentum_buffers in state
for p, momentum_buffer in zip(params, momentum_buffer_list):
for p, momentum_buffer in zip(
params, momentum_buffer_list, strict=True
):
state = self.state[p]
state["momentum_buffer"] = momentum_buffer

View File

@ -50,7 +50,7 @@ def get_ema_multi_avg_fn(decay=0.999):
):
torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay)
else:
for p_ema, p_model in zip(ema_param_list, current_param_list):
for p_ema, p_model in zip(ema_param_list, current_param_list, strict=True):
p_ema.copy_(p_ema * decay + p_model * (1 - decay))
return ema_update
@ -264,7 +264,7 @@ class AveragedModel(Module):
self_param_detached: list[Optional[Tensor]] = []
model_param_detached: list[Optional[Tensor]] = []
copy_param = bool(self.n_averaged == 0)
for p_averaged, p_model in zip(self_param, model_param):
for p_averaged, p_model in zip(self_param, model_param, strict=True):
p_model_ = p_model.detach().to(p_averaged.device)
self_param_detached.append(p_averaged.detach())
model_param_detached.append(p_model_)
@ -297,12 +297,14 @@ class AveragedModel(Module):
else:
avg_fn = get_swa_avg_fn()
n_averaged = self.n_averaged.to(device)
for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment]
for p_averaged, p_model in zip( # type: ignore[assignment]
self_params, model_params, strict=True
):
# pyrefly: ignore [missing-attribute]
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
else:
for p_averaged, p_model in zip( # type: ignore[assignment]
self_param_detached, model_param_detached
self_param_detached, model_param_detached, strict=True
):
# pyrefly: ignore [missing-attribute]
n_averaged = self.n_averaged.to(p_averaged.device)
@ -315,7 +317,9 @@ class AveragedModel(Module):
if not self.use_buffers:
# If not apply running averages to the buffers,
# keep the buffers in sync with the source model.
for b_swa, b_model in zip(self.module.buffers(), model.buffers()):
for b_swa, b_model in zip(
self.module.buffers(), model.buffers(), strict=True
):
b_swa.detach().copy_(b_model.detach().to(b_swa.device))
self.n_averaged += 1
@ -432,7 +436,7 @@ class SWALR(LRScheduler):
last_epoch=-1,
): # noqa: D107
swa_lrs = _format_param("swa_lr", optimizer, swa_lr)
for swa_lr, group in zip(swa_lrs, optimizer.param_groups):
for swa_lr, group in zip(swa_lrs, optimizer.param_groups, strict=True):
group["swa_lr"] = swa_lr
if anneal_strategy not in ["cos", "linear"]:
raise ValueError(
@ -509,7 +513,7 @@ class SWALR(LRScheduler):
alpha = self.anneal_func(t)
return [
group["swa_lr"] * alpha + lr * (1 - alpha)
for group, lr in zip(self.optimizer.param_groups, prev_lrs)
for group, lr in zip(self.optimizer.param_groups, prev_lrs, strict=True)
]
def _set_anneal_func(self, anneal_strategy: Literal["cos", "linear"]):

View File

@ -254,7 +254,9 @@ class SchemaMatcher:
def matches(schema) -> bool:
return len(schema.arguments) == len(signature) and all(
cls._types_match(observed, schema_arg.type)
for observed, schema_arg in zip(signature, schema.arguments)
for observed, schema_arg in zip(
signature, schema.arguments, strict=True
)
)
return tuple(s for s in cls.lookup_schemas(t.name) or () if matches(s))
@ -377,7 +379,9 @@ class SizeMap:
key = TensorKey.from_tensor(t)
if key is not None and t is not None and t.layout == torch.strided:
# Scalars are represented as zero dim Tensors
n = max(i[0] * i[1] for i in zip(t.sizes or [1], t.strides or [1]))
n = max(
i[0] * i[1] for i in zip(t.sizes or [1], t.strides or [1], strict=True)
)
num_bytes = n * _element_size(t.dtype)
assert num_bytes >= 0, f"{num_bytes}"
@ -430,7 +434,7 @@ class DataFlowNode:
mutable_by_key: dict[Optional[TensorKey], set[Optional[bool]]] = {}
for op in (i.typed[1] for i in subtree if i.typed[0] == _EventType.TorchOp):
for op_input, mutable in zip(
op.inputs, SchemaMatcher.inputs_are_mutable(op)
op.inputs, SchemaMatcher.inputs_are_mutable(op), strict=True
):
# Tensor
if isinstance(op_input, _TensorMetadata):

View File

@ -276,7 +276,7 @@ class ForLoopIndexingPattern(Pattern):
def same_ops(list1, list2) -> bool:
if len(list1) != len(list2):
return False
for op1, op2 in zip(list1, list2):
for op1, op2 in zip(list1, list2, strict=True):
if op1.name != op2.name:
return False
return True

View File

@ -336,7 +336,7 @@ class BasicEvaluation:
event_list = [
event
for _, event in sorted(
zip(heuristic_score_list, event_list),
zip(heuristic_score_list, event_list, strict=True),
key=operator.itemgetter(0),
reverse=True,
)

View File

@ -121,7 +121,7 @@ def slicer(dim, slice_range, *tensors):
def multidim_slicer(dims, slices, *tensors):
for t in tensors:
s = [slice(None)] * t.dim()
for d, d_slice in zip(dims, slices):
for d, d_slice in zip(dims, slices, strict=False):
if d is not None:
s[d] = d_slice
yield t[tuple(s)]
@ -140,7 +140,7 @@ def grid_partitioner(full_grid, grid_blocks, tensor_dims_map):
import itertools
def generate_grid_points():
for fg, mg in zip(full_grid, grid_blocks):
for fg, mg in zip(full_grid, grid_blocks, strict=False):
yield range(0, fg, mg)
def generate_sliced_tensors(slices):
@ -149,9 +149,10 @@ def grid_partitioner(full_grid, grid_blocks, tensor_dims_map):
for grid_point in itertools.product(*generate_grid_points()):
grid = [
min(fg - gp, mg) for fg, gp, mg in zip(full_grid, grid_point, grid_blocks)
min(fg - gp, mg)
for fg, gp, mg in zip(full_grid, grid_point, grid_blocks, strict=False)
]
slices = [slice(gp, gp + g) for gp, g in zip(grid_point, grid)]
slices = [slice(gp, gp + g) for gp, g in zip(grid_point, grid, strict=False)]
# grid_points are iterated in a "contiguous" order, i.e.
# left dimensions traversed slower than right dimensions.
# This order is reversed for CUDA grids.
@ -173,7 +174,8 @@ def launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks=None):
return max(1, min(g, mg))
grid_blocks = tuple(
valid_grid_dim(g, mg) for g, mg in zip(grid_blocks, cuda_max_grid)
valid_grid_dim(g, mg)
for g, mg in zip(grid_blocks, cuda_max_grid, strict=False)
) # type: ignore[assignment]
for grid, *sliced_tensors in grid_partitioner(

View File

@ -155,7 +155,11 @@ def get_meta(op, key, device_name=None, version=(0, torch.float16, 0.5), exact=F
matching_data = {}
if "*" in key:
for op_key in op_data:
if [None for k1, k2 in zip(op_key, key) if k2 != "*" and k1 != k2]:
if [
None
for k1, k2 in zip(op_key, key, strict=True)
if k2 != "*" and k1 != k2
]:
continue
matching_data[op_key] = op_data[op_key]
else:
@ -173,10 +177,14 @@ def get_meta(op, key, device_name=None, version=(0, torch.float16, 0.5), exact=F
"num_stages",
"num_warps",
)
meta = dict(zip(names, values))
meta = dict(zip(names, values, strict=True))
elif op in {"bsr_dense_addmm", "_int_bsr_dense_addmm"}:
meta = dict(
zip(("GROUP_SIZE_ROW", "SPLIT_N", "num_stages", "num_warps"), values)
zip(
("GROUP_SIZE_ROW", "SPLIT_N", "num_stages", "num_warps"),
values,
strict=True,
)
)
else:
raise NotImplementedError(f"names for {op=}")
@ -289,7 +297,7 @@ def minimize(
return tuple(parameters[k] for k in sorted(parameters))
def from_key(key, parameters):
return dict(zip(sorted(parameters), key))
return dict(zip(sorted(parameters), key, strict=True))
if all_values is None:
all_values = {}
@ -347,7 +355,7 @@ def minimize(
for i, (_, d_tuple) in enumerate(all_directions):
pbar.update(1)
next_parameters = parameters.copy()
for name, direction in zip(names, d_tuple):
for name, direction in zip(names, d_tuple, strict=True):
value = next_parameters[name]
if direction == 0:
continue

View File

@ -364,11 +364,15 @@ def register_dataclass(
def _unflatten_fn(values: Iterable[Any], context: Context) -> Any:
flat_names, none_names = context
return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
return cls(
**dict(zip(flat_names, values, strict=True)), **dict.fromkeys(none_names)
)
def _flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]:
flattened, (flat_names, _none_names) = _flatten_fn(obj) # type: ignore[misc]
return [(GetAttrKey(k), v) for k, v in zip(flat_names, flattened)], flat_names
return [
(GetAttrKey(k), v) for k, v in zip(flat_names, flattened, strict=True)
], flat_names
_private_register_pytree_node(
cls,
@ -788,11 +792,11 @@ def _dict_flatten_with_keys(
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _dict_flatten(d)
# pyrefly: ignore [bad-return]
return [(MappingKey(k), v) for k, v in zip(context, values)], context
return [(MappingKey(k), v) for k, v in zip(context, values, strict=True)], context
def _dict_unflatten(values: Iterable[T], context: Context) -> dict[Any, T]:
return dict(zip(context, values))
return dict(zip(context, values, strict=True))
def _namedtuple_flatten(d: NamedTuple) -> tuple[list[Any], Context]:
@ -805,7 +809,10 @@ def _namedtuple_flatten_with_keys(
values, context = _namedtuple_flatten(d)
# pyrefly: ignore [bad-return]
return (
[(GetAttrKey(field), v) for field, v in zip(context._fields, values)],
[
(GetAttrKey(field), v)
for field, v in zip(context._fields, values, strict=True)
],
context,
)
@ -854,14 +861,14 @@ def _ordereddict_flatten_with_keys(
) -> tuple[list[tuple[KeyEntry, T]], Context]:
values, context = _ordereddict_flatten(d)
# pyrefly: ignore [bad-return]
return [(MappingKey(k), v) for k, v in zip(context, values)], context
return [(MappingKey(k), v) for k, v in zip(context, values, strict=True)], context
def _ordereddict_unflatten(
values: Iterable[T],
context: Context,
) -> OrderedDict[Any, T]:
return OrderedDict((key, value) for key, value in zip(context, values))
return OrderedDict((key, value) for key, value in zip(context, values, strict=True))
_odict_flatten = _ordereddict_flatten
@ -879,7 +886,9 @@ def _defaultdict_flatten_with_keys(
values, context = _defaultdict_flatten(d)
_, dict_context = context
# pyrefly: ignore [bad-return]
return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context
return [
(MappingKey(k), v) for k, v in zip(dict_context, values, strict=True)
], context
def _defaultdict_unflatten(
@ -1197,7 +1206,7 @@ class TreeSpec:
f"expected {treespec.context!r}, but got {context!r}.", # namedtuple type mismatch
)
for subtree, subspec in zip(children, treespec.children_specs):
for subtree, subspec in zip(children, treespec.children_specs, strict=True):
helper(subspec, subtree, subtrees)
subtrees: list[PyTree] = []
@ -1761,7 +1770,7 @@ def _broadcast_to_and_flatten(
# Recursively flatten the children
result: list[Any] = []
for child, child_spec in zip(child_pytrees, treespec.children_specs):
for child, child_spec in zip(child_pytrees, treespec.children_specs, strict=True):
flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf)
if flat is not None:
result += flat
@ -2063,9 +2072,9 @@ def tree_map_with_path(
``xs`` is the tuple of values at corresponding nodes in ``rests``.
"""
keypath_leaves, treespec = tree_flatten_with_path(tree, is_leaf)
keypath_leaves = list(zip(*keypath_leaves))
keypath_leaves = list(zip(*keypath_leaves, strict=True))
all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves, strict=True))
def keystr(kp: KeyPath) -> str:

View File

@ -1223,7 +1223,8 @@ class IsNonOverlappingAndDenseIndicator(sympy.Function):
# When all strides are integral, we can sort, and the size for the
# largest stride doesn't matter and can be arbitrarily symbolic
s_sizes, s_strides = zip(
*sorted(zip(sizes, strides), key=operator.itemgetter(1))
*sorted(zip(sizes, strides, strict=True), key=operator.itemgetter(1)),
strict=True,
)
# Put something arbitrary in the max size spot, it'll be ignored
if all(isinstance(a, sympy.Integer) for a in s_sizes[:-1]):

View File

@ -32,7 +32,7 @@ def run(n, stmt, fuzzer_cls):
float_iter = fuzzer_cls(seed=0, dtype=torch.float32).take(n)
int_iter = fuzzer_cls(seed=0, dtype=torch.int32).take(n)
raw_results = []
for i, (float_values, int_values) in enumerate(zip(float_iter, int_iter)):
for i, (float_values, int_values) in enumerate(zip(float_iter, int_iter, strict=True)):
float_tensors, float_tensor_params, float_params = float_values
int_tensors, int_tensor_params, int_params = int_values
@ -89,7 +89,7 @@ def run(n, stmt, fuzzer_cls):
for t_float, t_int, rel_diff, descriptions in results:
time_str = [f"{rel_diff * 100:>4.1f}% {'int' if t_int < t_float else 'float':<20}"]
time_str.extend(["".ljust(len(time_str[0])) for _ in descriptions[:-1]])
for t_str, (name, shape, order, steps) in zip(time_str, descriptions):
for t_str, (name, shape, order, steps) in zip(time_str, descriptions, strict=True):
name = f"{name}:".ljust(name_len + 1)
shape = shape.ljust(shape_len + 10)
order = order.ljust(order_len)

View File

@ -29,7 +29,7 @@ def run(n, stmt, fuzzer_cls):
float_iter = fuzzer_cls(seed=0, dtype=torch.float32).take(n)
double_iter = fuzzer_cls(seed=0, dtype=torch.float64).take(n)
raw_results = []
for i, (float_values, int_values) in enumerate(zip(float_iter, double_iter)):
for i, (float_values, int_values) in enumerate(zip(float_iter, double_iter, strict=True)):
float_tensors, float_tensor_params, float_params = float_values
int_tensors, int_tensor_params, int_params = int_values
@ -84,7 +84,7 @@ def run(n, stmt, fuzzer_cls):
for t_float, t_int, rel_diff, descriptions in results:
time_str = [f"{rel_diff * 100:>4.1f}% {'int' if t_int < t_float else 'float':<20}"]
time_str.extend(["".ljust(len(time_str[0])) for _ in descriptions[:-1]])
for t_str, (name, shape, sparse_dim, is_coalesced) in zip(time_str, descriptions):
for t_str, (name, shape, sparse_dim, is_coalesced) in zip(time_str, descriptions, strict=True):
name = f"{name}:".ljust(name_len + 1)
shape = shape.ljust(shape_len + 10)
sparse_dim = sparse_dim.ljust(sparse_dim_len)

View File

@ -51,7 +51,7 @@ class _Column:
unit_digits = max(d for d in leading_digits if d is not None)
decimal_digits = min(
max(m.significant_figures - digits, 0)
for digits, m in zip(leading_digits, self._flat_results)
for digits, m in zip(leading_digits, self._flat_results, strict=True)
if (m is not None) and (digits is not None)
) if self._trim_significant_figures else 1
length = unit_digits + decimal_digits + (1 if decimal_digits else 0)
@ -99,7 +99,7 @@ class _Row:
env = f"({concrete_results[0].env})" if self._render_env else ""
env = env.ljust(self._env_str_len + 4)
output = [" " + env + concrete_results[0].as_row_name]
for m, col in zip(self._results, self._columns or ()):
for m, col in zip(self._results, self._columns or (), strict=False):
if m is None:
output.append(col.num_to_str(None, 1, None))
else:
@ -141,7 +141,7 @@ class _Row:
]
row_contents = [column_strings[0].ljust(col_widths[0])]
for col_str, width, result, best_value in zip(column_strings[1:], col_widths[1:], self._results, best_values):
for col_str, width, result, best_value in zip(column_strings[1:], col_widths[1:], self._results, best_values, strict=False):
col_str = col_str.center(width)
if self._colorize != Colorize.NONE and result is not None and best_value is not None:
col_str = self.color_segment(col_str, result.median, best_value)
@ -206,7 +206,7 @@ class Table:
prior_env = ""
row_group = -1
rows_by_group: list[list[list[Optional[common.Measurement]]]] = []
for (num_threads, env, _), row in zip(self.row_keys, ordered_results):
for (num_threads, env, _), row in zip(self.row_keys, ordered_results, strict=True):
thread_transition = (num_threads != prior_num_threads)
if thread_transition:
prior_num_threads = num_threads
@ -250,10 +250,10 @@ class Table:
for sr in string_rows:
sr.extend(["" for _ in range(num_cols - len(sr))])
col_widths = [max(len(j) for j in i) for i in zip(*string_rows)]
finalized_columns = [" | ".join(i.center(w) for i, w in zip(string_rows[0], col_widths))]
col_widths = [max(len(j) for j in i) for i in zip(*string_rows, strict=True)]
finalized_columns = [" | ".join(i.center(w) for i, w in zip(string_rows[0], col_widths, strict=True))]
overall_width = len(finalized_columns[0])
for string_row, row in zip(string_rows[1:], self.rows):
for string_row, row in zip(string_rows[1:], self.rows, strict=True):
finalized_columns.extend(row.row_separator(overall_width))
finalized_columns.append(" | ".join(row.finalize_column_strings(string_row, col_widths)))

View File

@ -295,7 +295,7 @@ class FuzzedTensor:
raw_tensor = raw_tensor.permute(tuple(order)).contiguous()
raw_tensor = raw_tensor.permute(tuple(np.argsort(order)))
slices = [slice(0, size * step, step) for size, step in zip(size, steps)]
slices = [slice(0, size * step, step) for size, step in zip(size, steps, strict=True)]
tensor = raw_tensor[tuple(slices)]
properties = {
@ -326,7 +326,7 @@ class FuzzedTensor:
size = resolve(self._size, dim)
steps = resolve(self._steps or (), dim)
allocation_size = tuple(size_i * step_i for size_i, step_i in zip(size, steps))
allocation_size = tuple(size_i * step_i for size_i, step_i in zip(size, steps, strict=True))
return size, steps, allocation_size
def satisfies_constraints(self, params):

View File

@ -196,7 +196,7 @@ def set_device_states(devices, states, *, device_type=None) -> None:
if device_type == "meta":
return
device_module = _get_device_module(device_type)
for device, state in zip(devices, states):
for device, state in zip(devices, states, strict=False):
with device_module.device(device):
device_module.set_rng_state(state)
@ -794,7 +794,7 @@ class _NoopSaveInputs(torch.autograd.Function):
# Only tensors can be saved with ctx.save_for_backward, everything else
# is captured by get_args, which is saved directly on ctx
tensor_indices, tensors = zip(
*[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)]
*[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)], strict=False
)
idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)}
# args but with tensors replaced with None as placeholders
@ -1020,7 +1020,7 @@ def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[Checkpoint
def get_str_tb(label, capture_logs):
out = ""
total_len = len(capture_logs.logs)
for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs)):
for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs, strict=False)):
out += f"{log} ({i + 1} of {total_len} in {label})\n\n"
found_torch_dispatch = False
for line in tb:

View File

@ -2947,7 +2947,7 @@ e.
# Emit one build rule per source to enable incremental build.
build = []
for source_file, object_file in zip(sources, objects):
for source_file, object_file in zip(sources, objects, strict=True):
is_cuda_source = _is_cuda_file(source_file) and with_cuda
is_sycl_source = _is_sycl_file(source_file) and with_sycl
if is_cuda_source:

View File

@ -197,7 +197,7 @@ def collate(
return elem_type(
*(
collate(samples, collate_fn_map=collate_fn_map)
for samples in zip(*batch)
for samples in zip(*batch, strict=False)
)
)
elif isinstance(elem, collections.abc.Sequence):
@ -207,7 +207,9 @@ def collate(
# pyrefly: ignore [not-iterable]
if not all(len(elem) == elem_size for elem in it):
raise RuntimeError("each element in list of batch should be of equal size")
transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
transposed = list(
zip(*batch, strict=False)
) # It may be accessed twice, so we use a list.
if isinstance(elem, tuple):
return [

View File

@ -184,7 +184,7 @@ def _issubtype_with_constraints(variant, constraints, recursive=True):
and len(v_args) == len(c_args)
and all(
issubtype(v_arg, c_arg)
for v_arg, c_arg in zip(v_args, c_args)
for v_arg, c_arg in zip(v_args, c_args, strict=True)
)
):
return True
@ -207,7 +207,7 @@ def issubinstance(data, data_type):
return True
if len(dt_args) != len(data):
return False
return all(issubinstance(d, t) for d, t in zip(data, dt_args))
return all(issubinstance(d, t) for d, t in zip(data, dt_args, strict=True))
elif isinstance(data, (list, set)):
if dt_args is None or len(dt_args) == 0:
return True

View File

@ -101,7 +101,7 @@ class FilterDataFramesPipe(DFIterDataPipe):
filter_res.append(self.filter_fn(df.iloc[i]))
buffer = []
for df, res in zip(all_buffer, filter_res):
for df, res in zip(all_buffer, filter_res, strict=True):
if res:
buffer.append(df)
if len(buffer) == size:

View File

@ -705,7 +705,7 @@ class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]):
def __iter__(self) -> Iterator[tuple[_T_co]]:
iterators = [iter(datapipe) for datapipe in self.datapipes]
yield from zip(*iterators)
yield from zip(*iterators, strict=False)
def __len__(self) -> int:
if all(isinstance(dp, Sized) for dp in self.datapipes):

View File

@ -267,10 +267,10 @@ class StackDataset(Dataset[_T_stack]):
"Nested dataset's output size mismatch."
f" Expected {len(indices)}, got {len(items)}"
)
for data, d_sample in zip(items, dict_batch):
for data, d_sample in zip(items, dict_batch, strict=True):
d_sample[k] = data
else:
for idx, d_sample in zip(indices, dict_batch):
for idx, d_sample in zip(indices, dict_batch, strict=True):
d_sample[k] = dataset[idx]
return dict_batch
@ -284,10 +284,10 @@ class StackDataset(Dataset[_T_stack]):
"Nested dataset's output size mismatch."
f" Expected {len(indices)}, got {len(items)}"
)
for data, t_sample in zip(items, list_batch):
for data, t_sample in zip(items, list_batch, strict=True):
t_sample.append(data)
else:
for idx, t_sample in zip(indices, list_batch):
for idx, t_sample in zip(indices, list_batch, strict=True):
t_sample.append(dataset[idx])
tuple_batch: list[_T_tuple] = [tuple(sample) for sample in list_batch]
return tuple_batch
@ -477,5 +477,5 @@ def random_split(
lengths = cast(Sequence[int], lengths)
return [
Subset(dataset, indices[offset - length : offset])
for offset, length in zip(itertools.accumulate(lengths), lengths)
for offset, length in zip(itertools.accumulate(lengths), lengths, strict=True)
]

View File

@ -335,7 +335,7 @@ class BatchSampler(Sampler[list[int]]):
if self.drop_last:
# Create multiple references to the same iterator
args = [sampler_iter] * self.batch_size
for batch_droplast in zip(*args):
for batch_droplast in zip(*args, strict=False):
yield [*batch_droplast]
else:
batch = [*itertools.islice(sampler_iter, self.batch_size)]

View File

@ -341,7 +341,7 @@ def _unpack_flash_attention_nested_shapes(
raise AssertionError("sdpa_flop_count: cum_seq_q and cum_seq_k must have the same shape")
seq_q_lengths = _offsets_to_lengths(cum_seq_q, max_q)
seq_k_lengths = _offsets_to_lengths(cum_seq_k, max_k)
for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths):
for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths, strict=True):
new_query_shape = (1, h_q, seq_q_len, d_q)
new_key_shape = (1, h_k, seq_k_len, d_k)
new_value_shape = (1, h_v, seq_k_len, d_v)
@ -396,7 +396,7 @@ def _unpack_efficient_attention_nested_shapes(
"cu_seqlens_q and cu_seqlens_k must have the same shape")
seqlens_q = _offsets_to_lengths(cu_seqlens_q, max_seqlen_q)
seqlens_k = _offsets_to_lengths(cu_seqlens_k, max_seqlen_k)
for len_q, len_k in zip(seqlens_q, seqlens_k):
for len_q, len_k in zip(seqlens_q, seqlens_k, strict=True):
new_query_shape = (1, h_q, len_q, d_q)
new_key_shape = (1, h_k, len_k, d_k)
new_value_shape = (1, h_v, len_k, d_v)

View File

@ -114,7 +114,7 @@ class BackwardHook:
def _pack_with_none(self, indices, values, size):
res = [None] * size
for idx, val in zip(indices, values):
for idx, val in zip(indices, values, strict=True):
res[idx] = val
return tuple(res)
@ -180,7 +180,7 @@ class BackwardHook:
fn(grad_fns[0])
arg_list = list(args)
for idx, val in zip(tensors_idx, new_tensors):
for idx, val in zip(tensors_idx, new_tensors, strict=True):
arg_list[idx] = val
if type(args) is tuple:

View File

@ -167,7 +167,7 @@ class GraphPy:
def populate_namespace_from_OP_to_IO(self):
for node in self.nodes_op:
for node_output, outputSize in zip(node.outputs, node.outputstensor_size):
for node_output, outputSize in zip(node.outputs, node.outputstensor_size, strict=True):
self.scope_name_appeared.append(node.scopeName)
self.nodes_io[node_output] = NodeBase(
node_output,

View File

@ -212,7 +212,7 @@ def object_annotation(obj):
"""
def format_sequence(obj):
body = ','.join(repr(x) if isinstance(x, BASE_TYPES) else type(x).__name__ for i, x in zip(range(8), obj))
body = ','.join(repr(x) if isinstance(x, BASE_TYPES) else type(x).__name__ for x in obj[:8])
if len(obj) > 8:
body = f'{body}, ...{len(obj) - 8}'
return body