mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Fix pyrefly ignores 1/n (#166239)
First diff adjusting the syntax for pyrefly: ignore suppressions so they only hide one class of type error. Test: lintrunner pyrefly check Pull Request resolved: https://github.com/pytorch/pytorch/pull/166239 Approved by: https://github.com/oulgen
This commit is contained in:
parent
621ba05107
commit
c7eee49525
|
|
@ -130,6 +130,7 @@ errors.bad-param-name-override = false
|
|||
# Mypy doesn't require that imports are explicitly imported, so be compatible with that.
|
||||
# Might be a good idea to turn this on in future.
|
||||
errors.implicit-import = false
|
||||
errors.deprecated = false # re-enable after we've fix import formatting
|
||||
permissive-ignores = true
|
||||
replace-imports-with-any = ["!sympy.printing.*", "sympy.*", "onnxscript.onnx_opset.*"]
|
||||
search-path = ["tools/experimental"]
|
||||
|
|
|
|||
|
|
@ -2,7 +2,8 @@ from dataclasses import dataclass, field
|
|||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from dataclasses_json import DataClassJsonMixin
|
||||
# pyrefly: ignore [missing-import]
|
||||
from dataclasses_json import DataClassJsonMixin # type: ignore[import-not-found]
|
||||
|
||||
|
||||
_DATA_MODEL_VERSION = 1.5
|
||||
|
|
@ -17,7 +18,7 @@ class UtilizationStats:
|
|||
|
||||
|
||||
@dataclass
|
||||
class UtilizationMetadata(DataClassJsonMixin):
|
||||
class UtilizationMetadata(DataClassJsonMixin): # type: ignore[misc, no-any-unimported]
|
||||
level: str
|
||||
workflow_id: str
|
||||
job_id: str
|
||||
|
|
@ -33,7 +34,7 @@ class UtilizationMetadata(DataClassJsonMixin):
|
|||
|
||||
|
||||
@dataclass
|
||||
class GpuUsage(DataClassJsonMixin):
|
||||
class GpuUsage(DataClassJsonMixin): # type: ignore[misc, no-any-unimported]
|
||||
uuid: Optional[str] = None
|
||||
util_percent: Optional[UtilizationStats] = None
|
||||
mem_util_percent: Optional[UtilizationStats] = None
|
||||
|
|
@ -43,14 +44,14 @@ class GpuUsage(DataClassJsonMixin):
|
|||
|
||||
|
||||
@dataclass
|
||||
class RecordData(DataClassJsonMixin):
|
||||
class RecordData(DataClassJsonMixin): # type: ignore[misc, no-any-unimported]
|
||||
cpu: Optional[UtilizationStats] = None
|
||||
memory: Optional[UtilizationStats] = None
|
||||
gpu_usage: Optional[list[GpuUsage]] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class UtilizationRecord(DataClassJsonMixin):
|
||||
class UtilizationRecord(DataClassJsonMixin): # type: ignore[misc, no-any-unimported]
|
||||
level: str
|
||||
timestamp: int
|
||||
data: Optional[RecordData] = None
|
||||
|
|
@ -63,7 +64,7 @@ class UtilizationRecord(DataClassJsonMixin):
|
|||
# the db schema related to this is:
|
||||
# https://github.com/pytorch/test-infra/blob/main/clickhouse_db_schema/oss_ci_utilization/oss_ci_utilization_metadata_schema.sql
|
||||
@dataclass
|
||||
class OssCiSegmentV1(DataClassJsonMixin):
|
||||
class OssCiSegmentV1(DataClassJsonMixin): # type: ignore[misc, no-any-unimported]
|
||||
level: str
|
||||
name: str
|
||||
start_at: int
|
||||
|
|
|
|||
|
|
@ -1703,7 +1703,7 @@ def _check(cond, message=None): # noqa: F811
|
|||
an object that has a ``__str__()`` method to be used as the error
|
||||
message. Default: ``None``
|
||||
"""
|
||||
_check_with(RuntimeError, cond, message) # pyrefly: ignore # bad-argument-type
|
||||
_check_with(RuntimeError, cond, message) # pyrefly: ignore [bad-argument-type]
|
||||
|
||||
|
||||
# TODO add deprecation annotation
|
||||
|
|
@ -1753,7 +1753,7 @@ def _check_index(cond, message=None): # noqa: F811
|
|||
an object that has a ``__str__()`` method to be used as the error
|
||||
message. Default: ``None``
|
||||
"""
|
||||
_check_with(IndexError, cond, message) # pyrefly: ignore # bad-argument-type
|
||||
_check_with(IndexError, cond, message) # pyrefly: ignore [bad-argument-type]
|
||||
|
||||
|
||||
def _check_value(cond, message=None): # noqa: F811
|
||||
|
|
@ -1771,7 +1771,7 @@ def _check_value(cond, message=None): # noqa: F811
|
|||
an object that has a ``__str__()`` method to be used as the error
|
||||
message. Default: ``None``
|
||||
"""
|
||||
_check_with(ValueError, cond, message) # pyrefly: ignore # bad-argument-type
|
||||
_check_with(ValueError, cond, message) # pyrefly: ignore [bad-argument-type]
|
||||
|
||||
|
||||
def _check_type(cond, message=None): # noqa: F811
|
||||
|
|
@ -1789,7 +1789,7 @@ def _check_type(cond, message=None): # noqa: F811
|
|||
an object that has a ``__str__()`` method to be used as the error
|
||||
message. Default: ``None``
|
||||
"""
|
||||
_check_with(TypeError, cond, message) # pyrefly: ignore # bad-argument-type
|
||||
_check_with(TypeError, cond, message) # pyrefly: ignore [bad-argument-type]
|
||||
|
||||
|
||||
def _check_not_implemented(cond, message=None): # noqa: F811
|
||||
|
|
|
|||
|
|
@ -101,7 +101,7 @@ def custom_op(
|
|||
lib, ns, function_schema, name, ophandle, _private_access=True
|
||||
)
|
||||
|
||||
result.__name__ = func.__name__ # pyrefly: ignore # bad-assignment
|
||||
result.__name__ = func.__name__ # pyrefly: ignore [bad-assignment]
|
||||
result.__module__ = func.__module__
|
||||
result.__doc__ = func.__doc__
|
||||
|
||||
|
|
|
|||
|
|
@ -154,7 +154,7 @@ def make_crossref_functionalize(
|
|||
maybe_detach, (f_args, f_kwargs)
|
||||
)
|
||||
with fake_mode:
|
||||
f_r = op(*f_args, **f_kwargs) # pyrefly: ignore # invalid-param-spec
|
||||
f_r = op(*f_args, **f_kwargs) # pyrefly: ignore [invalid-param-spec]
|
||||
r = op._op_dk(final_key, *args, **kwargs)
|
||||
|
||||
def desc():
|
||||
|
|
|
|||
|
|
@ -1029,7 +1029,7 @@ class BuiltinVariable(VariableTracker):
|
|||
|
||||
def call_self_handler(tx: "InstructionTranslator", args, kwargs):
|
||||
try:
|
||||
# pyrefly: ignore # not-callable
|
||||
# pyrefly: ignore [not-callable]
|
||||
result = self_handler(tx, *args, **kwargs)
|
||||
if result is not None:
|
||||
return result
|
||||
|
|
@ -1037,7 +1037,7 @@ class BuiltinVariable(VariableTracker):
|
|||
# Check if binding is bad. inspect signature bind is expensive.
|
||||
# So check only when handler call fails.
|
||||
try:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
inspect.signature(self_handler).bind(tx, *args, **kwargs)
|
||||
except TypeError as e:
|
||||
has_constant_handler = obj.has_constant_handler(args, kwargs)
|
||||
|
|
@ -1090,7 +1090,7 @@ class BuiltinVariable(VariableTracker):
|
|||
hints=[*graph_break_hints.DYNAMO_BUG],
|
||||
from_exc=exc,
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
return VariableTracker.build(tx, res)
|
||||
|
||||
else:
|
||||
|
|
@ -1119,7 +1119,7 @@ class BuiltinVariable(VariableTracker):
|
|||
tx,
|
||||
args=list(map(ConstantVariable.create, exc.args)),
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
return VariableTracker.build(tx, res)
|
||||
|
||||
handlers.append(constant_fold_handler)
|
||||
|
|
@ -1442,7 +1442,7 @@ class BuiltinVariable(VariableTracker):
|
|||
resolved_fn = getattr(self.fn, name)
|
||||
if resolved_fn in dict_methods:
|
||||
if isinstance(args[0], variables.UserDefinedDictVariable):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs)
|
||||
elif isinstance(args[0], variables.ConstDictVariable):
|
||||
return args[0].call_method(tx, name, args[1:], kwargs)
|
||||
|
|
@ -1451,7 +1451,7 @@ class BuiltinVariable(VariableTracker):
|
|||
resolved_fn = getattr(self.fn, name)
|
||||
if resolved_fn in set_methods:
|
||||
if isinstance(args[0], variables.UserDefinedSetVariable):
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return args[0]._set_vt.call_method(tx, name, args[1:], kwargs)
|
||||
elif isinstance(args[0], variables.SetVariable):
|
||||
return args[0].call_method(tx, name, args[1:], kwargs)
|
||||
|
|
@ -1540,12 +1540,12 @@ class BuiltinVariable(VariableTracker):
|
|||
if type(arg.value).__str__ is object.__str__:
|
||||
# Rely on the object str method
|
||||
try:
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
return variables.ConstantVariable.create(value=str_method())
|
||||
except AttributeError:
|
||||
# Graph break
|
||||
return
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
elif is_wrapper_or_member_descriptor(str_method):
|
||||
unimplemented_v2(
|
||||
gb_type="Attempted to a str() method implemented in C/C++",
|
||||
|
|
@ -1662,10 +1662,10 @@ class BuiltinVariable(VariableTracker):
|
|||
else:
|
||||
raw_b = b.raw_value
|
||||
if self.fn is max:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
raw_res = max(a.raw_value, raw_b)
|
||||
else:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
raw_res = min(a.raw_value, raw_b)
|
||||
|
||||
need_unwrap = any(
|
||||
|
|
@ -1980,12 +1980,16 @@ class BuiltinVariable(VariableTracker):
|
|||
if isinstance(arg, dict):
|
||||
arg = [ConstantVariable.create(k) for k in arg.keys()]
|
||||
return DictVariableType(
|
||||
dict.fromkeys(arg, value), user_cls, mutation_type=ValueMutationNew()
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
dict.fromkeys(arg, value),
|
||||
user_cls,
|
||||
mutation_type=ValueMutationNew(),
|
||||
)
|
||||
elif arg.has_force_unpack_var_sequence(tx):
|
||||
keys = arg.force_unpack_var_sequence(tx)
|
||||
if all(is_hashable(v) for v in keys):
|
||||
return DictVariableType(
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
dict.fromkeys(keys, value),
|
||||
user_cls,
|
||||
mutation_type=ValueMutationNew(),
|
||||
|
|
@ -2152,7 +2156,7 @@ class BuiltinVariable(VariableTracker):
|
|||
)
|
||||
|
||||
if isinstance(arg, variables.UserDefinedExceptionClassVariable):
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
return ConstantVariable.create(isinstance(arg_type, isinstance_type))
|
||||
|
||||
isinstance_type_tuple: tuple[type, ...]
|
||||
|
|
@ -2185,10 +2189,10 @@ class BuiltinVariable(VariableTracker):
|
|||
# through it. This is a limitation of the current implementation.
|
||||
# Usually `__subclasscheck__` and `__instancecheck__` can be constant fold through, it
|
||||
# might not be a big issue and we trade off it for performance.
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
val = issubclass(arg_type, isinstance_type_tuple)
|
||||
except TypeError:
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
val = arg_type in isinstance_type_tuple
|
||||
return variables.ConstantVariable.create(val)
|
||||
|
||||
|
|
@ -2210,7 +2214,7 @@ class BuiltinVariable(VariableTracker):
|
|||
|
||||
# WARNING: This might run arbitrary user code `__subclasscheck__`.
|
||||
# See the comment in call_isinstance above.
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py))
|
||||
|
||||
def call_super(self, tx: "InstructionTranslator", a, b):
|
||||
|
|
@ -2256,9 +2260,9 @@ class BuiltinVariable(VariableTracker):
|
|||
value = getattr(self.fn, name)
|
||||
except AttributeError:
|
||||
raise_observed_exception(AttributeError, tx)
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
if not callable(value):
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
return VariableTracker.build(tx, value, source)
|
||||
return variables.GetAttrVariable(self, name, source=source)
|
||||
|
||||
|
|
|
|||
|
|
@ -34,6 +34,7 @@ class LazyCache:
|
|||
self.vt = builder.VariableBuilder(tx, self.source)(self.value)
|
||||
|
||||
if self.name_hint is not None:
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self.vt.set_name_hint(self.name_hint)
|
||||
|
||||
del self.value
|
||||
|
|
|
|||
|
|
@ -1138,11 +1138,13 @@ def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]:
|
|||
|
||||
for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
|
||||
if isinstance(m, FakeTensorMode):
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
fake_modes.append((m, "active fake mode", i))
|
||||
|
||||
flat_inputs = pytree.tree_leaves(inputs)
|
||||
for i, flat_input in enumerate(flat_inputs):
|
||||
if isinstance(flat_input, FakeTensor):
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
fake_modes.append((flat_input.fake_mode, "fake tensor input", i))
|
||||
if is_traceable_wrapper_subclass(flat_input):
|
||||
out: list[Union[torch.Tensor, int, torch.SymInt]] = []
|
||||
|
|
@ -1151,6 +1153,7 @@ def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]:
|
|||
x for x in out if isinstance(x, FakeTensor)
|
||||
]
|
||||
fake_modes.extend(
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
[
|
||||
(tensor.fake_mode, f"subclass input {i}", ix)
|
||||
for ix, tensor in enumerate(fake_tensors)
|
||||
|
|
@ -1162,9 +1165,12 @@ def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]:
|
|||
for m, desc2, i2 in fake_modes[1:]:
|
||||
assert fake_mode is m, (
|
||||
f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n"
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n"
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
f"fake mode from {desc2} {i2} allocated at:\n{m.stack}"
|
||||
)
|
||||
# pyrefly: ignore [bad-return]
|
||||
return fake_mode
|
||||
else:
|
||||
return None
|
||||
|
|
|
|||
|
|
@ -114,6 +114,7 @@ def _cancel_all_tasks(
|
|||
for task in to_cancel:
|
||||
task.cancel()
|
||||
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
|
||||
|
||||
for task in to_cancel:
|
||||
|
|
@ -149,7 +150,7 @@ def _patch_loop(loop: AbstractEventLoop) -> OrderedSet[Future]: # type: ignore[
|
|||
task_factory = task_factories[0]
|
||||
if task_factory is None:
|
||||
if sys.version_info >= (3, 11):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
task = asyncio.Task(coro, loop=loop, context=context)
|
||||
else:
|
||||
task = asyncio.Task(coro, loop=loop)
|
||||
|
|
|
|||
|
|
@ -590,11 +590,11 @@ class CKGemmTemplate(CKTemplate):
|
|||
arg = f"/* {field_name} */ Tuple<{tuple_elements}>"
|
||||
else: # tile shape
|
||||
arg = f"/* {field_name} */ S<{tuple_elements}>"
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
template_params.append(arg)
|
||||
else:
|
||||
if field_value is not None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
template_params.append(f"/* {field_name} */ {field_value}")
|
||||
operation_name = op.name().replace("(", "").replace(",", "").replace(")", "")
|
||||
return self._template_from_string(template_definition).render(
|
||||
|
|
@ -939,6 +939,7 @@ class CKGemmTemplate(CKTemplate):
|
|||
for o in rops:
|
||||
kBatches = self._get_kBatch(o)
|
||||
for kBatch in kBatches:
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
ops.append(InductorROCmOp(op=o, kBatch=kBatch))
|
||||
|
||||
filtered_instances = list(filter(lambda op: self.filter_op(op), ops))
|
||||
|
|
|
|||
|
|
@ -273,7 +273,7 @@ def record_original_output_strides(gm: GraphModule) -> None:
|
|||
):
|
||||
output_strides.append(val.stride())
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
output_strides.append(None)
|
||||
output_node.meta["original_output_strides"] = output_strides
|
||||
|
||||
|
|
@ -1110,6 +1110,7 @@ def _compile_fx_inner(
|
|||
)
|
||||
log.info("-" * 130)
|
||||
for row in mm_table_data:
|
||||
# pyrefly: ignore [not-iterable]
|
||||
log.info("{:<30} | {:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format(*row)) # noqa: G001
|
||||
log.info("-" * 130)
|
||||
|
||||
|
|
@ -1551,7 +1552,7 @@ class _InProcessFxCompile(FxCompile):
|
|||
node_runtimes = None
|
||||
if inductor_metrics_log.isEnabledFor(logging.INFO):
|
||||
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
metrics.num_bytes_accessed += num_bytes
|
||||
metrics.node_runtimes += node_runtimes
|
||||
metrics.nodes_num_elem += nodes_num_elem
|
||||
|
|
@ -1595,10 +1596,10 @@ class _InProcessFxCompile(FxCompile):
|
|||
disable = f"{disable} Found from {stack_trace}\n"
|
||||
else:
|
||||
disable = f"{disable}\n"
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
V.graph.disable_cudagraphs_reason = disable
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
if cudagraphs and not V.graph.disable_cudagraphs_reason:
|
||||
maybe_incompat_node = get_first_incompatible_cudagraph_node(gm)
|
||||
if maybe_incompat_node:
|
||||
|
|
@ -1607,29 +1608,29 @@ class _InProcessFxCompile(FxCompile):
|
|||
"stack_trace", None
|
||||
):
|
||||
disable = f"{disable} Found from {stack_trace}\n"
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
V.graph.disable_cudagraphs_reason = disable
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
if V.aot_compilation:
|
||||
assert isinstance(
|
||||
compiled_fn,
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
(str, list, torch.fx.GraphModule),
|
||||
), type(compiled_fn)
|
||||
return CompiledAOTI(compiled_fn)
|
||||
|
||||
# TODO: Hoist this above V.aot_compilation
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
if cudagraphs and not V.graph.disable_cudagraphs_reason:
|
||||
from torch._inductor.cudagraph_utils import (
|
||||
check_lowering_disable_cudagraph,
|
||||
)
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
V.graph.disable_cudagraphs_reason = (
|
||||
check_lowering_disable_cudagraph(
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
V.graph.device_node_mapping
|
||||
)
|
||||
)
|
||||
|
|
@ -1637,29 +1638,29 @@ class _InProcessFxCompile(FxCompile):
|
|||
self._compile_stats[type(self)].codegen_and_compile += 1
|
||||
|
||||
if (
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
torch._inductor.debug.RECORD_GRAPH_EXECUTION
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
and torch._inductor.debug.GRAPH_COMPILE_IDS is not None
|
||||
):
|
||||
compile_id = str(
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
torch._guards.CompileContext.current_compile_id()
|
||||
)
|
||||
graph_id = graph_kwargs.get("graph_id")
|
||||
if graph_id is not None:
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
torch._inductor.debug.GRAPH_COMPILE_IDS[graph_id] = (
|
||||
compile_id
|
||||
)
|
||||
|
||||
return CompiledFxGraph(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
compiled_fn,
|
||||
graph,
|
||||
gm,
|
||||
output_strides,
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
V.graph.disable_cudagraphs_reason,
|
||||
metrics_helper.get_deltas(),
|
||||
counters["inductor"] - inductor_counters,
|
||||
|
|
@ -1701,18 +1702,18 @@ def fx_codegen_and_compile(
|
|||
from .compile_fx_async import _AsyncFxCompile
|
||||
from .compile_fx_ext import _OutOfProcessFxCompile
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
assert isinstance(scheme, _OutOfProcessFxCompile), (
|
||||
"async is only valid with an out-of-process compile mode"
|
||||
)
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
scheme = _AsyncFxCompile(scheme)
|
||||
|
||||
if fx_compile_progressive:
|
||||
from .compile_fx_async import _ProgressiveFxCompile
|
||||
from .compile_fx_ext import _OutOfProcessFxCompile
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
assert isinstance(scheme, _OutOfProcessFxCompile), (
|
||||
"progressive is only valid with an out-of-process compile mode"
|
||||
)
|
||||
|
|
@ -1722,10 +1723,10 @@ def fx_codegen_and_compile(
|
|||
# Use in-process compile for the fast version
|
||||
fast_scheme = _InProcessFxCompile()
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
scheme = _ProgressiveFxCompile(fast_scheme, scheme, progression_configs)
|
||||
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
|
||||
|
||||
|
||||
|
|
@ -1835,7 +1836,7 @@ def cudagraphify_impl(
|
|||
Assumes inputs[static_input_idxs[i]] are always the same memory address
|
||||
"""
|
||||
check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) # type: ignore[arg-type]
|
||||
# pyrefly: ignore # annotation-mismatch
|
||||
# pyrefly: ignore [annotation-mismatch]
|
||||
static_input_idxs: OrderedSet[int] = OrderedSet(
|
||||
remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type]
|
||||
)
|
||||
|
|
@ -1902,7 +1903,7 @@ def cudagraphify_impl(
|
|||
index_expanded_dims_and_copy_(dst, src, expanded_dims)
|
||||
new_inputs.clear()
|
||||
graph.replay()
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return static_outputs
|
||||
|
||||
else:
|
||||
|
|
@ -1918,7 +1919,7 @@ def cudagraphify_impl(
|
|||
index_expanded_dims_and_copy_(static_inputs[idx], src, expanded_dims)
|
||||
new_inputs.clear()
|
||||
graph.replay()
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return static_outputs
|
||||
|
||||
return align_inputs_from_check_idxs(run, check_input_idxs, OrderedSet())
|
||||
|
|
@ -1935,7 +1936,7 @@ def compile_fx_aot(
|
|||
# [See NOTE] Unwrapping subclasses AOT
|
||||
unwrap_tensor_subclass_parameters(model_)
|
||||
|
||||
# pyrefly: ignore # annotation-mismatch
|
||||
# pyrefly: ignore [annotation-mismatch]
|
||||
config_patches: dict[str, Any] = copy.deepcopy(config_patches or {})
|
||||
|
||||
if not (config_patches.get("fx_wrapper", False) or config.fx_wrapper):
|
||||
|
|
@ -2878,7 +2879,7 @@ def _aoti_flatten_inputs(
|
|||
Flatten the inputs to the graph module and return the flat inputs and options.
|
||||
Add "aot_inductor.serialized_in_spec" and "aot_inductor.serialized_out_spec" to the options.
|
||||
"""
|
||||
# pyrefly: ignore # missing-module-attribute
|
||||
# pyrefly: ignore [missing-module-attribute]
|
||||
from .compile_fx import graph_returns_tuple
|
||||
|
||||
assert graph_returns_tuple(gm), (
|
||||
|
|
|
|||
|
|
@ -291,7 +291,7 @@ def normalize_unbind_default(match: Match, *args, **kwargs):
|
|||
log.debug("example value absent for node: %s", input)
|
||||
return
|
||||
ndim = input.meta["example_value"].ndim
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
if dim < 0: # Normalize unbind dim
|
||||
dim += ndim
|
||||
with graph.inserting_after(node):
|
||||
|
|
@ -341,7 +341,7 @@ def normalize_cat_default(match: Match, *args, **kwargs):
|
|||
ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors
|
||||
)
|
||||
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
if cat_dim < 0: # Normalize cat dim
|
||||
cat_dim += ndim
|
||||
|
||||
|
|
@ -949,7 +949,7 @@ class SplitCatSimplifier:
|
|||
if isinstance(user_input, tuple):
|
||||
# Find the correct new getitem (present in split_items)
|
||||
new_user_inputs.append(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
split_items[
|
||||
split_ranges.index(
|
||||
(
|
||||
|
|
@ -1000,7 +1000,7 @@ class SplitCatSimplifier:
|
|||
for user_input_new, transform_param in zip(
|
||||
user_inputs_new, transform_params
|
||||
):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
if not is_node_meta_valid(user_input_new):
|
||||
log.debug("example value absent for node: %s", user_input_new)
|
||||
return
|
||||
|
|
@ -1015,7 +1015,7 @@ class SplitCatSimplifier:
|
|||
stack_dim is None or stack_dim == unsqueeze_params[0]
|
||||
):
|
||||
to_stack.append(user_input_new)
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
to_stack_meta.append(user_input_new.meta["example_value"])
|
||||
stack_dim = unsqueeze_params[0]
|
||||
continue
|
||||
|
|
@ -1036,12 +1036,12 @@ class SplitCatSimplifier:
|
|||
if unsqueeze_params:
|
||||
to_stack.append(user_input_new)
|
||||
stack_dim = unsqueeze_params[0]
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
to_stack_meta.append(user_input_new.meta["example_value"])
|
||||
continue
|
||||
|
||||
if unflatten_params:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
user_input_new_meta = user_input_new.meta["example_value"]
|
||||
user_input_new = graph.call_function(
|
||||
torch.unflatten, args=(user_input_new, *unflatten_params)
|
||||
|
|
@ -1051,7 +1051,7 @@ class SplitCatSimplifier:
|
|||
*unflatten_params, # type: ignore[arg-type]
|
||||
)
|
||||
if movedim_params:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
user_input_new_meta = user_input_new.meta["example_value"]
|
||||
user_input_new = graph.call_function(
|
||||
torch.movedim, args=(user_input_new, *movedim_params)
|
||||
|
|
@ -1061,7 +1061,7 @@ class SplitCatSimplifier:
|
|||
*movedim_params, # type: ignore[arg-type]
|
||||
)
|
||||
if flatten_params:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
user_input_new_meta = user_input_new.meta["example_value"]
|
||||
user_input_new = graph.call_function(
|
||||
torch.flatten, args=(user_input_new, *flatten_params)
|
||||
|
|
@ -1072,7 +1072,7 @@ class SplitCatSimplifier:
|
|||
)
|
||||
user_inputs_new_transformed.append(user_input_new)
|
||||
user_inputs_new_transformed_meta.append(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
user_input_new.meta["example_value"]
|
||||
)
|
||||
if to_stack:
|
||||
|
|
@ -1432,7 +1432,7 @@ def simplify_split_cat(match: Match, split_sections: list[int], dim: int):
|
|||
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
|
||||
return
|
||||
split_node = next(node for node in match.nodes if node.target == torch.split)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
SplitCatSimplifier().simplify(match.graph, split_node, split_sections)
|
||||
|
||||
|
||||
|
|
@ -1501,7 +1501,7 @@ def calculate_fused_tensor_size(split_node: torch.fx.Node, indices: list[int]) -
|
|||
for i in range(len(split_node.args[1])): # type: ignore[arg-type]
|
||||
if i in indices:
|
||||
fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index]
|
||||
# pyrefly: ignore # bad-return
|
||||
# pyrefly: ignore [bad-return]
|
||||
return fused_tensor_size
|
||||
|
||||
|
||||
|
|
@ -1978,7 +1978,7 @@ def normalize_cat_default_aten(match: Match, *args, **kwargs):
|
|||
|
||||
assert all(ndim == x.meta["val"].dim() or is_empty_tensor(x) for x in tensors)
|
||||
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
if cat_dim < 0: # Normalize cat dim
|
||||
cat_dim += ndim
|
||||
|
||||
|
|
@ -2512,7 +2512,8 @@ def reshape_cat_node_to_stack(
|
|||
args=(cat_node, tuple(reshape_list)),
|
||||
)
|
||||
reshape_node.meta["example_value"] = torch.reshape(
|
||||
cat_node.meta["example_value"], tuple(reshape_list)
|
||||
cat_node.meta["example_value"],
|
||||
tuple(reshape_list), # pyrefly: ignore [bad-argument-type]
|
||||
)
|
||||
permute_list = list(range(len(stack_shape)))
|
||||
permute_list[stack_dim], permute_list[split_or_unbind_dim] = (
|
||||
|
|
@ -3044,6 +3045,6 @@ def replace_einsum_to_pointwise(match: Match, *args, **kwargs):
|
|||
einsum_node = match.nodes[0]
|
||||
input, weights = get_arg_value(einsum_node, 1), get_arg_value(einsum_node, 2)
|
||||
if should_replace_einsum(einsum_node):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
match.replace_by_example(repl, [input, weights])
|
||||
counters[backend]["einsum_to_pointwise_pass"] += 1
|
||||
|
|
|
|||
|
|
@ -147,7 +147,7 @@ def _qualified_name(obj, mangle_name=True) -> str:
|
|||
|
||||
# If the module is actually a torchbind module, then we should short circuit
|
||||
if module_name == "torch._classes":
|
||||
return obj.qualified_name # pyrefly: ignore # missing-attribute
|
||||
return obj.qualified_name # pyrefly: ignore [missing-attribute]
|
||||
|
||||
# The Python docs are very clear that `__module__` can be None, but I can't
|
||||
# figure out when it actually would be.
|
||||
|
|
@ -759,7 +759,7 @@ def unused(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
|||
prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED
|
||||
)
|
||||
|
||||
return prop # pyrefly: ignore # bad-return
|
||||
return prop # pyrefly: ignore [bad-return]
|
||||
|
||||
fn._torchscript_modifier = FunctionModifiers.UNUSED # type: ignore[attr-defined]
|
||||
return fn
|
||||
|
|
|
|||
|
|
@ -65,8 +65,8 @@ class AsyncClosureHandler(ClosureHandler):
|
|||
|
||||
self._closure_event_loop = threading.Thread(
|
||||
target=event_loop
|
||||
) # pyrefly: ignore # bad-assignment
|
||||
self._closure_event_loop.start() # pyrefly: ignore # missing-attribute
|
||||
) # pyrefly: ignore [bad-assignment]
|
||||
self._closure_event_loop.start() # pyrefly: ignore [missing-attribute]
|
||||
|
||||
def run(self, closure):
|
||||
with self._closure_lock:
|
||||
|
|
|
|||
|
|
@ -515,8 +515,11 @@ def meta_copy_(self, src, non_blocking=False):
|
|||
def inferUnsqueezeGeometry(tensor, dim):
|
||||
result_sizes = list(tensor.size())
|
||||
result_strides = list(tensor.stride())
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim]
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
result_sizes.insert(dim, 1)
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
result_strides.insert(dim, new_stride)
|
||||
return result_sizes, result_strides
|
||||
|
||||
|
|
@ -2341,19 +2344,19 @@ def calc_conv_nd_return_shape(
|
|||
|
||||
ret_shape = [input_tensor.shape[0], out_channels]
|
||||
if isinstance(stride, IntLike):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
stride = [stride] * len(dims)
|
||||
elif len(stride) == 1:
|
||||
stride = [stride[0]] * len(dims)
|
||||
|
||||
if isinstance(padding, IntLike):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
padding = [padding] * len(dims)
|
||||
elif len(padding) == 1:
|
||||
padding = [padding[0]] * len(dims)
|
||||
|
||||
if isinstance(dilation, IntLike):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
dilation = [dilation] * len(dims)
|
||||
elif len(dilation) == 1:
|
||||
dilation = [dilation[0]] * len(dims)
|
||||
|
|
@ -2361,7 +2364,7 @@ def calc_conv_nd_return_shape(
|
|||
output_padding_list: Optional[list[int]] = None
|
||||
if output_padding:
|
||||
if isinstance(output_padding, IntLike):
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
output_padding_list = [output_padding] * len(dims)
|
||||
elif len(output_padding) == 1:
|
||||
output_padding_list = [output_padding[0]] * len(dims)
|
||||
|
|
@ -2374,19 +2377,19 @@ def calc_conv_nd_return_shape(
|
|||
ret_shape.append(
|
||||
_formula_transposed(
|
||||
dims[i],
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
padding[i],
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
dilation[i],
|
||||
kernel_size[i],
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
stride[i],
|
||||
output_padding_list[i],
|
||||
)
|
||||
)
|
||||
else:
|
||||
ret_shape.append(
|
||||
# pyrefly: ignore # index-error
|
||||
# pyrefly: ignore [index-error]
|
||||
_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])
|
||||
)
|
||||
from torch.fx.experimental.symbolic_shapes import sym_or
|
||||
|
|
@ -3454,7 +3457,7 @@ def meta_index_Tensor(self, indices):
|
|||
"""
|
||||
shape = before_shape + replacement_shape + after_shape
|
||||
strides = list(self.stride())
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
strides[len(before_shape) : len(self.shape) - len(after_shape)] = [0] * len(
|
||||
replacement_shape
|
||||
)
|
||||
|
|
@ -5311,6 +5314,7 @@ def full(size, fill_value, *args, **kwargs):
|
|||
if not dtype:
|
||||
dtype = utils.get_dtype(fill_value)
|
||||
kwargs["dtype"] = dtype
|
||||
# pyrefly: ignore [not-iterable]
|
||||
return torch.empty(size, *args, **kwargs)
|
||||
|
||||
|
||||
|
|
@ -6668,7 +6672,7 @@ def rnn_cell_checkSizes(
|
|||
)
|
||||
torch._check(
|
||||
all(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
x.device == input_gates.device
|
||||
for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
|
||||
),
|
||||
|
|
|
|||
|
|
@ -880,7 +880,7 @@ class OpOverload(OperatorBase, Generic[_P, _T]):
|
|||
elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):
|
||||
return self._op_dk(dk, *args, **kwargs)
|
||||
else:
|
||||
return NotImplemented # pyrefly: ignore # bad-return
|
||||
return NotImplemented # pyrefly: ignore [bad-return]
|
||||
|
||||
# Remove a dispatch key from the dispatch cache. This will force it to get
|
||||
# recomputed the next time. Does nothing
|
||||
|
|
@ -985,9 +985,9 @@ class OpOverload(OperatorBase, Generic[_P, _T]):
|
|||
|
||||
r = self.py_kernels.get(final_key, final_key)
|
||||
if cache_result:
|
||||
self._dispatch_cache[key] = r # pyrefly: ignore # unsupported-operation
|
||||
self._dispatch_cache[key] = r # pyrefly: ignore [unsupported-operation]
|
||||
add_cached_op(self)
|
||||
return r # pyrefly: ignore # bad-return
|
||||
return r # pyrefly: ignore [bad-return]
|
||||
|
||||
def name(self):
|
||||
return self._name
|
||||
|
|
@ -1117,7 +1117,7 @@ class TorchBindOpOverload(OpOverload[_P, _T]):
|
|||
)
|
||||
|
||||
assert isinstance(handler, Callable) # type: ignore[arg-type]
|
||||
return handler(*args, **kwargs) # pyrefly: ignore # bad-return
|
||||
return handler(*args, **kwargs) # pyrefly: ignore [bad-return]
|
||||
|
||||
|
||||
def _must_dispatch_in_python(args, kwargs):
|
||||
|
|
|
|||
|
|
@ -267,6 +267,7 @@ class FunctionalTensor(torch.Tensor):
|
|||
device=self.device,
|
||||
layout=self.layout,
|
||||
)
|
||||
# pyrefly: ignore [not-iterable]
|
||||
return super().to(*args, **kwargs)
|
||||
|
||||
def cuda(self, device=None, *args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -551,6 +551,7 @@ class Tensor(torch._C.TensorBase):
|
|||
raise RuntimeError("__setstate__ can be only called on leaf Tensors")
|
||||
if len(state) == 4:
|
||||
# legacy serialization of Tensor
|
||||
# pyrefly: ignore [not-iterable]
|
||||
self.set_(*state)
|
||||
return
|
||||
elif len(state) == 5:
|
||||
|
|
@ -758,7 +759,7 @@ class Tensor(torch._C.TensorBase):
|
|||
)
|
||||
if self._post_accumulate_grad_hooks is None:
|
||||
self._post_accumulate_grad_hooks: dict[Any, Any] = (
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
OrderedDict()
|
||||
)
|
||||
|
||||
|
|
@ -1062,7 +1063,7 @@ class Tensor(torch._C.TensorBase):
|
|||
else:
|
||||
return torch._VF.split_with_sizes(
|
||||
self,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
split_size,
|
||||
dim,
|
||||
)
|
||||
|
|
@ -1119,7 +1120,7 @@ class Tensor(torch._C.TensorBase):
|
|||
__rtruediv__ = __rdiv__
|
||||
__itruediv__ = _C.TensorBase.__idiv__
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
__pow__ = cast(
|
||||
Callable[
|
||||
["torch._C.TensorBase", Union["Tensor", int, float, bool, complex]],
|
||||
|
|
|
|||
|
|
@ -686,8 +686,8 @@ def _take_tensors(tensors, size_limit):
|
|||
if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0:
|
||||
yield buf_and_size[0]
|
||||
buf_and_size = buf_dict[t] = [[], 0]
|
||||
buf_and_size[0].append(tensor) # pyrefly: ignore # missing-attribute
|
||||
buf_and_size[1] += size # pyrefly: ignore # unsupported-operation
|
||||
buf_and_size[0].append(tensor) # pyrefly: ignore [missing-attribute]
|
||||
buf_and_size[1] += size # pyrefly: ignore [unsupported-operation]
|
||||
for buf, _ in buf_dict.values():
|
||||
if len(buf) > 0:
|
||||
yield buf
|
||||
|
|
@ -744,6 +744,7 @@ class ExceptionWrapper:
|
|||
if exc_info is None:
|
||||
exc_info = sys.exc_info()
|
||||
self.exc_type = exc_info[0]
|
||||
# pyrefly: ignore [not-iterable]
|
||||
self.exc_msg = "".join(traceback.format_exception(*exc_info))
|
||||
self.where = where
|
||||
|
||||
|
|
@ -751,7 +752,7 @@ class ExceptionWrapper:
|
|||
r"""Reraises the wrapped exception in the current thread"""
|
||||
# Format a message such as: "Caught ValueError in DataLoader worker
|
||||
# process 2. Original Traceback:", followed by the traceback.
|
||||
msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" # pyrefly: ignore # missing-attribute
|
||||
msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" # pyrefly: ignore [missing-attribute]
|
||||
if self.exc_type is KeyError:
|
||||
# KeyError calls repr() on its argument (usually a dict key). This
|
||||
# makes stack traces unreadable. It will not be changed in Python
|
||||
|
|
@ -760,13 +761,13 @@ class ExceptionWrapper:
|
|||
elif getattr(self.exc_type, "message", None):
|
||||
# Some exceptions have first argument as non-str but explicitly
|
||||
# have message field
|
||||
# pyrefly: ignore # not-callable
|
||||
# pyrefly: ignore [not-callable]
|
||||
raise self.exc_type(
|
||||
# pyrefly: ignore # unexpected-keyword
|
||||
# pyrefly: ignore [unexpected-keyword]
|
||||
message=msg
|
||||
)
|
||||
try:
|
||||
exception = self.exc_type(msg) # pyrefly: ignore # not-callable
|
||||
exception = self.exc_type(msg) # pyrefly: ignore [not-callable]
|
||||
except Exception:
|
||||
# If the exception takes multiple arguments or otherwise can't
|
||||
# be constructed, don't try to instantiate since we don't know how to
|
||||
|
|
@ -1018,12 +1019,12 @@ class _LazySeedTracker:
|
|||
self.call_order = []
|
||||
|
||||
def queue_seed_all(self, cb, traceback):
|
||||
self.manual_seed_all_cb = (cb, traceback) # pyrefly: ignore # bad-assignment
|
||||
self.manual_seed_all_cb = (cb, traceback) # pyrefly: ignore [bad-assignment]
|
||||
# update seed_all to be latest
|
||||
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
|
||||
|
||||
def queue_seed(self, cb, traceback):
|
||||
self.manual_seed_cb = (cb, traceback) # pyrefly: ignore # bad-assignment
|
||||
self.manual_seed_cb = (cb, traceback) # pyrefly: ignore [bad-assignment]
|
||||
# update seed to be latest
|
||||
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
|
||||
|
||||
|
|
|
|||
|
|
@ -419,6 +419,7 @@ class Unpickler:
|
|||
inst = self.stack[-1]
|
||||
if type(inst) is torch.Tensor:
|
||||
# Legacy unpickling
|
||||
# pyrefly: ignore [not-iterable]
|
||||
inst.set_(*state)
|
||||
elif type(inst) is torch.nn.Parameter:
|
||||
inst.__setstate__(state)
|
||||
|
|
|
|||
|
|
@ -104,6 +104,7 @@ def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]:
|
|||
|
||||
flatten("", stats)
|
||||
flat_stats.sort()
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
return OrderedDict(flat_stats)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -525,7 +525,7 @@ def custom_fwd(
|
|||
args[0]._dtype = torch.get_autocast_dtype(device_type)
|
||||
if cast_inputs is None:
|
||||
args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type)
|
||||
return fwd(*args, **kwargs) # pyrefly: ignore # not-callable
|
||||
return fwd(*args, **kwargs) # pyrefly: ignore [not-callable]
|
||||
else:
|
||||
autocast_context = torch.is_autocast_enabled(device_type)
|
||||
args[0]._fwd_used_autocast = False
|
||||
|
|
@ -536,7 +536,7 @@ def custom_fwd(
|
|||
**_cast(kwargs, device_type, cast_inputs),
|
||||
)
|
||||
else:
|
||||
return fwd(*args, **kwargs) # pyrefly: ignore # not-callable
|
||||
return fwd(*args, **kwargs) # pyrefly: ignore [not-callable]
|
||||
|
||||
return decorate_fwd
|
||||
|
||||
|
|
@ -571,6 +571,6 @@ def custom_bwd(bwd=None, *, device_type: str):
|
|||
enabled=args[0]._fwd_used_autocast,
|
||||
dtype=args[0]._dtype,
|
||||
):
|
||||
return bwd(*args, **kwargs) # pyrefly: ignore # not-callable
|
||||
return bwd(*args, **kwargs) # pyrefly: ignore [not-callable]
|
||||
|
||||
return decorate_bwd
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ class _NSGraphMatchableSubgraphsIterator:
|
|||
if is_match:
|
||||
# navigate to the base node
|
||||
for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.seen_nodes.add(cur_start_node)
|
||||
# for now, assume that there are no other nodes
|
||||
# which need to be added to the stack
|
||||
|
|
@ -95,10 +95,10 @@ class _NSGraphMatchableSubgraphsIterator:
|
|||
cur_base_op_node = cur_start_node
|
||||
break
|
||||
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.seen_nodes.add(cur_start_node)
|
||||
# add args of previous nodes to stack
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
for arg in cur_start_node.all_input_nodes:
|
||||
self._recursively_add_node_arg_to_stack(arg)
|
||||
|
||||
|
|
@ -106,7 +106,7 @@ class _NSGraphMatchableSubgraphsIterator:
|
|||
# note: this check is done on the start_node, i.e.
|
||||
# if we are matching linear-relu in reverse, this would do the matchable
|
||||
# check on the linear
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
if not self._is_matchable(cur_base_op_node):
|
||||
continue
|
||||
|
||||
|
|
@ -120,10 +120,10 @@ class _NSGraphMatchableSubgraphsIterator:
|
|||
continue
|
||||
|
||||
return NSSubgraph(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
start_node=cur_start_node,
|
||||
end_node=cur_end_node,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
base_op_node=cur_base_op_node,
|
||||
)
|
||||
|
||||
|
|
@ -481,4 +481,5 @@ of subgraphs."""
|
|||
# subgraphs in their order of execution.
|
||||
results = collections.OrderedDict(reversed(results.items()))
|
||||
|
||||
# pyrefly: ignore [bad-return]
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -30,6 +30,7 @@ class EventList(list):
|
|||
use_device = kwargs.pop("use_device", None)
|
||||
profile_memory = kwargs.pop("profile_memory", False)
|
||||
with_flops = kwargs.pop("with_flops", False)
|
||||
# pyrefly: ignore [not-iterable]
|
||||
super().__init__(*args, **kwargs)
|
||||
self._use_device = use_device
|
||||
self._profile_memory = profile_memory
|
||||
|
|
@ -505,9 +506,9 @@ class FunctionEvent(FormattedTimesMixin):
|
|||
self.id: int = id
|
||||
self.node_id: int = node_id
|
||||
self.name: str = name
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.overload_name: str = overload_name
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.trace_name: str = trace_name
|
||||
self.time_range: Interval = Interval(start_us, end_us)
|
||||
self.thread: int = thread
|
||||
|
|
@ -516,13 +517,13 @@ class FunctionEvent(FormattedTimesMixin):
|
|||
self.count: int = 1
|
||||
self.cpu_children: list[FunctionEvent] = []
|
||||
self.cpu_parent: Optional[FunctionEvent] = None
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.input_shapes: tuple[int, ...] = input_shapes
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.concrete_inputs: list[Any] = concrete_inputs
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.kwinputs: dict[str, Any] = kwinputs
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.stack: list = stack
|
||||
self.scope: int = scope
|
||||
self.use_device: Optional[str] = use_device
|
||||
|
|
@ -766,7 +767,7 @@ class FunctionEventAvg(FormattedTimesMixin):
|
|||
self.self_device_memory_usage += other.self_device_memory_usage
|
||||
self.count += other.count
|
||||
if self.flops is None:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.flops = other.flops
|
||||
elif other.flops is not None:
|
||||
self.flops += other.flops
|
||||
|
|
@ -1003,7 +1004,7 @@ def _build_table(
|
|||
]
|
||||
if flops <= 0:
|
||||
raise AssertionError(f"Expected flops to be positive, but got {flops}")
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1)))
|
||||
if not (log_flops >= 0 and log_flops < len(flop_headers)):
|
||||
raise AssertionError(
|
||||
|
|
|
|||
|
|
@ -50,6 +50,7 @@ def compile(*args, **kwargs):
|
|||
"""
|
||||
See :func:`torch.compile` for details on the arguments for this function.
|
||||
"""
|
||||
# pyrefly: ignore [not-iterable]
|
||||
return torch.compile(*args, **kwargs)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -198,6 +198,7 @@ def _for_each_rank_run_func(
|
|||
rr_val = flat_rank_rets[rr_key]
|
||||
|
||||
if isinstance(rr_val, Tensor):
|
||||
# pyrefly: ignore [bad-argument-type, bad-argument-count]
|
||||
ret = LocalTensor({r: flat_rank_rets[r] for r in sorted(ranks)})
|
||||
elif isinstance(rr_val, (list, tuple)):
|
||||
ret_list = []
|
||||
|
|
@ -206,6 +207,7 @@ def _for_each_rank_run_func(
|
|||
v_it = iter(rets.values())
|
||||
v = next(v_it)
|
||||
if isinstance(v, Tensor):
|
||||
# pyrefly: ignore [bad-argument-type, bad-argument-count]
|
||||
ret_list.append(LocalTensor(rets))
|
||||
elif isinstance(v, int) and not all(v == v2 for v2 in v_it):
|
||||
ret_list.append(torch.SymInt(LocalIntNode(rets)))
|
||||
|
|
@ -468,7 +470,7 @@ class LocalTensor(torch.Tensor):
|
|||
def __repr__(self) -> str: # type: ignore[override]
|
||||
parts = []
|
||||
for k, v in self._local_tensors.items():
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
parts.append(f" {k}: {v}")
|
||||
tensors_str = ",\n".join(parts)
|
||||
return f"LocalTensor(\n{tensors_str}\n)"
|
||||
|
|
@ -491,6 +493,7 @@ class LocalTensor(torch.Tensor):
|
|||
"Expecting spec to be not None from `__tensor_flatten__` return value!"
|
||||
)
|
||||
local_tensors = inner_tensors["_local_tensors"]
|
||||
# pyrefly: ignore [bad-argument-type, bad-argument-count]
|
||||
return LocalTensor(local_tensors)
|
||||
|
||||
@classmethod
|
||||
|
|
@ -751,6 +754,7 @@ class LocalTensorMode(TorchDispatchMode):
|
|||
"""
|
||||
|
||||
with self.disable():
|
||||
# pyrefly: ignore [bad-argument-type, bad-argument-count]
|
||||
return LocalTensor({r: cb(r) for r in self.ranks})
|
||||
|
||||
def _patch_device_mesh(self) -> None:
|
||||
|
|
@ -761,7 +765,7 @@ class LocalTensorMode(TorchDispatchMode):
|
|||
def _unpatch_device_mesh(self) -> None:
|
||||
assert self._old_get_coordinate is not None
|
||||
DeviceMesh.get_coordinate = self._old_get_coordinate
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self._old_get_coordinate = None
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -79,6 +79,7 @@ def _flatten_tensor_size(size) -> torch.Size:
|
|||
Checks if tensor size is valid, then flatten/return a torch.Size object.
|
||||
"""
|
||||
if len(size) == 1 and isinstance(size[0], collections.abc.Sequence):
|
||||
# pyrefly: ignore [not-iterable]
|
||||
dims = list(*size)
|
||||
else:
|
||||
dims = list(size)
|
||||
|
|
@ -208,7 +209,7 @@ def build_global_metadata(
|
|||
global_sharded_tensor_metadata = None
|
||||
global_metadata_rank = 0
|
||||
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
for rank, rank_metadata in enumerate(gathered_metadatas):
|
||||
if rank_metadata is None:
|
||||
continue
|
||||
|
|
|
|||
|
|
@ -227,7 +227,7 @@ class PGTransport:
|
|||
self._work: list[Work] = []
|
||||
self._pg = pg
|
||||
self._timeout = timeout
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
self._device = device
|
||||
self._state_dict = state_dict
|
||||
|
||||
|
|
@ -345,6 +345,7 @@ class PGTransport:
|
|||
values.append(recv(path, v))
|
||||
elif isinstance(v, _DTensorMeta):
|
||||
tensor = recv(path, v.local)
|
||||
# pyrefly: ignore [bad-argument-type, bad-argument-count, unexpected-keyword]
|
||||
values.append(DTensor(tensor, v.spec, requires_grad=False))
|
||||
elif isinstance(v, _ShardedTensorMeta):
|
||||
# Receive all local shards that were sent to us
|
||||
|
|
|
|||
|
|
@ -565,7 +565,7 @@ class FlatParamHandle:
|
|||
# Only align addresses for `use_orig_params=True` (for now)
|
||||
align_addresses = use_orig_params
|
||||
self._init_get_unflat_views_fn(align_addresses)
|
||||
# pyrefly: ignore # read-only
|
||||
# pyrefly: ignore [read-only]
|
||||
self.device = device
|
||||
self._device_handle = _FSDPDeviceHandle.from_device(self.device)
|
||||
self.process_group = process_group
|
||||
|
|
@ -2495,6 +2495,7 @@ class FlatParamHandle:
|
|||
###########
|
||||
def flat_param_to(self, *args, **kwargs):
|
||||
"""Wrap an in-place call to ``.to()`` for ``self.flat_param``."""
|
||||
# pyrefly: ignore [not-iterable]
|
||||
self.flat_param.data = self.flat_param.to(*args, **kwargs)
|
||||
if self._use_orig_params:
|
||||
# Refresh the views because their storage may have changed
|
||||
|
|
|
|||
|
|
@ -139,11 +139,14 @@ def _from_local_no_grad(
|
|||
"""
|
||||
|
||||
if not compiled_autograd_enabled():
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return DTensor(
|
||||
# Use the local tensor directly instead of constructing a new tensor
|
||||
# variable, e.g. with `view_as()`, since this is not differentiable
|
||||
# pyrefly: ignore [bad-argument-count]
|
||||
local_tensor,
|
||||
sharding_spec,
|
||||
# pyrefly: ignore [unexpected-keyword]
|
||||
requires_grad=local_tensor.requires_grad,
|
||||
)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -107,9 +107,12 @@ class _ToTorchTensor(torch.autograd.Function):
|
|||
)
|
||||
|
||||
return (
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
DTensor(
|
||||
# pyrefly: ignore [bad-argument-count]
|
||||
grad_output,
|
||||
grad_spec,
|
||||
# pyrefly: ignore [unexpected-keyword]
|
||||
requires_grad=grad_output.requires_grad,
|
||||
),
|
||||
None,
|
||||
|
|
@ -175,11 +178,14 @@ class _FromTorchTensor(torch.autograd.Function):
|
|||
)
|
||||
|
||||
# We want a fresh Tensor object that shares memory with the input tensor
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
dist_tensor = DTensor(
|
||||
# pyrefly: ignore [bad-argument-count]
|
||||
input.view_as(input),
|
||||
dist_spec,
|
||||
# requires_grad of the dist tensor depends on if input
|
||||
# requires_grad or not
|
||||
# pyrefly: ignore [unexpected-keyword]
|
||||
requires_grad=input.requires_grad,
|
||||
)
|
||||
return dist_tensor
|
||||
|
|
@ -304,9 +310,12 @@ class DTensor(torch.Tensor):
|
|||
spec.placements,
|
||||
tensor_meta=unflatten_tensor_meta,
|
||||
)
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return DTensor(
|
||||
# pyrefly: ignore [bad-argument-count]
|
||||
local_tensor,
|
||||
unflatten_spec,
|
||||
# pyrefly: ignore [unexpected-keyword]
|
||||
requires_grad=requires_grad,
|
||||
)
|
||||
|
||||
|
|
@ -820,9 +829,12 @@ def distribute_tensor(
|
|||
dtype=tensor.dtype,
|
||||
),
|
||||
)
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return DTensor(
|
||||
# pyrefly: ignore [bad-argument-count]
|
||||
local_tensor.requires_grad_(tensor.requires_grad),
|
||||
spec,
|
||||
# pyrefly: ignore [unexpected-keyword]
|
||||
requires_grad=tensor.requires_grad,
|
||||
)
|
||||
|
||||
|
|
@ -1077,9 +1089,12 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def]
|
|||
),
|
||||
)
|
||||
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return DTensor(
|
||||
# pyrefly: ignore [bad-argument-count]
|
||||
local_tensor,
|
||||
spec,
|
||||
# pyrefly: ignore [unexpected-keyword]
|
||||
requires_grad=kwargs["requires_grad"],
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -78,8 +78,11 @@ def found_inf_reduce_handler(
|
|||
dtype=target_tensor.dtype,
|
||||
),
|
||||
)
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
found_inf_dtensor = dtensor.DTensor(
|
||||
local_tensor=target_tensor, spec=spec, requires_grad=False
|
||||
local_tensor=target_tensor, # pyrefly: ignore [unexpected-keyword]
|
||||
spec=spec, # pyrefly: ignore [unexpected-keyword]
|
||||
requires_grad=False, # pyrefly: ignore [unexpected-keyword]
|
||||
)
|
||||
found_inf = found_inf_dtensor.full_tensor()
|
||||
target_tensor.copy_(found_inf)
|
||||
|
|
@ -189,7 +192,7 @@ class OpDispatcher:
|
|||
local_tensor_args = (
|
||||
pytree.tree_unflatten(
|
||||
cast(list[object], op_info.local_args),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
op_info.args_tree_spec,
|
||||
)
|
||||
if op_info.args_tree_spec
|
||||
|
|
@ -366,7 +369,7 @@ class OpDispatcher:
|
|||
resharded_local_tensor = redistribute_local_tensor(
|
||||
local_tensor,
|
||||
arg_spec,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
reshard_arg_spec,
|
||||
)
|
||||
new_local_args.append(resharded_local_tensor)
|
||||
|
|
@ -439,7 +442,7 @@ class OpDispatcher:
|
|||
kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor(
|
||||
op_call,
|
||||
v,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
compute_mesh,
|
||||
)
|
||||
local_kwargs[k] = v
|
||||
|
|
@ -456,7 +459,7 @@ class OpDispatcher:
|
|||
OpSchema(
|
||||
op_call,
|
||||
(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
pytree.tree_unflatten(args_schema, args_spec)
|
||||
if args_spec
|
||||
else tuple(args_schema)
|
||||
|
|
@ -478,6 +481,7 @@ class OpDispatcher:
|
|||
assert isinstance(spec, DTensorSpec), (
|
||||
f"output spec does not match with output! Expected DTensorSpec, got {spec}."
|
||||
)
|
||||
# pyrefly: ignore [bad-argument-type, bad-argument-count, unexpected-keyword]
|
||||
return dtensor.DTensor(res, spec, requires_grad=res.requires_grad)
|
||||
else:
|
||||
# if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor
|
||||
|
|
|
|||
|
|
@ -883,9 +883,12 @@ class Redistribute(torch.autograd.Function):
|
|||
output = local_tensor
|
||||
target_spec = current_spec
|
||||
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return dtensor.DTensor(
|
||||
# pyrefly: ignore [bad-argument-count]
|
||||
output,
|
||||
target_spec,
|
||||
# pyrefly: ignore [unexpected-keyword]
|
||||
requires_grad=input.requires_grad,
|
||||
)
|
||||
|
||||
|
|
@ -944,9 +947,12 @@ class Redistribute(torch.autograd.Function):
|
|||
dtype=output.dtype,
|
||||
),
|
||||
)
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
output_dtensor = dtensor.DTensor(
|
||||
# pyrefly: ignore [bad-argument-count]
|
||||
output,
|
||||
spec,
|
||||
# pyrefly: ignore [unexpected-keyword]
|
||||
requires_grad=grad_output.requires_grad,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -174,9 +174,12 @@ def _log_softmax_handler(
|
|||
tensor_meta=output_tensor_meta,
|
||||
)
|
||||
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return DTensor(
|
||||
# pyrefly: ignore [bad-argument-count]
|
||||
res,
|
||||
res_spec,
|
||||
# pyrefly: ignore [unexpected-keyword]
|
||||
requires_grad=res.requires_grad,
|
||||
)
|
||||
|
||||
|
|
@ -251,7 +254,7 @@ def _nll_loss_forward(
|
|||
if weight is not None:
|
||||
new_shape = list(x.shape)
|
||||
new_shape[channel_dim] = -1
|
||||
# pyrefly: ignore # unbound-name
|
||||
# pyrefly: ignore [unbound-name]
|
||||
w = w.expand(new_shape)
|
||||
wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
|
||||
wsum = torch.where(target != ignore_index, wsum, 0)
|
||||
|
|
@ -309,9 +312,9 @@ def _nll_loss_forward_handler(
|
|||
output_placements = all_replicate_placements
|
||||
|
||||
# tensor inputs to _propagate_tensor_meta need to be DTensors
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
args = list(args)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
args[1], args[2] = target, weight
|
||||
output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
|
||||
|
||||
|
|
@ -330,9 +333,12 @@ def _nll_loss_forward_handler(
|
|||
out_spec = DTensorSpec(spec.mesh, output_placements, tensor_meta=output_tensor_meta)
|
||||
|
||||
return (
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
DTensor(
|
||||
# pyrefly: ignore [bad-argument-count]
|
||||
result,
|
||||
out_spec,
|
||||
# pyrefly: ignore [unexpected-keyword]
|
||||
requires_grad=result.requires_grad,
|
||||
),
|
||||
total_weight,
|
||||
|
|
@ -442,11 +448,11 @@ def _nll_loss_backward_handler(
|
|||
weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh)
|
||||
|
||||
# tensor inputs to _propagate_tensor_meta need to be DTensors
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
args = list(args)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
args[2], args[3] = target, weight
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
args[6] = _cast_to_dtensor(total_weight, all_replicate_placements, spec.mesh)
|
||||
output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
|
||||
|
||||
|
|
@ -470,9 +476,12 @@ def _nll_loss_backward_handler(
|
|||
tensor_meta=output_tensor_meta,
|
||||
)
|
||||
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
return DTensor(
|
||||
# pyrefly: ignore [bad-argument-count]
|
||||
result,
|
||||
out_spec,
|
||||
# pyrefly: ignore [unexpected-keyword]
|
||||
requires_grad=result.requires_grad,
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -949,7 +949,7 @@ def _check_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module):
|
|||
for key, value in pytree.tree_map(arg_dump, node.kwargs).items()
|
||||
]
|
||||
target = node.target if node.op in ("call_function", "get_attr") else ""
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})")
|
||||
nodes_idx[id(node)] = i
|
||||
return "\n".join(ret)
|
||||
|
|
@ -1206,6 +1206,7 @@ class _ModuleFrame:
|
|||
for k in kwargs_spec.context
|
||||
}
|
||||
assert self.parent_call_module is not None
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
self.parent_call_module.args = tuple(arg_nodes)
|
||||
self.parent_call_module.kwargs = kwarg_nodes # type: ignore[assignment]
|
||||
|
||||
|
|
@ -1393,6 +1394,7 @@ class _ModuleFrame:
|
|||
|
||||
def print(self, *args, **kwargs):
|
||||
if self.verbose:
|
||||
# pyrefly: ignore [not-iterable]
|
||||
print(*args, **kwargs)
|
||||
|
||||
def run_from(self, node_idx):
|
||||
|
|
@ -1486,7 +1488,7 @@ class _ModuleFrame:
|
|||
self.seen_attrs[self.child_fqn].add(node.target)
|
||||
|
||||
self.copy_node(node)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
node_idx += 1
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1952,7 +1952,7 @@ def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
|
|||
)
|
||||
|
||||
if isinstance(shape, (int, torch.SymInt)):
|
||||
shape = torch.Size([shape]) # pyrefly: ignore # bad-argument-type
|
||||
shape = torch.Size([shape]) # pyrefly: ignore [bad-argument-type]
|
||||
else:
|
||||
for dim in shape:
|
||||
torch._check_type(
|
||||
|
|
|
|||
|
|
@ -87,6 +87,7 @@ def _reify_object_slots(o, s):
|
|||
@dispatch(slice, dict)
|
||||
def _reify(o, s):
|
||||
"""Reify a Python ``slice`` object"""
|
||||
# pyrefly: ignore [not-iterable]
|
||||
return slice(*reify((o.start, o.stop, o.step), s))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ Argument = Optional[
|
|||
BaseArgumentTypes,
|
||||
]
|
||||
]
|
||||
# pyrefly: ignore # invalid-annotation
|
||||
# pyrefly: ignore [invalid-annotation]
|
||||
ArgumentT = TypeVar("ArgumentT", bound=Argument)
|
||||
_P = ParamSpec("_P")
|
||||
_R = TypeVar("_R")
|
||||
|
|
@ -385,7 +385,7 @@ class Node(_NodeBase):
|
|||
Args:
|
||||
x (Node): The node to put before this node. Must be a member of the same graph.
|
||||
"""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self._prepend(x)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
|
|
@ -397,7 +397,7 @@ class Node(_NodeBase):
|
|||
Args:
|
||||
x (Node): The node to put after this node. Must be a member of the same graph.
|
||||
"""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self._next._prepend(x)
|
||||
|
||||
@property
|
||||
|
|
@ -698,7 +698,8 @@ class Node(_NodeBase):
|
|||
if replace_hooks:
|
||||
for replace_hook in replace_hooks:
|
||||
replace_hook(old=self, new=replace_with.name, user=use_node)
|
||||
use_node._replace_input_with(self, replace_with)
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
use_node._replace_input_with(self, replace_with) # type: ignore[attr-defined]
|
||||
return result
|
||||
|
||||
@compatibility(is_backward_compatible=False)
|
||||
|
|
@ -835,7 +836,8 @@ class Node(_NodeBase):
|
|||
for replace_hook in m._replace_hooks:
|
||||
replace_hook(old=old_input, new=new_input.name, user=self)
|
||||
|
||||
self._replace_input_with(old_input, new_input)
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self._replace_input_with(old_input, new_input) # type: ignore[attr-defined]
|
||||
|
||||
def _rename(self, candidate: str) -> None:
|
||||
if candidate == self.name:
|
||||
|
|
|
|||
|
|
@ -303,7 +303,7 @@ class StreamContext:
|
|||
self.idx = _get_device_index(None, True)
|
||||
if not torch.jit.is_scripting():
|
||||
if self.idx is None:
|
||||
self.idx = -1 # pyrefly: ignore # bad-assignment
|
||||
self.idx = -1 # pyrefly: ignore [bad-assignment]
|
||||
|
||||
self.src_prev_stream = (
|
||||
None if not torch.jit.is_scripting() else torch.mtia.default_stream(None)
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False):
|
|||
if canonicalize:
|
||||
dim = canonicalize_dims(ndim, dim)
|
||||
|
||||
assert dim >= 0 and dim < ndim # pyrefly: ignore # unsupported-operation
|
||||
assert dim >= 0 and dim < ndim # pyrefly: ignore [unsupported-operation]
|
||||
|
||||
# Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1.
|
||||
# For other dims, subtract 1 to convert to inner space.
|
||||
|
|
|
|||
|
|
@ -72,7 +72,7 @@ class _NormBase(Module):
|
|||
torch.tensor(
|
||||
0,
|
||||
dtype=torch.long,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
|
||||
),
|
||||
)
|
||||
|
|
@ -222,7 +222,7 @@ class _LazyNormBase(LazyModuleMixin, _NormBase):
|
|||
dtype=None,
|
||||
) -> None:
|
||||
factory_kwargs = {"device": device, "dtype": dtype}
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
super().__init__(
|
||||
# affine and track_running_stats are hardcoded to False to
|
||||
# avoid creating tensors that will soon be overwritten.
|
||||
|
|
@ -236,29 +236,29 @@ class _LazyNormBase(LazyModuleMixin, _NormBase):
|
|||
self.affine = affine
|
||||
self.track_running_stats = track_running_stats
|
||||
if self.affine:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.weight = UninitializedParameter(**factory_kwargs)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.bias = UninitializedParameter(**factory_kwargs)
|
||||
if self.track_running_stats:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.running_mean = UninitializedBuffer(**factory_kwargs)
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
self.running_var = UninitializedBuffer(**factory_kwargs)
|
||||
self.num_batches_tracked = torch.tensor(
|
||||
0,
|
||||
dtype=torch.long,
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
|
||||
)
|
||||
|
||||
def reset_parameters(self) -> None:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
if not self.has_uninitialized_params() and self.num_features != 0:
|
||||
super().reset_parameters()
|
||||
|
||||
def initialize_parameters(self, input) -> None: # type: ignore[override]
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
if self.has_uninitialized_params():
|
||||
self.num_features = input.shape[1]
|
||||
if self.affine:
|
||||
|
|
@ -352,6 +352,7 @@ class BatchNorm1d(_BatchNorm):
|
|||
raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
|
||||
|
||||
|
||||
# pyrefly: ignore [inconsistent-inheritance]
|
||||
class LazyBatchNorm1d(_LazyNormBase, _BatchNorm):
|
||||
r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization.
|
||||
|
||||
|
|
@ -463,6 +464,7 @@ class BatchNorm2d(_BatchNorm):
|
|||
raise ValueError(f"expected 4D input (got {input.dim()}D input)")
|
||||
|
||||
|
||||
# pyrefly: ignore [inconsistent-inheritance]
|
||||
class LazyBatchNorm2d(_LazyNormBase, _BatchNorm):
|
||||
r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization.
|
||||
|
||||
|
|
@ -574,6 +576,7 @@ class BatchNorm3d(_BatchNorm):
|
|||
raise ValueError(f"expected 5D input (got {input.dim()}D input)")
|
||||
|
||||
|
||||
# pyrefly: ignore [inconsistent-inheritance]
|
||||
class LazyBatchNorm3d(_LazyNormBase, _BatchNorm):
|
||||
r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization.
|
||||
|
||||
|
|
|
|||
|
|
@ -38,13 +38,13 @@ T = TypeVar("T", bound="Module")
|
|||
|
||||
|
||||
class _IncompatibleKeys(
|
||||
# pyrefly: ignore # invalid-inheritance
|
||||
# pyrefly: ignore [invalid-inheritance]
|
||||
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
|
||||
):
|
||||
__slots__ = ()
|
||||
|
||||
def __repr__(self) -> str:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if not self.missing_keys and not self.unexpected_keys:
|
||||
return "<All keys matched successfully>"
|
||||
return super().__repr__()
|
||||
|
|
@ -93,7 +93,7 @@ class _WrappedHook:
|
|||
def __getstate__(self) -> dict:
|
||||
result = {"hook": self.hook, "with_module": self.with_module}
|
||||
if self.with_module:
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
result["module"] = self.module()
|
||||
|
||||
return result
|
||||
|
|
@ -979,7 +979,7 @@ class Module:
|
|||
# Decrement use count of the gradient by setting to None
|
||||
param.grad = None
|
||||
param_applied = torch.nn.Parameter(
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
param_applied,
|
||||
requires_grad=param.requires_grad,
|
||||
)
|
||||
|
|
@ -992,13 +992,13 @@ class Module:
|
|||
) from e
|
||||
out_param = param
|
||||
elif p_should_use_set_data:
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
param.data = param_applied
|
||||
out_param = param
|
||||
else:
|
||||
assert isinstance(param, Parameter)
|
||||
assert param.is_leaf
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
out_param = Parameter(param_applied, param.requires_grad)
|
||||
self._parameters[key] = out_param
|
||||
|
||||
|
|
@ -1337,7 +1337,9 @@ class Module:
|
|||
|
||||
"""
|
||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
|
||||
*args, **kwargs
|
||||
# pyrefly: ignore [not-iterable]
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if dtype is not None:
|
||||
|
|
@ -2256,7 +2258,7 @@ class Module:
|
|||
|
||||
if destination is None:
|
||||
destination = OrderedDict()
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
destination._metadata = OrderedDict()
|
||||
|
||||
local_metadata = dict(version=self._version)
|
||||
|
|
@ -2407,7 +2409,7 @@ class Module:
|
|||
}
|
||||
local_name_params = itertools.chain(
|
||||
self._parameters.items(),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
persistent_buffers.items(),
|
||||
)
|
||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ def _verbose_printer(verbose: bool | None) -> Callable[..., None]:
|
|||
"""Prints messages based on `verbose`."""
|
||||
if verbose is False:
|
||||
return lambda *_, **__: None
|
||||
# pyrefly: ignore [not-iterable]
|
||||
return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs)
|
||||
|
||||
|
||||
|
|
@ -47,7 +48,7 @@ def _patch_dynamo_unsupported_functions():
|
|||
|
||||
# Replace torch.jit.isinstance with isinstance
|
||||
jit_isinstance = torch.jit.isinstance
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
torch.jit.isinstance = isinstance
|
||||
logger.info("Replaced torch.jit.isinstance with isinstance to allow dynamo tracing")
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -132,10 +132,10 @@ class TorchTensor(ir.Tensor):
|
|||
# view the tensor as that dtype so that it is convertible to NumPy,
|
||||
# and then view it back to the proper dtype (using ml_dtypes obtained by
|
||||
# calling dtype.numpy()).
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
if self.dtype == ir.DataType.BFLOAT16:
|
||||
return (
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy())
|
||||
)
|
||||
if self.dtype in {
|
||||
|
|
@ -144,11 +144,11 @@ class TorchTensor(ir.Tensor):
|
|||
ir.DataType.FLOAT8E5M2,
|
||||
ir.DataType.FLOAT8E5M2FNUZ,
|
||||
}:
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
|
||||
if self.dtype == ir.DataType.FLOAT4E2M1:
|
||||
return _type_casting.unpack_float4x2_as_uint8(self.raw).view(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
self.dtype.numpy()
|
||||
)
|
||||
|
||||
|
|
@ -170,7 +170,7 @@ class TorchTensor(ir.Tensor):
|
|||
|
||||
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
|
||||
raise TypeError(
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor "
|
||||
"with a tensor backed by real data using ONNXProgram.apply_weights() "
|
||||
"or save the model without initializers by setting include_initializers=False."
|
||||
|
|
@ -251,7 +251,7 @@ def _set_shape_type(
|
|||
if isinstance(dim, int):
|
||||
dims.append(dim)
|
||||
else:
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
dims.append(str(dim.node))
|
||||
|
||||
# If the dtype is set already (e.g. by the onnx_symbolic ops),
|
||||
|
|
@ -1232,7 +1232,7 @@ def _exported_program_to_onnx_program(
|
|||
# so we need to get them from the name_* apis.
|
||||
for name, torch_tensor in itertools.chain(
|
||||
exported_program.named_parameters(),
|
||||
# pyrefly: ignore # bad-argument-type
|
||||
# pyrefly: ignore [bad-argument-type]
|
||||
exported_program.named_buffers(),
|
||||
exported_program.constants.items(),
|
||||
):
|
||||
|
|
@ -1265,6 +1265,7 @@ def _verbose_printer(verbose: bool | None) -> Callable[..., None]:
|
|||
"""Prints messages based on `verbose`."""
|
||||
if verbose is False:
|
||||
return lambda *_, **__: None
|
||||
# pyrefly: ignore [not-iterable]
|
||||
return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -239,7 +239,7 @@ def _compare_onnx_pytorch_outputs_in_np(
|
|||
if acceptable_error_percentage:
|
||||
error_percentage = 1 - np.sum(
|
||||
np.isclose(ort_out, pt_out, rtol=options.rtol, atol=options.atol)
|
||||
) / np.prod(ort_out.shape) # pyrefly: ignore # missing-attribute
|
||||
) / np.prod(ort_out.shape) # pyrefly: ignore [missing-attribute]
|
||||
if error_percentage <= acceptable_error_percentage:
|
||||
warnings.warn(
|
||||
f"Suppressed AssertionError:\n{e}.\n"
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ from torch import optim
|
|||
|
||||
def partialclass(cls, *args, **kwargs): # noqa: D103
|
||||
class NewCls(cls):
|
||||
# pyrefly: ignore [not-iterable]
|
||||
__init__ = partialmethod(cls.__init__, *args, **kwargs)
|
||||
|
||||
return NewCls
|
||||
|
|
|
|||
|
|
@ -326,7 +326,7 @@ def gaussian(
|
|||
requires_grad=requires_grad,
|
||||
)
|
||||
|
||||
return torch.exp(-(k**2)) # pyrefly: ignore # unsupported-operation
|
||||
return torch.exp(-(k**2)) # pyrefly: ignore [unsupported-operation]
|
||||
|
||||
|
||||
@_add_docstr(
|
||||
|
|
|
|||
|
|
@ -618,7 +618,7 @@ def _get_storage_from_sequence(sequence, dtype, device):
|
|||
|
||||
def _isint(x):
|
||||
if HAS_NUMPY:
|
||||
return isinstance(x, (int, np.integer)) # pyrefly: ignore # missing-attribute
|
||||
return isinstance(x, (int, np.integer)) # pyrefly: ignore [missing-attribute]
|
||||
else:
|
||||
return isinstance(x, int)
|
||||
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ class OrderedSet(MutableSet[T], Reversible[T]):
|
|||
def pop(self) -> T:
|
||||
if not self:
|
||||
raise KeyError("pop from an empty set")
|
||||
# pyrefly: ignore [bad-return]
|
||||
return self._dict.popitem()[0]
|
||||
|
||||
def copy(self) -> OrderedSet[T]:
|
||||
|
|
@ -158,7 +159,7 @@ class OrderedSet(MutableSet[T], Reversible[T]):
|
|||
def __and__(self, other: AbstractSet[T_co]) -> OrderedSet[T]:
|
||||
# MutableSet impl will iterate over other, iter over smaller of two sets
|
||||
if isinstance(other, OrderedSet) and len(self) < len(other):
|
||||
# pyrefly: ignore # unsupported-operation, bad-return
|
||||
# pyrefly: ignore [unsupported-operation, bad-return]
|
||||
return other & self
|
||||
return cast(OrderedSet[T], super().__and__(other))
|
||||
|
||||
|
|
|
|||
|
|
@ -202,7 +202,7 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -
|
|||
Args:
|
||||
device (int, optional): if specified, all parameters will be copied to that device
|
||||
"""
|
||||
# pyrefly: ignore # missing-attribute
|
||||
# pyrefly: ignore [missing-attribute]
|
||||
return self._apply(lambda t: getattr(t, custom_backend_name)(device))
|
||||
|
||||
_check_register_once(torch.nn.Module, custom_backend_name)
|
||||
|
|
@ -252,11 +252,15 @@ def _generate_packed_sequence_methods_for_privateuse1_backend(
|
|||
device (int, optional): if specified, all parameters will be copied to that device
|
||||
"""
|
||||
ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(
|
||||
*args, **kwargs
|
||||
# pyrefly: ignore [not-iterable]
|
||||
*args,
|
||||
**kwargs,
|
||||
)
|
||||
if ex.device.type == custom_backend_name:
|
||||
# pyrefly: ignore [not-iterable]
|
||||
return self.to(*args, **kwargs)
|
||||
kwargs.update({"device": custom_backend_name})
|
||||
# pyrefly: ignore [not-iterable]
|
||||
return self.to(*args, **kwargs)
|
||||
|
||||
_check_register_once(torch.nn.utils.rnn.PackedSequence, custom_backend_name)
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ MATH_TRANSPILATIONS = collections.OrderedDict(
|
|||
]
|
||||
)
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
CUDA_TYPE_NAME_MAP = collections.OrderedDict(
|
||||
[
|
||||
("CUresult", ("hipError_t", CONV_TYPE, API_DRIVER)),
|
||||
|
|
@ -587,6 +587,7 @@ CUDA_TYPE_NAME_MAP = collections.OrderedDict(
|
|||
]
|
||||
)
|
||||
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
CUDA_INCLUDE_MAP = collections.OrderedDict(
|
||||
[
|
||||
# since pytorch uses "\b{pattern}\b" as the actual re pattern,
|
||||
|
|
@ -676,7 +677,7 @@ CUDA_INCLUDE_MAP = collections.OrderedDict(
|
|||
]
|
||||
)
|
||||
|
||||
# pyrefly: ignore # no-matching-overload
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
CUDA_IDENTIFIER_MAP = collections.OrderedDict(
|
||||
[
|
||||
("__CUDACC__", ("__HIPCC__", CONV_DEF, API_RUNTIME)),
|
||||
|
|
@ -8370,6 +8371,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
|
|||
]
|
||||
)
|
||||
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
CUDA_SPECIAL_MAP = collections.OrderedDict(
|
||||
[
|
||||
# SPARSE
|
||||
|
|
@ -8852,6 +8854,7 @@ CUDA_SPECIAL_MAP = collections.OrderedDict(
|
|||
]
|
||||
)
|
||||
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict(
|
||||
[
|
||||
("USE_CUDA", ("USE_ROCM", API_PYTORCH)),
|
||||
|
|
@ -9316,6 +9319,7 @@ PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict(
|
|||
]
|
||||
)
|
||||
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict(
|
||||
[
|
||||
("PYTORCH_NO_CUDA_MEMORY_CACHING", ("PYTORCH_NO_CUDA_MEMORY_CACHING", API_CAFFE2)),
|
||||
|
|
@ -9401,6 +9405,7 @@ CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict(
|
|||
#
|
||||
# NB: if you want a transformation to ONLY apply to the c10/ directory,
|
||||
# put it as API_CAFFE2
|
||||
# pyrefly: ignore [no-matching-overload]
|
||||
C10_MAPPINGS = collections.OrderedDict(
|
||||
[
|
||||
("CUDA_VERSION", ("TORCH_HIP_VERSION", API_PYTORCH)),
|
||||
|
|
|
|||
|
|
@ -120,6 +120,7 @@ class GeneratedFileCleaner:
|
|||
def open(self, fn, *args, **kwargs):
|
||||
if not os.path.exists(fn):
|
||||
self.files_to_clean.add(os.path.abspath(fn))
|
||||
# pyrefly: ignore [not-iterable]
|
||||
return open(fn, *args, **kwargs)
|
||||
|
||||
def makedirs(self, dn, exist_ok=False):
|
||||
|
|
@ -669,7 +670,7 @@ def is_caffe2_gpu_file(rel_filepath):
|
|||
return True
|
||||
filename = os.path.basename(rel_filepath)
|
||||
_, ext = os.path.splitext(filename)
|
||||
# pyrefly: ignore # unsupported-operation
|
||||
# pyrefly: ignore [unsupported-operation]
|
||||
return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename)
|
||||
|
||||
class TrieNode:
|
||||
|
|
@ -1145,7 +1146,7 @@ def hipify(
|
|||
out_of_place_only=out_of_place_only,
|
||||
is_pytorch_extension=is_pytorch_extension))
|
||||
all_files_set = set(all_files)
|
||||
# pyrefly: ignore # bad-assignment
|
||||
# pyrefly: ignore [bad-assignment]
|
||||
for f in extra_files:
|
||||
if not os.path.isabs(f):
|
||||
f = os.path.join(output_directory, f)
|
||||
|
|
|
|||
|
|
@ -292,9 +292,10 @@ class WeakIdKeyDictionary(MutableMapping):
|
|||
if o is not None:
|
||||
return o, value
|
||||
|
||||
# pyrefly: ignore # bad-override
|
||||
# pyrefly: ignore [bad-override]
|
||||
def pop(self, key, *args):
|
||||
self._dirty_len = True
|
||||
# pyrefly: ignore [not-iterable]
|
||||
return self.data.pop(self.ref_type(key), *args) # CHANGED
|
||||
|
||||
def setdefault(self, key, default=None):
|
||||
|
|
|
|||
|
|
@ -328,7 +328,7 @@ class StreamContext:
|
|||
self.stream = stream
|
||||
self.idx = _get_device_index(None, True)
|
||||
if self.idx is None:
|
||||
self.idx = -1 # pyrefly: ignore # bad-assignment
|
||||
self.idx = -1 # pyrefly: ignore [bad-assignment]
|
||||
|
||||
def __enter__(self):
|
||||
cur_stream = self.stream
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ class Event(torch._C._XpuEventBase):
|
|||
"""
|
||||
if stream is None:
|
||||
stream = torch.xpu.current_stream()
|
||||
super().record(stream) # pyrefly: ignore # bad-argument-type
|
||||
super().record(stream) # pyrefly: ignore [bad-argument-type]
|
||||
|
||||
def wait(self, stream=None) -> None:
|
||||
r"""Make all future work submitted to the given stream wait for this event.
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user