mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[2/N] Add strict parameter to Python zip calls (#166257)
This PR adds `strict=True/False` to zip calls in test utils. strict=True is passed when possible. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166257 Approved by: https://github.com/janeyx99
This commit is contained in:
parent
2829d48bd1
commit
39e5cdddf7
|
|
@ -509,7 +509,9 @@ def merge_chunks(
|
|||
values_to_cat = []
|
||||
chunk_start_idx = 0
|
||||
assert len(partial_values) == len(meta_chunks)
|
||||
for partial_value, meta_chunk in zip(partial_values, meta_chunks):
|
||||
for partial_value, meta_chunk in zip(
|
||||
partial_values, meta_chunks, strict=True
|
||||
):
|
||||
chunk_end_idx = chunk_start_idx + meta_chunk.size(arg.split_dim)
|
||||
|
||||
slice_indices = [slice(None, None, None)] * partial_value.ndim
|
||||
|
|
|
|||
|
|
@ -738,7 +738,7 @@ class BlockMask:
|
|||
(slice(i + n, i + n + 1) if -n <= i < 0 else slice(i, i + 1))
|
||||
if isinstance(i, int)
|
||||
else i
|
||||
for i, n in zip(padded, sizes)
|
||||
for i, n in zip(padded, sizes, strict=True)
|
||||
)
|
||||
new_kv_num_blocks = self.kv_num_blocks[index]
|
||||
new_kv_indices = self.kv_indices[index]
|
||||
|
|
|
|||
|
|
@ -3304,7 +3304,8 @@ def gaussian_nll_loss(
|
|||
# or input.size = (4, 3, 32, 32), var.size = (4, 1, 32, 32)
|
||||
elif (
|
||||
input.ndim == var.ndim
|
||||
and sum(y for x, y in zip(input.size(), var.size()) if x != y) == 1
|
||||
and sum(y for x, y in zip(input.size(), var.size(), strict=True) if x != y)
|
||||
== 1
|
||||
): # Heteroscedastic case
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
|
||||
import itertools
|
||||
from collections import namedtuple
|
||||
from collections.abc import Sequence
|
||||
|
||||
|
|
@ -273,7 +274,7 @@ class AdaptiveLogSoftmaxWithLoss(Module):
|
|||
|
||||
out[:, : self.shortlist_size] = head_logprob[:, : self.shortlist_size]
|
||||
|
||||
for i, (start_idx, stop_idx) in enumerate(zip(self.cutoffs, self.cutoffs[1:])):
|
||||
for i, (start_idx, stop_idx) in enumerate(itertools.pairwise(self.cutoffs)):
|
||||
cluster_output = self.tail[i](input)
|
||||
cluster_logprob = F.log_softmax(cluster_output, dim=1)
|
||||
output_logprob = cluster_logprob + head_logprob[
|
||||
|
|
|
|||
|
|
@ -150,7 +150,9 @@ class Sequential(Module):
|
|||
delattr(self, key)
|
||||
# To preserve numbering
|
||||
str_indices = [str(i) for i in range(len(self._modules))]
|
||||
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
|
||||
self._modules = OrderedDict(
|
||||
zip(str_indices, self._modules.values(), strict=True)
|
||||
)
|
||||
|
||||
@_copy_to_script_wrapper
|
||||
def __len__(self) -> int:
|
||||
|
|
@ -395,7 +397,9 @@ class ModuleList(Module):
|
|||
delattr(self, self._get_abs_string_index(idx))
|
||||
# To preserve numbering, self._modules is being reconstructed with modules after deletion
|
||||
str_indices = [str(i) for i in range(len(self._modules))]
|
||||
self._modules = OrderedDict(list(zip(str_indices, self._modules.values())))
|
||||
self._modules = OrderedDict(
|
||||
zip(str_indices, self._modules.values(), strict=True)
|
||||
)
|
||||
|
||||
@_copy_to_script_wrapper
|
||||
def __len__(self) -> int:
|
||||
|
|
@ -432,7 +436,9 @@ class ModuleList(Module):
|
|||
|
||||
lines = []
|
||||
main_str = self._get_name() + "("
|
||||
for (start_id, end_id), b in zip(start_end_indices, repeated_blocks):
|
||||
for (start_id, end_id), b in zip(
|
||||
start_end_indices, repeated_blocks, strict=True
|
||||
):
|
||||
local_repr = f"({start_id}): {b}" # default repr
|
||||
|
||||
if start_id != end_id:
|
||||
|
|
|
|||
|
|
@ -142,7 +142,10 @@ class _ConvNd(Module):
|
|||
self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size)
|
||||
if padding == "same":
|
||||
for d, k, i in zip(
|
||||
dilation, kernel_size, range(len(kernel_size) - 1, -1, -1)
|
||||
dilation,
|
||||
kernel_size,
|
||||
range(len(kernel_size) - 1, -1, -1),
|
||||
strict=False,
|
||||
):
|
||||
total_padding = d * (k - 1)
|
||||
left_pad = total_padding // 2
|
||||
|
|
|
|||
|
|
@ -196,7 +196,7 @@ class RNNBase(Module):
|
|||
param_names += ["weight_hr_l{}{}"]
|
||||
param_names = [x.format(layer, suffix) for x in param_names]
|
||||
|
||||
for name, param in zip(param_names, layer_params):
|
||||
for name, param in zip(param_names, layer_params, strict=True):
|
||||
setattr(self, name, param)
|
||||
self._flat_weights_names.extend(param_names)
|
||||
self._all_weights.append(param_names)
|
||||
|
|
@ -352,7 +352,9 @@ class RNNBase(Module):
|
|||
# Returns True if the weight tensors have changed since the last forward pass.
|
||||
# This is the case when used with torch.func.functional_call(), for example.
|
||||
weights_changed = False
|
||||
for ref, name in zip(self._flat_weight_refs, self._flat_weights_names):
|
||||
for ref, name in zip(
|
||||
self._flat_weight_refs, self._flat_weights_names, strict=True
|
||||
):
|
||||
weight = getattr(self, name) if hasattr(self, name) else None
|
||||
if weight is not None and ref is not None and ref() is not weight:
|
||||
weights_changed = True
|
||||
|
|
|
|||
|
|
@ -41,7 +41,8 @@ def _list_with_default(out_size: list[int], defaults: list[int]) -> list[int]:
|
|||
if len(defaults) <= len(out_size):
|
||||
raise ValueError(f"Input dimension should be at least {len(out_size) + 1}")
|
||||
return [
|
||||
v if v is not None else d for v, d in zip(out_size, defaults[-len(out_size) :])
|
||||
v if v is not None else d
|
||||
for v, d in zip(out_size, defaults[-len(out_size) :], strict=False)
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -141,18 +141,18 @@ def reduce_add_coalesced(inputs, destination=None, buffer_size=10485760):
|
|||
output = []
|
||||
ref_order = []
|
||||
# process sparse ones first since they may have different sizes on different gpus
|
||||
for tensor_at_gpus in zip(*inputs):
|
||||
for tensor_at_gpus in zip(*inputs, strict=True):
|
||||
if all(t.is_sparse for t in tensor_at_gpus):
|
||||
result = reduce_add(tensor_at_gpus, destination) # this will be sparse too
|
||||
output.append(result)
|
||||
ref_order.append(tensor_at_gpus[0])
|
||||
else:
|
||||
for coll, t in zip(dense_tensors, tensor_at_gpus):
|
||||
for coll, t in zip(dense_tensors, tensor_at_gpus, strict=True):
|
||||
coll.append(t.to_dense() if t.is_sparse else t)
|
||||
ref_order.append(dense_tensors[0][-1])
|
||||
itrs = [_take_tensors(tensors, buffer_size) for tensors in dense_tensors]
|
||||
# now the dense ones, which have consistent sizes
|
||||
for chunks in zip(*itrs):
|
||||
for chunks in zip(*itrs, strict=True):
|
||||
flat_tensors = [
|
||||
_flatten_dense_tensors(chunk) for chunk in chunks
|
||||
] # (num_gpus,)
|
||||
|
|
|
|||
|
|
@ -115,7 +115,7 @@ def parallel_apply(
|
|||
target=_worker, args=(i, module, input, kwargs, device, stream)
|
||||
)
|
||||
for i, (module, input, kwargs, device, stream) in enumerate(
|
||||
zip(modules, inputs, kwargs_tup, devices, streams)
|
||||
zip(modules, inputs, kwargs_tup, devices, streams, strict=True)
|
||||
)
|
||||
]
|
||||
|
||||
|
|
|
|||
|
|
@ -57,16 +57,20 @@ def scatter(inputs, target_gpus, dim=0):
|
|||
return Scatter.apply(target_gpus, None, dim, obj)
|
||||
if _is_namedtuple(obj):
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return [type(obj)(*args) for args in zip(*map(scatter_map, obj))]
|
||||
return [
|
||||
type(obj)(*args) for args in zip(*map(scatter_map, obj), strict=False)
|
||||
]
|
||||
if isinstance(obj, tuple) and len(obj) > 0:
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return list(zip(*map(scatter_map, obj)))
|
||||
return list(zip(*map(scatter_map, obj), strict=False))
|
||||
if isinstance(obj, list) and len(obj) > 0:
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return [list(i) for i in zip(*map(scatter_map, obj))]
|
||||
return [list(i) for i in zip(*map(scatter_map, obj), strict=False)]
|
||||
if isinstance(obj, dict) and len(obj) > 0:
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return [type(obj)(i) for i in zip(*map(scatter_map, obj.items()))]
|
||||
return [
|
||||
type(obj)(i) for i in zip(*map(scatter_map, obj.items()), strict=False)
|
||||
]
|
||||
return [obj for _ in target_gpus]
|
||||
|
||||
# After scatter_map is called, a scatter_map cell will exist. This cell
|
||||
|
|
@ -131,9 +135,9 @@ def gather(outputs: Any, target_device: Union[int, torch.device], dim: int = 0)
|
|||
return type(out)((k, gather_map([d[k] for d in outputs])) for k in out)
|
||||
if _is_namedtuple(out):
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return type(out)._make(map(gather_map, zip(*outputs)))
|
||||
return type(out)._make(map(gather_map, zip(*outputs, strict=True)))
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return type(out)(map(gather_map, zip(*outputs)))
|
||||
return type(out)(map(gather_map, zip(*outputs, strict=True)))
|
||||
|
||||
# Recursive function calls like this create reference cycles.
|
||||
# Setting the function to None clears the refcycle.
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ def conv_args_and_kwargs(kwarg_names, expanded_args_and_kwargs):
|
|||
kwargs = expanded_args_and_kwargs[
|
||||
len(expanded_args_and_kwargs) - len(kwarg_names) :
|
||||
]
|
||||
kwargs = dict(zip(kwarg_names, kwargs))
|
||||
kwargs = dict(zip(kwarg_names, kwargs, strict=True))
|
||||
|
||||
return conv_normalizer(*args, **kwargs)
|
||||
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def standard_kwargs(kwarg_names, expanded_args):
|
|||
expanded_args_without_kwargs = expanded_args[
|
||||
: len(expanded_args) - len(kwarg_names)
|
||||
]
|
||||
expanded_kwargs = dict(zip(kwarg_names, kwarg_values))
|
||||
expanded_kwargs = dict(zip(kwarg_names, kwarg_values, strict=True))
|
||||
return expanded_args_without_kwargs, expanded_kwargs
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -250,7 +250,7 @@ class NamedMemberAccessor:
|
|||
values = list(values)
|
||||
assert len(names) == len(values), "names and values must have the same length"
|
||||
|
||||
for name, value in zip(names, values):
|
||||
for name, value in zip(names, values, strict=True):
|
||||
self.set_tensor(name, value)
|
||||
|
||||
def set_tensors_dict(self, named_tensors: dict[str, torch.Tensor]) -> None:
|
||||
|
|
@ -298,7 +298,7 @@ class NamedMemberAccessor:
|
|||
|
||||
return [
|
||||
self.swap_tensor(name, value, allow_missing=allow_missing)
|
||||
for name, value in zip(names, values)
|
||||
for name, value in zip(names, values, strict=True)
|
||||
]
|
||||
|
||||
def swap_tensors_dict(
|
||||
|
|
|
|||
|
|
@ -528,7 +528,7 @@ def unpad_sequence(
|
|||
max_length = padded_sequences.shape[1]
|
||||
idx = torch.arange(max_length, device=lengths.device)
|
||||
|
||||
for seq, length in zip(padded_sequences, lengths):
|
||||
for seq, length in zip(padded_sequences, lengths, strict=True):
|
||||
mask = idx < length
|
||||
unpacked_seq = seq[mask]
|
||||
unpadded_sequences.append(unpacked_seq)
|
||||
|
|
|
|||
|
|
@ -532,7 +532,7 @@ def _multi_tensor_adafactor(
|
|||
|
||||
alphas = [
|
||||
max(eps2, p.norm(2).item() / (p.numel() ** 0.5)) * r
|
||||
for p, r in zip(device_params, rho_ts)
|
||||
for p, r in zip(device_params, rho_ts, strict=True)
|
||||
]
|
||||
|
||||
# Perform stepweight decay
|
||||
|
|
@ -566,7 +566,9 @@ def _multi_tensor_adafactor(
|
|||
|
||||
var_estimates = [
|
||||
row_var @ col_var
|
||||
for row_var, col_var in zip(device_row_vars, device_col_vars)
|
||||
for row_var, col_var in zip(
|
||||
device_row_vars, device_col_vars, strict=True
|
||||
)
|
||||
]
|
||||
row_var_means = [
|
||||
row_var.mean(dim=-2, keepdim=True) for row_var in device_row_vars
|
||||
|
|
@ -594,7 +596,7 @@ def _multi_tensor_adafactor(
|
|||
|
||||
alphas = [
|
||||
-a / (max(1.0, update.norm(2).item() / ((update.numel() ** 0.5) * d)))
|
||||
for a, update in zip(alphas, updates)
|
||||
for a, update in zip(alphas, updates, strict=True)
|
||||
]
|
||||
torch._foreach_mul_(updates, alphas)
|
||||
torch._foreach_add_(device_params, updates)
|
||||
|
|
|
|||
|
|
@ -266,7 +266,7 @@ def _single_tensor_adadelta(
|
|||
if not all(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
for p, step in zip(params, state_steps, strict=True)
|
||||
):
|
||||
raise AssertionError(
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
|
|
@ -276,7 +276,7 @@ def _single_tensor_adadelta(
|
|||
lr = _to_scalar(lr)
|
||||
|
||||
for param, grad, square_avg, acc_delta, step in zip(
|
||||
params, grads, square_avgs, acc_deltas, state_steps
|
||||
params, grads, square_avgs, acc_deltas, state_steps, strict=True
|
||||
):
|
||||
step += 1
|
||||
grad = grad if not maximize else -grad
|
||||
|
|
@ -329,7 +329,7 @@ def _multi_tensor_adadelta(
|
|||
if not all(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
for p, step in zip(params, state_steps, strict=True)
|
||||
):
|
||||
raise AssertionError(
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
|
|
|
|||
|
|
@ -343,7 +343,9 @@ def _single_tensor_adagrad(
|
|||
if not torch.jit.is_scripting():
|
||||
lr = _to_scalar(lr)
|
||||
|
||||
for param, grad, state_sum, step_t in zip(params, grads, state_sums, state_steps):
|
||||
for param, grad, state_sum, step_t in zip(
|
||||
params, grads, state_sums, state_steps, strict=True
|
||||
):
|
||||
# update step
|
||||
step_t += 1
|
||||
step = _get_value(step_t)
|
||||
|
|
|
|||
|
|
@ -608,7 +608,7 @@ def _multi_tensor_adam(
|
|||
if not all(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
for p, step in zip(params, state_steps, strict=True)
|
||||
):
|
||||
raise AssertionError(
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
|
|
|
|||
|
|
@ -334,7 +334,7 @@ def _multi_tensor_adamax(
|
|||
if not all(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
for p, step in zip(params, state_steps, strict=True)
|
||||
):
|
||||
raise AssertionError(
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
|
|
|
|||
|
|
@ -307,7 +307,7 @@ def _multi_tensor_asgd(
|
|||
if not all(
|
||||
p.device.type == mu.device.type == eta.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, mu, eta, step in zip(params, mus, etas, state_steps)
|
||||
for p, mu, eta, step in zip(params, mus, etas, state_steps, strict=True)
|
||||
):
|
||||
raise AssertionError(
|
||||
f"If capturable=True, params, mus, etas, and state_steps must be on "
|
||||
|
|
|
|||
|
|
@ -314,7 +314,7 @@ class LBFGS(Optimizer):
|
|||
return [p.clone(memory_format=torch.contiguous_format) for p in self._params]
|
||||
|
||||
def _set_param(self, params_data):
|
||||
for p, pdata in zip(self._params, params_data):
|
||||
for p, pdata in zip(self._params, params_data, strict=True):
|
||||
p.copy_(pdata)
|
||||
|
||||
def _directional_evaluate(self, closure, x, t, d):
|
||||
|
|
|
|||
|
|
@ -302,7 +302,7 @@ class LRScheduler:
|
|||
else:
|
||||
values = self.get_lr()
|
||||
|
||||
for param_group, lr in zip(self.optimizer.param_groups, values):
|
||||
for param_group, lr in zip(self.optimizer.param_groups, values, strict=True):
|
||||
_update_param_group_val(param_group, "lr", lr)
|
||||
|
||||
self._last_lr: list[float | Tensor] = _param_groups_val_list(
|
||||
|
|
@ -472,7 +472,7 @@ class LambdaLR(LRScheduler):
|
|||
|
||||
return [
|
||||
base_lr * lmbda(self.last_epoch)
|
||||
for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs)
|
||||
for lmbda, base_lr in zip(self.lr_lambdas, self.base_lrs, strict=True)
|
||||
]
|
||||
|
||||
|
||||
|
|
@ -592,7 +592,9 @@ class MultiplicativeLR(LRScheduler):
|
|||
if not self._is_initial:
|
||||
return [
|
||||
group["lr"] * lmbda(self.last_epoch)
|
||||
for lmbda, group in zip(self.lr_lambdas, self.optimizer.param_groups)
|
||||
for lmbda, group in zip(
|
||||
self.lr_lambdas, self.optimizer.param_groups, strict=True
|
||||
)
|
||||
]
|
||||
else:
|
||||
return _param_groups_val_list(self.optimizer, "lr")
|
||||
|
|
@ -1441,13 +1443,17 @@ class CosineAnnealingLR(LRScheduler):
|
|||
+ (base_lr - self.eta_min)
|
||||
* (1 + math.cos((self.last_epoch) * math.pi / self.T_max))
|
||||
/ 2
|
||||
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
|
||||
for base_lr, group in zip(
|
||||
self.base_lrs, self.optimizer.param_groups, strict=True
|
||||
)
|
||||
]
|
||||
elif (self.last_epoch - 1 - self.T_max) % (2 * self.T_max) == 0:
|
||||
return [
|
||||
group["lr"]
|
||||
+ (base_lr - self.eta_min) * (1 - math.cos(math.pi / self.T_max)) / 2
|
||||
for base_lr, group in zip(self.base_lrs, self.optimizer.param_groups)
|
||||
for base_lr, group in zip(
|
||||
self.base_lrs, self.optimizer.param_groups, strict=True
|
||||
)
|
||||
]
|
||||
return [
|
||||
(1 + math.cos(math.pi * self.last_epoch / self.T_max))
|
||||
|
|
@ -1906,7 +1912,7 @@ class CyclicLR(LRScheduler):
|
|||
|
||||
base_lrs = _format_param("base_lr", optimizer, base_lr)
|
||||
if last_epoch == -1:
|
||||
for lr, group in zip(base_lrs, optimizer.param_groups):
|
||||
for lr, group in zip(base_lrs, optimizer.param_groups, strict=True):
|
||||
_update_param_group_val(group, "lr", lr)
|
||||
|
||||
self.max_lrs = _format_param("max_lr", optimizer, max_lr)
|
||||
|
|
@ -1949,7 +1955,10 @@ class CyclicLR(LRScheduler):
|
|||
self.max_momentums = _format_param("max_momentum", optimizer, max_momentum)
|
||||
if last_epoch == -1:
|
||||
for m_momentum, b_momentum, group in zip(
|
||||
self.max_momentums, self.base_momentums, optimizer.param_groups
|
||||
self.max_momentums,
|
||||
self.base_momentums,
|
||||
optimizer.param_groups,
|
||||
strict=True,
|
||||
):
|
||||
if self.use_beta1:
|
||||
group["betas"] = (m_momentum, *group["betas"][1:])
|
||||
|
|
@ -2033,7 +2042,7 @@ class CyclicLR(LRScheduler):
|
|||
scale_factor = (x - 1) / (self.step_ratio - 1)
|
||||
|
||||
lrs = []
|
||||
for base_lr, max_lr in zip(self.base_lrs, self.max_lrs):
|
||||
for base_lr, max_lr in zip(self.base_lrs, self.max_lrs, strict=True):
|
||||
base_height = (max_lr - base_lr) * scale_factor
|
||||
if self.scale_mode == "cycle":
|
||||
lr = base_lr + base_height * self.scale_fn(cycle)
|
||||
|
|
@ -2044,7 +2053,7 @@ class CyclicLR(LRScheduler):
|
|||
if self.cycle_momentum:
|
||||
momentums = []
|
||||
for base_momentum, max_momentum in zip(
|
||||
self.base_momentums, self.max_momentums
|
||||
self.base_momentums, self.max_momentums, strict=True
|
||||
):
|
||||
base_height = (max_momentum - base_momentum) * scale_factor
|
||||
if self.scale_mode == "cycle":
|
||||
|
|
@ -2054,7 +2063,9 @@ class CyclicLR(LRScheduler):
|
|||
self.last_epoch
|
||||
)
|
||||
momentums.append(momentum)
|
||||
for param_group, momentum in zip(self.optimizer.param_groups, momentums):
|
||||
for param_group, momentum in zip(
|
||||
self.optimizer.param_groups, momentums, strict=True
|
||||
):
|
||||
if self.use_beta1:
|
||||
param_group["betas"] = (momentum, *param_group["betas"][1:])
|
||||
else:
|
||||
|
|
@ -2260,7 +2271,9 @@ class CosineAnnealingWarmRestarts(LRScheduler):
|
|||
self.last_epoch = math.floor(epoch)
|
||||
|
||||
with _enable_get_lr_call(self):
|
||||
for param_group, lr in zip(self.optimizer.param_groups, self.get_lr()):
|
||||
for param_group, lr in zip(
|
||||
self.optimizer.param_groups, self.get_lr(), strict=True
|
||||
):
|
||||
_update_param_group_val(param_group, "lr", lr)
|
||||
|
||||
self._last_lr = _param_groups_val_list(self.optimizer, "lr")
|
||||
|
|
@ -2500,7 +2513,7 @@ class OneCycleLR(LRScheduler):
|
|||
base_momentums = _format_param("base_momentum", optimizer, base_momentum)
|
||||
if last_epoch == -1:
|
||||
for m_momentum, b_momentum, group in zip(
|
||||
max_momentums, base_momentums, optimizer.param_groups
|
||||
max_momentums, base_momentums, optimizer.param_groups, strict=True
|
||||
):
|
||||
if self.use_beta1:
|
||||
group["betas"] = (m_momentum, *group["betas"][1:])
|
||||
|
|
|
|||
|
|
@ -412,7 +412,7 @@ def _multi_tensor_nadam(
|
|||
if not all(
|
||||
p.device.type == mp.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, mp, step in zip(params, mu_products, state_steps)
|
||||
for p, mp, step in zip(params, mu_products, state_steps, strict=True)
|
||||
):
|
||||
raise AssertionError(
|
||||
"If capturable=True, "
|
||||
|
|
@ -570,7 +570,7 @@ def _multi_tensor_nadam(
|
|||
step_size_grads = _stack_if_compiling(
|
||||
[
|
||||
(_get_value(lr) * (1.0 - mu) / (1.0 - _get_value(mu_product))) * -1
|
||||
for mu_product, mu in zip(grouped_mu_products, mus)
|
||||
for mu_product, mu in zip(grouped_mu_products, mus, strict=True)
|
||||
]
|
||||
)
|
||||
step_size_expavg = _stack_if_compiling(
|
||||
|
|
@ -581,7 +581,9 @@ def _multi_tensor_nadam(
|
|||
/ (1.0 - _get_value(mu_product) * mu_next)
|
||||
)
|
||||
* -1
|
||||
for mu_product, mu_next in zip(grouped_mu_products, mu_nexts)
|
||||
for mu_product, mu_next in zip(
|
||||
grouped_mu_products, mu_nexts, strict=True
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -941,7 +941,9 @@ class Optimizer:
|
|||
)
|
||||
param_lens = (len(g["params"]) for g in groups)
|
||||
saved_lens = (len(g["params"]) for g in saved_groups)
|
||||
if any(p_len != s_len for p_len, s_len in zip(param_lens, saved_lens)):
|
||||
if any(
|
||||
p_len != s_len for p_len, s_len in zip(param_lens, saved_lens, strict=True)
|
||||
):
|
||||
raise ValueError(
|
||||
"loaded state dict contains a parameter group "
|
||||
"that doesn't match the size of optimizer's group"
|
||||
|
|
@ -952,6 +954,7 @@ class Optimizer:
|
|||
zip(
|
||||
chain.from_iterable(g["params"] for g in saved_groups),
|
||||
chain.from_iterable(g["params"] for g in groups),
|
||||
strict=True,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -1005,7 +1008,9 @@ class Optimizer:
|
|||
new_group["param_names"] = group["param_names"]
|
||||
return new_group
|
||||
|
||||
param_groups = [update_group(g, ng) for g, ng in zip(groups, saved_groups)]
|
||||
param_groups = [
|
||||
update_group(g, ng) for g, ng in zip(groups, saved_groups, strict=True)
|
||||
]
|
||||
self.__setstate__({"state": state, "param_groups": param_groups})
|
||||
|
||||
for post_hook in self._optimizer_load_state_dict_post_hooks.values():
|
||||
|
|
|
|||
|
|
@ -392,7 +392,7 @@ def _multi_tensor_radam(
|
|||
if not all(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
for p, step in zip(params, state_steps, strict=True)
|
||||
):
|
||||
raise AssertionError(
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
|
|
@ -501,7 +501,8 @@ def _multi_tensor_radam(
|
|||
|
||||
# TODO(mlazos): we should try and get a foreach_where op https://github.com/pytorch/pytorch/issues/117884
|
||||
rect = [
|
||||
torch.where(rho_t > 5.0, n, 0.0) for n, rho_t in zip(num, rho_t_list)
|
||||
torch.where(rho_t > 5.0, n, 0.0)
|
||||
for n, rho_t in zip(num, rho_t_list, strict=True)
|
||||
]
|
||||
del num
|
||||
del rho_t_list
|
||||
|
|
@ -544,11 +545,14 @@ def _multi_tensor_radam(
|
|||
1 - beta1 ** _get_value(step) for step in grouped_state_steps
|
||||
]
|
||||
unrect_step_size = [
|
||||
(lr * rect / bc) * -1 for rect, bc in zip(unrectified, bias_correction1)
|
||||
(lr * rect / bc) * -1
|
||||
for rect, bc in zip(unrectified, bias_correction1, strict=True)
|
||||
]
|
||||
bias_correction2 = [
|
||||
((1 - beta2 ** _get_value(step)) ** 0.5) * (lr * rect / bc) * -1
|
||||
for step, rect, bc in zip(grouped_state_steps, rect, bias_correction1)
|
||||
for step, rect, bc in zip(
|
||||
grouped_state_steps, rect, bias_correction1, strict=True
|
||||
)
|
||||
]
|
||||
|
||||
buffer = torch._foreach_sqrt(grouped_exp_avg_sqs)
|
||||
|
|
|
|||
|
|
@ -370,7 +370,7 @@ def _multi_tensor_rmsprop(
|
|||
if not all(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
for p, step in zip(params, state_steps, strict=True)
|
||||
):
|
||||
raise AssertionError(
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
|
|
|
|||
|
|
@ -319,7 +319,7 @@ def _multi_tensor_rprop(
|
|||
if not all(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
for p, step in zip(params, state_steps, strict=True)
|
||||
):
|
||||
raise AssertionError(
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
|
|
|
|||
|
|
@ -143,7 +143,9 @@ class SGD(Optimizer): # noqa: D101
|
|||
|
||||
if group["momentum"] != 0:
|
||||
# update momentum_buffers in state
|
||||
for p, momentum_buffer in zip(params, momentum_buffer_list):
|
||||
for p, momentum_buffer in zip(
|
||||
params, momentum_buffer_list, strict=True
|
||||
):
|
||||
state = self.state[p]
|
||||
state["momentum_buffer"] = momentum_buffer
|
||||
|
||||
|
|
|
|||
|
|
@ -50,7 +50,7 @@ def get_ema_multi_avg_fn(decay=0.999):
|
|||
):
|
||||
torch._foreach_lerp_(ema_param_list, current_param_list, 1 - decay)
|
||||
else:
|
||||
for p_ema, p_model in zip(ema_param_list, current_param_list):
|
||||
for p_ema, p_model in zip(ema_param_list, current_param_list, strict=True):
|
||||
p_ema.copy_(p_ema * decay + p_model * (1 - decay))
|
||||
|
||||
return ema_update
|
||||
|
|
@ -264,7 +264,7 @@ class AveragedModel(Module):
|
|||
self_param_detached: list[Optional[Tensor]] = []
|
||||
model_param_detached: list[Optional[Tensor]] = []
|
||||
copy_param = bool(self.n_averaged == 0)
|
||||
for p_averaged, p_model in zip(self_param, model_param):
|
||||
for p_averaged, p_model in zip(self_param, model_param, strict=True):
|
||||
p_model_ = p_model.detach().to(p_averaged.device)
|
||||
self_param_detached.append(p_averaged.detach())
|
||||
model_param_detached.append(p_model_)
|
||||
|
|
@ -297,12 +297,14 @@ class AveragedModel(Module):
|
|||
else:
|
||||
avg_fn = get_swa_avg_fn()
|
||||
n_averaged = self.n_averaged.to(device)
|
||||
for p_averaged, p_model in zip(self_params, model_params): # type: ignore[assignment]
|
||||
for p_averaged, p_model in zip( # type: ignore[assignment]
|
||||
self_params, model_params, strict=True
|
||||
):
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
p_averaged.copy_(avg_fn(p_averaged, p_model, n_averaged))
|
||||
else:
|
||||
for p_averaged, p_model in zip( # type: ignore[assignment]
|
||||
self_param_detached, model_param_detached
|
||||
self_param_detached, model_param_detached, strict=True
|
||||
):
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
n_averaged = self.n_averaged.to(p_averaged.device)
|
||||
|
|
@ -315,7 +317,9 @@ class AveragedModel(Module):
|
|||
if not self.use_buffers:
|
||||
# If not apply running averages to the buffers,
|
||||
# keep the buffers in sync with the source model.
|
||||
for b_swa, b_model in zip(self.module.buffers(), model.buffers()):
|
||||
for b_swa, b_model in zip(
|
||||
self.module.buffers(), model.buffers(), strict=True
|
||||
):
|
||||
b_swa.detach().copy_(b_model.detach().to(b_swa.device))
|
||||
self.n_averaged += 1
|
||||
|
||||
|
|
@ -432,7 +436,7 @@ class SWALR(LRScheduler):
|
|||
last_epoch=-1,
|
||||
): # noqa: D107
|
||||
swa_lrs = _format_param("swa_lr", optimizer, swa_lr)
|
||||
for swa_lr, group in zip(swa_lrs, optimizer.param_groups):
|
||||
for swa_lr, group in zip(swa_lrs, optimizer.param_groups, strict=True):
|
||||
group["swa_lr"] = swa_lr
|
||||
if anneal_strategy not in ["cos", "linear"]:
|
||||
raise ValueError(
|
||||
|
|
@ -509,7 +513,7 @@ class SWALR(LRScheduler):
|
|||
alpha = self.anneal_func(t)
|
||||
return [
|
||||
group["swa_lr"] * alpha + lr * (1 - alpha)
|
||||
for group, lr in zip(self.optimizer.param_groups, prev_lrs)
|
||||
for group, lr in zip(self.optimizer.param_groups, prev_lrs, strict=True)
|
||||
]
|
||||
|
||||
def _set_anneal_func(self, anneal_strategy: Literal["cos", "linear"]):
|
||||
|
|
|
|||
|
|
@ -254,7 +254,9 @@ class SchemaMatcher:
|
|||
def matches(schema) -> bool:
|
||||
return len(schema.arguments) == len(signature) and all(
|
||||
cls._types_match(observed, schema_arg.type)
|
||||
for observed, schema_arg in zip(signature, schema.arguments)
|
||||
for observed, schema_arg in zip(
|
||||
signature, schema.arguments, strict=True
|
||||
)
|
||||
)
|
||||
|
||||
return tuple(s for s in cls.lookup_schemas(t.name) or () if matches(s))
|
||||
|
|
@ -377,7 +379,9 @@ class SizeMap:
|
|||
key = TensorKey.from_tensor(t)
|
||||
if key is not None and t is not None and t.layout == torch.strided:
|
||||
# Scalars are represented as zero dim Tensors
|
||||
n = max(i[0] * i[1] for i in zip(t.sizes or [1], t.strides or [1]))
|
||||
n = max(
|
||||
i[0] * i[1] for i in zip(t.sizes or [1], t.strides or [1], strict=True)
|
||||
)
|
||||
|
||||
num_bytes = n * _element_size(t.dtype)
|
||||
assert num_bytes >= 0, f"{num_bytes}"
|
||||
|
|
@ -430,7 +434,7 @@ class DataFlowNode:
|
|||
mutable_by_key: dict[Optional[TensorKey], set[Optional[bool]]] = {}
|
||||
for op in (i.typed[1] for i in subtree if i.typed[0] == _EventType.TorchOp):
|
||||
for op_input, mutable in zip(
|
||||
op.inputs, SchemaMatcher.inputs_are_mutable(op)
|
||||
op.inputs, SchemaMatcher.inputs_are_mutable(op), strict=True
|
||||
):
|
||||
# Tensor
|
||||
if isinstance(op_input, _TensorMetadata):
|
||||
|
|
|
|||
|
|
@ -276,7 +276,7 @@ class ForLoopIndexingPattern(Pattern):
|
|||
def same_ops(list1, list2) -> bool:
|
||||
if len(list1) != len(list2):
|
||||
return False
|
||||
for op1, op2 in zip(list1, list2):
|
||||
for op1, op2 in zip(list1, list2, strict=True):
|
||||
if op1.name != op2.name:
|
||||
return False
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -336,7 +336,7 @@ class BasicEvaluation:
|
|||
event_list = [
|
||||
event
|
||||
for _, event in sorted(
|
||||
zip(heuristic_score_list, event_list),
|
||||
zip(heuristic_score_list, event_list, strict=True),
|
||||
key=operator.itemgetter(0),
|
||||
reverse=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -121,7 +121,7 @@ def slicer(dim, slice_range, *tensors):
|
|||
def multidim_slicer(dims, slices, *tensors):
|
||||
for t in tensors:
|
||||
s = [slice(None)] * t.dim()
|
||||
for d, d_slice in zip(dims, slices):
|
||||
for d, d_slice in zip(dims, slices, strict=False):
|
||||
if d is not None:
|
||||
s[d] = d_slice
|
||||
yield t[tuple(s)]
|
||||
|
|
@ -140,7 +140,7 @@ def grid_partitioner(full_grid, grid_blocks, tensor_dims_map):
|
|||
import itertools
|
||||
|
||||
def generate_grid_points():
|
||||
for fg, mg in zip(full_grid, grid_blocks):
|
||||
for fg, mg in zip(full_grid, grid_blocks, strict=False):
|
||||
yield range(0, fg, mg)
|
||||
|
||||
def generate_sliced_tensors(slices):
|
||||
|
|
@ -149,9 +149,10 @@ def grid_partitioner(full_grid, grid_blocks, tensor_dims_map):
|
|||
|
||||
for grid_point in itertools.product(*generate_grid_points()):
|
||||
grid = [
|
||||
min(fg - gp, mg) for fg, gp, mg in zip(full_grid, grid_point, grid_blocks)
|
||||
min(fg - gp, mg)
|
||||
for fg, gp, mg in zip(full_grid, grid_point, grid_blocks, strict=False)
|
||||
]
|
||||
slices = [slice(gp, gp + g) for gp, g in zip(grid_point, grid)]
|
||||
slices = [slice(gp, gp + g) for gp, g in zip(grid_point, grid, strict=False)]
|
||||
# grid_points are iterated in a "contiguous" order, i.e.
|
||||
# left dimensions traversed slower than right dimensions.
|
||||
# This order is reversed for CUDA grids.
|
||||
|
|
@ -173,7 +174,8 @@ def launch_kernel(kernel, tensor_dims_map, full_grid, grid_blocks=None):
|
|||
return max(1, min(g, mg))
|
||||
|
||||
grid_blocks = tuple(
|
||||
valid_grid_dim(g, mg) for g, mg in zip(grid_blocks, cuda_max_grid)
|
||||
valid_grid_dim(g, mg)
|
||||
for g, mg in zip(grid_blocks, cuda_max_grid, strict=False)
|
||||
) # type: ignore[assignment]
|
||||
|
||||
for grid, *sliced_tensors in grid_partitioner(
|
||||
|
|
|
|||
|
|
@ -155,7 +155,11 @@ def get_meta(op, key, device_name=None, version=(0, torch.float16, 0.5), exact=F
|
|||
matching_data = {}
|
||||
if "*" in key:
|
||||
for op_key in op_data:
|
||||
if [None for k1, k2 in zip(op_key, key) if k2 != "*" and k1 != k2]:
|
||||
if [
|
||||
None
|
||||
for k1, k2 in zip(op_key, key, strict=True)
|
||||
if k2 != "*" and k1 != k2
|
||||
]:
|
||||
continue
|
||||
matching_data[op_key] = op_data[op_key]
|
||||
else:
|
||||
|
|
@ -173,10 +177,14 @@ def get_meta(op, key, device_name=None, version=(0, torch.float16, 0.5), exact=F
|
|||
"num_stages",
|
||||
"num_warps",
|
||||
)
|
||||
meta = dict(zip(names, values))
|
||||
meta = dict(zip(names, values, strict=True))
|
||||
elif op in {"bsr_dense_addmm", "_int_bsr_dense_addmm"}:
|
||||
meta = dict(
|
||||
zip(("GROUP_SIZE_ROW", "SPLIT_N", "num_stages", "num_warps"), values)
|
||||
zip(
|
||||
("GROUP_SIZE_ROW", "SPLIT_N", "num_stages", "num_warps"),
|
||||
values,
|
||||
strict=True,
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"names for {op=}")
|
||||
|
|
@ -289,7 +297,7 @@ def minimize(
|
|||
return tuple(parameters[k] for k in sorted(parameters))
|
||||
|
||||
def from_key(key, parameters):
|
||||
return dict(zip(sorted(parameters), key))
|
||||
return dict(zip(sorted(parameters), key, strict=True))
|
||||
|
||||
if all_values is None:
|
||||
all_values = {}
|
||||
|
|
@ -347,7 +355,7 @@ def minimize(
|
|||
for i, (_, d_tuple) in enumerate(all_directions):
|
||||
pbar.update(1)
|
||||
next_parameters = parameters.copy()
|
||||
for name, direction in zip(names, d_tuple):
|
||||
for name, direction in zip(names, d_tuple, strict=True):
|
||||
value = next_parameters[name]
|
||||
if direction == 0:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -364,11 +364,15 @@ def register_dataclass(
|
|||
|
||||
def _unflatten_fn(values: Iterable[Any], context: Context) -> Any:
|
||||
flat_names, none_names = context
|
||||
return cls(**dict(zip(flat_names, values)), **dict.fromkeys(none_names))
|
||||
return cls(
|
||||
**dict(zip(flat_names, values, strict=True)), **dict.fromkeys(none_names)
|
||||
)
|
||||
|
||||
def _flatten_fn_with_keys(obj: Any) -> tuple[list[Any], Context]:
|
||||
flattened, (flat_names, _none_names) = _flatten_fn(obj) # type: ignore[misc]
|
||||
return [(GetAttrKey(k), v) for k, v in zip(flat_names, flattened)], flat_names
|
||||
return [
|
||||
(GetAttrKey(k), v) for k, v in zip(flat_names, flattened, strict=True)
|
||||
], flat_names
|
||||
|
||||
_private_register_pytree_node(
|
||||
cls,
|
||||
|
|
@ -788,11 +792,11 @@ def _dict_flatten_with_keys(
|
|||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _dict_flatten(d)
|
||||
# pyrefly: ignore [bad-return]
|
||||
return [(MappingKey(k), v) for k, v in zip(context, values)], context
|
||||
return [(MappingKey(k), v) for k, v in zip(context, values, strict=True)], context
|
||||
|
||||
|
||||
def _dict_unflatten(values: Iterable[T], context: Context) -> dict[Any, T]:
|
||||
return dict(zip(context, values))
|
||||
return dict(zip(context, values, strict=True))
|
||||
|
||||
|
||||
def _namedtuple_flatten(d: NamedTuple) -> tuple[list[Any], Context]:
|
||||
|
|
@ -805,7 +809,10 @@ def _namedtuple_flatten_with_keys(
|
|||
values, context = _namedtuple_flatten(d)
|
||||
# pyrefly: ignore [bad-return]
|
||||
return (
|
||||
[(GetAttrKey(field), v) for field, v in zip(context._fields, values)],
|
||||
[
|
||||
(GetAttrKey(field), v)
|
||||
for field, v in zip(context._fields, values, strict=True)
|
||||
],
|
||||
context,
|
||||
)
|
||||
|
||||
|
|
@ -854,14 +861,14 @@ def _ordereddict_flatten_with_keys(
|
|||
) -> tuple[list[tuple[KeyEntry, T]], Context]:
|
||||
values, context = _ordereddict_flatten(d)
|
||||
# pyrefly: ignore [bad-return]
|
||||
return [(MappingKey(k), v) for k, v in zip(context, values)], context
|
||||
return [(MappingKey(k), v) for k, v in zip(context, values, strict=True)], context
|
||||
|
||||
|
||||
def _ordereddict_unflatten(
|
||||
values: Iterable[T],
|
||||
context: Context,
|
||||
) -> OrderedDict[Any, T]:
|
||||
return OrderedDict((key, value) for key, value in zip(context, values))
|
||||
return OrderedDict((key, value) for key, value in zip(context, values, strict=True))
|
||||
|
||||
|
||||
_odict_flatten = _ordereddict_flatten
|
||||
|
|
@ -879,7 +886,9 @@ def _defaultdict_flatten_with_keys(
|
|||
values, context = _defaultdict_flatten(d)
|
||||
_, dict_context = context
|
||||
# pyrefly: ignore [bad-return]
|
||||
return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context
|
||||
return [
|
||||
(MappingKey(k), v) for k, v in zip(dict_context, values, strict=True)
|
||||
], context
|
||||
|
||||
|
||||
def _defaultdict_unflatten(
|
||||
|
|
@ -1197,7 +1206,7 @@ class TreeSpec:
|
|||
f"expected {treespec.context!r}, but got {context!r}.", # namedtuple type mismatch
|
||||
)
|
||||
|
||||
for subtree, subspec in zip(children, treespec.children_specs):
|
||||
for subtree, subspec in zip(children, treespec.children_specs, strict=True):
|
||||
helper(subspec, subtree, subtrees)
|
||||
|
||||
subtrees: list[PyTree] = []
|
||||
|
|
@ -1761,7 +1770,7 @@ def _broadcast_to_and_flatten(
|
|||
|
||||
# Recursively flatten the children
|
||||
result: list[Any] = []
|
||||
for child, child_spec in zip(child_pytrees, treespec.children_specs):
|
||||
for child, child_spec in zip(child_pytrees, treespec.children_specs, strict=True):
|
||||
flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf)
|
||||
if flat is not None:
|
||||
result += flat
|
||||
|
|
@ -2063,9 +2072,9 @@ def tree_map_with_path(
|
|||
``xs`` is the tuple of values at corresponding nodes in ``rests``.
|
||||
"""
|
||||
keypath_leaves, treespec = tree_flatten_with_path(tree, is_leaf)
|
||||
keypath_leaves = list(zip(*keypath_leaves))
|
||||
keypath_leaves = list(zip(*keypath_leaves, strict=True))
|
||||
all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests]
|
||||
return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves))
|
||||
return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves, strict=True))
|
||||
|
||||
|
||||
def keystr(kp: KeyPath) -> str:
|
||||
|
|
|
|||
|
|
@ -1223,7 +1223,8 @@ class IsNonOverlappingAndDenseIndicator(sympy.Function):
|
|||
# When all strides are integral, we can sort, and the size for the
|
||||
# largest stride doesn't matter and can be arbitrarily symbolic
|
||||
s_sizes, s_strides = zip(
|
||||
*sorted(zip(sizes, strides), key=operator.itemgetter(1))
|
||||
*sorted(zip(sizes, strides, strict=True), key=operator.itemgetter(1)),
|
||||
strict=True,
|
||||
)
|
||||
# Put something arbitrary in the max size spot, it'll be ignored
|
||||
if all(isinstance(a, sympy.Integer) for a in s_sizes[:-1]):
|
||||
|
|
|
|||
|
|
@ -32,7 +32,7 @@ def run(n, stmt, fuzzer_cls):
|
|||
float_iter = fuzzer_cls(seed=0, dtype=torch.float32).take(n)
|
||||
int_iter = fuzzer_cls(seed=0, dtype=torch.int32).take(n)
|
||||
raw_results = []
|
||||
for i, (float_values, int_values) in enumerate(zip(float_iter, int_iter)):
|
||||
for i, (float_values, int_values) in enumerate(zip(float_iter, int_iter, strict=True)):
|
||||
float_tensors, float_tensor_params, float_params = float_values
|
||||
int_tensors, int_tensor_params, int_params = int_values
|
||||
|
||||
|
|
@ -89,7 +89,7 @@ def run(n, stmt, fuzzer_cls):
|
|||
for t_float, t_int, rel_diff, descriptions in results:
|
||||
time_str = [f"{rel_diff * 100:>4.1f}% {'int' if t_int < t_float else 'float':<20}"]
|
||||
time_str.extend(["".ljust(len(time_str[0])) for _ in descriptions[:-1]])
|
||||
for t_str, (name, shape, order, steps) in zip(time_str, descriptions):
|
||||
for t_str, (name, shape, order, steps) in zip(time_str, descriptions, strict=True):
|
||||
name = f"{name}:".ljust(name_len + 1)
|
||||
shape = shape.ljust(shape_len + 10)
|
||||
order = order.ljust(order_len)
|
||||
|
|
|
|||
|
|
@ -29,7 +29,7 @@ def run(n, stmt, fuzzer_cls):
|
|||
float_iter = fuzzer_cls(seed=0, dtype=torch.float32).take(n)
|
||||
double_iter = fuzzer_cls(seed=0, dtype=torch.float64).take(n)
|
||||
raw_results = []
|
||||
for i, (float_values, int_values) in enumerate(zip(float_iter, double_iter)):
|
||||
for i, (float_values, int_values) in enumerate(zip(float_iter, double_iter, strict=True)):
|
||||
float_tensors, float_tensor_params, float_params = float_values
|
||||
int_tensors, int_tensor_params, int_params = int_values
|
||||
|
||||
|
|
@ -84,7 +84,7 @@ def run(n, stmt, fuzzer_cls):
|
|||
for t_float, t_int, rel_diff, descriptions in results:
|
||||
time_str = [f"{rel_diff * 100:>4.1f}% {'int' if t_int < t_float else 'float':<20}"]
|
||||
time_str.extend(["".ljust(len(time_str[0])) for _ in descriptions[:-1]])
|
||||
for t_str, (name, shape, sparse_dim, is_coalesced) in zip(time_str, descriptions):
|
||||
for t_str, (name, shape, sparse_dim, is_coalesced) in zip(time_str, descriptions, strict=True):
|
||||
name = f"{name}:".ljust(name_len + 1)
|
||||
shape = shape.ljust(shape_len + 10)
|
||||
sparse_dim = sparse_dim.ljust(sparse_dim_len)
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ class _Column:
|
|||
unit_digits = max(d for d in leading_digits if d is not None)
|
||||
decimal_digits = min(
|
||||
max(m.significant_figures - digits, 0)
|
||||
for digits, m in zip(leading_digits, self._flat_results)
|
||||
for digits, m in zip(leading_digits, self._flat_results, strict=True)
|
||||
if (m is not None) and (digits is not None)
|
||||
) if self._trim_significant_figures else 1
|
||||
length = unit_digits + decimal_digits + (1 if decimal_digits else 0)
|
||||
|
|
@ -99,7 +99,7 @@ class _Row:
|
|||
env = f"({concrete_results[0].env})" if self._render_env else ""
|
||||
env = env.ljust(self._env_str_len + 4)
|
||||
output = [" " + env + concrete_results[0].as_row_name]
|
||||
for m, col in zip(self._results, self._columns or ()):
|
||||
for m, col in zip(self._results, self._columns or (), strict=False):
|
||||
if m is None:
|
||||
output.append(col.num_to_str(None, 1, None))
|
||||
else:
|
||||
|
|
@ -141,7 +141,7 @@ class _Row:
|
|||
]
|
||||
|
||||
row_contents = [column_strings[0].ljust(col_widths[0])]
|
||||
for col_str, width, result, best_value in zip(column_strings[1:], col_widths[1:], self._results, best_values):
|
||||
for col_str, width, result, best_value in zip(column_strings[1:], col_widths[1:], self._results, best_values, strict=False):
|
||||
col_str = col_str.center(width)
|
||||
if self._colorize != Colorize.NONE and result is not None and best_value is not None:
|
||||
col_str = self.color_segment(col_str, result.median, best_value)
|
||||
|
|
@ -206,7 +206,7 @@ class Table:
|
|||
prior_env = ""
|
||||
row_group = -1
|
||||
rows_by_group: list[list[list[Optional[common.Measurement]]]] = []
|
||||
for (num_threads, env, _), row in zip(self.row_keys, ordered_results):
|
||||
for (num_threads, env, _), row in zip(self.row_keys, ordered_results, strict=True):
|
||||
thread_transition = (num_threads != prior_num_threads)
|
||||
if thread_transition:
|
||||
prior_num_threads = num_threads
|
||||
|
|
@ -250,10 +250,10 @@ class Table:
|
|||
for sr in string_rows:
|
||||
sr.extend(["" for _ in range(num_cols - len(sr))])
|
||||
|
||||
col_widths = [max(len(j) for j in i) for i in zip(*string_rows)]
|
||||
finalized_columns = [" | ".join(i.center(w) for i, w in zip(string_rows[0], col_widths))]
|
||||
col_widths = [max(len(j) for j in i) for i in zip(*string_rows, strict=True)]
|
||||
finalized_columns = [" | ".join(i.center(w) for i, w in zip(string_rows[0], col_widths, strict=True))]
|
||||
overall_width = len(finalized_columns[0])
|
||||
for string_row, row in zip(string_rows[1:], self.rows):
|
||||
for string_row, row in zip(string_rows[1:], self.rows, strict=True):
|
||||
finalized_columns.extend(row.row_separator(overall_width))
|
||||
finalized_columns.append(" | ".join(row.finalize_column_strings(string_row, col_widths)))
|
||||
|
||||
|
|
|
|||
|
|
@ -295,7 +295,7 @@ class FuzzedTensor:
|
|||
raw_tensor = raw_tensor.permute(tuple(order)).contiguous()
|
||||
raw_tensor = raw_tensor.permute(tuple(np.argsort(order)))
|
||||
|
||||
slices = [slice(0, size * step, step) for size, step in zip(size, steps)]
|
||||
slices = [slice(0, size * step, step) for size, step in zip(size, steps, strict=True)]
|
||||
tensor = raw_tensor[tuple(slices)]
|
||||
|
||||
properties = {
|
||||
|
|
@ -326,7 +326,7 @@ class FuzzedTensor:
|
|||
|
||||
size = resolve(self._size, dim)
|
||||
steps = resolve(self._steps or (), dim)
|
||||
allocation_size = tuple(size_i * step_i for size_i, step_i in zip(size, steps))
|
||||
allocation_size = tuple(size_i * step_i for size_i, step_i in zip(size, steps, strict=True))
|
||||
return size, steps, allocation_size
|
||||
|
||||
def satisfies_constraints(self, params):
|
||||
|
|
|
|||
|
|
@ -196,7 +196,7 @@ def set_device_states(devices, states, *, device_type=None) -> None:
|
|||
if device_type == "meta":
|
||||
return
|
||||
device_module = _get_device_module(device_type)
|
||||
for device, state in zip(devices, states):
|
||||
for device, state in zip(devices, states, strict=False):
|
||||
with device_module.device(device):
|
||||
device_module.set_rng_state(state)
|
||||
|
||||
|
|
@ -794,7 +794,7 @@ class _NoopSaveInputs(torch.autograd.Function):
|
|||
# Only tensors can be saved with ctx.save_for_backward, everything else
|
||||
# is captured by get_args, which is saved directly on ctx
|
||||
tensor_indices, tensors = zip(
|
||||
*[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)]
|
||||
*[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)], strict=False
|
||||
)
|
||||
idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)}
|
||||
# args but with tensors replaced with None as placeholders
|
||||
|
|
@ -1020,7 +1020,7 @@ def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[Checkpoint
|
|||
def get_str_tb(label, capture_logs):
|
||||
out = ""
|
||||
total_len = len(capture_logs.logs)
|
||||
for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs)):
|
||||
for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs, strict=False)):
|
||||
out += f"{log} ({i + 1} of {total_len} in {label})\n\n"
|
||||
found_torch_dispatch = False
|
||||
for line in tb:
|
||||
|
|
|
|||
|
|
@ -2947,7 +2947,7 @@ e.
|
|||
|
||||
# Emit one build rule per source to enable incremental build.
|
||||
build = []
|
||||
for source_file, object_file in zip(sources, objects):
|
||||
for source_file, object_file in zip(sources, objects, strict=True):
|
||||
is_cuda_source = _is_cuda_file(source_file) and with_cuda
|
||||
is_sycl_source = _is_sycl_file(source_file) and with_sycl
|
||||
if is_cuda_source:
|
||||
|
|
|
|||
|
|
@ -197,7 +197,7 @@ def collate(
|
|||
return elem_type(
|
||||
*(
|
||||
collate(samples, collate_fn_map=collate_fn_map)
|
||||
for samples in zip(*batch)
|
||||
for samples in zip(*batch, strict=False)
|
||||
)
|
||||
)
|
||||
elif isinstance(elem, collections.abc.Sequence):
|
||||
|
|
@ -207,7 +207,9 @@ def collate(
|
|||
# pyrefly: ignore [not-iterable]
|
||||
if not all(len(elem) == elem_size for elem in it):
|
||||
raise RuntimeError("each element in list of batch should be of equal size")
|
||||
transposed = list(zip(*batch)) # It may be accessed twice, so we use a list.
|
||||
transposed = list(
|
||||
zip(*batch, strict=False)
|
||||
) # It may be accessed twice, so we use a list.
|
||||
|
||||
if isinstance(elem, tuple):
|
||||
return [
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@ def _issubtype_with_constraints(variant, constraints, recursive=True):
|
|||
and len(v_args) == len(c_args)
|
||||
and all(
|
||||
issubtype(v_arg, c_arg)
|
||||
for v_arg, c_arg in zip(v_args, c_args)
|
||||
for v_arg, c_arg in zip(v_args, c_args, strict=True)
|
||||
)
|
||||
):
|
||||
return True
|
||||
|
|
@ -207,7 +207,7 @@ def issubinstance(data, data_type):
|
|||
return True
|
||||
if len(dt_args) != len(data):
|
||||
return False
|
||||
return all(issubinstance(d, t) for d, t in zip(data, dt_args))
|
||||
return all(issubinstance(d, t) for d, t in zip(data, dt_args, strict=True))
|
||||
elif isinstance(data, (list, set)):
|
||||
if dt_args is None or len(dt_args) == 0:
|
||||
return True
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ class FilterDataFramesPipe(DFIterDataPipe):
|
|||
filter_res.append(self.filter_fn(df.iloc[i]))
|
||||
|
||||
buffer = []
|
||||
for df, res in zip(all_buffer, filter_res):
|
||||
for df, res in zip(all_buffer, filter_res, strict=True):
|
||||
if res:
|
||||
buffer.append(df)
|
||||
if len(buffer) == size:
|
||||
|
|
|
|||
|
|
@ -705,7 +705,7 @@ class ZipperIterDataPipe(IterDataPipe[tuple[_T_co]]):
|
|||
|
||||
def __iter__(self) -> Iterator[tuple[_T_co]]:
|
||||
iterators = [iter(datapipe) for datapipe in self.datapipes]
|
||||
yield from zip(*iterators)
|
||||
yield from zip(*iterators, strict=False)
|
||||
|
||||
def __len__(self) -> int:
|
||||
if all(isinstance(dp, Sized) for dp in self.datapipes):
|
||||
|
|
|
|||
|
|
@ -267,10 +267,10 @@ class StackDataset(Dataset[_T_stack]):
|
|||
"Nested dataset's output size mismatch."
|
||||
f" Expected {len(indices)}, got {len(items)}"
|
||||
)
|
||||
for data, d_sample in zip(items, dict_batch):
|
||||
for data, d_sample in zip(items, dict_batch, strict=True):
|
||||
d_sample[k] = data
|
||||
else:
|
||||
for idx, d_sample in zip(indices, dict_batch):
|
||||
for idx, d_sample in zip(indices, dict_batch, strict=True):
|
||||
d_sample[k] = dataset[idx]
|
||||
return dict_batch
|
||||
|
||||
|
|
@ -284,10 +284,10 @@ class StackDataset(Dataset[_T_stack]):
|
|||
"Nested dataset's output size mismatch."
|
||||
f" Expected {len(indices)}, got {len(items)}"
|
||||
)
|
||||
for data, t_sample in zip(items, list_batch):
|
||||
for data, t_sample in zip(items, list_batch, strict=True):
|
||||
t_sample.append(data)
|
||||
else:
|
||||
for idx, t_sample in zip(indices, list_batch):
|
||||
for idx, t_sample in zip(indices, list_batch, strict=True):
|
||||
t_sample.append(dataset[idx])
|
||||
tuple_batch: list[_T_tuple] = [tuple(sample) for sample in list_batch]
|
||||
return tuple_batch
|
||||
|
|
@ -477,5 +477,5 @@ def random_split(
|
|||
lengths = cast(Sequence[int], lengths)
|
||||
return [
|
||||
Subset(dataset, indices[offset - length : offset])
|
||||
for offset, length in zip(itertools.accumulate(lengths), lengths)
|
||||
for offset, length in zip(itertools.accumulate(lengths), lengths, strict=True)
|
||||
]
|
||||
|
|
|
|||
|
|
@ -335,7 +335,7 @@ class BatchSampler(Sampler[list[int]]):
|
|||
if self.drop_last:
|
||||
# Create multiple references to the same iterator
|
||||
args = [sampler_iter] * self.batch_size
|
||||
for batch_droplast in zip(*args):
|
||||
for batch_droplast in zip(*args, strict=False):
|
||||
yield [*batch_droplast]
|
||||
else:
|
||||
batch = [*itertools.islice(sampler_iter, self.batch_size)]
|
||||
|
|
|
|||
|
|
@ -341,7 +341,7 @@ def _unpack_flash_attention_nested_shapes(
|
|||
raise AssertionError("sdpa_flop_count: cum_seq_q and cum_seq_k must have the same shape")
|
||||
seq_q_lengths = _offsets_to_lengths(cum_seq_q, max_q)
|
||||
seq_k_lengths = _offsets_to_lengths(cum_seq_k, max_k)
|
||||
for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths):
|
||||
for (seq_q_len, seq_k_len) in zip(seq_q_lengths, seq_k_lengths, strict=True):
|
||||
new_query_shape = (1, h_q, seq_q_len, d_q)
|
||||
new_key_shape = (1, h_k, seq_k_len, d_k)
|
||||
new_value_shape = (1, h_v, seq_k_len, d_v)
|
||||
|
|
@ -396,7 +396,7 @@ def _unpack_efficient_attention_nested_shapes(
|
|||
"cu_seqlens_q and cu_seqlens_k must have the same shape")
|
||||
seqlens_q = _offsets_to_lengths(cu_seqlens_q, max_seqlen_q)
|
||||
seqlens_k = _offsets_to_lengths(cu_seqlens_k, max_seqlen_k)
|
||||
for len_q, len_k in zip(seqlens_q, seqlens_k):
|
||||
for len_q, len_k in zip(seqlens_q, seqlens_k, strict=True):
|
||||
new_query_shape = (1, h_q, len_q, d_q)
|
||||
new_key_shape = (1, h_k, len_k, d_k)
|
||||
new_value_shape = (1, h_v, len_k, d_v)
|
||||
|
|
|
|||
|
|
@ -114,7 +114,7 @@ class BackwardHook:
|
|||
|
||||
def _pack_with_none(self, indices, values, size):
|
||||
res = [None] * size
|
||||
for idx, val in zip(indices, values):
|
||||
for idx, val in zip(indices, values, strict=True):
|
||||
res[idx] = val
|
||||
|
||||
return tuple(res)
|
||||
|
|
@ -180,7 +180,7 @@ class BackwardHook:
|
|||
fn(grad_fns[0])
|
||||
|
||||
arg_list = list(args)
|
||||
for idx, val in zip(tensors_idx, new_tensors):
|
||||
for idx, val in zip(tensors_idx, new_tensors, strict=True):
|
||||
arg_list[idx] = val
|
||||
|
||||
if type(args) is tuple:
|
||||
|
|
|
|||
|
|
@ -167,7 +167,7 @@ class GraphPy:
|
|||
|
||||
def populate_namespace_from_OP_to_IO(self):
|
||||
for node in self.nodes_op:
|
||||
for node_output, outputSize in zip(node.outputs, node.outputstensor_size):
|
||||
for node_output, outputSize in zip(node.outputs, node.outputstensor_size, strict=True):
|
||||
self.scope_name_appeared.append(node.scopeName)
|
||||
self.nodes_io[node_output] = NodeBase(
|
||||
node_output,
|
||||
|
|
|
|||
|
|
@ -212,7 +212,7 @@ def object_annotation(obj):
|
|||
"""
|
||||
|
||||
def format_sequence(obj):
|
||||
body = ','.join(repr(x) if isinstance(x, BASE_TYPES) else type(x).__name__ for i, x in zip(range(8), obj))
|
||||
body = ','.join(repr(x) if isinstance(x, BASE_TYPES) else type(x).__name__ for x in obj[:8])
|
||||
if len(obj) > 8:
|
||||
body = f'{body}, ...{len(obj) - 8}'
|
||||
return body
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user