[Pyrefly][Refactor] Replace dict() calls with literal dict syntax for improved readability (#157735)

There are 31 places that I spotted which construct literal dictionaries.

This PR refactors dictionary construction by replacing` dict(...) `calls with `literal {...}` syntax where applicable.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/157735
Approved by: https://github.com/ezyang, https://github.com/Skylion007
This commit is contained in:
Zeina Migeed 2025-07-08 18:10:29 +00:00 committed by PyTorch MergeBot
parent 0f31445139
commit 4f5be56612
21 changed files with 348 additions and 334 deletions

View File

@ -211,12 +211,12 @@ inline void {{kernel_name}}(
) -> str: ) -> str:
buffer_size = " * ".join(map(str, size_args)) buffer_size = " * ".join(map(str, size_args))
return KernelTemplate._template_from_string(self.ALLOCATE_WEIGHT_BUFFER).render( return KernelTemplate._template_from_string(self.ALLOCATE_WEIGHT_BUFFER).render(
dict( {
buffer_name=buffer_name, "buffer_name": buffer_name,
buffer_dtype=buffer_dtype, "buffer_dtype": buffer_dtype,
buffer_size=buffer_size, "buffer_size": buffer_size,
is_msvc_compiler=cpp_builder.is_msvc_cl(), "is_msvc_compiler": cpp_builder.is_msvc_cl(),
) }
) )
def is_woq_int4(self): def is_woq_int4(self):

View File

@ -1179,27 +1179,27 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
instance_definition, instance_type = self._define_gemm_instance(op, evt_name) instance_definition, instance_type = self._define_gemm_instance(op, evt_name)
options = dict( options = {
alpha=self.alpha, "alpha": self.alpha,
beta=self.beta, "beta": self.beta,
X=X, "X": X,
W=W, "W": W,
Y=Y, "Y": Y,
kernel_call_signature=kernel_call_signature, "kernel_call_signature": kernel_call_signature,
Bias=Bias, "Bias": Bias,
epilogue_template=epilogue_template, "epilogue_template": epilogue_template,
argument_template=argument_template, "argument_template": argument_template,
should_swap_xw=should_swap_xw, "should_swap_xw": should_swap_xw,
template=self, "template": self,
kernel=kernel, "kernel": kernel,
instance_definition=instance_definition, "instance_definition": instance_definition,
instance_type=instance_type, "instance_type": instance_type,
input_reorder=self.input_reorder, "input_reorder": self.input_reorder,
epilogue_args=evt_args, "epilogue_args": evt_args,
test_call_statement=test_call_statement, "test_call_statement": test_call_statement,
op_conf_name=op.configuration_name(), "op_conf_name": op.configuration_name(),
epilogue_visitor_tree=evt_code, "epilogue_visitor_tree": evt_code,
) }
options.update(dict(zip(extra_names, extra_inputs))) options.update(dict(zip(extra_names, extra_inputs)))
res = self._template_from_string(self._get_template()).render(**options) res = self._template_from_string(self._get_template()).render(**options)
if inductor_cuda_config.generate_test_runner and not is_dynamic(X, W, Y, Bias): if inductor_cuda_config.generate_test_runner and not is_dynamic(X, W, Y, Bias):
@ -1587,19 +1587,19 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped
before the function call. before the function call.
""" """
options = dict( options = {
alpha=alpha, "alpha": alpha,
beta=beta, "beta": beta,
X=X, "X": X,
W=W, "W": W,
Y=Y, "Y": Y,
Bias=Bias, "Bias": Bias,
template=self, "template": self,
kernel=kernel, "kernel": kernel,
M="M", "M": "M",
N="N", "N": "N",
epilogue_args=epilogue_args, "epilogue_args": epilogue_args,
) }
assert epilogue_template is not None assert epilogue_template is not None
if should_swap_xw: if should_swap_xw:
@ -1878,21 +1878,21 @@ class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate):
tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped
before the function call. before the function call.
""" """
options = dict( options = {
instance_type=instance_type, "instance_type": instance_type,
alpha=alpha, "alpha": alpha,
beta=beta, "beta": beta,
X=X, "X": X,
W=W, "W": W,
Y=Y, "Y": Y,
Bias=Bias, "Bias": Bias,
Meta=Meta, "Meta": Meta,
template=self, "template": self,
kernel=kernel, "kernel": kernel,
M="M", "M": "M",
N="N", "N": "N",
epilogue_args=epilogue_args, "epilogue_args": epilogue_args,
) }
if epilogue_template is None: if epilogue_template is None:
arguments = self._template_from_string(argument_template).render( arguments = self._template_from_string(argument_template).render(

View File

@ -86,12 +86,12 @@ def tma_options() -> dict[str, Any]:
def persistent_mm_options(mat1, mat2): def persistent_mm_options(mat1, mat2):
res = dict( res = {
A_ROW_MAJOR=not mat1.layout.is_transposed(), "A_ROW_MAJOR": not mat1.layout.is_transposed(),
B_ROW_MAJOR=not mat2.layout.is_transposed(), "B_ROW_MAJOR": not mat2.layout.is_transposed(),
NUM_SMS=get_num_sms(), "NUM_SMS": get_num_sms(),
TMA_SIZE=TMA_DESCRIPTOR_SIZE, "TMA_SIZE": TMA_DESCRIPTOR_SIZE,
) }
res.update(tma_options()) res.update(tma_options())
return res return res

View File

@ -147,12 +147,12 @@ def grouped_gemm_lowering(
choices: list[ChoiceCaller] = [] choices: list[ChoiceCaller] = []
*_, layout, x, _ = mm_args(x, permute(w[0], [1, 0]), layout=layout) *_, layout, x, _ = mm_args(x, permute(w[0], [1, 0]), layout=layout)
kwargs = dict( kwargs = {
has_bias=[bias is not None for bias in b], "has_bias": [bias is not None for bias in b],
trans_w=True, "trans_w": True,
epilogue_creator=None, "epilogue_creator": None,
act_mapping=dict.fromkeys(range(num_gemm), x), "act_mapping": dict.fromkeys(range(num_gemm), x),
) }
input_nodes = [x, *w] input_nodes = [x, *w]
input_nodes.extend([bias for bias in b if bias is not None]) input_nodes.extend([bias for bias in b if bias is not None])
@ -353,11 +353,13 @@ def register_onednn_fusion_ops():
buf, attr, scalars=scalars, algorithm=algorithm buf, attr, scalars=scalars, algorithm=algorithm
) )
kwargs = dict( kwargs = {
has_bias=b is not None, "has_bias": b is not None,
trans_w=True, "trans_w": True,
epilogue_creator=None if attr == "none" else epilogue_creator, "epilogue_creator": (
) None if attr == "none" else epilogue_creator
),
}
if b is not None: if b is not None:
kwargs["input_indices"] = [2, 0, 1] # type: ignore[assignment] kwargs["input_indices"] = [2, 0, 1] # type: ignore[assignment]
CppGemmTemplate.add_choices( CppGemmTemplate.add_choices(
@ -416,11 +418,12 @@ def register_onednn_fusion_ops():
def epilogue_creator(buf): def epilogue_creator(buf):
return create_epilogue_with_attr(buf, attr, other=y) return create_epilogue_with_attr(buf, attr, other=y)
kwargs = dict( kwargs = {
has_bias=b is not None, "has_bias": b is not None,
trans_w=True, "trans_w": True,
epilogue_creator=epilogue_creator, "epilogue_creator": epilogue_creator,
) }
kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1] kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1]
CppGemmTemplate.add_choices( CppGemmTemplate.add_choices(
choices, choices,

View File

@ -166,123 +166,123 @@ Example::
""", """,
) )
args_and_kwargs = dict( args_and_kwargs = {
# argument name sufficies separated by double underscore will # argument name sufficies separated by double underscore will
# be removed in the final documentation string. # be removed in the final documentation string.
sum=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), "sum": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
prod=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), "prod": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
cumsum=(("dim__as_int",), ("dtype=None", "mask=None")), "cumsum": (("dim__as_int",), ("dtype=None", "mask=None")),
cumprod=(("dim__as_int",), ("dtype=None", "mask=None")), "cumprod": (("dim__as_int",), ("dtype=None", "mask=None")),
amin=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), "amin": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
amax=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), "amax": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
argmin=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), "argmin": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
argmax=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), "argmax": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
mean=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), "mean": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
median=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")), "median": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
norm=( "norm": (
( (
"ord", "ord",
"dim", "dim",
), ),
("keepdim=False", "dtype=None", "mask=None"), ("keepdim=False", "dtype=None", "mask=None"),
), ),
var=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")), "var": (("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
std=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")), "std": (("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
logsumexp=(("dim",), ("keepdim=False", "dtype=None", "mask=None")), "logsumexp": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
softmax=(("dim__as_int",), ("dtype=None", "mask=None")), "softmax": (("dim__as_int",), ("dtype=None", "mask=None")),
log_softmax=(("dim__as_int",), ("dtype=None", "mask=None")), "log_softmax": (("dim__as_int",), ("dtype=None", "mask=None")),
softmin=(("dim__as_int",), ("dtype=None", "mask=None")), "softmin": (("dim__as_int",), ("dtype=None", "mask=None")),
normalize=( "normalize": (
( (
"ord__required", "ord__required",
"dim__as_int", "dim__as_int",
), ),
("eps=1e-12", "dtype=None", "mask=None"), ("eps=1e-12", "dtype=None", "mask=None"),
), ),
) }
argument_declarations = dict( argument_declarations = {
dim="""\ "dim": """\
dim (int or tuple of ints, optional): the dimension or dimensions to reduce. dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
Default: None that is equivalent to ``tuple(range(input.ndim))``.""", Default: None that is equivalent to ``tuple(range(input.ndim))``.""",
dim__as_int="""\ "dim__as_int": """\
dim (int): the dimension along which {operation name} is computed.""", dim (int): the dimension along which {operation name} is computed.""",
ord="""\ "ord": """\
ord (int, float, optional): the order of vector norm. Default: 2. ord (int, float, optional): the order of vector norm. Default: 2.
See :func:`torch.linalg.vector_norm` for a list of supported norms.""", See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
ord__required="""\ "ord__required": """\
ord (int, float): the order of vector norm. Default: 2. ord (int, float): the order of vector norm. Default: 2.
See :func:`torch.linalg.vector_norm` for a list of supported norms.""", See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
unbiased="""\ "unbiased": """\
unbiased (bool): when True, use Bessel's correction, otherwise, compute unbiased (bool): when True, use Bessel's correction, otherwise, compute
the uncorrected sample variance.""", the uncorrected sample variance.""",
eps="""\ "eps": """\
eps (float, optional): small value to avoid division by zero. Default: {default}.""", eps (float, optional): small value to avoid division by zero. Default: {default}.""",
keepdim="""\ "keepdim": """\
keepdim (bool, optional): whether the output tensor has keepdim (bool, optional): whether the output tensor has
:attr:`dim` retained or not. Default: {default}.""", :attr:`dim` retained or not. Default: {default}.""",
dtype="""\ "dtype": """\
dtype (:class:`torch.dtype`, optional): the desired data type dtype (:class:`torch.dtype`, optional): the desired data type
of returned tensor. If specified, the input tensor is of returned tensor. If specified, the input tensor is
casted to :attr:`dtype` before the operation is casted to :attr:`dtype` before the operation is
performed. Default: {default}.""", performed. Default: {default}.""",
mask="""\ "mask": """\
mask (:class:`torch.Tensor`, optional): the boolean tensor mask (:class:`torch.Tensor`, optional): the boolean tensor
containing the binary mask of validity of input tensor containing the binary mask of validity of input tensor
elements. elements.
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""", Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""",
) }
definitions = dict( definitions = {
softmax="""\ "softmax": """\
Let ``x`` be a sequence of unmasked elements of one-dimensional slice Let ``x`` be a sequence of unmasked elements of one-dimensional slice
of the :attr:`input` tensor. Softmax of i-th element in ``x`` is of the :attr:`input` tensor. Softmax of i-th element in ``x`` is
defined as ``exp(x[i])/sum(exp(x))``.""", defined as ``exp(x[i])/sum(exp(x))``.""",
log_softmax="""\ "log_softmax": """\
Let ``x`` be a sequence of unmasked elements of one-dimensional slice Let ``x`` be a sequence of unmasked elements of one-dimensional slice
of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is
defined as ``log(exp(x[i])/sum(exp(x)))``.""", defined as ``log(exp(x[i])/sum(exp(x)))``.""",
softmin="""\ "softmin": """\
Let ``x`` be a sequence of unmasked elements of one-dimensional slice Let ``x`` be a sequence of unmasked elements of one-dimensional slice
of the :attr:`input` tensor. Softmin of i-th element in ``x`` is of the :attr:`input` tensor. Softmin of i-th element in ``x`` is
defined as ``exp(-x[i])/sum(exp(-x))``.""", defined as ``exp(-x[i])/sum(exp(-x))``.""",
normalize="""\ "normalize": """\
Let ``x`` be a sequence of unmasked elements of one-dimensional slice Let ``x`` be a sequence of unmasked elements of one-dimensional slice
of the :attr:`input` tensor. Normalize of i-th element in ``x`` is of the :attr:`input` tensor. Normalize of i-th element in ``x`` is
defined as ``x[i]/max(norm(x, p), eps)``.""", defined as ``x[i]/max(norm(x, p), eps)``.""",
cumsum="""\ "cumsum": """\
Let ``x`` be a sequence of unmasked elements of one-dimensional slice Let ``x`` be a sequence of unmasked elements of one-dimensional slice
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
defined as ``sum(x[:i])``.""", defined as ``sum(x[:i])``.""",
cumprod="""\ "cumprod": """\
Let ``x`` be a sequence of unmasked elements of one-dimensional slice Let ``x`` be a sequence of unmasked elements of one-dimensional slice
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
defined as ``prod(x[:i])``.""", defined as ``prod(x[:i])``.""",
) }
reduction_names = dict( reduction_names = {
sum="sum", "sum": "sum",
prod="product", "prod": "product",
amax="maximum", "amax": "maximum",
amin="minimum", "amin": "minimum",
argmax="argmax", "argmax": "argmax",
argmin="argmin", "argmin": "argmin",
mean="mean", "mean": "mean",
median="median", "median": "median",
norm="norm", "norm": "norm",
var="variance", "var": "variance",
std="standard_deviation", "std": "standard_deviation",
logsumexp="logsumexp", "logsumexp": "logsumexp",
) }
normalization_names = dict( normalization_names = {
softmax="softmax", "softmax": "softmax",
log_softmax="log_softmax", "log_softmax": "log_softmax",
softmin="softmin", "softmin": "softmin",
normalize="normalize", "normalize": "normalize",
cumsum="cumulative_sum", "cumsum": "cumulative_sum",
cumprod="cumulative_prod", "cumprod": "cumulative_prod",
) }
operation_names = {} operation_names = {}
operation_names.update(reduction_names) operation_names.update(reduction_names)

View File

@ -47,15 +47,15 @@ class Adafactor(Optimizer):
raise ValueError(f"Clipping threshold d should be >= 1 but is: {d}") raise ValueError(f"Clipping threshold d should be >= 1 but is: {d}")
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError(f"weight_decay should be >= 0 but is: {weight_decay}") raise ValueError(f"weight_decay should be >= 0 but is: {weight_decay}")
defaults = dict( defaults = {
lr=lr, "lr": lr,
beta2_decay=beta2_decay, "beta2_decay": beta2_decay,
eps=eps, "eps": eps,
d=d, "d": d,
weight_decay=weight_decay, "weight_decay": weight_decay,
foreach=foreach, "foreach": foreach,
maximize=maximize, "maximize": maximize,
) }
super().__init__(params, defaults) super().__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):

View File

@ -50,16 +50,16 @@ class Adadelta(Optimizer):
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}") raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict( defaults = {
lr=lr, "lr": lr,
rho=rho, "rho": rho,
eps=eps, "eps": eps,
weight_decay=weight_decay, "weight_decay": weight_decay,
maximize=maximize, "maximize": maximize,
capturable=capturable, "capturable": capturable,
foreach=foreach, "foreach": foreach,
differentiable=differentiable, "differentiable": differentiable,
) }
super().__init__(params, defaults) super().__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):

View File

@ -54,17 +54,17 @@ class Adagrad(Optimizer):
if not 0.0 <= eps: if not 0.0 <= eps:
raise ValueError(f"Invalid epsilon value: {eps}") raise ValueError(f"Invalid epsilon value: {eps}")
defaults = dict( defaults = {
lr=lr, "lr": lr,
lr_decay=lr_decay, "lr_decay": lr_decay,
eps=eps, "eps": eps,
weight_decay=weight_decay, "weight_decay": weight_decay,
initial_accumulator_value=initial_accumulator_value, "initial_accumulator_value": initial_accumulator_value,
foreach=foreach, "foreach": foreach,
maximize=maximize, "maximize": maximize,
differentiable=differentiable, "differentiable": differentiable,
fused=fused, "fused": fused,
) }
super().__init__(params, defaults) super().__init__(params, defaults)
if fused: if fused:

View File

@ -85,19 +85,19 @@ class Adam(Optimizer):
if betas[1].numel() != 1: if betas[1].numel() != 1:
raise ValueError("Tensor betas[1] must be 1-element") raise ValueError("Tensor betas[1] must be 1-element")
defaults = dict( defaults = {
lr=lr, "lr": lr,
betas=betas, "betas": betas,
eps=eps, "eps": eps,
weight_decay=weight_decay, "weight_decay": weight_decay,
amsgrad=amsgrad, "amsgrad": amsgrad,
maximize=maximize, "maximize": maximize,
foreach=foreach, "foreach": foreach,
capturable=capturable, "capturable": capturable,
differentiable=differentiable, "differentiable": differentiable,
fused=fused, "fused": fused,
decoupled_weight_decay=decoupled_weight_decay, "decoupled_weight_decay": decoupled_weight_decay,
) }
super().__init__(params, defaults) super().__init__(params, defaults)
if fused: if fused:

View File

@ -53,16 +53,16 @@ class Adamax(Optimizer):
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}") raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict( defaults = {
lr=lr, "lr": lr,
betas=betas, "betas": betas,
eps=eps, "eps": eps,
weight_decay=weight_decay, "weight_decay": weight_decay,
foreach=foreach, "foreach": foreach,
maximize=maximize, "maximize": maximize,
differentiable=differentiable, "differentiable": differentiable,
capturable=capturable, "capturable": capturable,
) }
super().__init__(params, defaults) super().__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):

View File

@ -47,17 +47,17 @@ class ASGD(Optimizer):
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}") raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict( defaults = {
lr=lr, "lr": lr,
lambd=lambd, "lambd": lambd,
alpha=alpha, "alpha": alpha,
t0=t0, "t0": t0,
weight_decay=weight_decay, "weight_decay": weight_decay,
foreach=foreach, "foreach": foreach,
maximize=maximize, "maximize": maximize,
differentiable=differentiable, "differentiable": differentiable,
capturable=capturable, "capturable": capturable,
) }
super().__init__(params, defaults) super().__init__(params, defaults)
def __setstate__(self, state): def __setstate__(self, state):

View File

@ -231,15 +231,15 @@ class LBFGS(Optimizer):
raise ValueError(f"Invalid learning rate: {lr}") raise ValueError(f"Invalid learning rate: {lr}")
if max_eval is None: if max_eval is None:
max_eval = max_iter * 5 // 4 max_eval = max_iter * 5 // 4
defaults = dict( defaults = {
lr=lr, "lr": lr,
max_iter=max_iter, "max_iter": max_iter,
max_eval=max_eval, "max_eval": max_eval,
tolerance_grad=tolerance_grad, "tolerance_grad": tolerance_grad,
tolerance_change=tolerance_change, "tolerance_change": tolerance_change,
history_size=history_size, "history_size": history_size,
line_search_fn=line_search_fn, "line_search_fn": line_search_fn,
) }
super().__init__(params, defaults) super().__init__(params, defaults)
if len(self.param_groups) != 1: if len(self.param_groups) != 1:

View File

@ -59,18 +59,18 @@ class NAdam(Optimizer): # noqa: D101
raise ValueError(f"Invalid weight_decay value: {weight_decay}") raise ValueError(f"Invalid weight_decay value: {weight_decay}")
if not 0.0 <= momentum_decay: if not 0.0 <= momentum_decay:
raise ValueError(f"Invalid momentum_decay value: {momentum_decay}") raise ValueError(f"Invalid momentum_decay value: {momentum_decay}")
defaults = dict( defaults = {
lr=lr, "lr": lr,
betas=betas, "betas": betas,
eps=eps, "eps": eps,
weight_decay=weight_decay, "weight_decay": weight_decay,
momentum_decay=momentum_decay, "momentum_decay": momentum_decay,
decoupled_weight_decay=decoupled_weight_decay, "decoupled_weight_decay": decoupled_weight_decay,
maximize=maximize, "maximize": maximize,
foreach=foreach, "foreach": foreach,
capturable=capturable, "capturable": capturable,
differentiable=differentiable, "differentiable": differentiable,
) }
super().__init__(params, defaults) super().__init__(params, defaults)
def __setstate__(self, state): # noqa: D105 def __setstate__(self, state): # noqa: D105

View File

@ -56,17 +56,17 @@ class RAdam(Optimizer): # noqa: D101
if not 0.0 <= weight_decay: if not 0.0 <= weight_decay:
raise ValueError(f"Invalid weight_decay value: {weight_decay}") raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict( defaults = {
lr=lr, "lr": lr,
betas=betas, "betas": betas,
eps=eps, "eps": eps,
weight_decay=weight_decay, "weight_decay": weight_decay,
maximize=maximize, "maximize": maximize,
foreach=foreach, "foreach": foreach,
capturable=capturable, "capturable": capturable,
decoupled_weight_decay=decoupled_weight_decay, "decoupled_weight_decay": decoupled_weight_decay,
differentiable=differentiable, "differentiable": differentiable,
) }
super().__init__(params, defaults) super().__init__(params, defaults)
def __setstate__(self, state): # noqa: D105 def __setstate__(self, state): # noqa: D105

View File

@ -55,18 +55,18 @@ class RMSprop(Optimizer): # noqa: D101
if not 0.0 <= alpha: if not 0.0 <= alpha:
raise ValueError(f"Invalid alpha value: {alpha}") raise ValueError(f"Invalid alpha value: {alpha}")
defaults = dict( defaults = {
lr=lr, "lr": lr,
momentum=momentum, "momentum": momentum,
alpha=alpha, "alpha": alpha,
eps=eps, "eps": eps,
centered=centered, "centered": centered,
weight_decay=weight_decay, "weight_decay": weight_decay,
capturable=capturable, "capturable": capturable,
foreach=foreach, "foreach": foreach,
maximize=maximize, "maximize": maximize,
differentiable=differentiable, "differentiable": differentiable,
) }
super().__init__(params, defaults) super().__init__(params, defaults)
def __setstate__(self, state): # noqa: D105 def __setstate__(self, state): # noqa: D105

View File

@ -47,15 +47,15 @@ class Rprop(Optimizer): # noqa: D101
if not 0.0 < etas[0] < 1.0 < etas[1]: if not 0.0 < etas[0] < 1.0 < etas[1]:
raise ValueError(f"Invalid eta values: {etas[0]}, {etas[1]}") raise ValueError(f"Invalid eta values: {etas[0]}, {etas[1]}")
defaults = dict( defaults = {
lr=lr, "lr": lr,
etas=etas, "etas": etas,
step_sizes=step_sizes, "step_sizes": step_sizes,
foreach=foreach, "foreach": foreach,
maximize=maximize, "maximize": maximize,
differentiable=differentiable, "differentiable": differentiable,
capturable=capturable, "capturable": capturable,
) }
super().__init__(params, defaults) super().__init__(params, defaults)
def __setstate__(self, state): # noqa: D105 def __setstate__(self, state): # noqa: D105

View File

@ -49,17 +49,17 @@ class SGD(Optimizer): # noqa: D101
if weight_decay < 0.0: if weight_decay < 0.0:
raise ValueError(f"Invalid weight_decay value: {weight_decay}") raise ValueError(f"Invalid weight_decay value: {weight_decay}")
defaults = dict( defaults = {
lr=lr, "lr": lr,
momentum=momentum, "momentum": momentum,
dampening=dampening, "dampening": dampening,
weight_decay=weight_decay, "weight_decay": weight_decay,
nesterov=nesterov, "nesterov": nesterov,
maximize=maximize, "maximize": maximize,
foreach=foreach, "foreach": foreach,
differentiable=differentiable, "differentiable": differentiable,
fused=fused, "fused": fused,
) }
if nesterov and (momentum <= 0 or dampening != 0): if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening") raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super().__init__(params, defaults) super().__init__(params, defaults)

View File

@ -31,7 +31,12 @@ class SparseAdam(Optimizer):
if not 0.0 <= betas[1] < 1.0: if not 0.0 <= betas[1] < 1.0:
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}") raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
defaults = dict(lr=lr, betas=betas, eps=eps, maximize=maximize) defaults = {
"lr": lr,
"betas": betas,
"eps": eps,
"maximize": maximize,
}
super().__init__(params, defaults) super().__init__(params, defaults)
sparse_params = [] sparse_params = []

View File

@ -1109,15 +1109,15 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
return res return res
return None return None
sys_info = dict( sys_info = {
protocol_version=PROTOCOL_VERSION, "protocol_version": PROTOCOL_VERSION,
little_endian=sys.byteorder == "little", "little_endian": sys.byteorder == "little",
type_sizes=dict( "type_sizes": {
short=SHORT_SIZE, "short": SHORT_SIZE,
int=INT_SIZE, "int": INT_SIZE,
long=LONG_SIZE, "long": LONG_SIZE,
), },
) }
pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol) pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol) pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)

View File

@ -598,7 +598,10 @@ def as_sparse_gradcheck(gradcheck):
and obj.requires_grad and obj.requires_grad
and obj.layout in sparse_layouts and obj.layout in sparse_layouts
): ):
d = dict(layout=obj.layout, shape=obj.shape) d = {
"layout": obj.layout,
"shape": obj.shape,
}
if not masked: if not masked:
# Materialize unspecified elements with zero values # Materialize unspecified elements with zero values
batch_dim = obj.ndim - obj.dense_dim() - obj.sparse_dim() batch_dim = obj.ndim - obj.dense_dim() - obj.sparse_dim()

View File

@ -213,13 +213,14 @@ def get_model_info(
path_prefix = prefix path_prefix = prefix
elif prefix != path_prefix: elif prefix != path_prefix:
raise Exception(f"Mismatched prefixes: {path_prefix} != {prefix}") # noqa: TRY002 raise Exception(f"Mismatched prefixes: {path_prefix} != {prefix}") # noqa: TRY002
zip_files.append(dict( zip_files.append(
filename=zi.filename, {
compression=zi.compress_type, "filename": zi.filename,
compressed_size=zi.compress_size, "compression": zi.compress_type,
file_size=zi.file_size, "compressed_size": zi.compress_size,
)) "file_size": zi.file_size,
}
)
assert path_prefix is not None assert path_prefix is not None
version = zf.read(path_prefix + "/version").decode("utf-8").strip() version = zf.read(path_prefix + "/version").decode("utf-8").strip()
@ -332,18 +333,20 @@ def get_model_info(
continue continue
extra_pickles[zi.filename] = contents extra_pickles[zi.filename] = contents
return {"model": dict( return {
title=title, "model": {
file_size=file_size, "title": title,
version=version, "file_size": file_size,
zip_files=zip_files, "version": version,
interned_strings=list(interned_strings), "zip_files": zip_files,
code_files=code_files, "interned_strings": list(interned_strings),
model_data=model_data, "code_files": code_files,
constants=constants, "model_data": model_data,
extra_files_jsons=extra_files_jsons, "constants": constants,
extra_pickles=extra_pickles, "extra_files_jsons": extra_files_jsons,
)} "extra_pickles": extra_pickles,
}
}
def get_inline_skeleton(): def get_inline_skeleton():