mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[Pyrefly][Refactor] Replace dict() calls with literal dict syntax for improved readability (#157735)
There are 31 places that I spotted which construct literal dictionaries.
This PR refactors dictionary construction by replacing` dict(...) `calls with `literal {...}` syntax where applicable.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/157735
Approved by: https://github.com/ezyang, https://github.com/Skylion007
This commit is contained in:
parent
0f31445139
commit
4f5be56612
|
|
@ -211,12 +211,12 @@ inline void {{kernel_name}}(
|
||||||
) -> str:
|
) -> str:
|
||||||
buffer_size = " * ".join(map(str, size_args))
|
buffer_size = " * ".join(map(str, size_args))
|
||||||
return KernelTemplate._template_from_string(self.ALLOCATE_WEIGHT_BUFFER).render(
|
return KernelTemplate._template_from_string(self.ALLOCATE_WEIGHT_BUFFER).render(
|
||||||
dict(
|
{
|
||||||
buffer_name=buffer_name,
|
"buffer_name": buffer_name,
|
||||||
buffer_dtype=buffer_dtype,
|
"buffer_dtype": buffer_dtype,
|
||||||
buffer_size=buffer_size,
|
"buffer_size": buffer_size,
|
||||||
is_msvc_compiler=cpp_builder.is_msvc_cl(),
|
"is_msvc_compiler": cpp_builder.is_msvc_cl(),
|
||||||
)
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_woq_int4(self):
|
def is_woq_int4(self):
|
||||||
|
|
|
||||||
|
|
@ -1179,27 +1179,27 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
||||||
|
|
||||||
instance_definition, instance_type = self._define_gemm_instance(op, evt_name)
|
instance_definition, instance_type = self._define_gemm_instance(op, evt_name)
|
||||||
|
|
||||||
options = dict(
|
options = {
|
||||||
alpha=self.alpha,
|
"alpha": self.alpha,
|
||||||
beta=self.beta,
|
"beta": self.beta,
|
||||||
X=X,
|
"X": X,
|
||||||
W=W,
|
"W": W,
|
||||||
Y=Y,
|
"Y": Y,
|
||||||
kernel_call_signature=kernel_call_signature,
|
"kernel_call_signature": kernel_call_signature,
|
||||||
Bias=Bias,
|
"Bias": Bias,
|
||||||
epilogue_template=epilogue_template,
|
"epilogue_template": epilogue_template,
|
||||||
argument_template=argument_template,
|
"argument_template": argument_template,
|
||||||
should_swap_xw=should_swap_xw,
|
"should_swap_xw": should_swap_xw,
|
||||||
template=self,
|
"template": self,
|
||||||
kernel=kernel,
|
"kernel": kernel,
|
||||||
instance_definition=instance_definition,
|
"instance_definition": instance_definition,
|
||||||
instance_type=instance_type,
|
"instance_type": instance_type,
|
||||||
input_reorder=self.input_reorder,
|
"input_reorder": self.input_reorder,
|
||||||
epilogue_args=evt_args,
|
"epilogue_args": evt_args,
|
||||||
test_call_statement=test_call_statement,
|
"test_call_statement": test_call_statement,
|
||||||
op_conf_name=op.configuration_name(),
|
"op_conf_name": op.configuration_name(),
|
||||||
epilogue_visitor_tree=evt_code,
|
"epilogue_visitor_tree": evt_code,
|
||||||
)
|
}
|
||||||
options.update(dict(zip(extra_names, extra_inputs)))
|
options.update(dict(zip(extra_names, extra_inputs)))
|
||||||
res = self._template_from_string(self._get_template()).render(**options)
|
res = self._template_from_string(self._get_template()).render(**options)
|
||||||
if inductor_cuda_config.generate_test_runner and not is_dynamic(X, W, Y, Bias):
|
if inductor_cuda_config.generate_test_runner and not is_dynamic(X, W, Y, Bias):
|
||||||
|
|
@ -1587,19 +1587,19 @@ class CUTLASS3xGemmTemplate(CUTLASSGemmTemplate):
|
||||||
tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped
|
tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped
|
||||||
before the function call.
|
before the function call.
|
||||||
"""
|
"""
|
||||||
options = dict(
|
options = {
|
||||||
alpha=alpha,
|
"alpha": alpha,
|
||||||
beta=beta,
|
"beta": beta,
|
||||||
X=X,
|
"X": X,
|
||||||
W=W,
|
"W": W,
|
||||||
Y=Y,
|
"Y": Y,
|
||||||
Bias=Bias,
|
"Bias": Bias,
|
||||||
template=self,
|
"template": self,
|
||||||
kernel=kernel,
|
"kernel": kernel,
|
||||||
M="M",
|
"M": "M",
|
||||||
N="N",
|
"N": "N",
|
||||||
epilogue_args=epilogue_args,
|
"epilogue_args": epilogue_args,
|
||||||
)
|
}
|
||||||
assert epilogue_template is not None
|
assert epilogue_template is not None
|
||||||
|
|
||||||
if should_swap_xw:
|
if should_swap_xw:
|
||||||
|
|
@ -1878,21 +1878,21 @@ class CUTLASS2xGemmTemplate(CUTLASSGemmTemplate):
|
||||||
tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped
|
tensors. This operation also implies the M and N dimensions of Bias and GEMM output to be swapped
|
||||||
before the function call.
|
before the function call.
|
||||||
"""
|
"""
|
||||||
options = dict(
|
options = {
|
||||||
instance_type=instance_type,
|
"instance_type": instance_type,
|
||||||
alpha=alpha,
|
"alpha": alpha,
|
||||||
beta=beta,
|
"beta": beta,
|
||||||
X=X,
|
"X": X,
|
||||||
W=W,
|
"W": W,
|
||||||
Y=Y,
|
"Y": Y,
|
||||||
Bias=Bias,
|
"Bias": Bias,
|
||||||
Meta=Meta,
|
"Meta": Meta,
|
||||||
template=self,
|
"template": self,
|
||||||
kernel=kernel,
|
"kernel": kernel,
|
||||||
M="M",
|
"M": "M",
|
||||||
N="N",
|
"N": "N",
|
||||||
epilogue_args=epilogue_args,
|
"epilogue_args": epilogue_args,
|
||||||
)
|
}
|
||||||
|
|
||||||
if epilogue_template is None:
|
if epilogue_template is None:
|
||||||
arguments = self._template_from_string(argument_template).render(
|
arguments = self._template_from_string(argument_template).render(
|
||||||
|
|
|
||||||
|
|
@ -86,12 +86,12 @@ def tma_options() -> dict[str, Any]:
|
||||||
|
|
||||||
|
|
||||||
def persistent_mm_options(mat1, mat2):
|
def persistent_mm_options(mat1, mat2):
|
||||||
res = dict(
|
res = {
|
||||||
A_ROW_MAJOR=not mat1.layout.is_transposed(),
|
"A_ROW_MAJOR": not mat1.layout.is_transposed(),
|
||||||
B_ROW_MAJOR=not mat2.layout.is_transposed(),
|
"B_ROW_MAJOR": not mat2.layout.is_transposed(),
|
||||||
NUM_SMS=get_num_sms(),
|
"NUM_SMS": get_num_sms(),
|
||||||
TMA_SIZE=TMA_DESCRIPTOR_SIZE,
|
"TMA_SIZE": TMA_DESCRIPTOR_SIZE,
|
||||||
)
|
}
|
||||||
res.update(tma_options())
|
res.update(tma_options())
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -147,12 +147,12 @@ def grouped_gemm_lowering(
|
||||||
choices: list[ChoiceCaller] = []
|
choices: list[ChoiceCaller] = []
|
||||||
*_, layout, x, _ = mm_args(x, permute(w[0], [1, 0]), layout=layout)
|
*_, layout, x, _ = mm_args(x, permute(w[0], [1, 0]), layout=layout)
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = {
|
||||||
has_bias=[bias is not None for bias in b],
|
"has_bias": [bias is not None for bias in b],
|
||||||
trans_w=True,
|
"trans_w": True,
|
||||||
epilogue_creator=None,
|
"epilogue_creator": None,
|
||||||
act_mapping=dict.fromkeys(range(num_gemm), x),
|
"act_mapping": dict.fromkeys(range(num_gemm), x),
|
||||||
)
|
}
|
||||||
|
|
||||||
input_nodes = [x, *w]
|
input_nodes = [x, *w]
|
||||||
input_nodes.extend([bias for bias in b if bias is not None])
|
input_nodes.extend([bias for bias in b if bias is not None])
|
||||||
|
|
@ -353,11 +353,13 @@ def register_onednn_fusion_ops():
|
||||||
buf, attr, scalars=scalars, algorithm=algorithm
|
buf, attr, scalars=scalars, algorithm=algorithm
|
||||||
)
|
)
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = {
|
||||||
has_bias=b is not None,
|
"has_bias": b is not None,
|
||||||
trans_w=True,
|
"trans_w": True,
|
||||||
epilogue_creator=None if attr == "none" else epilogue_creator,
|
"epilogue_creator": (
|
||||||
)
|
None if attr == "none" else epilogue_creator
|
||||||
|
),
|
||||||
|
}
|
||||||
if b is not None:
|
if b is not None:
|
||||||
kwargs["input_indices"] = [2, 0, 1] # type: ignore[assignment]
|
kwargs["input_indices"] = [2, 0, 1] # type: ignore[assignment]
|
||||||
CppGemmTemplate.add_choices(
|
CppGemmTemplate.add_choices(
|
||||||
|
|
@ -416,11 +418,12 @@ def register_onednn_fusion_ops():
|
||||||
def epilogue_creator(buf):
|
def epilogue_creator(buf):
|
||||||
return create_epilogue_with_attr(buf, attr, other=y)
|
return create_epilogue_with_attr(buf, attr, other=y)
|
||||||
|
|
||||||
kwargs = dict(
|
kwargs = {
|
||||||
has_bias=b is not None,
|
"has_bias": b is not None,
|
||||||
trans_w=True,
|
"trans_w": True,
|
||||||
epilogue_creator=epilogue_creator,
|
"epilogue_creator": epilogue_creator,
|
||||||
)
|
}
|
||||||
|
|
||||||
kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1]
|
kwargs["input_indices"] = [0, 2, 1] if b is None else [3, 0, 2, 1]
|
||||||
CppGemmTemplate.add_choices(
|
CppGemmTemplate.add_choices(
|
||||||
choices,
|
choices,
|
||||||
|
|
|
||||||
|
|
@ -166,123 +166,123 @@ Example::
|
||||||
""",
|
""",
|
||||||
)
|
)
|
||||||
|
|
||||||
args_and_kwargs = dict(
|
args_and_kwargs = {
|
||||||
# argument name sufficies separated by double underscore will
|
# argument name sufficies separated by double underscore will
|
||||||
# be removed in the final documentation string.
|
# be removed in the final documentation string.
|
||||||
sum=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
"sum": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||||
prod=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
"prod": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||||
cumsum=(("dim__as_int",), ("dtype=None", "mask=None")),
|
"cumsum": (("dim__as_int",), ("dtype=None", "mask=None")),
|
||||||
cumprod=(("dim__as_int",), ("dtype=None", "mask=None")),
|
"cumprod": (("dim__as_int",), ("dtype=None", "mask=None")),
|
||||||
amin=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
"amin": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||||
amax=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
"amax": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||||
argmin=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
|
"argmin": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||||
argmax=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
|
"argmax": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||||
mean=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
"mean": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||||
median=(("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
|
"median": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||||
norm=(
|
"norm": (
|
||||||
(
|
(
|
||||||
"ord",
|
"ord",
|
||||||
"dim",
|
"dim",
|
||||||
),
|
),
|
||||||
("keepdim=False", "dtype=None", "mask=None"),
|
("keepdim=False", "dtype=None", "mask=None"),
|
||||||
),
|
),
|
||||||
var=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
|
"var": (("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
|
||||||
std=(("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
|
"std": (("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
|
||||||
logsumexp=(("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
"logsumexp": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
|
||||||
softmax=(("dim__as_int",), ("dtype=None", "mask=None")),
|
"softmax": (("dim__as_int",), ("dtype=None", "mask=None")),
|
||||||
log_softmax=(("dim__as_int",), ("dtype=None", "mask=None")),
|
"log_softmax": (("dim__as_int",), ("dtype=None", "mask=None")),
|
||||||
softmin=(("dim__as_int",), ("dtype=None", "mask=None")),
|
"softmin": (("dim__as_int",), ("dtype=None", "mask=None")),
|
||||||
normalize=(
|
"normalize": (
|
||||||
(
|
(
|
||||||
"ord__required",
|
"ord__required",
|
||||||
"dim__as_int",
|
"dim__as_int",
|
||||||
),
|
),
|
||||||
("eps=1e-12", "dtype=None", "mask=None"),
|
("eps=1e-12", "dtype=None", "mask=None"),
|
||||||
),
|
),
|
||||||
)
|
}
|
||||||
|
|
||||||
argument_declarations = dict(
|
argument_declarations = {
|
||||||
dim="""\
|
"dim": """\
|
||||||
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
|
dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
|
||||||
Default: None that is equivalent to ``tuple(range(input.ndim))``.""",
|
Default: None that is equivalent to ``tuple(range(input.ndim))``.""",
|
||||||
dim__as_int="""\
|
"dim__as_int": """\
|
||||||
dim (int): the dimension along which {operation name} is computed.""",
|
dim (int): the dimension along which {operation name} is computed.""",
|
||||||
ord="""\
|
"ord": """\
|
||||||
ord (int, float, optional): the order of vector norm. Default: 2.
|
ord (int, float, optional): the order of vector norm. Default: 2.
|
||||||
See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
|
See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
|
||||||
ord__required="""\
|
"ord__required": """\
|
||||||
ord (int, float): the order of vector norm. Default: 2.
|
ord (int, float): the order of vector norm. Default: 2.
|
||||||
See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
|
See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
|
||||||
unbiased="""\
|
"unbiased": """\
|
||||||
unbiased (bool): when True, use Bessel's correction, otherwise, compute
|
unbiased (bool): when True, use Bessel's correction, otherwise, compute
|
||||||
the uncorrected sample variance.""",
|
the uncorrected sample variance.""",
|
||||||
eps="""\
|
"eps": """\
|
||||||
eps (float, optional): small value to avoid division by zero. Default: {default}.""",
|
eps (float, optional): small value to avoid division by zero. Default: {default}.""",
|
||||||
keepdim="""\
|
"keepdim": """\
|
||||||
keepdim (bool, optional): whether the output tensor has
|
keepdim (bool, optional): whether the output tensor has
|
||||||
:attr:`dim` retained or not. Default: {default}.""",
|
:attr:`dim` retained or not. Default: {default}.""",
|
||||||
dtype="""\
|
"dtype": """\
|
||||||
dtype (:class:`torch.dtype`, optional): the desired data type
|
dtype (:class:`torch.dtype`, optional): the desired data type
|
||||||
of returned tensor. If specified, the input tensor is
|
of returned tensor. If specified, the input tensor is
|
||||||
casted to :attr:`dtype` before the operation is
|
casted to :attr:`dtype` before the operation is
|
||||||
performed. Default: {default}.""",
|
performed. Default: {default}.""",
|
||||||
mask="""\
|
"mask": """\
|
||||||
mask (:class:`torch.Tensor`, optional): the boolean tensor
|
mask (:class:`torch.Tensor`, optional): the boolean tensor
|
||||||
containing the binary mask of validity of input tensor
|
containing the binary mask of validity of input tensor
|
||||||
elements.
|
elements.
|
||||||
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""",
|
Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""",
|
||||||
)
|
}
|
||||||
|
|
||||||
definitions = dict(
|
definitions = {
|
||||||
softmax="""\
|
"softmax": """\
|
||||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||||
of the :attr:`input` tensor. Softmax of i-th element in ``x`` is
|
of the :attr:`input` tensor. Softmax of i-th element in ``x`` is
|
||||||
defined as ``exp(x[i])/sum(exp(x))``.""",
|
defined as ``exp(x[i])/sum(exp(x))``.""",
|
||||||
log_softmax="""\
|
"log_softmax": """\
|
||||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||||
of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is
|
of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is
|
||||||
defined as ``log(exp(x[i])/sum(exp(x)))``.""",
|
defined as ``log(exp(x[i])/sum(exp(x)))``.""",
|
||||||
softmin="""\
|
"softmin": """\
|
||||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||||
of the :attr:`input` tensor. Softmin of i-th element in ``x`` is
|
of the :attr:`input` tensor. Softmin of i-th element in ``x`` is
|
||||||
defined as ``exp(-x[i])/sum(exp(-x))``.""",
|
defined as ``exp(-x[i])/sum(exp(-x))``.""",
|
||||||
normalize="""\
|
"normalize": """\
|
||||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||||
of the :attr:`input` tensor. Normalize of i-th element in ``x`` is
|
of the :attr:`input` tensor. Normalize of i-th element in ``x`` is
|
||||||
defined as ``x[i]/max(norm(x, p), eps)``.""",
|
defined as ``x[i]/max(norm(x, p), eps)``.""",
|
||||||
cumsum="""\
|
"cumsum": """\
|
||||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||||
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
|
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
|
||||||
defined as ``sum(x[:i])``.""",
|
defined as ``sum(x[:i])``.""",
|
||||||
cumprod="""\
|
"cumprod": """\
|
||||||
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
Let ``x`` be a sequence of unmasked elements of one-dimensional slice
|
||||||
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
|
of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
|
||||||
defined as ``prod(x[:i])``.""",
|
defined as ``prod(x[:i])``.""",
|
||||||
)
|
}
|
||||||
|
|
||||||
reduction_names = dict(
|
reduction_names = {
|
||||||
sum="sum",
|
"sum": "sum",
|
||||||
prod="product",
|
"prod": "product",
|
||||||
amax="maximum",
|
"amax": "maximum",
|
||||||
amin="minimum",
|
"amin": "minimum",
|
||||||
argmax="argmax",
|
"argmax": "argmax",
|
||||||
argmin="argmin",
|
"argmin": "argmin",
|
||||||
mean="mean",
|
"mean": "mean",
|
||||||
median="median",
|
"median": "median",
|
||||||
norm="norm",
|
"norm": "norm",
|
||||||
var="variance",
|
"var": "variance",
|
||||||
std="standard_deviation",
|
"std": "standard_deviation",
|
||||||
logsumexp="logsumexp",
|
"logsumexp": "logsumexp",
|
||||||
)
|
}
|
||||||
|
|
||||||
normalization_names = dict(
|
normalization_names = {
|
||||||
softmax="softmax",
|
"softmax": "softmax",
|
||||||
log_softmax="log_softmax",
|
"log_softmax": "log_softmax",
|
||||||
softmin="softmin",
|
"softmin": "softmin",
|
||||||
normalize="normalize",
|
"normalize": "normalize",
|
||||||
cumsum="cumulative_sum",
|
"cumsum": "cumulative_sum",
|
||||||
cumprod="cumulative_prod",
|
"cumprod": "cumulative_prod",
|
||||||
)
|
}
|
||||||
|
|
||||||
operation_names = {}
|
operation_names = {}
|
||||||
operation_names.update(reduction_names)
|
operation_names.update(reduction_names)
|
||||||
|
|
|
||||||
|
|
@ -47,15 +47,15 @@ class Adafactor(Optimizer):
|
||||||
raise ValueError(f"Clipping threshold d should be >= 1 but is: {d}")
|
raise ValueError(f"Clipping threshold d should be >= 1 but is: {d}")
|
||||||
if not 0.0 <= weight_decay:
|
if not 0.0 <= weight_decay:
|
||||||
raise ValueError(f"weight_decay should be >= 0 but is: {weight_decay}")
|
raise ValueError(f"weight_decay should be >= 0 but is: {weight_decay}")
|
||||||
defaults = dict(
|
defaults = {
|
||||||
lr=lr,
|
"lr": lr,
|
||||||
beta2_decay=beta2_decay,
|
"beta2_decay": beta2_decay,
|
||||||
eps=eps,
|
"eps": eps,
|
||||||
d=d,
|
"d": d,
|
||||||
weight_decay=weight_decay,
|
"weight_decay": weight_decay,
|
||||||
foreach=foreach,
|
"foreach": foreach,
|
||||||
maximize=maximize,
|
"maximize": maximize,
|
||||||
)
|
}
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
|
|
|
||||||
|
|
@ -50,16 +50,16 @@ class Adadelta(Optimizer):
|
||||||
if not 0.0 <= weight_decay:
|
if not 0.0 <= weight_decay:
|
||||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||||
|
|
||||||
defaults = dict(
|
defaults = {
|
||||||
lr=lr,
|
"lr": lr,
|
||||||
rho=rho,
|
"rho": rho,
|
||||||
eps=eps,
|
"eps": eps,
|
||||||
weight_decay=weight_decay,
|
"weight_decay": weight_decay,
|
||||||
maximize=maximize,
|
"maximize": maximize,
|
||||||
capturable=capturable,
|
"capturable": capturable,
|
||||||
foreach=foreach,
|
"foreach": foreach,
|
||||||
differentiable=differentiable,
|
"differentiable": differentiable,
|
||||||
)
|
}
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
|
|
|
||||||
|
|
@ -54,17 +54,17 @@ class Adagrad(Optimizer):
|
||||||
if not 0.0 <= eps:
|
if not 0.0 <= eps:
|
||||||
raise ValueError(f"Invalid epsilon value: {eps}")
|
raise ValueError(f"Invalid epsilon value: {eps}")
|
||||||
|
|
||||||
defaults = dict(
|
defaults = {
|
||||||
lr=lr,
|
"lr": lr,
|
||||||
lr_decay=lr_decay,
|
"lr_decay": lr_decay,
|
||||||
eps=eps,
|
"eps": eps,
|
||||||
weight_decay=weight_decay,
|
"weight_decay": weight_decay,
|
||||||
initial_accumulator_value=initial_accumulator_value,
|
"initial_accumulator_value": initial_accumulator_value,
|
||||||
foreach=foreach,
|
"foreach": foreach,
|
||||||
maximize=maximize,
|
"maximize": maximize,
|
||||||
differentiable=differentiable,
|
"differentiable": differentiable,
|
||||||
fused=fused,
|
"fused": fused,
|
||||||
)
|
}
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
if fused:
|
if fused:
|
||||||
|
|
|
||||||
|
|
@ -85,19 +85,19 @@ class Adam(Optimizer):
|
||||||
if betas[1].numel() != 1:
|
if betas[1].numel() != 1:
|
||||||
raise ValueError("Tensor betas[1] must be 1-element")
|
raise ValueError("Tensor betas[1] must be 1-element")
|
||||||
|
|
||||||
defaults = dict(
|
defaults = {
|
||||||
lr=lr,
|
"lr": lr,
|
||||||
betas=betas,
|
"betas": betas,
|
||||||
eps=eps,
|
"eps": eps,
|
||||||
weight_decay=weight_decay,
|
"weight_decay": weight_decay,
|
||||||
amsgrad=amsgrad,
|
"amsgrad": amsgrad,
|
||||||
maximize=maximize,
|
"maximize": maximize,
|
||||||
foreach=foreach,
|
"foreach": foreach,
|
||||||
capturable=capturable,
|
"capturable": capturable,
|
||||||
differentiable=differentiable,
|
"differentiable": differentiable,
|
||||||
fused=fused,
|
"fused": fused,
|
||||||
decoupled_weight_decay=decoupled_weight_decay,
|
"decoupled_weight_decay": decoupled_weight_decay,
|
||||||
)
|
}
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
if fused:
|
if fused:
|
||||||
|
|
|
||||||
|
|
@ -53,16 +53,16 @@ class Adamax(Optimizer):
|
||||||
if not 0.0 <= weight_decay:
|
if not 0.0 <= weight_decay:
|
||||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||||
|
|
||||||
defaults = dict(
|
defaults = {
|
||||||
lr=lr,
|
"lr": lr,
|
||||||
betas=betas,
|
"betas": betas,
|
||||||
eps=eps,
|
"eps": eps,
|
||||||
weight_decay=weight_decay,
|
"weight_decay": weight_decay,
|
||||||
foreach=foreach,
|
"foreach": foreach,
|
||||||
maximize=maximize,
|
"maximize": maximize,
|
||||||
differentiable=differentiable,
|
"differentiable": differentiable,
|
||||||
capturable=capturable,
|
"capturable": capturable,
|
||||||
)
|
}
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
|
|
|
||||||
|
|
@ -47,17 +47,17 @@ class ASGD(Optimizer):
|
||||||
if not 0.0 <= weight_decay:
|
if not 0.0 <= weight_decay:
|
||||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||||
|
|
||||||
defaults = dict(
|
defaults = {
|
||||||
lr=lr,
|
"lr": lr,
|
||||||
lambd=lambd,
|
"lambd": lambd,
|
||||||
alpha=alpha,
|
"alpha": alpha,
|
||||||
t0=t0,
|
"t0": t0,
|
||||||
weight_decay=weight_decay,
|
"weight_decay": weight_decay,
|
||||||
foreach=foreach,
|
"foreach": foreach,
|
||||||
maximize=maximize,
|
"maximize": maximize,
|
||||||
differentiable=differentiable,
|
"differentiable": differentiable,
|
||||||
capturable=capturable,
|
"capturable": capturable,
|
||||||
)
|
}
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
def __setstate__(self, state):
|
def __setstate__(self, state):
|
||||||
|
|
|
||||||
|
|
@ -231,15 +231,15 @@ class LBFGS(Optimizer):
|
||||||
raise ValueError(f"Invalid learning rate: {lr}")
|
raise ValueError(f"Invalid learning rate: {lr}")
|
||||||
if max_eval is None:
|
if max_eval is None:
|
||||||
max_eval = max_iter * 5 // 4
|
max_eval = max_iter * 5 // 4
|
||||||
defaults = dict(
|
defaults = {
|
||||||
lr=lr,
|
"lr": lr,
|
||||||
max_iter=max_iter,
|
"max_iter": max_iter,
|
||||||
max_eval=max_eval,
|
"max_eval": max_eval,
|
||||||
tolerance_grad=tolerance_grad,
|
"tolerance_grad": tolerance_grad,
|
||||||
tolerance_change=tolerance_change,
|
"tolerance_change": tolerance_change,
|
||||||
history_size=history_size,
|
"history_size": history_size,
|
||||||
line_search_fn=line_search_fn,
|
"line_search_fn": line_search_fn,
|
||||||
)
|
}
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
if len(self.param_groups) != 1:
|
if len(self.param_groups) != 1:
|
||||||
|
|
|
||||||
|
|
@ -59,18 +59,18 @@ class NAdam(Optimizer): # noqa: D101
|
||||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||||
if not 0.0 <= momentum_decay:
|
if not 0.0 <= momentum_decay:
|
||||||
raise ValueError(f"Invalid momentum_decay value: {momentum_decay}")
|
raise ValueError(f"Invalid momentum_decay value: {momentum_decay}")
|
||||||
defaults = dict(
|
defaults = {
|
||||||
lr=lr,
|
"lr": lr,
|
||||||
betas=betas,
|
"betas": betas,
|
||||||
eps=eps,
|
"eps": eps,
|
||||||
weight_decay=weight_decay,
|
"weight_decay": weight_decay,
|
||||||
momentum_decay=momentum_decay,
|
"momentum_decay": momentum_decay,
|
||||||
decoupled_weight_decay=decoupled_weight_decay,
|
"decoupled_weight_decay": decoupled_weight_decay,
|
||||||
maximize=maximize,
|
"maximize": maximize,
|
||||||
foreach=foreach,
|
"foreach": foreach,
|
||||||
capturable=capturable,
|
"capturable": capturable,
|
||||||
differentiable=differentiable,
|
"differentiable": differentiable,
|
||||||
)
|
}
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
def __setstate__(self, state): # noqa: D105
|
def __setstate__(self, state): # noqa: D105
|
||||||
|
|
|
||||||
|
|
@ -56,17 +56,17 @@ class RAdam(Optimizer): # noqa: D101
|
||||||
if not 0.0 <= weight_decay:
|
if not 0.0 <= weight_decay:
|
||||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||||
|
|
||||||
defaults = dict(
|
defaults = {
|
||||||
lr=lr,
|
"lr": lr,
|
||||||
betas=betas,
|
"betas": betas,
|
||||||
eps=eps,
|
"eps": eps,
|
||||||
weight_decay=weight_decay,
|
"weight_decay": weight_decay,
|
||||||
maximize=maximize,
|
"maximize": maximize,
|
||||||
foreach=foreach,
|
"foreach": foreach,
|
||||||
capturable=capturable,
|
"capturable": capturable,
|
||||||
decoupled_weight_decay=decoupled_weight_decay,
|
"decoupled_weight_decay": decoupled_weight_decay,
|
||||||
differentiable=differentiable,
|
"differentiable": differentiable,
|
||||||
)
|
}
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
def __setstate__(self, state): # noqa: D105
|
def __setstate__(self, state): # noqa: D105
|
||||||
|
|
|
||||||
|
|
@ -55,18 +55,18 @@ class RMSprop(Optimizer): # noqa: D101
|
||||||
if not 0.0 <= alpha:
|
if not 0.0 <= alpha:
|
||||||
raise ValueError(f"Invalid alpha value: {alpha}")
|
raise ValueError(f"Invalid alpha value: {alpha}")
|
||||||
|
|
||||||
defaults = dict(
|
defaults = {
|
||||||
lr=lr,
|
"lr": lr,
|
||||||
momentum=momentum,
|
"momentum": momentum,
|
||||||
alpha=alpha,
|
"alpha": alpha,
|
||||||
eps=eps,
|
"eps": eps,
|
||||||
centered=centered,
|
"centered": centered,
|
||||||
weight_decay=weight_decay,
|
"weight_decay": weight_decay,
|
||||||
capturable=capturable,
|
"capturable": capturable,
|
||||||
foreach=foreach,
|
"foreach": foreach,
|
||||||
maximize=maximize,
|
"maximize": maximize,
|
||||||
differentiable=differentiable,
|
"differentiable": differentiable,
|
||||||
)
|
}
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
def __setstate__(self, state): # noqa: D105
|
def __setstate__(self, state): # noqa: D105
|
||||||
|
|
|
||||||
|
|
@ -47,15 +47,15 @@ class Rprop(Optimizer): # noqa: D101
|
||||||
if not 0.0 < etas[0] < 1.0 < etas[1]:
|
if not 0.0 < etas[0] < 1.0 < etas[1]:
|
||||||
raise ValueError(f"Invalid eta values: {etas[0]}, {etas[1]}")
|
raise ValueError(f"Invalid eta values: {etas[0]}, {etas[1]}")
|
||||||
|
|
||||||
defaults = dict(
|
defaults = {
|
||||||
lr=lr,
|
"lr": lr,
|
||||||
etas=etas,
|
"etas": etas,
|
||||||
step_sizes=step_sizes,
|
"step_sizes": step_sizes,
|
||||||
foreach=foreach,
|
"foreach": foreach,
|
||||||
maximize=maximize,
|
"maximize": maximize,
|
||||||
differentiable=differentiable,
|
"differentiable": differentiable,
|
||||||
capturable=capturable,
|
"capturable": capturable,
|
||||||
)
|
}
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
def __setstate__(self, state): # noqa: D105
|
def __setstate__(self, state): # noqa: D105
|
||||||
|
|
|
||||||
|
|
@ -49,17 +49,17 @@ class SGD(Optimizer): # noqa: D101
|
||||||
if weight_decay < 0.0:
|
if weight_decay < 0.0:
|
||||||
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
raise ValueError(f"Invalid weight_decay value: {weight_decay}")
|
||||||
|
|
||||||
defaults = dict(
|
defaults = {
|
||||||
lr=lr,
|
"lr": lr,
|
||||||
momentum=momentum,
|
"momentum": momentum,
|
||||||
dampening=dampening,
|
"dampening": dampening,
|
||||||
weight_decay=weight_decay,
|
"weight_decay": weight_decay,
|
||||||
nesterov=nesterov,
|
"nesterov": nesterov,
|
||||||
maximize=maximize,
|
"maximize": maximize,
|
||||||
foreach=foreach,
|
"foreach": foreach,
|
||||||
differentiable=differentiable,
|
"differentiable": differentiable,
|
||||||
fused=fused,
|
"fused": fused,
|
||||||
)
|
}
|
||||||
if nesterov and (momentum <= 0 or dampening != 0):
|
if nesterov and (momentum <= 0 or dampening != 0):
|
||||||
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
|
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
|
||||||
|
|
@ -31,7 +31,12 @@ class SparseAdam(Optimizer):
|
||||||
if not 0.0 <= betas[1] < 1.0:
|
if not 0.0 <= betas[1] < 1.0:
|
||||||
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
|
||||||
|
|
||||||
defaults = dict(lr=lr, betas=betas, eps=eps, maximize=maximize)
|
defaults = {
|
||||||
|
"lr": lr,
|
||||||
|
"betas": betas,
|
||||||
|
"eps": eps,
|
||||||
|
"maximize": maximize,
|
||||||
|
}
|
||||||
super().__init__(params, defaults)
|
super().__init__(params, defaults)
|
||||||
|
|
||||||
sparse_params = []
|
sparse_params = []
|
||||||
|
|
|
||||||
|
|
@ -1109,15 +1109,15 @@ def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
|
||||||
return res
|
return res
|
||||||
return None
|
return None
|
||||||
|
|
||||||
sys_info = dict(
|
sys_info = {
|
||||||
protocol_version=PROTOCOL_VERSION,
|
"protocol_version": PROTOCOL_VERSION,
|
||||||
little_endian=sys.byteorder == "little",
|
"little_endian": sys.byteorder == "little",
|
||||||
type_sizes=dict(
|
"type_sizes": {
|
||||||
short=SHORT_SIZE,
|
"short": SHORT_SIZE,
|
||||||
int=INT_SIZE,
|
"int": INT_SIZE,
|
||||||
long=LONG_SIZE,
|
"long": LONG_SIZE,
|
||||||
),
|
},
|
||||||
)
|
}
|
||||||
|
|
||||||
pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
|
pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol)
|
||||||
pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
|
pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol)
|
||||||
|
|
|
||||||
|
|
@ -598,7 +598,10 @@ def as_sparse_gradcheck(gradcheck):
|
||||||
and obj.requires_grad
|
and obj.requires_grad
|
||||||
and obj.layout in sparse_layouts
|
and obj.layout in sparse_layouts
|
||||||
):
|
):
|
||||||
d = dict(layout=obj.layout, shape=obj.shape)
|
d = {
|
||||||
|
"layout": obj.layout,
|
||||||
|
"shape": obj.shape,
|
||||||
|
}
|
||||||
if not masked:
|
if not masked:
|
||||||
# Materialize unspecified elements with zero values
|
# Materialize unspecified elements with zero values
|
||||||
batch_dim = obj.ndim - obj.dense_dim() - obj.sparse_dim()
|
batch_dim = obj.ndim - obj.dense_dim() - obj.sparse_dim()
|
||||||
|
|
|
||||||
|
|
@ -213,13 +213,14 @@ def get_model_info(
|
||||||
path_prefix = prefix
|
path_prefix = prefix
|
||||||
elif prefix != path_prefix:
|
elif prefix != path_prefix:
|
||||||
raise Exception(f"Mismatched prefixes: {path_prefix} != {prefix}") # noqa: TRY002
|
raise Exception(f"Mismatched prefixes: {path_prefix} != {prefix}") # noqa: TRY002
|
||||||
zip_files.append(dict(
|
zip_files.append(
|
||||||
filename=zi.filename,
|
{
|
||||||
compression=zi.compress_type,
|
"filename": zi.filename,
|
||||||
compressed_size=zi.compress_size,
|
"compression": zi.compress_type,
|
||||||
file_size=zi.file_size,
|
"compressed_size": zi.compress_size,
|
||||||
))
|
"file_size": zi.file_size,
|
||||||
|
}
|
||||||
|
)
|
||||||
assert path_prefix is not None
|
assert path_prefix is not None
|
||||||
version = zf.read(path_prefix + "/version").decode("utf-8").strip()
|
version = zf.read(path_prefix + "/version").decode("utf-8").strip()
|
||||||
|
|
||||||
|
|
@ -332,18 +333,20 @@ def get_model_info(
|
||||||
continue
|
continue
|
||||||
extra_pickles[zi.filename] = contents
|
extra_pickles[zi.filename] = contents
|
||||||
|
|
||||||
return {"model": dict(
|
return {
|
||||||
title=title,
|
"model": {
|
||||||
file_size=file_size,
|
"title": title,
|
||||||
version=version,
|
"file_size": file_size,
|
||||||
zip_files=zip_files,
|
"version": version,
|
||||||
interned_strings=list(interned_strings),
|
"zip_files": zip_files,
|
||||||
code_files=code_files,
|
"interned_strings": list(interned_strings),
|
||||||
model_data=model_data,
|
"code_files": code_files,
|
||||||
constants=constants,
|
"model_data": model_data,
|
||||||
extra_files_jsons=extra_files_jsons,
|
"constants": constants,
|
||||||
extra_pickles=extra_pickles,
|
"extra_files_jsons": extra_files_jsons,
|
||||||
)}
|
"extra_pickles": extra_pickles,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_inline_skeleton():
|
def get_inline_skeleton():
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user