mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
Fix pyrefly ignores 1/n (#166239)
First diff adjusting the syntax for pyrefly: ignore suppressions so they only hide one class of type error. Test: lintrunner pyrefly check Pull Request resolved: https://github.com/pytorch/pytorch/pull/166239 Approved by: https://github.com/oulgen
This commit is contained in:
parent
621ba05107
commit
c7eee49525
|
|
@ -130,6 +130,7 @@ errors.bad-param-name-override = false
|
||||||
# Mypy doesn't require that imports are explicitly imported, so be compatible with that.
|
# Mypy doesn't require that imports are explicitly imported, so be compatible with that.
|
||||||
# Might be a good idea to turn this on in future.
|
# Might be a good idea to turn this on in future.
|
||||||
errors.implicit-import = false
|
errors.implicit-import = false
|
||||||
|
errors.deprecated = false # re-enable after we've fix import formatting
|
||||||
permissive-ignores = true
|
permissive-ignores = true
|
||||||
replace-imports-with-any = ["!sympy.printing.*", "sympy.*", "onnxscript.onnx_opset.*"]
|
replace-imports-with-any = ["!sympy.printing.*", "sympy.*", "onnxscript.onnx_opset.*"]
|
||||||
search-path = ["tools/experimental"]
|
search-path = ["tools/experimental"]
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,8 @@ from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from dataclasses_json import DataClassJsonMixin
|
# pyrefly: ignore [missing-import]
|
||||||
|
from dataclasses_json import DataClassJsonMixin # type: ignore[import-not-found]
|
||||||
|
|
||||||
|
|
||||||
_DATA_MODEL_VERSION = 1.5
|
_DATA_MODEL_VERSION = 1.5
|
||||||
|
|
@ -17,7 +18,7 @@ class UtilizationStats:
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UtilizationMetadata(DataClassJsonMixin):
|
class UtilizationMetadata(DataClassJsonMixin): # type: ignore[misc, no-any-unimported]
|
||||||
level: str
|
level: str
|
||||||
workflow_id: str
|
workflow_id: str
|
||||||
job_id: str
|
job_id: str
|
||||||
|
|
@ -33,7 +34,7 @@ class UtilizationMetadata(DataClassJsonMixin):
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class GpuUsage(DataClassJsonMixin):
|
class GpuUsage(DataClassJsonMixin): # type: ignore[misc, no-any-unimported]
|
||||||
uuid: Optional[str] = None
|
uuid: Optional[str] = None
|
||||||
util_percent: Optional[UtilizationStats] = None
|
util_percent: Optional[UtilizationStats] = None
|
||||||
mem_util_percent: Optional[UtilizationStats] = None
|
mem_util_percent: Optional[UtilizationStats] = None
|
||||||
|
|
@ -43,14 +44,14 @@ class GpuUsage(DataClassJsonMixin):
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RecordData(DataClassJsonMixin):
|
class RecordData(DataClassJsonMixin): # type: ignore[misc, no-any-unimported]
|
||||||
cpu: Optional[UtilizationStats] = None
|
cpu: Optional[UtilizationStats] = None
|
||||||
memory: Optional[UtilizationStats] = None
|
memory: Optional[UtilizationStats] = None
|
||||||
gpu_usage: Optional[list[GpuUsage]] = None
|
gpu_usage: Optional[list[GpuUsage]] = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class UtilizationRecord(DataClassJsonMixin):
|
class UtilizationRecord(DataClassJsonMixin): # type: ignore[misc, no-any-unimported]
|
||||||
level: str
|
level: str
|
||||||
timestamp: int
|
timestamp: int
|
||||||
data: Optional[RecordData] = None
|
data: Optional[RecordData] = None
|
||||||
|
|
@ -63,7 +64,7 @@ class UtilizationRecord(DataClassJsonMixin):
|
||||||
# the db schema related to this is:
|
# the db schema related to this is:
|
||||||
# https://github.com/pytorch/test-infra/blob/main/clickhouse_db_schema/oss_ci_utilization/oss_ci_utilization_metadata_schema.sql
|
# https://github.com/pytorch/test-infra/blob/main/clickhouse_db_schema/oss_ci_utilization/oss_ci_utilization_metadata_schema.sql
|
||||||
@dataclass
|
@dataclass
|
||||||
class OssCiSegmentV1(DataClassJsonMixin):
|
class OssCiSegmentV1(DataClassJsonMixin): # type: ignore[misc, no-any-unimported]
|
||||||
level: str
|
level: str
|
||||||
name: str
|
name: str
|
||||||
start_at: int
|
start_at: int
|
||||||
|
|
|
||||||
|
|
@ -1703,7 +1703,7 @@ def _check(cond, message=None): # noqa: F811
|
||||||
an object that has a ``__str__()`` method to be used as the error
|
an object that has a ``__str__()`` method to be used as the error
|
||||||
message. Default: ``None``
|
message. Default: ``None``
|
||||||
"""
|
"""
|
||||||
_check_with(RuntimeError, cond, message) # pyrefly: ignore # bad-argument-type
|
_check_with(RuntimeError, cond, message) # pyrefly: ignore [bad-argument-type]
|
||||||
|
|
||||||
|
|
||||||
# TODO add deprecation annotation
|
# TODO add deprecation annotation
|
||||||
|
|
@ -1753,7 +1753,7 @@ def _check_index(cond, message=None): # noqa: F811
|
||||||
an object that has a ``__str__()`` method to be used as the error
|
an object that has a ``__str__()`` method to be used as the error
|
||||||
message. Default: ``None``
|
message. Default: ``None``
|
||||||
"""
|
"""
|
||||||
_check_with(IndexError, cond, message) # pyrefly: ignore # bad-argument-type
|
_check_with(IndexError, cond, message) # pyrefly: ignore [bad-argument-type]
|
||||||
|
|
||||||
|
|
||||||
def _check_value(cond, message=None): # noqa: F811
|
def _check_value(cond, message=None): # noqa: F811
|
||||||
|
|
@ -1771,7 +1771,7 @@ def _check_value(cond, message=None): # noqa: F811
|
||||||
an object that has a ``__str__()`` method to be used as the error
|
an object that has a ``__str__()`` method to be used as the error
|
||||||
message. Default: ``None``
|
message. Default: ``None``
|
||||||
"""
|
"""
|
||||||
_check_with(ValueError, cond, message) # pyrefly: ignore # bad-argument-type
|
_check_with(ValueError, cond, message) # pyrefly: ignore [bad-argument-type]
|
||||||
|
|
||||||
|
|
||||||
def _check_type(cond, message=None): # noqa: F811
|
def _check_type(cond, message=None): # noqa: F811
|
||||||
|
|
@ -1789,7 +1789,7 @@ def _check_type(cond, message=None): # noqa: F811
|
||||||
an object that has a ``__str__()`` method to be used as the error
|
an object that has a ``__str__()`` method to be used as the error
|
||||||
message. Default: ``None``
|
message. Default: ``None``
|
||||||
"""
|
"""
|
||||||
_check_with(TypeError, cond, message) # pyrefly: ignore # bad-argument-type
|
_check_with(TypeError, cond, message) # pyrefly: ignore [bad-argument-type]
|
||||||
|
|
||||||
|
|
||||||
def _check_not_implemented(cond, message=None): # noqa: F811
|
def _check_not_implemented(cond, message=None): # noqa: F811
|
||||||
|
|
|
||||||
|
|
@ -101,7 +101,7 @@ def custom_op(
|
||||||
lib, ns, function_schema, name, ophandle, _private_access=True
|
lib, ns, function_schema, name, ophandle, _private_access=True
|
||||||
)
|
)
|
||||||
|
|
||||||
result.__name__ = func.__name__ # pyrefly: ignore # bad-assignment
|
result.__name__ = func.__name__ # pyrefly: ignore [bad-assignment]
|
||||||
result.__module__ = func.__module__
|
result.__module__ = func.__module__
|
||||||
result.__doc__ = func.__doc__
|
result.__doc__ = func.__doc__
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -154,7 +154,7 @@ def make_crossref_functionalize(
|
||||||
maybe_detach, (f_args, f_kwargs)
|
maybe_detach, (f_args, f_kwargs)
|
||||||
)
|
)
|
||||||
with fake_mode:
|
with fake_mode:
|
||||||
f_r = op(*f_args, **f_kwargs) # pyrefly: ignore # invalid-param-spec
|
f_r = op(*f_args, **f_kwargs) # pyrefly: ignore [invalid-param-spec]
|
||||||
r = op._op_dk(final_key, *args, **kwargs)
|
r = op._op_dk(final_key, *args, **kwargs)
|
||||||
|
|
||||||
def desc():
|
def desc():
|
||||||
|
|
|
||||||
|
|
@ -1029,7 +1029,7 @@ class BuiltinVariable(VariableTracker):
|
||||||
|
|
||||||
def call_self_handler(tx: "InstructionTranslator", args, kwargs):
|
def call_self_handler(tx: "InstructionTranslator", args, kwargs):
|
||||||
try:
|
try:
|
||||||
# pyrefly: ignore # not-callable
|
# pyrefly: ignore [not-callable]
|
||||||
result = self_handler(tx, *args, **kwargs)
|
result = self_handler(tx, *args, **kwargs)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
return result
|
return result
|
||||||
|
|
@ -1037,7 +1037,7 @@ class BuiltinVariable(VariableTracker):
|
||||||
# Check if binding is bad. inspect signature bind is expensive.
|
# Check if binding is bad. inspect signature bind is expensive.
|
||||||
# So check only when handler call fails.
|
# So check only when handler call fails.
|
||||||
try:
|
try:
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
inspect.signature(self_handler).bind(tx, *args, **kwargs)
|
inspect.signature(self_handler).bind(tx, *args, **kwargs)
|
||||||
except TypeError as e:
|
except TypeError as e:
|
||||||
has_constant_handler = obj.has_constant_handler(args, kwargs)
|
has_constant_handler = obj.has_constant_handler(args, kwargs)
|
||||||
|
|
@ -1090,7 +1090,7 @@ class BuiltinVariable(VariableTracker):
|
||||||
hints=[*graph_break_hints.DYNAMO_BUG],
|
hints=[*graph_break_hints.DYNAMO_BUG],
|
||||||
from_exc=exc,
|
from_exc=exc,
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
return VariableTracker.build(tx, res)
|
return VariableTracker.build(tx, res)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
@ -1119,7 +1119,7 @@ class BuiltinVariable(VariableTracker):
|
||||||
tx,
|
tx,
|
||||||
args=list(map(ConstantVariable.create, exc.args)),
|
args=list(map(ConstantVariable.create, exc.args)),
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
return VariableTracker.build(tx, res)
|
return VariableTracker.build(tx, res)
|
||||||
|
|
||||||
handlers.append(constant_fold_handler)
|
handlers.append(constant_fold_handler)
|
||||||
|
|
@ -1442,7 +1442,7 @@ class BuiltinVariable(VariableTracker):
|
||||||
resolved_fn = getattr(self.fn, name)
|
resolved_fn = getattr(self.fn, name)
|
||||||
if resolved_fn in dict_methods:
|
if resolved_fn in dict_methods:
|
||||||
if isinstance(args[0], variables.UserDefinedDictVariable):
|
if isinstance(args[0], variables.UserDefinedDictVariable):
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs)
|
return args[0]._dict_vt.call_method(tx, name, args[1:], kwargs)
|
||||||
elif isinstance(args[0], variables.ConstDictVariable):
|
elif isinstance(args[0], variables.ConstDictVariable):
|
||||||
return args[0].call_method(tx, name, args[1:], kwargs)
|
return args[0].call_method(tx, name, args[1:], kwargs)
|
||||||
|
|
@ -1451,7 +1451,7 @@ class BuiltinVariable(VariableTracker):
|
||||||
resolved_fn = getattr(self.fn, name)
|
resolved_fn = getattr(self.fn, name)
|
||||||
if resolved_fn in set_methods:
|
if resolved_fn in set_methods:
|
||||||
if isinstance(args[0], variables.UserDefinedSetVariable):
|
if isinstance(args[0], variables.UserDefinedSetVariable):
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
return args[0]._set_vt.call_method(tx, name, args[1:], kwargs)
|
return args[0]._set_vt.call_method(tx, name, args[1:], kwargs)
|
||||||
elif isinstance(args[0], variables.SetVariable):
|
elif isinstance(args[0], variables.SetVariable):
|
||||||
return args[0].call_method(tx, name, args[1:], kwargs)
|
return args[0].call_method(tx, name, args[1:], kwargs)
|
||||||
|
|
@ -1540,12 +1540,12 @@ class BuiltinVariable(VariableTracker):
|
||||||
if type(arg.value).__str__ is object.__str__:
|
if type(arg.value).__str__ is object.__str__:
|
||||||
# Rely on the object str method
|
# Rely on the object str method
|
||||||
try:
|
try:
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
return variables.ConstantVariable.create(value=str_method())
|
return variables.ConstantVariable.create(value=str_method())
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# Graph break
|
# Graph break
|
||||||
return
|
return
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
elif is_wrapper_or_member_descriptor(str_method):
|
elif is_wrapper_or_member_descriptor(str_method):
|
||||||
unimplemented_v2(
|
unimplemented_v2(
|
||||||
gb_type="Attempted to a str() method implemented in C/C++",
|
gb_type="Attempted to a str() method implemented in C/C++",
|
||||||
|
|
@ -1662,10 +1662,10 @@ class BuiltinVariable(VariableTracker):
|
||||||
else:
|
else:
|
||||||
raw_b = b.raw_value
|
raw_b = b.raw_value
|
||||||
if self.fn is max:
|
if self.fn is max:
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
raw_res = max(a.raw_value, raw_b)
|
raw_res = max(a.raw_value, raw_b)
|
||||||
else:
|
else:
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
raw_res = min(a.raw_value, raw_b)
|
raw_res = min(a.raw_value, raw_b)
|
||||||
|
|
||||||
need_unwrap = any(
|
need_unwrap = any(
|
||||||
|
|
@ -1980,12 +1980,16 @@ class BuiltinVariable(VariableTracker):
|
||||||
if isinstance(arg, dict):
|
if isinstance(arg, dict):
|
||||||
arg = [ConstantVariable.create(k) for k in arg.keys()]
|
arg = [ConstantVariable.create(k) for k in arg.keys()]
|
||||||
return DictVariableType(
|
return DictVariableType(
|
||||||
dict.fromkeys(arg, value), user_cls, mutation_type=ValueMutationNew()
|
# pyrefly: ignore [bad-argument-type]
|
||||||
|
dict.fromkeys(arg, value),
|
||||||
|
user_cls,
|
||||||
|
mutation_type=ValueMutationNew(),
|
||||||
)
|
)
|
||||||
elif arg.has_force_unpack_var_sequence(tx):
|
elif arg.has_force_unpack_var_sequence(tx):
|
||||||
keys = arg.force_unpack_var_sequence(tx)
|
keys = arg.force_unpack_var_sequence(tx)
|
||||||
if all(is_hashable(v) for v in keys):
|
if all(is_hashable(v) for v in keys):
|
||||||
return DictVariableType(
|
return DictVariableType(
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
dict.fromkeys(keys, value),
|
dict.fromkeys(keys, value),
|
||||||
user_cls,
|
user_cls,
|
||||||
mutation_type=ValueMutationNew(),
|
mutation_type=ValueMutationNew(),
|
||||||
|
|
@ -2152,7 +2156,7 @@ class BuiltinVariable(VariableTracker):
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(arg, variables.UserDefinedExceptionClassVariable):
|
if isinstance(arg, variables.UserDefinedExceptionClassVariable):
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
return ConstantVariable.create(isinstance(arg_type, isinstance_type))
|
return ConstantVariable.create(isinstance(arg_type, isinstance_type))
|
||||||
|
|
||||||
isinstance_type_tuple: tuple[type, ...]
|
isinstance_type_tuple: tuple[type, ...]
|
||||||
|
|
@ -2185,10 +2189,10 @@ class BuiltinVariable(VariableTracker):
|
||||||
# through it. This is a limitation of the current implementation.
|
# through it. This is a limitation of the current implementation.
|
||||||
# Usually `__subclasscheck__` and `__instancecheck__` can be constant fold through, it
|
# Usually `__subclasscheck__` and `__instancecheck__` can be constant fold through, it
|
||||||
# might not be a big issue and we trade off it for performance.
|
# might not be a big issue and we trade off it for performance.
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
val = issubclass(arg_type, isinstance_type_tuple)
|
val = issubclass(arg_type, isinstance_type_tuple)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
val = arg_type in isinstance_type_tuple
|
val = arg_type in isinstance_type_tuple
|
||||||
return variables.ConstantVariable.create(val)
|
return variables.ConstantVariable.create(val)
|
||||||
|
|
||||||
|
|
@ -2210,7 +2214,7 @@ class BuiltinVariable(VariableTracker):
|
||||||
|
|
||||||
# WARNING: This might run arbitrary user code `__subclasscheck__`.
|
# WARNING: This might run arbitrary user code `__subclasscheck__`.
|
||||||
# See the comment in call_isinstance above.
|
# See the comment in call_isinstance above.
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py))
|
return variables.ConstantVariable(issubclass(left_ty_py, right_ty_py))
|
||||||
|
|
||||||
def call_super(self, tx: "InstructionTranslator", a, b):
|
def call_super(self, tx: "InstructionTranslator", a, b):
|
||||||
|
|
@ -2256,9 +2260,9 @@ class BuiltinVariable(VariableTracker):
|
||||||
value = getattr(self.fn, name)
|
value = getattr(self.fn, name)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
raise_observed_exception(AttributeError, tx)
|
raise_observed_exception(AttributeError, tx)
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
if not callable(value):
|
if not callable(value):
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
return VariableTracker.build(tx, value, source)
|
return VariableTracker.build(tx, value, source)
|
||||||
return variables.GetAttrVariable(self, name, source=source)
|
return variables.GetAttrVariable(self, name, source=source)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,7 @@ class LazyCache:
|
||||||
self.vt = builder.VariableBuilder(tx, self.source)(self.value)
|
self.vt = builder.VariableBuilder(tx, self.source)(self.value)
|
||||||
|
|
||||||
if self.name_hint is not None:
|
if self.name_hint is not None:
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
self.vt.set_name_hint(self.name_hint)
|
self.vt.set_name_hint(self.name_hint)
|
||||||
|
|
||||||
del self.value
|
del self.value
|
||||||
|
|
|
||||||
|
|
@ -1138,11 +1138,13 @@ def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]:
|
||||||
|
|
||||||
for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
|
for i, m in enumerate(reversed(_get_current_dispatch_mode_stack())):
|
||||||
if isinstance(m, FakeTensorMode):
|
if isinstance(m, FakeTensorMode):
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
fake_modes.append((m, "active fake mode", i))
|
fake_modes.append((m, "active fake mode", i))
|
||||||
|
|
||||||
flat_inputs = pytree.tree_leaves(inputs)
|
flat_inputs = pytree.tree_leaves(inputs)
|
||||||
for i, flat_input in enumerate(flat_inputs):
|
for i, flat_input in enumerate(flat_inputs):
|
||||||
if isinstance(flat_input, FakeTensor):
|
if isinstance(flat_input, FakeTensor):
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
fake_modes.append((flat_input.fake_mode, "fake tensor input", i))
|
fake_modes.append((flat_input.fake_mode, "fake tensor input", i))
|
||||||
if is_traceable_wrapper_subclass(flat_input):
|
if is_traceable_wrapper_subclass(flat_input):
|
||||||
out: list[Union[torch.Tensor, int, torch.SymInt]] = []
|
out: list[Union[torch.Tensor, int, torch.SymInt]] = []
|
||||||
|
|
@ -1151,6 +1153,7 @@ def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]:
|
||||||
x for x in out if isinstance(x, FakeTensor)
|
x for x in out if isinstance(x, FakeTensor)
|
||||||
]
|
]
|
||||||
fake_modes.extend(
|
fake_modes.extend(
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
[
|
[
|
||||||
(tensor.fake_mode, f"subclass input {i}", ix)
|
(tensor.fake_mode, f"subclass input {i}", ix)
|
||||||
for ix, tensor in enumerate(fake_tensors)
|
for ix, tensor in enumerate(fake_tensors)
|
||||||
|
|
@ -1162,9 +1165,12 @@ def detect_fake_mode(inputs: Any = None) -> Optional[FakeTensorMode]:
|
||||||
for m, desc2, i2 in fake_modes[1:]:
|
for m, desc2, i2 in fake_modes[1:]:
|
||||||
assert fake_mode is m, (
|
assert fake_mode is m, (
|
||||||
f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n"
|
f"fake mode ({fake_mode}) from {desc1} {i1} doesn't match mode ({m}) from {desc2} {i2}\n\n"
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n"
|
f"fake mode from {desc1} {i1} allocated at:\n{fake_mode.stack}\n"
|
||||||
|
# pyrefly: ignore [missing-attribute]
|
||||||
f"fake mode from {desc2} {i2} allocated at:\n{m.stack}"
|
f"fake mode from {desc2} {i2} allocated at:\n{m.stack}"
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore [bad-return]
|
||||||
return fake_mode
|
return fake_mode
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
|
||||||
|
|
@ -114,6 +114,7 @@ def _cancel_all_tasks(
|
||||||
for task in to_cancel:
|
for task in to_cancel:
|
||||||
task.cancel()
|
task.cancel()
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
|
loop.run_until_complete(asyncio.gather(*to_cancel, return_exceptions=True))
|
||||||
|
|
||||||
for task in to_cancel:
|
for task in to_cancel:
|
||||||
|
|
@ -149,7 +150,7 @@ def _patch_loop(loop: AbstractEventLoop) -> OrderedSet[Future]: # type: ignore[
|
||||||
task_factory = task_factories[0]
|
task_factory = task_factories[0]
|
||||||
if task_factory is None:
|
if task_factory is None:
|
||||||
if sys.version_info >= (3, 11):
|
if sys.version_info >= (3, 11):
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
task = asyncio.Task(coro, loop=loop, context=context)
|
task = asyncio.Task(coro, loop=loop, context=context)
|
||||||
else:
|
else:
|
||||||
task = asyncio.Task(coro, loop=loop)
|
task = asyncio.Task(coro, loop=loop)
|
||||||
|
|
|
||||||
|
|
@ -590,11 +590,11 @@ class CKGemmTemplate(CKTemplate):
|
||||||
arg = f"/* {field_name} */ Tuple<{tuple_elements}>"
|
arg = f"/* {field_name} */ Tuple<{tuple_elements}>"
|
||||||
else: # tile shape
|
else: # tile shape
|
||||||
arg = f"/* {field_name} */ S<{tuple_elements}>"
|
arg = f"/* {field_name} */ S<{tuple_elements}>"
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
template_params.append(arg)
|
template_params.append(arg)
|
||||||
else:
|
else:
|
||||||
if field_value is not None:
|
if field_value is not None:
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
template_params.append(f"/* {field_name} */ {field_value}")
|
template_params.append(f"/* {field_name} */ {field_value}")
|
||||||
operation_name = op.name().replace("(", "").replace(",", "").replace(")", "")
|
operation_name = op.name().replace("(", "").replace(",", "").replace(")", "")
|
||||||
return self._template_from_string(template_definition).render(
|
return self._template_from_string(template_definition).render(
|
||||||
|
|
@ -939,6 +939,7 @@ class CKGemmTemplate(CKTemplate):
|
||||||
for o in rops:
|
for o in rops:
|
||||||
kBatches = self._get_kBatch(o)
|
kBatches = self._get_kBatch(o)
|
||||||
for kBatch in kBatches:
|
for kBatch in kBatches:
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
ops.append(InductorROCmOp(op=o, kBatch=kBatch))
|
ops.append(InductorROCmOp(op=o, kBatch=kBatch))
|
||||||
|
|
||||||
filtered_instances = list(filter(lambda op: self.filter_op(op), ops))
|
filtered_instances = list(filter(lambda op: self.filter_op(op), ops))
|
||||||
|
|
|
||||||
|
|
@ -273,7 +273,7 @@ def record_original_output_strides(gm: GraphModule) -> None:
|
||||||
):
|
):
|
||||||
output_strides.append(val.stride())
|
output_strides.append(val.stride())
|
||||||
else:
|
else:
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
output_strides.append(None)
|
output_strides.append(None)
|
||||||
output_node.meta["original_output_strides"] = output_strides
|
output_node.meta["original_output_strides"] = output_strides
|
||||||
|
|
||||||
|
|
@ -1110,6 +1110,7 @@ def _compile_fx_inner(
|
||||||
)
|
)
|
||||||
log.info("-" * 130)
|
log.info("-" * 130)
|
||||||
for row in mm_table_data:
|
for row in mm_table_data:
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
log.info("{:<30} | {:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format(*row)) # noqa: G001
|
log.info("{:<30} | {:<20} | {:<20} | {:<20} | {:<20} | {:<20}".format(*row)) # noqa: G001
|
||||||
log.info("-" * 130)
|
log.info("-" * 130)
|
||||||
|
|
||||||
|
|
@ -1551,7 +1552,7 @@ class _InProcessFxCompile(FxCompile):
|
||||||
node_runtimes = None
|
node_runtimes = None
|
||||||
if inductor_metrics_log.isEnabledFor(logging.INFO):
|
if inductor_metrics_log.isEnabledFor(logging.INFO):
|
||||||
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
|
num_bytes, nodes_num_elem, node_runtimes = graph.count_bytes()
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
metrics.num_bytes_accessed += num_bytes
|
metrics.num_bytes_accessed += num_bytes
|
||||||
metrics.node_runtimes += node_runtimes
|
metrics.node_runtimes += node_runtimes
|
||||||
metrics.nodes_num_elem += nodes_num_elem
|
metrics.nodes_num_elem += nodes_num_elem
|
||||||
|
|
@ -1595,10 +1596,10 @@ class _InProcessFxCompile(FxCompile):
|
||||||
disable = f"{disable} Found from {stack_trace}\n"
|
disable = f"{disable} Found from {stack_trace}\n"
|
||||||
else:
|
else:
|
||||||
disable = f"{disable}\n"
|
disable = f"{disable}\n"
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
V.graph.disable_cudagraphs_reason = disable
|
V.graph.disable_cudagraphs_reason = disable
|
||||||
|
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
if cudagraphs and not V.graph.disable_cudagraphs_reason:
|
if cudagraphs and not V.graph.disable_cudagraphs_reason:
|
||||||
maybe_incompat_node = get_first_incompatible_cudagraph_node(gm)
|
maybe_incompat_node = get_first_incompatible_cudagraph_node(gm)
|
||||||
if maybe_incompat_node:
|
if maybe_incompat_node:
|
||||||
|
|
@ -1607,29 +1608,29 @@ class _InProcessFxCompile(FxCompile):
|
||||||
"stack_trace", None
|
"stack_trace", None
|
||||||
):
|
):
|
||||||
disable = f"{disable} Found from {stack_trace}\n"
|
disable = f"{disable} Found from {stack_trace}\n"
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
V.graph.disable_cudagraphs_reason = disable
|
V.graph.disable_cudagraphs_reason = disable
|
||||||
|
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
if V.aot_compilation:
|
if V.aot_compilation:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
compiled_fn,
|
compiled_fn,
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
(str, list, torch.fx.GraphModule),
|
(str, list, torch.fx.GraphModule),
|
||||||
), type(compiled_fn)
|
), type(compiled_fn)
|
||||||
return CompiledAOTI(compiled_fn)
|
return CompiledAOTI(compiled_fn)
|
||||||
|
|
||||||
# TODO: Hoist this above V.aot_compilation
|
# TODO: Hoist this above V.aot_compilation
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
if cudagraphs and not V.graph.disable_cudagraphs_reason:
|
if cudagraphs and not V.graph.disable_cudagraphs_reason:
|
||||||
from torch._inductor.cudagraph_utils import (
|
from torch._inductor.cudagraph_utils import (
|
||||||
check_lowering_disable_cudagraph,
|
check_lowering_disable_cudagraph,
|
||||||
)
|
)
|
||||||
|
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
V.graph.disable_cudagraphs_reason = (
|
V.graph.disable_cudagraphs_reason = (
|
||||||
check_lowering_disable_cudagraph(
|
check_lowering_disable_cudagraph(
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
V.graph.device_node_mapping
|
V.graph.device_node_mapping
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -1637,29 +1638,29 @@ class _InProcessFxCompile(FxCompile):
|
||||||
self._compile_stats[type(self)].codegen_and_compile += 1
|
self._compile_stats[type(self)].codegen_and_compile += 1
|
||||||
|
|
||||||
if (
|
if (
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
torch._inductor.debug.RECORD_GRAPH_EXECUTION
|
torch._inductor.debug.RECORD_GRAPH_EXECUTION
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
and torch._inductor.debug.GRAPH_COMPILE_IDS is not None
|
and torch._inductor.debug.GRAPH_COMPILE_IDS is not None
|
||||||
):
|
):
|
||||||
compile_id = str(
|
compile_id = str(
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
torch._guards.CompileContext.current_compile_id()
|
torch._guards.CompileContext.current_compile_id()
|
||||||
)
|
)
|
||||||
graph_id = graph_kwargs.get("graph_id")
|
graph_id = graph_kwargs.get("graph_id")
|
||||||
if graph_id is not None:
|
if graph_id is not None:
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
torch._inductor.debug.GRAPH_COMPILE_IDS[graph_id] = (
|
torch._inductor.debug.GRAPH_COMPILE_IDS[graph_id] = (
|
||||||
compile_id
|
compile_id
|
||||||
)
|
)
|
||||||
|
|
||||||
return CompiledFxGraph(
|
return CompiledFxGraph(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
compiled_fn,
|
compiled_fn,
|
||||||
graph,
|
graph,
|
||||||
gm,
|
gm,
|
||||||
output_strides,
|
output_strides,
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
V.graph.disable_cudagraphs_reason,
|
V.graph.disable_cudagraphs_reason,
|
||||||
metrics_helper.get_deltas(),
|
metrics_helper.get_deltas(),
|
||||||
counters["inductor"] - inductor_counters,
|
counters["inductor"] - inductor_counters,
|
||||||
|
|
@ -1701,18 +1702,18 @@ def fx_codegen_and_compile(
|
||||||
from .compile_fx_async import _AsyncFxCompile
|
from .compile_fx_async import _AsyncFxCompile
|
||||||
from .compile_fx_ext import _OutOfProcessFxCompile
|
from .compile_fx_ext import _OutOfProcessFxCompile
|
||||||
|
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
assert isinstance(scheme, _OutOfProcessFxCompile), (
|
assert isinstance(scheme, _OutOfProcessFxCompile), (
|
||||||
"async is only valid with an out-of-process compile mode"
|
"async is only valid with an out-of-process compile mode"
|
||||||
)
|
)
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
scheme = _AsyncFxCompile(scheme)
|
scheme = _AsyncFxCompile(scheme)
|
||||||
|
|
||||||
if fx_compile_progressive:
|
if fx_compile_progressive:
|
||||||
from .compile_fx_async import _ProgressiveFxCompile
|
from .compile_fx_async import _ProgressiveFxCompile
|
||||||
from .compile_fx_ext import _OutOfProcessFxCompile
|
from .compile_fx_ext import _OutOfProcessFxCompile
|
||||||
|
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
assert isinstance(scheme, _OutOfProcessFxCompile), (
|
assert isinstance(scheme, _OutOfProcessFxCompile), (
|
||||||
"progressive is only valid with an out-of-process compile mode"
|
"progressive is only valid with an out-of-process compile mode"
|
||||||
)
|
)
|
||||||
|
|
@ -1722,10 +1723,10 @@ def fx_codegen_and_compile(
|
||||||
# Use in-process compile for the fast version
|
# Use in-process compile for the fast version
|
||||||
fast_scheme = _InProcessFxCompile()
|
fast_scheme = _InProcessFxCompile()
|
||||||
|
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
scheme = _ProgressiveFxCompile(fast_scheme, scheme, progression_configs)
|
scheme = _ProgressiveFxCompile(fast_scheme, scheme, progression_configs)
|
||||||
|
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
|
return scheme.codegen_and_compile(gm, example_inputs, inputs_to_check, graph_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1835,7 +1836,7 @@ def cudagraphify_impl(
|
||||||
Assumes inputs[static_input_idxs[i]] are always the same memory address
|
Assumes inputs[static_input_idxs[i]] are always the same memory address
|
||||||
"""
|
"""
|
||||||
check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) # type: ignore[arg-type]
|
check_input_idxs = get_input_idxs_to_check(inputs, static_input_idxs) # type: ignore[arg-type]
|
||||||
# pyrefly: ignore # annotation-mismatch
|
# pyrefly: ignore [annotation-mismatch]
|
||||||
static_input_idxs: OrderedSet[int] = OrderedSet(
|
static_input_idxs: OrderedSet[int] = OrderedSet(
|
||||||
remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type]
|
remove_unaligned_input_idxs(inputs, static_input_idxs) # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
@ -1902,7 +1903,7 @@ def cudagraphify_impl(
|
||||||
index_expanded_dims_and_copy_(dst, src, expanded_dims)
|
index_expanded_dims_and_copy_(dst, src, expanded_dims)
|
||||||
new_inputs.clear()
|
new_inputs.clear()
|
||||||
graph.replay()
|
graph.replay()
|
||||||
# pyrefly: ignore # bad-return
|
# pyrefly: ignore [bad-return]
|
||||||
return static_outputs
|
return static_outputs
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
@ -1918,7 +1919,7 @@ def cudagraphify_impl(
|
||||||
index_expanded_dims_and_copy_(static_inputs[idx], src, expanded_dims)
|
index_expanded_dims_and_copy_(static_inputs[idx], src, expanded_dims)
|
||||||
new_inputs.clear()
|
new_inputs.clear()
|
||||||
graph.replay()
|
graph.replay()
|
||||||
# pyrefly: ignore # bad-return
|
# pyrefly: ignore [bad-return]
|
||||||
return static_outputs
|
return static_outputs
|
||||||
|
|
||||||
return align_inputs_from_check_idxs(run, check_input_idxs, OrderedSet())
|
return align_inputs_from_check_idxs(run, check_input_idxs, OrderedSet())
|
||||||
|
|
@ -1935,7 +1936,7 @@ def compile_fx_aot(
|
||||||
# [See NOTE] Unwrapping subclasses AOT
|
# [See NOTE] Unwrapping subclasses AOT
|
||||||
unwrap_tensor_subclass_parameters(model_)
|
unwrap_tensor_subclass_parameters(model_)
|
||||||
|
|
||||||
# pyrefly: ignore # annotation-mismatch
|
# pyrefly: ignore [annotation-mismatch]
|
||||||
config_patches: dict[str, Any] = copy.deepcopy(config_patches or {})
|
config_patches: dict[str, Any] = copy.deepcopy(config_patches or {})
|
||||||
|
|
||||||
if not (config_patches.get("fx_wrapper", False) or config.fx_wrapper):
|
if not (config_patches.get("fx_wrapper", False) or config.fx_wrapper):
|
||||||
|
|
@ -2878,7 +2879,7 @@ def _aoti_flatten_inputs(
|
||||||
Flatten the inputs to the graph module and return the flat inputs and options.
|
Flatten the inputs to the graph module and return the flat inputs and options.
|
||||||
Add "aot_inductor.serialized_in_spec" and "aot_inductor.serialized_out_spec" to the options.
|
Add "aot_inductor.serialized_in_spec" and "aot_inductor.serialized_out_spec" to the options.
|
||||||
"""
|
"""
|
||||||
# pyrefly: ignore # missing-module-attribute
|
# pyrefly: ignore [missing-module-attribute]
|
||||||
from .compile_fx import graph_returns_tuple
|
from .compile_fx import graph_returns_tuple
|
||||||
|
|
||||||
assert graph_returns_tuple(gm), (
|
assert graph_returns_tuple(gm), (
|
||||||
|
|
|
||||||
|
|
@ -291,7 +291,7 @@ def normalize_unbind_default(match: Match, *args, **kwargs):
|
||||||
log.debug("example value absent for node: %s", input)
|
log.debug("example value absent for node: %s", input)
|
||||||
return
|
return
|
||||||
ndim = input.meta["example_value"].ndim
|
ndim = input.meta["example_value"].ndim
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
if dim < 0: # Normalize unbind dim
|
if dim < 0: # Normalize unbind dim
|
||||||
dim += ndim
|
dim += ndim
|
||||||
with graph.inserting_after(node):
|
with graph.inserting_after(node):
|
||||||
|
|
@ -341,7 +341,7 @@ def normalize_cat_default(match: Match, *args, **kwargs):
|
||||||
ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors
|
ndim == x.meta["example_value"].dim() or is_empty_tensor(x) for x in tensors
|
||||||
)
|
)
|
||||||
|
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
if cat_dim < 0: # Normalize cat dim
|
if cat_dim < 0: # Normalize cat dim
|
||||||
cat_dim += ndim
|
cat_dim += ndim
|
||||||
|
|
||||||
|
|
@ -949,7 +949,7 @@ class SplitCatSimplifier:
|
||||||
if isinstance(user_input, tuple):
|
if isinstance(user_input, tuple):
|
||||||
# Find the correct new getitem (present in split_items)
|
# Find the correct new getitem (present in split_items)
|
||||||
new_user_inputs.append(
|
new_user_inputs.append(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
split_items[
|
split_items[
|
||||||
split_ranges.index(
|
split_ranges.index(
|
||||||
(
|
(
|
||||||
|
|
@ -1000,7 +1000,7 @@ class SplitCatSimplifier:
|
||||||
for user_input_new, transform_param in zip(
|
for user_input_new, transform_param in zip(
|
||||||
user_inputs_new, transform_params
|
user_inputs_new, transform_params
|
||||||
):
|
):
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
if not is_node_meta_valid(user_input_new):
|
if not is_node_meta_valid(user_input_new):
|
||||||
log.debug("example value absent for node: %s", user_input_new)
|
log.debug("example value absent for node: %s", user_input_new)
|
||||||
return
|
return
|
||||||
|
|
@ -1015,7 +1015,7 @@ class SplitCatSimplifier:
|
||||||
stack_dim is None or stack_dim == unsqueeze_params[0]
|
stack_dim is None or stack_dim == unsqueeze_params[0]
|
||||||
):
|
):
|
||||||
to_stack.append(user_input_new)
|
to_stack.append(user_input_new)
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
to_stack_meta.append(user_input_new.meta["example_value"])
|
to_stack_meta.append(user_input_new.meta["example_value"])
|
||||||
stack_dim = unsqueeze_params[0]
|
stack_dim = unsqueeze_params[0]
|
||||||
continue
|
continue
|
||||||
|
|
@ -1036,12 +1036,12 @@ class SplitCatSimplifier:
|
||||||
if unsqueeze_params:
|
if unsqueeze_params:
|
||||||
to_stack.append(user_input_new)
|
to_stack.append(user_input_new)
|
||||||
stack_dim = unsqueeze_params[0]
|
stack_dim = unsqueeze_params[0]
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
to_stack_meta.append(user_input_new.meta["example_value"])
|
to_stack_meta.append(user_input_new.meta["example_value"])
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if unflatten_params:
|
if unflatten_params:
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
user_input_new_meta = user_input_new.meta["example_value"]
|
user_input_new_meta = user_input_new.meta["example_value"]
|
||||||
user_input_new = graph.call_function(
|
user_input_new = graph.call_function(
|
||||||
torch.unflatten, args=(user_input_new, *unflatten_params)
|
torch.unflatten, args=(user_input_new, *unflatten_params)
|
||||||
|
|
@ -1051,7 +1051,7 @@ class SplitCatSimplifier:
|
||||||
*unflatten_params, # type: ignore[arg-type]
|
*unflatten_params, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
if movedim_params:
|
if movedim_params:
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
user_input_new_meta = user_input_new.meta["example_value"]
|
user_input_new_meta = user_input_new.meta["example_value"]
|
||||||
user_input_new = graph.call_function(
|
user_input_new = graph.call_function(
|
||||||
torch.movedim, args=(user_input_new, *movedim_params)
|
torch.movedim, args=(user_input_new, *movedim_params)
|
||||||
|
|
@ -1061,7 +1061,7 @@ class SplitCatSimplifier:
|
||||||
*movedim_params, # type: ignore[arg-type]
|
*movedim_params, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
if flatten_params:
|
if flatten_params:
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
user_input_new_meta = user_input_new.meta["example_value"]
|
user_input_new_meta = user_input_new.meta["example_value"]
|
||||||
user_input_new = graph.call_function(
|
user_input_new = graph.call_function(
|
||||||
torch.flatten, args=(user_input_new, *flatten_params)
|
torch.flatten, args=(user_input_new, *flatten_params)
|
||||||
|
|
@ -1072,7 +1072,7 @@ class SplitCatSimplifier:
|
||||||
)
|
)
|
||||||
user_inputs_new_transformed.append(user_input_new)
|
user_inputs_new_transformed.append(user_input_new)
|
||||||
user_inputs_new_transformed_meta.append(
|
user_inputs_new_transformed_meta.append(
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
user_input_new.meta["example_value"]
|
user_input_new.meta["example_value"]
|
||||||
)
|
)
|
||||||
if to_stack:
|
if to_stack:
|
||||||
|
|
@ -1432,7 +1432,7 @@ def simplify_split_cat(match: Match, split_sections: list[int], dim: int):
|
||||||
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
|
if not isinstance(split_sections, (list, tuple)): # Unnormalized split
|
||||||
return
|
return
|
||||||
split_node = next(node for node in match.nodes if node.target == torch.split)
|
split_node = next(node for node in match.nodes if node.target == torch.split)
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
SplitCatSimplifier().simplify(match.graph, split_node, split_sections)
|
SplitCatSimplifier().simplify(match.graph, split_node, split_sections)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1501,7 +1501,7 @@ def calculate_fused_tensor_size(split_node: torch.fx.Node, indices: list[int]) -
|
||||||
for i in range(len(split_node.args[1])): # type: ignore[arg-type]
|
for i in range(len(split_node.args[1])): # type: ignore[arg-type]
|
||||||
if i in indices:
|
if i in indices:
|
||||||
fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index]
|
fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index]
|
||||||
# pyrefly: ignore # bad-return
|
# pyrefly: ignore [bad-return]
|
||||||
return fused_tensor_size
|
return fused_tensor_size
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1978,7 +1978,7 @@ def normalize_cat_default_aten(match: Match, *args, **kwargs):
|
||||||
|
|
||||||
assert all(ndim == x.meta["val"].dim() or is_empty_tensor(x) for x in tensors)
|
assert all(ndim == x.meta["val"].dim() or is_empty_tensor(x) for x in tensors)
|
||||||
|
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
if cat_dim < 0: # Normalize cat dim
|
if cat_dim < 0: # Normalize cat dim
|
||||||
cat_dim += ndim
|
cat_dim += ndim
|
||||||
|
|
||||||
|
|
@ -2512,7 +2512,8 @@ def reshape_cat_node_to_stack(
|
||||||
args=(cat_node, tuple(reshape_list)),
|
args=(cat_node, tuple(reshape_list)),
|
||||||
)
|
)
|
||||||
reshape_node.meta["example_value"] = torch.reshape(
|
reshape_node.meta["example_value"] = torch.reshape(
|
||||||
cat_node.meta["example_value"], tuple(reshape_list)
|
cat_node.meta["example_value"],
|
||||||
|
tuple(reshape_list), # pyrefly: ignore [bad-argument-type]
|
||||||
)
|
)
|
||||||
permute_list = list(range(len(stack_shape)))
|
permute_list = list(range(len(stack_shape)))
|
||||||
permute_list[stack_dim], permute_list[split_or_unbind_dim] = (
|
permute_list[stack_dim], permute_list[split_or_unbind_dim] = (
|
||||||
|
|
@ -3044,6 +3045,6 @@ def replace_einsum_to_pointwise(match: Match, *args, **kwargs):
|
||||||
einsum_node = match.nodes[0]
|
einsum_node = match.nodes[0]
|
||||||
input, weights = get_arg_value(einsum_node, 1), get_arg_value(einsum_node, 2)
|
input, weights = get_arg_value(einsum_node, 1), get_arg_value(einsum_node, 2)
|
||||||
if should_replace_einsum(einsum_node):
|
if should_replace_einsum(einsum_node):
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
match.replace_by_example(repl, [input, weights])
|
match.replace_by_example(repl, [input, weights])
|
||||||
counters[backend]["einsum_to_pointwise_pass"] += 1
|
counters[backend]["einsum_to_pointwise_pass"] += 1
|
||||||
|
|
|
||||||
|
|
@ -147,7 +147,7 @@ def _qualified_name(obj, mangle_name=True) -> str:
|
||||||
|
|
||||||
# If the module is actually a torchbind module, then we should short circuit
|
# If the module is actually a torchbind module, then we should short circuit
|
||||||
if module_name == "torch._classes":
|
if module_name == "torch._classes":
|
||||||
return obj.qualified_name # pyrefly: ignore # missing-attribute
|
return obj.qualified_name # pyrefly: ignore [missing-attribute]
|
||||||
|
|
||||||
# The Python docs are very clear that `__module__` can be None, but I can't
|
# The Python docs are very clear that `__module__` can be None, but I can't
|
||||||
# figure out when it actually would be.
|
# figure out when it actually would be.
|
||||||
|
|
@ -759,7 +759,7 @@ def unused(fn: Callable[_P, _R]) -> Callable[_P, _R]:
|
||||||
prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED
|
prop.fset, "_torchscript_modifier", FunctionModifiers.UNUSED
|
||||||
)
|
)
|
||||||
|
|
||||||
return prop # pyrefly: ignore # bad-return
|
return prop # pyrefly: ignore [bad-return]
|
||||||
|
|
||||||
fn._torchscript_modifier = FunctionModifiers.UNUSED # type: ignore[attr-defined]
|
fn._torchscript_modifier = FunctionModifiers.UNUSED # type: ignore[attr-defined]
|
||||||
return fn
|
return fn
|
||||||
|
|
|
||||||
|
|
@ -65,8 +65,8 @@ class AsyncClosureHandler(ClosureHandler):
|
||||||
|
|
||||||
self._closure_event_loop = threading.Thread(
|
self._closure_event_loop = threading.Thread(
|
||||||
target=event_loop
|
target=event_loop
|
||||||
) # pyrefly: ignore # bad-assignment
|
) # pyrefly: ignore [bad-assignment]
|
||||||
self._closure_event_loop.start() # pyrefly: ignore # missing-attribute
|
self._closure_event_loop.start() # pyrefly: ignore [missing-attribute]
|
||||||
|
|
||||||
def run(self, closure):
|
def run(self, closure):
|
||||||
with self._closure_lock:
|
with self._closure_lock:
|
||||||
|
|
|
||||||
|
|
@ -515,8 +515,11 @@ def meta_copy_(self, src, non_blocking=False):
|
||||||
def inferUnsqueezeGeometry(tensor, dim):
|
def inferUnsqueezeGeometry(tensor, dim):
|
||||||
result_sizes = list(tensor.size())
|
result_sizes = list(tensor.size())
|
||||||
result_strides = list(tensor.stride())
|
result_strides = list(tensor.stride())
|
||||||
|
# pyrefly: ignore [unsupported-operation]
|
||||||
new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim]
|
new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim]
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
result_sizes.insert(dim, 1)
|
result_sizes.insert(dim, 1)
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
result_strides.insert(dim, new_stride)
|
result_strides.insert(dim, new_stride)
|
||||||
return result_sizes, result_strides
|
return result_sizes, result_strides
|
||||||
|
|
||||||
|
|
@ -2341,19 +2344,19 @@ def calc_conv_nd_return_shape(
|
||||||
|
|
||||||
ret_shape = [input_tensor.shape[0], out_channels]
|
ret_shape = [input_tensor.shape[0], out_channels]
|
||||||
if isinstance(stride, IntLike):
|
if isinstance(stride, IntLike):
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
stride = [stride] * len(dims)
|
stride = [stride] * len(dims)
|
||||||
elif len(stride) == 1:
|
elif len(stride) == 1:
|
||||||
stride = [stride[0]] * len(dims)
|
stride = [stride[0]] * len(dims)
|
||||||
|
|
||||||
if isinstance(padding, IntLike):
|
if isinstance(padding, IntLike):
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
padding = [padding] * len(dims)
|
padding = [padding] * len(dims)
|
||||||
elif len(padding) == 1:
|
elif len(padding) == 1:
|
||||||
padding = [padding[0]] * len(dims)
|
padding = [padding[0]] * len(dims)
|
||||||
|
|
||||||
if isinstance(dilation, IntLike):
|
if isinstance(dilation, IntLike):
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
dilation = [dilation] * len(dims)
|
dilation = [dilation] * len(dims)
|
||||||
elif len(dilation) == 1:
|
elif len(dilation) == 1:
|
||||||
dilation = [dilation[0]] * len(dims)
|
dilation = [dilation[0]] * len(dims)
|
||||||
|
|
@ -2361,7 +2364,7 @@ def calc_conv_nd_return_shape(
|
||||||
output_padding_list: Optional[list[int]] = None
|
output_padding_list: Optional[list[int]] = None
|
||||||
if output_padding:
|
if output_padding:
|
||||||
if isinstance(output_padding, IntLike):
|
if isinstance(output_padding, IntLike):
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
output_padding_list = [output_padding] * len(dims)
|
output_padding_list = [output_padding] * len(dims)
|
||||||
elif len(output_padding) == 1:
|
elif len(output_padding) == 1:
|
||||||
output_padding_list = [output_padding[0]] * len(dims)
|
output_padding_list = [output_padding[0]] * len(dims)
|
||||||
|
|
@ -2374,19 +2377,19 @@ def calc_conv_nd_return_shape(
|
||||||
ret_shape.append(
|
ret_shape.append(
|
||||||
_formula_transposed(
|
_formula_transposed(
|
||||||
dims[i],
|
dims[i],
|
||||||
# pyrefly: ignore # index-error
|
# pyrefly: ignore [index-error]
|
||||||
padding[i],
|
padding[i],
|
||||||
# pyrefly: ignore # index-error
|
# pyrefly: ignore [index-error]
|
||||||
dilation[i],
|
dilation[i],
|
||||||
kernel_size[i],
|
kernel_size[i],
|
||||||
# pyrefly: ignore # index-error
|
# pyrefly: ignore [index-error]
|
||||||
stride[i],
|
stride[i],
|
||||||
output_padding_list[i],
|
output_padding_list[i],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ret_shape.append(
|
ret_shape.append(
|
||||||
# pyrefly: ignore # index-error
|
# pyrefly: ignore [index-error]
|
||||||
_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])
|
_formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])
|
||||||
)
|
)
|
||||||
from torch.fx.experimental.symbolic_shapes import sym_or
|
from torch.fx.experimental.symbolic_shapes import sym_or
|
||||||
|
|
@ -3454,7 +3457,7 @@ def meta_index_Tensor(self, indices):
|
||||||
"""
|
"""
|
||||||
shape = before_shape + replacement_shape + after_shape
|
shape = before_shape + replacement_shape + after_shape
|
||||||
strides = list(self.stride())
|
strides = list(self.stride())
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
strides[len(before_shape) : len(self.shape) - len(after_shape)] = [0] * len(
|
strides[len(before_shape) : len(self.shape) - len(after_shape)] = [0] * len(
|
||||||
replacement_shape
|
replacement_shape
|
||||||
)
|
)
|
||||||
|
|
@ -5311,6 +5314,7 @@ def full(size, fill_value, *args, **kwargs):
|
||||||
if not dtype:
|
if not dtype:
|
||||||
dtype = utils.get_dtype(fill_value)
|
dtype = utils.get_dtype(fill_value)
|
||||||
kwargs["dtype"] = dtype
|
kwargs["dtype"] = dtype
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
return torch.empty(size, *args, **kwargs)
|
return torch.empty(size, *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -6668,7 +6672,7 @@ def rnn_cell_checkSizes(
|
||||||
)
|
)
|
||||||
torch._check(
|
torch._check(
|
||||||
all(
|
all(
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
x.device == input_gates.device
|
x.device == input_gates.device
|
||||||
for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
|
for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -880,7 +880,7 @@ class OpOverload(OperatorBase, Generic[_P, _T]):
|
||||||
elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):
|
elif torch._C._dispatch_has_kernel_for_dispatch_key(self.name(), dk):
|
||||||
return self._op_dk(dk, *args, **kwargs)
|
return self._op_dk(dk, *args, **kwargs)
|
||||||
else:
|
else:
|
||||||
return NotImplemented # pyrefly: ignore # bad-return
|
return NotImplemented # pyrefly: ignore [bad-return]
|
||||||
|
|
||||||
# Remove a dispatch key from the dispatch cache. This will force it to get
|
# Remove a dispatch key from the dispatch cache. This will force it to get
|
||||||
# recomputed the next time. Does nothing
|
# recomputed the next time. Does nothing
|
||||||
|
|
@ -985,9 +985,9 @@ class OpOverload(OperatorBase, Generic[_P, _T]):
|
||||||
|
|
||||||
r = self.py_kernels.get(final_key, final_key)
|
r = self.py_kernels.get(final_key, final_key)
|
||||||
if cache_result:
|
if cache_result:
|
||||||
self._dispatch_cache[key] = r # pyrefly: ignore # unsupported-operation
|
self._dispatch_cache[key] = r # pyrefly: ignore [unsupported-operation]
|
||||||
add_cached_op(self)
|
add_cached_op(self)
|
||||||
return r # pyrefly: ignore # bad-return
|
return r # pyrefly: ignore [bad-return]
|
||||||
|
|
||||||
def name(self):
|
def name(self):
|
||||||
return self._name
|
return self._name
|
||||||
|
|
@ -1117,7 +1117,7 @@ class TorchBindOpOverload(OpOverload[_P, _T]):
|
||||||
)
|
)
|
||||||
|
|
||||||
assert isinstance(handler, Callable) # type: ignore[arg-type]
|
assert isinstance(handler, Callable) # type: ignore[arg-type]
|
||||||
return handler(*args, **kwargs) # pyrefly: ignore # bad-return
|
return handler(*args, **kwargs) # pyrefly: ignore [bad-return]
|
||||||
|
|
||||||
|
|
||||||
def _must_dispatch_in_python(args, kwargs):
|
def _must_dispatch_in_python(args, kwargs):
|
||||||
|
|
|
||||||
|
|
@ -267,6 +267,7 @@ class FunctionalTensor(torch.Tensor):
|
||||||
device=self.device,
|
device=self.device,
|
||||||
layout=self.layout,
|
layout=self.layout,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
return super().to(*args, **kwargs)
|
return super().to(*args, **kwargs)
|
||||||
|
|
||||||
def cuda(self, device=None, *args, **kwargs):
|
def cuda(self, device=None, *args, **kwargs):
|
||||||
|
|
|
||||||
|
|
@ -551,6 +551,7 @@ class Tensor(torch._C.TensorBase):
|
||||||
raise RuntimeError("__setstate__ can be only called on leaf Tensors")
|
raise RuntimeError("__setstate__ can be only called on leaf Tensors")
|
||||||
if len(state) == 4:
|
if len(state) == 4:
|
||||||
# legacy serialization of Tensor
|
# legacy serialization of Tensor
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
self.set_(*state)
|
self.set_(*state)
|
||||||
return
|
return
|
||||||
elif len(state) == 5:
|
elif len(state) == 5:
|
||||||
|
|
@ -758,7 +759,7 @@ class Tensor(torch._C.TensorBase):
|
||||||
)
|
)
|
||||||
if self._post_accumulate_grad_hooks is None:
|
if self._post_accumulate_grad_hooks is None:
|
||||||
self._post_accumulate_grad_hooks: dict[Any, Any] = (
|
self._post_accumulate_grad_hooks: dict[Any, Any] = (
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
OrderedDict()
|
OrderedDict()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1062,7 +1063,7 @@ class Tensor(torch._C.TensorBase):
|
||||||
else:
|
else:
|
||||||
return torch._VF.split_with_sizes(
|
return torch._VF.split_with_sizes(
|
||||||
self,
|
self,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
split_size,
|
split_size,
|
||||||
dim,
|
dim,
|
||||||
)
|
)
|
||||||
|
|
@ -1119,7 +1120,7 @@ class Tensor(torch._C.TensorBase):
|
||||||
__rtruediv__ = __rdiv__
|
__rtruediv__ = __rdiv__
|
||||||
__itruediv__ = _C.TensorBase.__idiv__
|
__itruediv__ = _C.TensorBase.__idiv__
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
__pow__ = cast(
|
__pow__ = cast(
|
||||||
Callable[
|
Callable[
|
||||||
["torch._C.TensorBase", Union["Tensor", int, float, bool, complex]],
|
["torch._C.TensorBase", Union["Tensor", int, float, bool, complex]],
|
||||||
|
|
|
||||||
|
|
@ -686,8 +686,8 @@ def _take_tensors(tensors, size_limit):
|
||||||
if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0:
|
if buf_and_size[1] + size > size_limit and buf_and_size[1] > 0:
|
||||||
yield buf_and_size[0]
|
yield buf_and_size[0]
|
||||||
buf_and_size = buf_dict[t] = [[], 0]
|
buf_and_size = buf_dict[t] = [[], 0]
|
||||||
buf_and_size[0].append(tensor) # pyrefly: ignore # missing-attribute
|
buf_and_size[0].append(tensor) # pyrefly: ignore [missing-attribute]
|
||||||
buf_and_size[1] += size # pyrefly: ignore # unsupported-operation
|
buf_and_size[1] += size # pyrefly: ignore [unsupported-operation]
|
||||||
for buf, _ in buf_dict.values():
|
for buf, _ in buf_dict.values():
|
||||||
if len(buf) > 0:
|
if len(buf) > 0:
|
||||||
yield buf
|
yield buf
|
||||||
|
|
@ -744,6 +744,7 @@ class ExceptionWrapper:
|
||||||
if exc_info is None:
|
if exc_info is None:
|
||||||
exc_info = sys.exc_info()
|
exc_info = sys.exc_info()
|
||||||
self.exc_type = exc_info[0]
|
self.exc_type = exc_info[0]
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
self.exc_msg = "".join(traceback.format_exception(*exc_info))
|
self.exc_msg = "".join(traceback.format_exception(*exc_info))
|
||||||
self.where = where
|
self.where = where
|
||||||
|
|
||||||
|
|
@ -751,7 +752,7 @@ class ExceptionWrapper:
|
||||||
r"""Reraises the wrapped exception in the current thread"""
|
r"""Reraises the wrapped exception in the current thread"""
|
||||||
# Format a message such as: "Caught ValueError in DataLoader worker
|
# Format a message such as: "Caught ValueError in DataLoader worker
|
||||||
# process 2. Original Traceback:", followed by the traceback.
|
# process 2. Original Traceback:", followed by the traceback.
|
||||||
msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" # pyrefly: ignore # missing-attribute
|
msg = f"Caught {self.exc_type.__name__} {self.where}.\nOriginal {self.exc_msg}" # pyrefly: ignore [missing-attribute]
|
||||||
if self.exc_type is KeyError:
|
if self.exc_type is KeyError:
|
||||||
# KeyError calls repr() on its argument (usually a dict key). This
|
# KeyError calls repr() on its argument (usually a dict key). This
|
||||||
# makes stack traces unreadable. It will not be changed in Python
|
# makes stack traces unreadable. It will not be changed in Python
|
||||||
|
|
@ -760,13 +761,13 @@ class ExceptionWrapper:
|
||||||
elif getattr(self.exc_type, "message", None):
|
elif getattr(self.exc_type, "message", None):
|
||||||
# Some exceptions have first argument as non-str but explicitly
|
# Some exceptions have first argument as non-str but explicitly
|
||||||
# have message field
|
# have message field
|
||||||
# pyrefly: ignore # not-callable
|
# pyrefly: ignore [not-callable]
|
||||||
raise self.exc_type(
|
raise self.exc_type(
|
||||||
# pyrefly: ignore # unexpected-keyword
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
message=msg
|
message=msg
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
exception = self.exc_type(msg) # pyrefly: ignore # not-callable
|
exception = self.exc_type(msg) # pyrefly: ignore [not-callable]
|
||||||
except Exception:
|
except Exception:
|
||||||
# If the exception takes multiple arguments or otherwise can't
|
# If the exception takes multiple arguments or otherwise can't
|
||||||
# be constructed, don't try to instantiate since we don't know how to
|
# be constructed, don't try to instantiate since we don't know how to
|
||||||
|
|
@ -1018,12 +1019,12 @@ class _LazySeedTracker:
|
||||||
self.call_order = []
|
self.call_order = []
|
||||||
|
|
||||||
def queue_seed_all(self, cb, traceback):
|
def queue_seed_all(self, cb, traceback):
|
||||||
self.manual_seed_all_cb = (cb, traceback) # pyrefly: ignore # bad-assignment
|
self.manual_seed_all_cb = (cb, traceback) # pyrefly: ignore [bad-assignment]
|
||||||
# update seed_all to be latest
|
# update seed_all to be latest
|
||||||
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
|
self.call_order = [self.manual_seed_cb, self.manual_seed_all_cb]
|
||||||
|
|
||||||
def queue_seed(self, cb, traceback):
|
def queue_seed(self, cb, traceback):
|
||||||
self.manual_seed_cb = (cb, traceback) # pyrefly: ignore # bad-assignment
|
self.manual_seed_cb = (cb, traceback) # pyrefly: ignore [bad-assignment]
|
||||||
# update seed to be latest
|
# update seed to be latest
|
||||||
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
|
self.call_order = [self.manual_seed_all_cb, self.manual_seed_cb]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -419,6 +419,7 @@ class Unpickler:
|
||||||
inst = self.stack[-1]
|
inst = self.stack[-1]
|
||||||
if type(inst) is torch.Tensor:
|
if type(inst) is torch.Tensor:
|
||||||
# Legacy unpickling
|
# Legacy unpickling
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
inst.set_(*state)
|
inst.set_(*state)
|
||||||
elif type(inst) is torch.nn.Parameter:
|
elif type(inst) is torch.nn.Parameter:
|
||||||
inst.__setstate__(state)
|
inst.__setstate__(state)
|
||||||
|
|
|
||||||
|
|
@ -104,6 +104,7 @@ def memory_stats(device_index: _device_t = None, /) -> OrderedDict[str, Any]:
|
||||||
|
|
||||||
flatten("", stats)
|
flatten("", stats)
|
||||||
flat_stats.sort()
|
flat_stats.sort()
|
||||||
|
# pyrefly: ignore [no-matching-overload]
|
||||||
return OrderedDict(flat_stats)
|
return OrderedDict(flat_stats)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -525,7 +525,7 @@ def custom_fwd(
|
||||||
args[0]._dtype = torch.get_autocast_dtype(device_type)
|
args[0]._dtype = torch.get_autocast_dtype(device_type)
|
||||||
if cast_inputs is None:
|
if cast_inputs is None:
|
||||||
args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type)
|
args[0]._fwd_used_autocast = torch.is_autocast_enabled(device_type)
|
||||||
return fwd(*args, **kwargs) # pyrefly: ignore # not-callable
|
return fwd(*args, **kwargs) # pyrefly: ignore [not-callable]
|
||||||
else:
|
else:
|
||||||
autocast_context = torch.is_autocast_enabled(device_type)
|
autocast_context = torch.is_autocast_enabled(device_type)
|
||||||
args[0]._fwd_used_autocast = False
|
args[0]._fwd_used_autocast = False
|
||||||
|
|
@ -536,7 +536,7 @@ def custom_fwd(
|
||||||
**_cast(kwargs, device_type, cast_inputs),
|
**_cast(kwargs, device_type, cast_inputs),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return fwd(*args, **kwargs) # pyrefly: ignore # not-callable
|
return fwd(*args, **kwargs) # pyrefly: ignore [not-callable]
|
||||||
|
|
||||||
return decorate_fwd
|
return decorate_fwd
|
||||||
|
|
||||||
|
|
@ -571,6 +571,6 @@ def custom_bwd(bwd=None, *, device_type: str):
|
||||||
enabled=args[0]._fwd_used_autocast,
|
enabled=args[0]._fwd_used_autocast,
|
||||||
dtype=args[0]._dtype,
|
dtype=args[0]._dtype,
|
||||||
):
|
):
|
||||||
return bwd(*args, **kwargs) # pyrefly: ignore # not-callable
|
return bwd(*args, **kwargs) # pyrefly: ignore [not-callable]
|
||||||
|
|
||||||
return decorate_bwd
|
return decorate_bwd
|
||||||
|
|
|
||||||
|
|
@ -84,7 +84,7 @@ class _NSGraphMatchableSubgraphsIterator:
|
||||||
if is_match:
|
if is_match:
|
||||||
# navigate to the base node
|
# navigate to the base node
|
||||||
for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
|
for rev_fusion_idx in range(len(_reverse_fusion_ops) - 1):
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.seen_nodes.add(cur_start_node)
|
self.seen_nodes.add(cur_start_node)
|
||||||
# for now, assume that there are no other nodes
|
# for now, assume that there are no other nodes
|
||||||
# which need to be added to the stack
|
# which need to be added to the stack
|
||||||
|
|
@ -95,10 +95,10 @@ class _NSGraphMatchableSubgraphsIterator:
|
||||||
cur_base_op_node = cur_start_node
|
cur_base_op_node = cur_start_node
|
||||||
break
|
break
|
||||||
|
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.seen_nodes.add(cur_start_node)
|
self.seen_nodes.add(cur_start_node)
|
||||||
# add args of previous nodes to stack
|
# add args of previous nodes to stack
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
for arg in cur_start_node.all_input_nodes:
|
for arg in cur_start_node.all_input_nodes:
|
||||||
self._recursively_add_node_arg_to_stack(arg)
|
self._recursively_add_node_arg_to_stack(arg)
|
||||||
|
|
||||||
|
|
@ -106,7 +106,7 @@ class _NSGraphMatchableSubgraphsIterator:
|
||||||
# note: this check is done on the start_node, i.e.
|
# note: this check is done on the start_node, i.e.
|
||||||
# if we are matching linear-relu in reverse, this would do the matchable
|
# if we are matching linear-relu in reverse, this would do the matchable
|
||||||
# check on the linear
|
# check on the linear
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
if not self._is_matchable(cur_base_op_node):
|
if not self._is_matchable(cur_base_op_node):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
@ -120,10 +120,10 @@ class _NSGraphMatchableSubgraphsIterator:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
return NSSubgraph(
|
return NSSubgraph(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
start_node=cur_start_node,
|
start_node=cur_start_node,
|
||||||
end_node=cur_end_node,
|
end_node=cur_end_node,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
base_op_node=cur_base_op_node,
|
base_op_node=cur_base_op_node,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -481,4 +481,5 @@ of subgraphs."""
|
||||||
# subgraphs in their order of execution.
|
# subgraphs in their order of execution.
|
||||||
results = collections.OrderedDict(reversed(results.items()))
|
results = collections.OrderedDict(reversed(results.items()))
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-return]
|
||||||
return results
|
return results
|
||||||
|
|
|
||||||
|
|
@ -30,6 +30,7 @@ class EventList(list):
|
||||||
use_device = kwargs.pop("use_device", None)
|
use_device = kwargs.pop("use_device", None)
|
||||||
profile_memory = kwargs.pop("profile_memory", False)
|
profile_memory = kwargs.pop("profile_memory", False)
|
||||||
with_flops = kwargs.pop("with_flops", False)
|
with_flops = kwargs.pop("with_flops", False)
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
super().__init__(*args, **kwargs)
|
super().__init__(*args, **kwargs)
|
||||||
self._use_device = use_device
|
self._use_device = use_device
|
||||||
self._profile_memory = profile_memory
|
self._profile_memory = profile_memory
|
||||||
|
|
@ -505,9 +506,9 @@ class FunctionEvent(FormattedTimesMixin):
|
||||||
self.id: int = id
|
self.id: int = id
|
||||||
self.node_id: int = node_id
|
self.node_id: int = node_id
|
||||||
self.name: str = name
|
self.name: str = name
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.overload_name: str = overload_name
|
self.overload_name: str = overload_name
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.trace_name: str = trace_name
|
self.trace_name: str = trace_name
|
||||||
self.time_range: Interval = Interval(start_us, end_us)
|
self.time_range: Interval = Interval(start_us, end_us)
|
||||||
self.thread: int = thread
|
self.thread: int = thread
|
||||||
|
|
@ -516,13 +517,13 @@ class FunctionEvent(FormattedTimesMixin):
|
||||||
self.count: int = 1
|
self.count: int = 1
|
||||||
self.cpu_children: list[FunctionEvent] = []
|
self.cpu_children: list[FunctionEvent] = []
|
||||||
self.cpu_parent: Optional[FunctionEvent] = None
|
self.cpu_parent: Optional[FunctionEvent] = None
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.input_shapes: tuple[int, ...] = input_shapes
|
self.input_shapes: tuple[int, ...] = input_shapes
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.concrete_inputs: list[Any] = concrete_inputs
|
self.concrete_inputs: list[Any] = concrete_inputs
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.kwinputs: dict[str, Any] = kwinputs
|
self.kwinputs: dict[str, Any] = kwinputs
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.stack: list = stack
|
self.stack: list = stack
|
||||||
self.scope: int = scope
|
self.scope: int = scope
|
||||||
self.use_device: Optional[str] = use_device
|
self.use_device: Optional[str] = use_device
|
||||||
|
|
@ -766,7 +767,7 @@ class FunctionEventAvg(FormattedTimesMixin):
|
||||||
self.self_device_memory_usage += other.self_device_memory_usage
|
self.self_device_memory_usage += other.self_device_memory_usage
|
||||||
self.count += other.count
|
self.count += other.count
|
||||||
if self.flops is None:
|
if self.flops is None:
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.flops = other.flops
|
self.flops = other.flops
|
||||||
elif other.flops is not None:
|
elif other.flops is not None:
|
||||||
self.flops += other.flops
|
self.flops += other.flops
|
||||||
|
|
@ -1003,7 +1004,7 @@ def _build_table(
|
||||||
]
|
]
|
||||||
if flops <= 0:
|
if flops <= 0:
|
||||||
raise AssertionError(f"Expected flops to be positive, but got {flops}")
|
raise AssertionError(f"Expected flops to be positive, but got {flops}")
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1)))
|
log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1)))
|
||||||
if not (log_flops >= 0 and log_flops < len(flop_headers)):
|
if not (log_flops >= 0 and log_flops < len(flop_headers)):
|
||||||
raise AssertionError(
|
raise AssertionError(
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@ def compile(*args, **kwargs):
|
||||||
"""
|
"""
|
||||||
See :func:`torch.compile` for details on the arguments for this function.
|
See :func:`torch.compile` for details on the arguments for this function.
|
||||||
"""
|
"""
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
return torch.compile(*args, **kwargs)
|
return torch.compile(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -198,6 +198,7 @@ def _for_each_rank_run_func(
|
||||||
rr_val = flat_rank_rets[rr_key]
|
rr_val = flat_rank_rets[rr_key]
|
||||||
|
|
||||||
if isinstance(rr_val, Tensor):
|
if isinstance(rr_val, Tensor):
|
||||||
|
# pyrefly: ignore [bad-argument-type, bad-argument-count]
|
||||||
ret = LocalTensor({r: flat_rank_rets[r] for r in sorted(ranks)})
|
ret = LocalTensor({r: flat_rank_rets[r] for r in sorted(ranks)})
|
||||||
elif isinstance(rr_val, (list, tuple)):
|
elif isinstance(rr_val, (list, tuple)):
|
||||||
ret_list = []
|
ret_list = []
|
||||||
|
|
@ -206,6 +207,7 @@ def _for_each_rank_run_func(
|
||||||
v_it = iter(rets.values())
|
v_it = iter(rets.values())
|
||||||
v = next(v_it)
|
v = next(v_it)
|
||||||
if isinstance(v, Tensor):
|
if isinstance(v, Tensor):
|
||||||
|
# pyrefly: ignore [bad-argument-type, bad-argument-count]
|
||||||
ret_list.append(LocalTensor(rets))
|
ret_list.append(LocalTensor(rets))
|
||||||
elif isinstance(v, int) and not all(v == v2 for v2 in v_it):
|
elif isinstance(v, int) and not all(v == v2 for v2 in v_it):
|
||||||
ret_list.append(torch.SymInt(LocalIntNode(rets)))
|
ret_list.append(torch.SymInt(LocalIntNode(rets)))
|
||||||
|
|
@ -468,7 +470,7 @@ class LocalTensor(torch.Tensor):
|
||||||
def __repr__(self) -> str: # type: ignore[override]
|
def __repr__(self) -> str: # type: ignore[override]
|
||||||
parts = []
|
parts = []
|
||||||
for k, v in self._local_tensors.items():
|
for k, v in self._local_tensors.items():
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
parts.append(f" {k}: {v}")
|
parts.append(f" {k}: {v}")
|
||||||
tensors_str = ",\n".join(parts)
|
tensors_str = ",\n".join(parts)
|
||||||
return f"LocalTensor(\n{tensors_str}\n)"
|
return f"LocalTensor(\n{tensors_str}\n)"
|
||||||
|
|
@ -491,6 +493,7 @@ class LocalTensor(torch.Tensor):
|
||||||
"Expecting spec to be not None from `__tensor_flatten__` return value!"
|
"Expecting spec to be not None from `__tensor_flatten__` return value!"
|
||||||
)
|
)
|
||||||
local_tensors = inner_tensors["_local_tensors"]
|
local_tensors = inner_tensors["_local_tensors"]
|
||||||
|
# pyrefly: ignore [bad-argument-type, bad-argument-count]
|
||||||
return LocalTensor(local_tensors)
|
return LocalTensor(local_tensors)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -751,6 +754,7 @@ class LocalTensorMode(TorchDispatchMode):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
with self.disable():
|
with self.disable():
|
||||||
|
# pyrefly: ignore [bad-argument-type, bad-argument-count]
|
||||||
return LocalTensor({r: cb(r) for r in self.ranks})
|
return LocalTensor({r: cb(r) for r in self.ranks})
|
||||||
|
|
||||||
def _patch_device_mesh(self) -> None:
|
def _patch_device_mesh(self) -> None:
|
||||||
|
|
@ -761,7 +765,7 @@ class LocalTensorMode(TorchDispatchMode):
|
||||||
def _unpatch_device_mesh(self) -> None:
|
def _unpatch_device_mesh(self) -> None:
|
||||||
assert self._old_get_coordinate is not None
|
assert self._old_get_coordinate is not None
|
||||||
DeviceMesh.get_coordinate = self._old_get_coordinate
|
DeviceMesh.get_coordinate = self._old_get_coordinate
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
self._old_get_coordinate = None
|
self._old_get_coordinate = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -79,6 +79,7 @@ def _flatten_tensor_size(size) -> torch.Size:
|
||||||
Checks if tensor size is valid, then flatten/return a torch.Size object.
|
Checks if tensor size is valid, then flatten/return a torch.Size object.
|
||||||
"""
|
"""
|
||||||
if len(size) == 1 and isinstance(size[0], collections.abc.Sequence):
|
if len(size) == 1 and isinstance(size[0], collections.abc.Sequence):
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
dims = list(*size)
|
dims = list(*size)
|
||||||
else:
|
else:
|
||||||
dims = list(size)
|
dims = list(size)
|
||||||
|
|
@ -208,7 +209,7 @@ def build_global_metadata(
|
||||||
global_sharded_tensor_metadata = None
|
global_sharded_tensor_metadata = None
|
||||||
global_metadata_rank = 0
|
global_metadata_rank = 0
|
||||||
|
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
for rank, rank_metadata in enumerate(gathered_metadatas):
|
for rank, rank_metadata in enumerate(gathered_metadatas):
|
||||||
if rank_metadata is None:
|
if rank_metadata is None:
|
||||||
continue
|
continue
|
||||||
|
|
|
||||||
|
|
@ -227,7 +227,7 @@ class PGTransport:
|
||||||
self._work: list[Work] = []
|
self._work: list[Work] = []
|
||||||
self._pg = pg
|
self._pg = pg
|
||||||
self._timeout = timeout
|
self._timeout = timeout
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
self._device = device
|
self._device = device
|
||||||
self._state_dict = state_dict
|
self._state_dict = state_dict
|
||||||
|
|
||||||
|
|
@ -345,6 +345,7 @@ class PGTransport:
|
||||||
values.append(recv(path, v))
|
values.append(recv(path, v))
|
||||||
elif isinstance(v, _DTensorMeta):
|
elif isinstance(v, _DTensorMeta):
|
||||||
tensor = recv(path, v.local)
|
tensor = recv(path, v.local)
|
||||||
|
# pyrefly: ignore [bad-argument-type, bad-argument-count, unexpected-keyword]
|
||||||
values.append(DTensor(tensor, v.spec, requires_grad=False))
|
values.append(DTensor(tensor, v.spec, requires_grad=False))
|
||||||
elif isinstance(v, _ShardedTensorMeta):
|
elif isinstance(v, _ShardedTensorMeta):
|
||||||
# Receive all local shards that were sent to us
|
# Receive all local shards that were sent to us
|
||||||
|
|
|
||||||
|
|
@ -565,7 +565,7 @@ class FlatParamHandle:
|
||||||
# Only align addresses for `use_orig_params=True` (for now)
|
# Only align addresses for `use_orig_params=True` (for now)
|
||||||
align_addresses = use_orig_params
|
align_addresses = use_orig_params
|
||||||
self._init_get_unflat_views_fn(align_addresses)
|
self._init_get_unflat_views_fn(align_addresses)
|
||||||
# pyrefly: ignore # read-only
|
# pyrefly: ignore [read-only]
|
||||||
self.device = device
|
self.device = device
|
||||||
self._device_handle = _FSDPDeviceHandle.from_device(self.device)
|
self._device_handle = _FSDPDeviceHandle.from_device(self.device)
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
|
@ -2495,6 +2495,7 @@ class FlatParamHandle:
|
||||||
###########
|
###########
|
||||||
def flat_param_to(self, *args, **kwargs):
|
def flat_param_to(self, *args, **kwargs):
|
||||||
"""Wrap an in-place call to ``.to()`` for ``self.flat_param``."""
|
"""Wrap an in-place call to ``.to()`` for ``self.flat_param``."""
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
self.flat_param.data = self.flat_param.to(*args, **kwargs)
|
self.flat_param.data = self.flat_param.to(*args, **kwargs)
|
||||||
if self._use_orig_params:
|
if self._use_orig_params:
|
||||||
# Refresh the views because their storage may have changed
|
# Refresh the views because their storage may have changed
|
||||||
|
|
|
||||||
|
|
@ -139,11 +139,14 @@ def _from_local_no_grad(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if not compiled_autograd_enabled():
|
if not compiled_autograd_enabled():
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
return DTensor(
|
return DTensor(
|
||||||
# Use the local tensor directly instead of constructing a new tensor
|
# Use the local tensor directly instead of constructing a new tensor
|
||||||
# variable, e.g. with `view_as()`, since this is not differentiable
|
# variable, e.g. with `view_as()`, since this is not differentiable
|
||||||
|
# pyrefly: ignore [bad-argument-count]
|
||||||
local_tensor,
|
local_tensor,
|
||||||
sharding_spec,
|
sharding_spec,
|
||||||
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
requires_grad=local_tensor.requires_grad,
|
requires_grad=local_tensor.requires_grad,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -107,9 +107,12 @@ class _ToTorchTensor(torch.autograd.Function):
|
||||||
)
|
)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
DTensor(
|
DTensor(
|
||||||
|
# pyrefly: ignore [bad-argument-count]
|
||||||
grad_output,
|
grad_output,
|
||||||
grad_spec,
|
grad_spec,
|
||||||
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
requires_grad=grad_output.requires_grad,
|
requires_grad=grad_output.requires_grad,
|
||||||
),
|
),
|
||||||
None,
|
None,
|
||||||
|
|
@ -175,11 +178,14 @@ class _FromTorchTensor(torch.autograd.Function):
|
||||||
)
|
)
|
||||||
|
|
||||||
# We want a fresh Tensor object that shares memory with the input tensor
|
# We want a fresh Tensor object that shares memory with the input tensor
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
dist_tensor = DTensor(
|
dist_tensor = DTensor(
|
||||||
|
# pyrefly: ignore [bad-argument-count]
|
||||||
input.view_as(input),
|
input.view_as(input),
|
||||||
dist_spec,
|
dist_spec,
|
||||||
# requires_grad of the dist tensor depends on if input
|
# requires_grad of the dist tensor depends on if input
|
||||||
# requires_grad or not
|
# requires_grad or not
|
||||||
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
requires_grad=input.requires_grad,
|
requires_grad=input.requires_grad,
|
||||||
)
|
)
|
||||||
return dist_tensor
|
return dist_tensor
|
||||||
|
|
@ -304,9 +310,12 @@ class DTensor(torch.Tensor):
|
||||||
spec.placements,
|
spec.placements,
|
||||||
tensor_meta=unflatten_tensor_meta,
|
tensor_meta=unflatten_tensor_meta,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
return DTensor(
|
return DTensor(
|
||||||
|
# pyrefly: ignore [bad-argument-count]
|
||||||
local_tensor,
|
local_tensor,
|
||||||
unflatten_spec,
|
unflatten_spec,
|
||||||
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
requires_grad=requires_grad,
|
requires_grad=requires_grad,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -820,9 +829,12 @@ def distribute_tensor(
|
||||||
dtype=tensor.dtype,
|
dtype=tensor.dtype,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
return DTensor(
|
return DTensor(
|
||||||
|
# pyrefly: ignore [bad-argument-count]
|
||||||
local_tensor.requires_grad_(tensor.requires_grad),
|
local_tensor.requires_grad_(tensor.requires_grad),
|
||||||
spec,
|
spec,
|
||||||
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
requires_grad=tensor.requires_grad,
|
requires_grad=tensor.requires_grad,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1077,9 +1089,12 @@ def _dtensor_init_helper( # type: ignore[no-untyped-def]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
return DTensor(
|
return DTensor(
|
||||||
|
# pyrefly: ignore [bad-argument-count]
|
||||||
local_tensor,
|
local_tensor,
|
||||||
spec,
|
spec,
|
||||||
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
requires_grad=kwargs["requires_grad"],
|
requires_grad=kwargs["requires_grad"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -78,8 +78,11 @@ def found_inf_reduce_handler(
|
||||||
dtype=target_tensor.dtype,
|
dtype=target_tensor.dtype,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
found_inf_dtensor = dtensor.DTensor(
|
found_inf_dtensor = dtensor.DTensor(
|
||||||
local_tensor=target_tensor, spec=spec, requires_grad=False
|
local_tensor=target_tensor, # pyrefly: ignore [unexpected-keyword]
|
||||||
|
spec=spec, # pyrefly: ignore [unexpected-keyword]
|
||||||
|
requires_grad=False, # pyrefly: ignore [unexpected-keyword]
|
||||||
)
|
)
|
||||||
found_inf = found_inf_dtensor.full_tensor()
|
found_inf = found_inf_dtensor.full_tensor()
|
||||||
target_tensor.copy_(found_inf)
|
target_tensor.copy_(found_inf)
|
||||||
|
|
@ -189,7 +192,7 @@ class OpDispatcher:
|
||||||
local_tensor_args = (
|
local_tensor_args = (
|
||||||
pytree.tree_unflatten(
|
pytree.tree_unflatten(
|
||||||
cast(list[object], op_info.local_args),
|
cast(list[object], op_info.local_args),
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
op_info.args_tree_spec,
|
op_info.args_tree_spec,
|
||||||
)
|
)
|
||||||
if op_info.args_tree_spec
|
if op_info.args_tree_spec
|
||||||
|
|
@ -366,7 +369,7 @@ class OpDispatcher:
|
||||||
resharded_local_tensor = redistribute_local_tensor(
|
resharded_local_tensor = redistribute_local_tensor(
|
||||||
local_tensor,
|
local_tensor,
|
||||||
arg_spec,
|
arg_spec,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
reshard_arg_spec,
|
reshard_arg_spec,
|
||||||
)
|
)
|
||||||
new_local_args.append(resharded_local_tensor)
|
new_local_args.append(resharded_local_tensor)
|
||||||
|
|
@ -439,7 +442,7 @@ class OpDispatcher:
|
||||||
kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor(
|
kwargs_schema[k] = self._try_replicate_spec_for_scalar_tensor(
|
||||||
op_call,
|
op_call,
|
||||||
v,
|
v,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
compute_mesh,
|
compute_mesh,
|
||||||
)
|
)
|
||||||
local_kwargs[k] = v
|
local_kwargs[k] = v
|
||||||
|
|
@ -456,7 +459,7 @@ class OpDispatcher:
|
||||||
OpSchema(
|
OpSchema(
|
||||||
op_call,
|
op_call,
|
||||||
(
|
(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
pytree.tree_unflatten(args_schema, args_spec)
|
pytree.tree_unflatten(args_schema, args_spec)
|
||||||
if args_spec
|
if args_spec
|
||||||
else tuple(args_schema)
|
else tuple(args_schema)
|
||||||
|
|
@ -478,6 +481,7 @@ class OpDispatcher:
|
||||||
assert isinstance(spec, DTensorSpec), (
|
assert isinstance(spec, DTensorSpec), (
|
||||||
f"output spec does not match with output! Expected DTensorSpec, got {spec}."
|
f"output spec does not match with output! Expected DTensorSpec, got {spec}."
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore [bad-argument-type, bad-argument-count, unexpected-keyword]
|
||||||
return dtensor.DTensor(res, spec, requires_grad=res.requires_grad)
|
return dtensor.DTensor(res, spec, requires_grad=res.requires_grad)
|
||||||
else:
|
else:
|
||||||
# if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor
|
# if output does not have a DTensorSpec due to specific ops, it must be a scalar tensor
|
||||||
|
|
|
||||||
|
|
@ -883,9 +883,12 @@ class Redistribute(torch.autograd.Function):
|
||||||
output = local_tensor
|
output = local_tensor
|
||||||
target_spec = current_spec
|
target_spec = current_spec
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
return dtensor.DTensor(
|
return dtensor.DTensor(
|
||||||
|
# pyrefly: ignore [bad-argument-count]
|
||||||
output,
|
output,
|
||||||
target_spec,
|
target_spec,
|
||||||
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
requires_grad=input.requires_grad,
|
requires_grad=input.requires_grad,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -944,9 +947,12 @@ class Redistribute(torch.autograd.Function):
|
||||||
dtype=output.dtype,
|
dtype=output.dtype,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
output_dtensor = dtensor.DTensor(
|
output_dtensor = dtensor.DTensor(
|
||||||
|
# pyrefly: ignore [bad-argument-count]
|
||||||
output,
|
output,
|
||||||
spec,
|
spec,
|
||||||
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
requires_grad=grad_output.requires_grad,
|
requires_grad=grad_output.requires_grad,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -174,9 +174,12 @@ def _log_softmax_handler(
|
||||||
tensor_meta=output_tensor_meta,
|
tensor_meta=output_tensor_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
return DTensor(
|
return DTensor(
|
||||||
|
# pyrefly: ignore [bad-argument-count]
|
||||||
res,
|
res,
|
||||||
res_spec,
|
res_spec,
|
||||||
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
requires_grad=res.requires_grad,
|
requires_grad=res.requires_grad,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -251,7 +254,7 @@ def _nll_loss_forward(
|
||||||
if weight is not None:
|
if weight is not None:
|
||||||
new_shape = list(x.shape)
|
new_shape = list(x.shape)
|
||||||
new_shape[channel_dim] = -1
|
new_shape[channel_dim] = -1
|
||||||
# pyrefly: ignore # unbound-name
|
# pyrefly: ignore [unbound-name]
|
||||||
w = w.expand(new_shape)
|
w = w.expand(new_shape)
|
||||||
wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
|
wsum = torch.gather(w, channel_dim, safe_target_).squeeze(channel_dim)
|
||||||
wsum = torch.where(target != ignore_index, wsum, 0)
|
wsum = torch.where(target != ignore_index, wsum, 0)
|
||||||
|
|
@ -309,9 +312,9 @@ def _nll_loss_forward_handler(
|
||||||
output_placements = all_replicate_placements
|
output_placements = all_replicate_placements
|
||||||
|
|
||||||
# tensor inputs to _propagate_tensor_meta need to be DTensors
|
# tensor inputs to _propagate_tensor_meta need to be DTensors
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
args = list(args)
|
args = list(args)
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
args[1], args[2] = target, weight
|
args[1], args[2] = target, weight
|
||||||
output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
|
output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
|
||||||
|
|
||||||
|
|
@ -330,9 +333,12 @@ def _nll_loss_forward_handler(
|
||||||
out_spec = DTensorSpec(spec.mesh, output_placements, tensor_meta=output_tensor_meta)
|
out_spec = DTensorSpec(spec.mesh, output_placements, tensor_meta=output_tensor_meta)
|
||||||
|
|
||||||
return (
|
return (
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
DTensor(
|
DTensor(
|
||||||
|
# pyrefly: ignore [bad-argument-count]
|
||||||
result,
|
result,
|
||||||
out_spec,
|
out_spec,
|
||||||
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
requires_grad=result.requires_grad,
|
requires_grad=result.requires_grad,
|
||||||
),
|
),
|
||||||
total_weight,
|
total_weight,
|
||||||
|
|
@ -442,11 +448,11 @@ def _nll_loss_backward_handler(
|
||||||
weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh)
|
weight = _cast_to_dtensor(weight, all_replicate_placements, spec.mesh)
|
||||||
|
|
||||||
# tensor inputs to _propagate_tensor_meta need to be DTensors
|
# tensor inputs to _propagate_tensor_meta need to be DTensors
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
args = list(args)
|
args = list(args)
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
args[2], args[3] = target, weight
|
args[2], args[3] = target, weight
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
args[6] = _cast_to_dtensor(total_weight, all_replicate_placements, spec.mesh)
|
args[6] = _cast_to_dtensor(total_weight, all_replicate_placements, spec.mesh)
|
||||||
output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
|
output_tensor_meta = _propagate_tensor_meta(op_call, tuple(args), kwargs)
|
||||||
|
|
||||||
|
|
@ -470,9 +476,12 @@ def _nll_loss_backward_handler(
|
||||||
tensor_meta=output_tensor_meta,
|
tensor_meta=output_tensor_meta,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore [bad-argument-type]
|
||||||
return DTensor(
|
return DTensor(
|
||||||
|
# pyrefly: ignore [bad-argument-count]
|
||||||
result,
|
result,
|
||||||
out_spec,
|
out_spec,
|
||||||
|
# pyrefly: ignore [unexpected-keyword]
|
||||||
requires_grad=result.requires_grad,
|
requires_grad=result.requires_grad,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -949,7 +949,7 @@ def _check_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module):
|
||||||
for key, value in pytree.tree_map(arg_dump, node.kwargs).items()
|
for key, value in pytree.tree_map(arg_dump, node.kwargs).items()
|
||||||
]
|
]
|
||||||
target = node.target if node.op in ("call_function", "get_attr") else ""
|
target = node.target if node.op in ("call_function", "get_attr") else ""
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})")
|
ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})")
|
||||||
nodes_idx[id(node)] = i
|
nodes_idx[id(node)] = i
|
||||||
return "\n".join(ret)
|
return "\n".join(ret)
|
||||||
|
|
@ -1206,6 +1206,7 @@ class _ModuleFrame:
|
||||||
for k in kwargs_spec.context
|
for k in kwargs_spec.context
|
||||||
}
|
}
|
||||||
assert self.parent_call_module is not None
|
assert self.parent_call_module is not None
|
||||||
|
# pyrefly: ignore [bad-assignment]
|
||||||
self.parent_call_module.args = tuple(arg_nodes)
|
self.parent_call_module.args = tuple(arg_nodes)
|
||||||
self.parent_call_module.kwargs = kwarg_nodes # type: ignore[assignment]
|
self.parent_call_module.kwargs = kwarg_nodes # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
@ -1393,6 +1394,7 @@ class _ModuleFrame:
|
||||||
|
|
||||||
def print(self, *args, **kwargs):
|
def print(self, *args, **kwargs):
|
||||||
if self.verbose:
|
if self.verbose:
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
print(*args, **kwargs)
|
print(*args, **kwargs)
|
||||||
|
|
||||||
def run_from(self, node_idx):
|
def run_from(self, node_idx):
|
||||||
|
|
@ -1486,7 +1488,7 @@ class _ModuleFrame:
|
||||||
self.seen_attrs[self.child_fqn].add(node.target)
|
self.seen_attrs[self.child_fqn].add(node.target)
|
||||||
|
|
||||||
self.copy_node(node)
|
self.copy_node(node)
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
node_idx += 1
|
node_idx += 1
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1952,7 +1952,7 @@ def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(shape, (int, torch.SymInt)):
|
if isinstance(shape, (int, torch.SymInt)):
|
||||||
shape = torch.Size([shape]) # pyrefly: ignore # bad-argument-type
|
shape = torch.Size([shape]) # pyrefly: ignore [bad-argument-type]
|
||||||
else:
|
else:
|
||||||
for dim in shape:
|
for dim in shape:
|
||||||
torch._check_type(
|
torch._check_type(
|
||||||
|
|
|
||||||
|
|
@ -87,6 +87,7 @@ def _reify_object_slots(o, s):
|
||||||
@dispatch(slice, dict)
|
@dispatch(slice, dict)
|
||||||
def _reify(o, s):
|
def _reify(o, s):
|
||||||
"""Reify a Python ``slice`` object"""
|
"""Reify a Python ``slice`` object"""
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
return slice(*reify((o.start, o.stop, o.step), s))
|
return slice(*reify((o.start, o.stop, o.step), s))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -59,7 +59,7 @@ Argument = Optional[
|
||||||
BaseArgumentTypes,
|
BaseArgumentTypes,
|
||||||
]
|
]
|
||||||
]
|
]
|
||||||
# pyrefly: ignore # invalid-annotation
|
# pyrefly: ignore [invalid-annotation]
|
||||||
ArgumentT = TypeVar("ArgumentT", bound=Argument)
|
ArgumentT = TypeVar("ArgumentT", bound=Argument)
|
||||||
_P = ParamSpec("_P")
|
_P = ParamSpec("_P")
|
||||||
_R = TypeVar("_R")
|
_R = TypeVar("_R")
|
||||||
|
|
@ -385,7 +385,7 @@ class Node(_NodeBase):
|
||||||
Args:
|
Args:
|
||||||
x (Node): The node to put before this node. Must be a member of the same graph.
|
x (Node): The node to put before this node. Must be a member of the same graph.
|
||||||
"""
|
"""
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
self._prepend(x)
|
self._prepend(x)
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=True)
|
@compatibility(is_backward_compatible=True)
|
||||||
|
|
@ -397,7 +397,7 @@ class Node(_NodeBase):
|
||||||
Args:
|
Args:
|
||||||
x (Node): The node to put after this node. Must be a member of the same graph.
|
x (Node): The node to put after this node. Must be a member of the same graph.
|
||||||
"""
|
"""
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
self._next._prepend(x)
|
self._next._prepend(x)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -698,7 +698,8 @@ class Node(_NodeBase):
|
||||||
if replace_hooks:
|
if replace_hooks:
|
||||||
for replace_hook in replace_hooks:
|
for replace_hook in replace_hooks:
|
||||||
replace_hook(old=self, new=replace_with.name, user=use_node)
|
replace_hook(old=self, new=replace_with.name, user=use_node)
|
||||||
use_node._replace_input_with(self, replace_with)
|
# pyrefly: ignore [missing-attribute]
|
||||||
|
use_node._replace_input_with(self, replace_with) # type: ignore[attr-defined]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@compatibility(is_backward_compatible=False)
|
@compatibility(is_backward_compatible=False)
|
||||||
|
|
@ -835,7 +836,8 @@ class Node(_NodeBase):
|
||||||
for replace_hook in m._replace_hooks:
|
for replace_hook in m._replace_hooks:
|
||||||
replace_hook(old=old_input, new=new_input.name, user=self)
|
replace_hook(old=old_input, new=new_input.name, user=self)
|
||||||
|
|
||||||
self._replace_input_with(old_input, new_input)
|
# pyrefly: ignore [missing-attribute]
|
||||||
|
self._replace_input_with(old_input, new_input) # type: ignore[attr-defined]
|
||||||
|
|
||||||
def _rename(self, candidate: str) -> None:
|
def _rename(self, candidate: str) -> None:
|
||||||
if candidate == self.name:
|
if candidate == self.name:
|
||||||
|
|
|
||||||
|
|
@ -303,7 +303,7 @@ class StreamContext:
|
||||||
self.idx = _get_device_index(None, True)
|
self.idx = _get_device_index(None, True)
|
||||||
if not torch.jit.is_scripting():
|
if not torch.jit.is_scripting():
|
||||||
if self.idx is None:
|
if self.idx is None:
|
||||||
self.idx = -1 # pyrefly: ignore # bad-assignment
|
self.idx = -1 # pyrefly: ignore [bad-assignment]
|
||||||
|
|
||||||
self.src_prev_stream = (
|
self.src_prev_stream = (
|
||||||
None if not torch.jit.is_scripting() else torch.mtia.default_stream(None)
|
None if not torch.jit.is_scripting() else torch.mtia.default_stream(None)
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ def _outer_to_inner_dim(ndim, dim, ragged_dim, canonicalize=False):
|
||||||
if canonicalize:
|
if canonicalize:
|
||||||
dim = canonicalize_dims(ndim, dim)
|
dim = canonicalize_dims(ndim, dim)
|
||||||
|
|
||||||
assert dim >= 0 and dim < ndim # pyrefly: ignore # unsupported-operation
|
assert dim >= 0 and dim < ndim # pyrefly: ignore [unsupported-operation]
|
||||||
|
|
||||||
# Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1.
|
# Map dim=0 (AKA batch dim) -> packed dim i.e. outer ragged dim - 1.
|
||||||
# For other dims, subtract 1 to convert to inner space.
|
# For other dims, subtract 1 to convert to inner space.
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,7 @@ class _NormBase(Module):
|
||||||
torch.tensor(
|
torch.tensor(
|
||||||
0,
|
0,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
|
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
@ -222,7 +222,7 @@ class _LazyNormBase(LazyModuleMixin, _NormBase):
|
||||||
dtype=None,
|
dtype=None,
|
||||||
) -> None:
|
) -> None:
|
||||||
factory_kwargs = {"device": device, "dtype": dtype}
|
factory_kwargs = {"device": device, "dtype": dtype}
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
super().__init__(
|
super().__init__(
|
||||||
# affine and track_running_stats are hardcoded to False to
|
# affine and track_running_stats are hardcoded to False to
|
||||||
# avoid creating tensors that will soon be overwritten.
|
# avoid creating tensors that will soon be overwritten.
|
||||||
|
|
@ -236,29 +236,29 @@ class _LazyNormBase(LazyModuleMixin, _NormBase):
|
||||||
self.affine = affine
|
self.affine = affine
|
||||||
self.track_running_stats = track_running_stats
|
self.track_running_stats = track_running_stats
|
||||||
if self.affine:
|
if self.affine:
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.weight = UninitializedParameter(**factory_kwargs)
|
self.weight = UninitializedParameter(**factory_kwargs)
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.bias = UninitializedParameter(**factory_kwargs)
|
self.bias = UninitializedParameter(**factory_kwargs)
|
||||||
if self.track_running_stats:
|
if self.track_running_stats:
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.running_mean = UninitializedBuffer(**factory_kwargs)
|
self.running_mean = UninitializedBuffer(**factory_kwargs)
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
self.running_var = UninitializedBuffer(**factory_kwargs)
|
self.running_var = UninitializedBuffer(**factory_kwargs)
|
||||||
self.num_batches_tracked = torch.tensor(
|
self.num_batches_tracked = torch.tensor(
|
||||||
0,
|
0,
|
||||||
dtype=torch.long,
|
dtype=torch.long,
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
|
**{k: v for k, v in factory_kwargs.items() if k != "dtype"},
|
||||||
)
|
)
|
||||||
|
|
||||||
def reset_parameters(self) -> None:
|
def reset_parameters(self) -> None:
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
if not self.has_uninitialized_params() and self.num_features != 0:
|
if not self.has_uninitialized_params() and self.num_features != 0:
|
||||||
super().reset_parameters()
|
super().reset_parameters()
|
||||||
|
|
||||||
def initialize_parameters(self, input) -> None: # type: ignore[override]
|
def initialize_parameters(self, input) -> None: # type: ignore[override]
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
if self.has_uninitialized_params():
|
if self.has_uninitialized_params():
|
||||||
self.num_features = input.shape[1]
|
self.num_features = input.shape[1]
|
||||||
if self.affine:
|
if self.affine:
|
||||||
|
|
@ -352,6 +352,7 @@ class BatchNorm1d(_BatchNorm):
|
||||||
raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
|
raise ValueError(f"expected 2D or 3D input (got {input.dim()}D input)")
|
||||||
|
|
||||||
|
|
||||||
|
# pyrefly: ignore [inconsistent-inheritance]
|
||||||
class LazyBatchNorm1d(_LazyNormBase, _BatchNorm):
|
class LazyBatchNorm1d(_LazyNormBase, _BatchNorm):
|
||||||
r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization.
|
r"""A :class:`torch.nn.BatchNorm1d` module with lazy initialization.
|
||||||
|
|
||||||
|
|
@ -463,6 +464,7 @@ class BatchNorm2d(_BatchNorm):
|
||||||
raise ValueError(f"expected 4D input (got {input.dim()}D input)")
|
raise ValueError(f"expected 4D input (got {input.dim()}D input)")
|
||||||
|
|
||||||
|
|
||||||
|
# pyrefly: ignore [inconsistent-inheritance]
|
||||||
class LazyBatchNorm2d(_LazyNormBase, _BatchNorm):
|
class LazyBatchNorm2d(_LazyNormBase, _BatchNorm):
|
||||||
r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization.
|
r"""A :class:`torch.nn.BatchNorm2d` module with lazy initialization.
|
||||||
|
|
||||||
|
|
@ -574,6 +576,7 @@ class BatchNorm3d(_BatchNorm):
|
||||||
raise ValueError(f"expected 5D input (got {input.dim()}D input)")
|
raise ValueError(f"expected 5D input (got {input.dim()}D input)")
|
||||||
|
|
||||||
|
|
||||||
|
# pyrefly: ignore [inconsistent-inheritance]
|
||||||
class LazyBatchNorm3d(_LazyNormBase, _BatchNorm):
|
class LazyBatchNorm3d(_LazyNormBase, _BatchNorm):
|
||||||
r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization.
|
r"""A :class:`torch.nn.BatchNorm3d` module with lazy initialization.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -38,13 +38,13 @@ T = TypeVar("T", bound="Module")
|
||||||
|
|
||||||
|
|
||||||
class _IncompatibleKeys(
|
class _IncompatibleKeys(
|
||||||
# pyrefly: ignore # invalid-inheritance
|
# pyrefly: ignore [invalid-inheritance]
|
||||||
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
|
namedtuple("IncompatibleKeys", ["missing_keys", "unexpected_keys"]),
|
||||||
):
|
):
|
||||||
__slots__ = ()
|
__slots__ = ()
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
if not self.missing_keys and not self.unexpected_keys:
|
if not self.missing_keys and not self.unexpected_keys:
|
||||||
return "<All keys matched successfully>"
|
return "<All keys matched successfully>"
|
||||||
return super().__repr__()
|
return super().__repr__()
|
||||||
|
|
@ -93,7 +93,7 @@ class _WrappedHook:
|
||||||
def __getstate__(self) -> dict:
|
def __getstate__(self) -> dict:
|
||||||
result = {"hook": self.hook, "with_module": self.with_module}
|
result = {"hook": self.hook, "with_module": self.with_module}
|
||||||
if self.with_module:
|
if self.with_module:
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
result["module"] = self.module()
|
result["module"] = self.module()
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
@ -979,7 +979,7 @@ class Module:
|
||||||
# Decrement use count of the gradient by setting to None
|
# Decrement use count of the gradient by setting to None
|
||||||
param.grad = None
|
param.grad = None
|
||||||
param_applied = torch.nn.Parameter(
|
param_applied = torch.nn.Parameter(
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
param_applied,
|
param_applied,
|
||||||
requires_grad=param.requires_grad,
|
requires_grad=param.requires_grad,
|
||||||
)
|
)
|
||||||
|
|
@ -992,13 +992,13 @@ class Module:
|
||||||
) from e
|
) from e
|
||||||
out_param = param
|
out_param = param
|
||||||
elif p_should_use_set_data:
|
elif p_should_use_set_data:
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
param.data = param_applied
|
param.data = param_applied
|
||||||
out_param = param
|
out_param = param
|
||||||
else:
|
else:
|
||||||
assert isinstance(param, Parameter)
|
assert isinstance(param, Parameter)
|
||||||
assert param.is_leaf
|
assert param.is_leaf
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
out_param = Parameter(param_applied, param.requires_grad)
|
out_param = Parameter(param_applied, param.requires_grad)
|
||||||
self._parameters[key] = out_param
|
self._parameters[key] = out_param
|
||||||
|
|
||||||
|
|
@ -1337,7 +1337,9 @@ class Module:
|
||||||
|
|
||||||
"""
|
"""
|
||||||
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
|
device, dtype, non_blocking, convert_to_format = torch._C._nn._parse_to(
|
||||||
*args, **kwargs
|
# pyrefly: ignore [not-iterable]
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
|
|
@ -2256,7 +2258,7 @@ class Module:
|
||||||
|
|
||||||
if destination is None:
|
if destination is None:
|
||||||
destination = OrderedDict()
|
destination = OrderedDict()
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
destination._metadata = OrderedDict()
|
destination._metadata = OrderedDict()
|
||||||
|
|
||||||
local_metadata = dict(version=self._version)
|
local_metadata = dict(version=self._version)
|
||||||
|
|
@ -2407,7 +2409,7 @@ class Module:
|
||||||
}
|
}
|
||||||
local_name_params = itertools.chain(
|
local_name_params = itertools.chain(
|
||||||
self._parameters.items(),
|
self._parameters.items(),
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
persistent_buffers.items(),
|
persistent_buffers.items(),
|
||||||
)
|
)
|
||||||
local_state = {k: v for k, v in local_name_params if v is not None}
|
local_state = {k: v for k, v in local_name_params if v is not None}
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ def _verbose_printer(verbose: bool | None) -> Callable[..., None]:
|
||||||
"""Prints messages based on `verbose`."""
|
"""Prints messages based on `verbose`."""
|
||||||
if verbose is False:
|
if verbose is False:
|
||||||
return lambda *_, **__: None
|
return lambda *_, **__: None
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs)
|
return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -47,7 +48,7 @@ def _patch_dynamo_unsupported_functions():
|
||||||
|
|
||||||
# Replace torch.jit.isinstance with isinstance
|
# Replace torch.jit.isinstance with isinstance
|
||||||
jit_isinstance = torch.jit.isinstance
|
jit_isinstance = torch.jit.isinstance
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
torch.jit.isinstance = isinstance
|
torch.jit.isinstance = isinstance
|
||||||
logger.info("Replaced torch.jit.isinstance with isinstance to allow dynamo tracing")
|
logger.info("Replaced torch.jit.isinstance with isinstance to allow dynamo tracing")
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -132,10 +132,10 @@ class TorchTensor(ir.Tensor):
|
||||||
# view the tensor as that dtype so that it is convertible to NumPy,
|
# view the tensor as that dtype so that it is convertible to NumPy,
|
||||||
# and then view it back to the proper dtype (using ml_dtypes obtained by
|
# and then view it back to the proper dtype (using ml_dtypes obtained by
|
||||||
# calling dtype.numpy()).
|
# calling dtype.numpy()).
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
if self.dtype == ir.DataType.BFLOAT16:
|
if self.dtype == ir.DataType.BFLOAT16:
|
||||||
return (
|
return (
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy())
|
self.raw.view(torch.uint16).numpy(force=True).view(self.dtype.numpy())
|
||||||
)
|
)
|
||||||
if self.dtype in {
|
if self.dtype in {
|
||||||
|
|
@ -144,11 +144,11 @@ class TorchTensor(ir.Tensor):
|
||||||
ir.DataType.FLOAT8E5M2,
|
ir.DataType.FLOAT8E5M2,
|
||||||
ir.DataType.FLOAT8E5M2FNUZ,
|
ir.DataType.FLOAT8E5M2FNUZ,
|
||||||
}:
|
}:
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
|
return self.raw.view(torch.uint8).numpy(force=True).view(self.dtype.numpy())
|
||||||
if self.dtype == ir.DataType.FLOAT4E2M1:
|
if self.dtype == ir.DataType.FLOAT4E2M1:
|
||||||
return _type_casting.unpack_float4x2_as_uint8(self.raw).view(
|
return _type_casting.unpack_float4x2_as_uint8(self.raw).view(
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
self.dtype.numpy()
|
self.dtype.numpy()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -170,7 +170,7 @@ class TorchTensor(ir.Tensor):
|
||||||
|
|
||||||
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
|
if isinstance(tensor, torch._subclasses.fake_tensor.FakeTensor):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor "
|
f"Cannot take content out from the FakeTensor ('{self.name}'). Please replace the tensor "
|
||||||
"with a tensor backed by real data using ONNXProgram.apply_weights() "
|
"with a tensor backed by real data using ONNXProgram.apply_weights() "
|
||||||
"or save the model without initializers by setting include_initializers=False."
|
"or save the model without initializers by setting include_initializers=False."
|
||||||
|
|
@ -251,7 +251,7 @@ def _set_shape_type(
|
||||||
if isinstance(dim, int):
|
if isinstance(dim, int):
|
||||||
dims.append(dim)
|
dims.append(dim)
|
||||||
else:
|
else:
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
dims.append(str(dim.node))
|
dims.append(str(dim.node))
|
||||||
|
|
||||||
# If the dtype is set already (e.g. by the onnx_symbolic ops),
|
# If the dtype is set already (e.g. by the onnx_symbolic ops),
|
||||||
|
|
@ -1232,7 +1232,7 @@ def _exported_program_to_onnx_program(
|
||||||
# so we need to get them from the name_* apis.
|
# so we need to get them from the name_* apis.
|
||||||
for name, torch_tensor in itertools.chain(
|
for name, torch_tensor in itertools.chain(
|
||||||
exported_program.named_parameters(),
|
exported_program.named_parameters(),
|
||||||
# pyrefly: ignore # bad-argument-type
|
# pyrefly: ignore [bad-argument-type]
|
||||||
exported_program.named_buffers(),
|
exported_program.named_buffers(),
|
||||||
exported_program.constants.items(),
|
exported_program.constants.items(),
|
||||||
):
|
):
|
||||||
|
|
@ -1265,6 +1265,7 @@ def _verbose_printer(verbose: bool | None) -> Callable[..., None]:
|
||||||
"""Prints messages based on `verbose`."""
|
"""Prints messages based on `verbose`."""
|
||||||
if verbose is False:
|
if verbose is False:
|
||||||
return lambda *_, **__: None
|
return lambda *_, **__: None
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs)
|
return lambda *args, **kwargs: print("[torch.onnx]", *args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -239,7 +239,7 @@ def _compare_onnx_pytorch_outputs_in_np(
|
||||||
if acceptable_error_percentage:
|
if acceptable_error_percentage:
|
||||||
error_percentage = 1 - np.sum(
|
error_percentage = 1 - np.sum(
|
||||||
np.isclose(ort_out, pt_out, rtol=options.rtol, atol=options.atol)
|
np.isclose(ort_out, pt_out, rtol=options.rtol, atol=options.atol)
|
||||||
) / np.prod(ort_out.shape) # pyrefly: ignore # missing-attribute
|
) / np.prod(ort_out.shape) # pyrefly: ignore [missing-attribute]
|
||||||
if error_percentage <= acceptable_error_percentage:
|
if error_percentage <= acceptable_error_percentage:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
f"Suppressed AssertionError:\n{e}.\n"
|
f"Suppressed AssertionError:\n{e}.\n"
|
||||||
|
|
|
||||||
|
|
@ -13,6 +13,7 @@ from torch import optim
|
||||||
|
|
||||||
def partialclass(cls, *args, **kwargs): # noqa: D103
|
def partialclass(cls, *args, **kwargs): # noqa: D103
|
||||||
class NewCls(cls):
|
class NewCls(cls):
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
__init__ = partialmethod(cls.__init__, *args, **kwargs)
|
__init__ = partialmethod(cls.__init__, *args, **kwargs)
|
||||||
|
|
||||||
return NewCls
|
return NewCls
|
||||||
|
|
|
||||||
|
|
@ -326,7 +326,7 @@ def gaussian(
|
||||||
requires_grad=requires_grad,
|
requires_grad=requires_grad,
|
||||||
)
|
)
|
||||||
|
|
||||||
return torch.exp(-(k**2)) # pyrefly: ignore # unsupported-operation
|
return torch.exp(-(k**2)) # pyrefly: ignore [unsupported-operation]
|
||||||
|
|
||||||
|
|
||||||
@_add_docstr(
|
@_add_docstr(
|
||||||
|
|
|
||||||
|
|
@ -618,7 +618,7 @@ def _get_storage_from_sequence(sequence, dtype, device):
|
||||||
|
|
||||||
def _isint(x):
|
def _isint(x):
|
||||||
if HAS_NUMPY:
|
if HAS_NUMPY:
|
||||||
return isinstance(x, (int, np.integer)) # pyrefly: ignore # missing-attribute
|
return isinstance(x, (int, np.integer)) # pyrefly: ignore [missing-attribute]
|
||||||
else:
|
else:
|
||||||
return isinstance(x, int)
|
return isinstance(x, int)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,7 @@ class OrderedSet(MutableSet[T], Reversible[T]):
|
||||||
def pop(self) -> T:
|
def pop(self) -> T:
|
||||||
if not self:
|
if not self:
|
||||||
raise KeyError("pop from an empty set")
|
raise KeyError("pop from an empty set")
|
||||||
|
# pyrefly: ignore [bad-return]
|
||||||
return self._dict.popitem()[0]
|
return self._dict.popitem()[0]
|
||||||
|
|
||||||
def copy(self) -> OrderedSet[T]:
|
def copy(self) -> OrderedSet[T]:
|
||||||
|
|
@ -158,7 +159,7 @@ class OrderedSet(MutableSet[T], Reversible[T]):
|
||||||
def __and__(self, other: AbstractSet[T_co]) -> OrderedSet[T]:
|
def __and__(self, other: AbstractSet[T_co]) -> OrderedSet[T]:
|
||||||
# MutableSet impl will iterate over other, iter over smaller of two sets
|
# MutableSet impl will iterate over other, iter over smaller of two sets
|
||||||
if isinstance(other, OrderedSet) and len(self) < len(other):
|
if isinstance(other, OrderedSet) and len(self) < len(other):
|
||||||
# pyrefly: ignore # unsupported-operation, bad-return
|
# pyrefly: ignore [unsupported-operation, bad-return]
|
||||||
return other & self
|
return other & self
|
||||||
return cast(OrderedSet[T], super().__and__(other))
|
return cast(OrderedSet[T], super().__and__(other))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -202,7 +202,7 @@ def _generate_module_methods_for_privateuse1_backend(custom_backend_name: str) -
|
||||||
Args:
|
Args:
|
||||||
device (int, optional): if specified, all parameters will be copied to that device
|
device (int, optional): if specified, all parameters will be copied to that device
|
||||||
"""
|
"""
|
||||||
# pyrefly: ignore # missing-attribute
|
# pyrefly: ignore [missing-attribute]
|
||||||
return self._apply(lambda t: getattr(t, custom_backend_name)(device))
|
return self._apply(lambda t: getattr(t, custom_backend_name)(device))
|
||||||
|
|
||||||
_check_register_once(torch.nn.Module, custom_backend_name)
|
_check_register_once(torch.nn.Module, custom_backend_name)
|
||||||
|
|
@ -252,11 +252,15 @@ def _generate_packed_sequence_methods_for_privateuse1_backend(
|
||||||
device (int, optional): if specified, all parameters will be copied to that device
|
device (int, optional): if specified, all parameters will be copied to that device
|
||||||
"""
|
"""
|
||||||
ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(
|
ex = torch.tensor((), dtype=self.data.dtype, device=self.data.device).to(
|
||||||
*args, **kwargs
|
# pyrefly: ignore [not-iterable]
|
||||||
|
*args,
|
||||||
|
**kwargs,
|
||||||
)
|
)
|
||||||
if ex.device.type == custom_backend_name:
|
if ex.device.type == custom_backend_name:
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
return self.to(*args, **kwargs)
|
return self.to(*args, **kwargs)
|
||||||
kwargs.update({"device": custom_backend_name})
|
kwargs.update({"device": custom_backend_name})
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
return self.to(*args, **kwargs)
|
return self.to(*args, **kwargs)
|
||||||
|
|
||||||
_check_register_once(torch.nn.utils.rnn.PackedSequence, custom_backend_name)
|
_check_register_once(torch.nn.utils.rnn.PackedSequence, custom_backend_name)
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ MATH_TRANSPILATIONS = collections.OrderedDict(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
CUDA_TYPE_NAME_MAP = collections.OrderedDict(
|
CUDA_TYPE_NAME_MAP = collections.OrderedDict(
|
||||||
[
|
[
|
||||||
("CUresult", ("hipError_t", CONV_TYPE, API_DRIVER)),
|
("CUresult", ("hipError_t", CONV_TYPE, API_DRIVER)),
|
||||||
|
|
@ -587,6 +587,7 @@ CUDA_TYPE_NAME_MAP = collections.OrderedDict(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore [no-matching-overload]
|
||||||
CUDA_INCLUDE_MAP = collections.OrderedDict(
|
CUDA_INCLUDE_MAP = collections.OrderedDict(
|
||||||
[
|
[
|
||||||
# since pytorch uses "\b{pattern}\b" as the actual re pattern,
|
# since pytorch uses "\b{pattern}\b" as the actual re pattern,
|
||||||
|
|
@ -676,7 +677,7 @@ CUDA_INCLUDE_MAP = collections.OrderedDict(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
# pyrefly: ignore # no-matching-overload
|
# pyrefly: ignore [no-matching-overload]
|
||||||
CUDA_IDENTIFIER_MAP = collections.OrderedDict(
|
CUDA_IDENTIFIER_MAP = collections.OrderedDict(
|
||||||
[
|
[
|
||||||
("__CUDACC__", ("__HIPCC__", CONV_DEF, API_RUNTIME)),
|
("__CUDACC__", ("__HIPCC__", CONV_DEF, API_RUNTIME)),
|
||||||
|
|
@ -8370,6 +8371,7 @@ CUDA_IDENTIFIER_MAP = collections.OrderedDict(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore [no-matching-overload]
|
||||||
CUDA_SPECIAL_MAP = collections.OrderedDict(
|
CUDA_SPECIAL_MAP = collections.OrderedDict(
|
||||||
[
|
[
|
||||||
# SPARSE
|
# SPARSE
|
||||||
|
|
@ -8852,6 +8854,7 @@ CUDA_SPECIAL_MAP = collections.OrderedDict(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore [no-matching-overload]
|
||||||
PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict(
|
PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict(
|
||||||
[
|
[
|
||||||
("USE_CUDA", ("USE_ROCM", API_PYTORCH)),
|
("USE_CUDA", ("USE_ROCM", API_PYTORCH)),
|
||||||
|
|
@ -9316,6 +9319,7 @@ PYTORCH_SPECIFIC_MAPPINGS = collections.OrderedDict(
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore [no-matching-overload]
|
||||||
CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict(
|
CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict(
|
||||||
[
|
[
|
||||||
("PYTORCH_NO_CUDA_MEMORY_CACHING", ("PYTORCH_NO_CUDA_MEMORY_CACHING", API_CAFFE2)),
|
("PYTORCH_NO_CUDA_MEMORY_CACHING", ("PYTORCH_NO_CUDA_MEMORY_CACHING", API_CAFFE2)),
|
||||||
|
|
@ -9401,6 +9405,7 @@ CAFFE2_SPECIFIC_MAPPINGS = collections.OrderedDict(
|
||||||
#
|
#
|
||||||
# NB: if you want a transformation to ONLY apply to the c10/ directory,
|
# NB: if you want a transformation to ONLY apply to the c10/ directory,
|
||||||
# put it as API_CAFFE2
|
# put it as API_CAFFE2
|
||||||
|
# pyrefly: ignore [no-matching-overload]
|
||||||
C10_MAPPINGS = collections.OrderedDict(
|
C10_MAPPINGS = collections.OrderedDict(
|
||||||
[
|
[
|
||||||
("CUDA_VERSION", ("TORCH_HIP_VERSION", API_PYTORCH)),
|
("CUDA_VERSION", ("TORCH_HIP_VERSION", API_PYTORCH)),
|
||||||
|
|
|
||||||
|
|
@ -120,6 +120,7 @@ class GeneratedFileCleaner:
|
||||||
def open(self, fn, *args, **kwargs):
|
def open(self, fn, *args, **kwargs):
|
||||||
if not os.path.exists(fn):
|
if not os.path.exists(fn):
|
||||||
self.files_to_clean.add(os.path.abspath(fn))
|
self.files_to_clean.add(os.path.abspath(fn))
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
return open(fn, *args, **kwargs)
|
return open(fn, *args, **kwargs)
|
||||||
|
|
||||||
def makedirs(self, dn, exist_ok=False):
|
def makedirs(self, dn, exist_ok=False):
|
||||||
|
|
@ -669,7 +670,7 @@ def is_caffe2_gpu_file(rel_filepath):
|
||||||
return True
|
return True
|
||||||
filename = os.path.basename(rel_filepath)
|
filename = os.path.basename(rel_filepath)
|
||||||
_, ext = os.path.splitext(filename)
|
_, ext = os.path.splitext(filename)
|
||||||
# pyrefly: ignore # unsupported-operation
|
# pyrefly: ignore [unsupported-operation]
|
||||||
return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename)
|
return ('gpu' in filename or ext in ['.cu', '.cuh']) and ('cudnn' not in filename)
|
||||||
|
|
||||||
class TrieNode:
|
class TrieNode:
|
||||||
|
|
@ -1145,7 +1146,7 @@ def hipify(
|
||||||
out_of_place_only=out_of_place_only,
|
out_of_place_only=out_of_place_only,
|
||||||
is_pytorch_extension=is_pytorch_extension))
|
is_pytorch_extension=is_pytorch_extension))
|
||||||
all_files_set = set(all_files)
|
all_files_set = set(all_files)
|
||||||
# pyrefly: ignore # bad-assignment
|
# pyrefly: ignore [bad-assignment]
|
||||||
for f in extra_files:
|
for f in extra_files:
|
||||||
if not os.path.isabs(f):
|
if not os.path.isabs(f):
|
||||||
f = os.path.join(output_directory, f)
|
f = os.path.join(output_directory, f)
|
||||||
|
|
|
||||||
|
|
@ -292,9 +292,10 @@ class WeakIdKeyDictionary(MutableMapping):
|
||||||
if o is not None:
|
if o is not None:
|
||||||
return o, value
|
return o, value
|
||||||
|
|
||||||
# pyrefly: ignore # bad-override
|
# pyrefly: ignore [bad-override]
|
||||||
def pop(self, key, *args):
|
def pop(self, key, *args):
|
||||||
self._dirty_len = True
|
self._dirty_len = True
|
||||||
|
# pyrefly: ignore [not-iterable]
|
||||||
return self.data.pop(self.ref_type(key), *args) # CHANGED
|
return self.data.pop(self.ref_type(key), *args) # CHANGED
|
||||||
|
|
||||||
def setdefault(self, key, default=None):
|
def setdefault(self, key, default=None):
|
||||||
|
|
|
||||||
|
|
@ -328,7 +328,7 @@ class StreamContext:
|
||||||
self.stream = stream
|
self.stream = stream
|
||||||
self.idx = _get_device_index(None, True)
|
self.idx = _get_device_index(None, True)
|
||||||
if self.idx is None:
|
if self.idx is None:
|
||||||
self.idx = -1 # pyrefly: ignore # bad-assignment
|
self.idx = -1 # pyrefly: ignore [bad-assignment]
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
cur_stream = self.stream
|
cur_stream = self.stream
|
||||||
|
|
|
||||||
|
|
@ -126,7 +126,7 @@ class Event(torch._C._XpuEventBase):
|
||||||
"""
|
"""
|
||||||
if stream is None:
|
if stream is None:
|
||||||
stream = torch.xpu.current_stream()
|
stream = torch.xpu.current_stream()
|
||||||
super().record(stream) # pyrefly: ignore # bad-argument-type
|
super().record(stream) # pyrefly: ignore [bad-argument-type]
|
||||||
|
|
||||||
def wait(self, stream=None) -> None:
|
def wait(self, stream=None) -> None:
|
||||||
r"""Make all future work submitted to the given stream wait for this event.
|
r"""Make all future work submitted to the given stream wait for this event.
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user