mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add pyrefly suppressions (3/n) (#164588)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283 Test plan: dmypy restart && python3 scripts/lintrunner.py -a pyrefly check step 1: uncomment lines in the pyrefly.toml file step 2: run pyrefly check step 3: add suppressions, clean up unused suppressions before: https://gist.github.com/maggiemoss/bb31574ac8a59893c9cf52189e67bb2d after: 0 errors (1,970 ignored) Pull Request resolved: https://github.com/pytorch/pytorch/pull/164588 Approved by: https://github.com/oulgen
This commit is contained in:
parent
e438db2546
commit
f414aa8e0d
|
|
@ -34,12 +34,6 @@ project-excludes = [
|
||||||
"torch/jit/**",
|
"torch/jit/**",
|
||||||
"torch/optim/**",
|
"torch/optim/**",
|
||||||
"torch/_higher_order_ops/**",
|
"torch/_higher_order_ops/**",
|
||||||
"torch/_functorch/**",
|
|
||||||
"torch/masked/**",
|
|
||||||
"torch/_subclasses/**",
|
|
||||||
"torch/autograd/**",
|
|
||||||
"torch/cuda/**",
|
|
||||||
"torch/export/**",
|
|
||||||
# formatting issues
|
# formatting issues
|
||||||
"torch/linalg/__init__.py",
|
"torch/linalg/__init__.py",
|
||||||
"torch/package/importer.py",
|
"torch/package/importer.py",
|
||||||
|
|
|
||||||
|
|
@ -58,20 +58,20 @@ class TestBundledInputs(TestCase):
|
||||||
# Make sure the model only grew a little bit,
|
# Make sure the model only grew a little bit,
|
||||||
# despite having nominally large bundled inputs.
|
# despite having nominally large bundled inputs.
|
||||||
augmented_size = model_size(sm)
|
augmented_size = model_size(sm)
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.assertLess(augmented_size, original_size + (1 << 12))
|
self.assertLess(augmented_size, original_size + (1 << 12))
|
||||||
|
|
||||||
loaded = save_and_load(sm)
|
loaded = save_and_load(sm)
|
||||||
inflated = loaded.get_all_bundled_inputs()
|
inflated = loaded.get_all_bundled_inputs()
|
||||||
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
|
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
|
||||||
self.assertEqual(len(inflated), len(samples))
|
self.assertEqual(len(inflated), len(samples))
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
|
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
|
||||||
|
|
||||||
for idx, inp in enumerate(inflated):
|
for idx, inp in enumerate(inflated):
|
||||||
self.assertIsInstance(inp, tuple) # pyrefly: ignore # missing-attribute
|
self.assertIsInstance(inp, tuple)
|
||||||
self.assertEqual(len(inp), 1)
|
self.assertEqual(len(inp), 1)
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.assertIsInstance(inp[0], torch.Tensor)
|
self.assertIsInstance(inp[0], torch.Tensor)
|
||||||
if idx != 5:
|
if idx != 5:
|
||||||
# Strides might be important for benchmarking.
|
# Strides might be important for benchmarking.
|
||||||
|
|
@ -139,7 +139,7 @@ class TestBundledInputs(TestCase):
|
||||||
loaded = save_and_load(sm)
|
loaded = save_and_load(sm)
|
||||||
inflated = loaded.get_all_bundled_inputs()
|
inflated = loaded.get_all_bundled_inputs()
|
||||||
self.assertEqual(inflated, samples)
|
self.assertEqual(inflated, samples)
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.assertTrue(loaded(*inflated[0]) == "first 1")
|
self.assertTrue(loaded(*inflated[0]) == "first 1")
|
||||||
|
|
||||||
def test_multiple_methods_with_inputs(self):
|
def test_multiple_methods_with_inputs(self):
|
||||||
|
|
@ -186,7 +186,7 @@ class TestBundledInputs(TestCase):
|
||||||
self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_foo())
|
self.assertEqual(inflated, loaded.get_all_bundled_inputs_for_foo())
|
||||||
|
|
||||||
# Check running and size helpers
|
# Check running and size helpers
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
|
self.assertTrue(loaded(*inflated[0]) is inflated[0][0])
|
||||||
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
|
self.assertEqual(loaded.get_num_bundled_inputs(), len(samples))
|
||||||
|
|
||||||
|
|
@ -419,7 +419,7 @@ class TestBundledInputs(TestCase):
|
||||||
)
|
)
|
||||||
augmented_size = model_size(sm)
|
augmented_size = model_size(sm)
|
||||||
# assert the size has not increased more than 8KB
|
# assert the size has not increased more than 8KB
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.assertLess(augmented_size, original_size + (1 << 13))
|
self.assertLess(augmented_size, original_size + (1 << 13))
|
||||||
|
|
||||||
loaded = save_and_load(sm)
|
loaded = save_and_load(sm)
|
||||||
|
|
|
||||||
|
|
@ -48,7 +48,7 @@ class TestComplexTensor(TestCase):
|
||||||
def test_all(self, device, dtype):
|
def test_all(self, device, dtype):
|
||||||
# issue: https://github.com/pytorch/pytorch/issues/120875
|
# issue: https://github.com/pytorch/pytorch/issues/120875
|
||||||
x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype)
|
x = torch.tensor([1 + 2j, 3 - 4j, 5j, 6], device=device, dtype=dtype)
|
||||||
self.assertTrue(torch.all(x)) # pyrefly: ignore # missing-attribute
|
self.assertTrue(torch.all(x))
|
||||||
|
|
||||||
@dtypes(*complex_types())
|
@dtypes(*complex_types())
|
||||||
def test_any(self, device, dtype):
|
def test_any(self, device, dtype):
|
||||||
|
|
@ -56,7 +56,7 @@ class TestComplexTensor(TestCase):
|
||||||
x = torch.tensor(
|
x = torch.tensor(
|
||||||
[0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype
|
[0, 0j, -0 + 0j, -0 - 0j, 0 + 0j, 0 - 0j], device=device, dtype=dtype
|
||||||
)
|
)
|
||||||
self.assertFalse(torch.any(x)) # pyrefly: ignore # missing-attribute
|
self.assertFalse(torch.any(x))
|
||||||
|
|
||||||
@onlyCPU
|
@onlyCPU
|
||||||
@dtypes(*complex_types())
|
@dtypes(*complex_types())
|
||||||
|
|
|
||||||
|
|
@ -142,7 +142,6 @@ class TestTypeHints(TestCase):
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
if result != 0:
|
if result != 0:
|
||||||
# pyrefly: ignore # missing-attribute
|
|
||||||
self.fail(f"mypy failed:\n{stderr}\n{stdout}")
|
self.fail(f"mypy failed:\n{stderr}\n{stdout}")
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -125,7 +125,7 @@ class TestDTypeInfo(TestCase):
|
||||||
# Regression test for https://github.com/pytorch/pytorch/issues/124868
|
# Regression test for https://github.com/pytorch/pytorch/issues/124868
|
||||||
# If reference count is leaked this would be a set of 10 elements
|
# If reference count is leaked this would be a set of 10 elements
|
||||||
ref_cnt = {sys.getrefcount(torch.float32.to_complex()) for _ in range(10)}
|
ref_cnt = {sys.getrefcount(torch.float32.to_complex()) for _ in range(10)}
|
||||||
self.assertLess(len(ref_cnt), 3) # pyrefly: ignore # missing-attribute
|
self.assertLess(len(ref_cnt), 3)
|
||||||
|
|
||||||
self.assertEqual(torch.float64.to_complex(), torch.complex128)
|
self.assertEqual(torch.float64.to_complex(), torch.complex128)
|
||||||
self.assertEqual(torch.float32.to_complex(), torch.complex64)
|
self.assertEqual(torch.float32.to_complex(), torch.complex64)
|
||||||
|
|
@ -135,7 +135,7 @@ class TestDTypeInfo(TestCase):
|
||||||
# Regression test for https://github.com/pytorch/pytorch/issues/124868
|
# Regression test for https://github.com/pytorch/pytorch/issues/124868
|
||||||
# If reference count is leaked this would be a set of 10 elements
|
# If reference count is leaked this would be a set of 10 elements
|
||||||
ref_cnt = {sys.getrefcount(torch.cfloat.to_real()) for _ in range(10)}
|
ref_cnt = {sys.getrefcount(torch.cfloat.to_real()) for _ in range(10)}
|
||||||
self.assertLess(len(ref_cnt), 3) # pyrefly: ignore # missing-attribute
|
self.assertLess(len(ref_cnt), 3)
|
||||||
|
|
||||||
self.assertEqual(torch.complex128.to_real(), torch.double)
|
self.assertEqual(torch.complex128.to_real(), torch.double)
|
||||||
self.assertEqual(torch.complex64.to_real(), torch.float32)
|
self.assertEqual(torch.complex64.to_real(), torch.float32)
|
||||||
|
|
|
||||||
|
|
@ -2653,6 +2653,7 @@ def compile(
|
||||||
dynamic=dynamic,
|
dynamic=dynamic,
|
||||||
disable=disable,
|
disable=disable,
|
||||||
guard_filter_fn=guard_filter_fn,
|
guard_filter_fn=guard_filter_fn,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
)(model)(*args, **kwargs)
|
)(model)(*args, **kwargs)
|
||||||
|
|
||||||
return export_wrapped_fn
|
return export_wrapped_fn
|
||||||
|
|
|
||||||
|
|
@ -384,6 +384,7 @@ class AOTAutogradCacheDetails(FxGraphHashDetails):
|
||||||
class AOTAutogradCachePickler(FxGraphCachePickler):
|
class AOTAutogradCachePickler(FxGraphCachePickler):
|
||||||
def __init__(self, gm: torch.fx.GraphModule):
|
def __init__(self, gm: torch.fx.GraphModule):
|
||||||
super().__init__(gm)
|
super().__init__(gm)
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
self.dispatch_table: dict
|
self.dispatch_table: dict
|
||||||
self.dispatch_table.update(
|
self.dispatch_table.update(
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -86,8 +86,10 @@ def coerce_tangent_and_suggest_memory_format(x: Tensor):
|
||||||
|
|
||||||
memory_format = MemoryFormatMeta.from_tensor(out)
|
memory_format = MemoryFormatMeta.from_tensor(out)
|
||||||
|
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if memory_format.memory_format is not None:
|
if memory_format.memory_format is not None:
|
||||||
was = out
|
was = out
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
out = out.contiguous(memory_format=memory_format.memory_format)
|
out = out.contiguous(memory_format=memory_format.memory_format)
|
||||||
updated = was is not out
|
updated = was is not out
|
||||||
|
|
||||||
|
|
@ -117,6 +119,7 @@ def coerce_tangent_and_suggest_memory_format(x: Tensor):
|
||||||
out = out.__coerce_tangent_metadata__() # type: ignore[attr-defined]
|
out = out.__coerce_tangent_metadata__() # type: ignore[attr-defined]
|
||||||
|
|
||||||
if is_subclass:
|
if is_subclass:
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
attrs = out.__tensor_flatten__()[0]
|
attrs = out.__tensor_flatten__()[0]
|
||||||
|
|
||||||
for attr in attrs:
|
for attr in attrs:
|
||||||
|
|
@ -126,6 +129,7 @@ def coerce_tangent_and_suggest_memory_format(x: Tensor):
|
||||||
new_elem_memory_format,
|
new_elem_memory_format,
|
||||||
elem_updated,
|
elem_updated,
|
||||||
) = coerce_tangent_and_suggest_memory_format(elem)
|
) = coerce_tangent_and_suggest_memory_format(elem)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
out_memory_format.append(new_elem_memory_format)
|
out_memory_format.append(new_elem_memory_format)
|
||||||
if elem_updated:
|
if elem_updated:
|
||||||
setattr(out, attr, new_elem)
|
setattr(out, attr, new_elem)
|
||||||
|
|
@ -492,6 +496,7 @@ def run_functionalized_fw_and_collect_metadata(
|
||||||
curr_storage in inp_storage_refs
|
curr_storage in inp_storage_refs
|
||||||
and not functional_tensor_storage_changed
|
and not functional_tensor_storage_changed
|
||||||
):
|
):
|
||||||
|
# pyrefly: ignore # index-error
|
||||||
base_idx = inp_storage_refs[curr_storage]
|
base_idx = inp_storage_refs[curr_storage]
|
||||||
is_input_tensor = id(o) in inp_tensor_ids
|
is_input_tensor = id(o) in inp_tensor_ids
|
||||||
num_aliased_outs = out_tensor_alias_counts[curr_storage]
|
num_aliased_outs = out_tensor_alias_counts[curr_storage]
|
||||||
|
|
@ -699,6 +704,7 @@ from a multi-output view call"
|
||||||
# Anything that aliases (inputs returned in the fw due to metadata mutations, or outputs that alias inputs/intermediates)
|
# Anything that aliases (inputs returned in the fw due to metadata mutations, or outputs that alias inputs/intermediates)
|
||||||
# are *regenerated* later, and not used directly in the autograd graph
|
# are *regenerated* later, and not used directly in the autograd graph
|
||||||
def _plain_fake_tensor_like_subclass(x):
|
def _plain_fake_tensor_like_subclass(x):
|
||||||
|
# pyrefly: ignore # bad-context-manager
|
||||||
with detect_fake_mode():
|
with detect_fake_mode():
|
||||||
return torch.empty(
|
return torch.empty(
|
||||||
x.shape, dtype=x.dtype, device=x.device, layout=x.layout
|
x.shape, dtype=x.dtype, device=x.device, layout=x.layout
|
||||||
|
|
|
||||||
|
|
@ -78,6 +78,7 @@ def get_all_input_and_grad_nodes(
|
||||||
continue
|
continue
|
||||||
if isinstance(desc, SubclassGetAttrAOTInput):
|
if isinstance(desc, SubclassGetAttrAOTInput):
|
||||||
_raise_autograd_subclass_not_implemented(n, desc)
|
_raise_autograd_subclass_not_implemented(n, desc)
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
input_index[desc] = (n, None)
|
input_index[desc] = (n, None)
|
||||||
elif n.op == "output":
|
elif n.op == "output":
|
||||||
assert "desc" in n.meta, (n, n.meta)
|
assert "desc" in n.meta, (n, n.meta)
|
||||||
|
|
@ -129,6 +130,7 @@ def get_all_output_and_tangent_nodes(
|
||||||
continue
|
continue
|
||||||
if isinstance(sub_d, SubclassGetAttrAOTOutput):
|
if isinstance(sub_d, SubclassGetAttrAOTOutput):
|
||||||
_raise_autograd_subclass_not_implemented(sub_n, sub_d)
|
_raise_autograd_subclass_not_implemented(sub_n, sub_d)
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
output_index[sub_d] = (sub_n, None)
|
output_index[sub_d] = (sub_n, None)
|
||||||
for n in g.nodes:
|
for n in g.nodes:
|
||||||
if n.op == "placeholder":
|
if n.op == "placeholder":
|
||||||
|
|
|
||||||
|
|
@ -1305,10 +1305,12 @@ def aot_dispatch_subclass(
|
||||||
# See Note: [Partitioner handling for Subclasses, Part 2] for more info.
|
# See Note: [Partitioner handling for Subclasses, Part 2] for more info.
|
||||||
meta_updated = run_functionalized_fw_and_collect_metadata(
|
meta_updated = run_functionalized_fw_and_collect_metadata(
|
||||||
without_output_descs(metadata_fn),
|
without_output_descs(metadata_fn),
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
flat_args_descs=primals_unwrapped_descs,
|
flat_args_descs=primals_unwrapped_descs,
|
||||||
static_input_indices=remapped_static_indices,
|
static_input_indices=remapped_static_indices,
|
||||||
keep_input_mutations=meta.keep_input_mutations,
|
keep_input_mutations=meta.keep_input_mutations,
|
||||||
is_train=meta.is_train,
|
is_train=meta.is_train,
|
||||||
|
# pyrefly: ignore # not-iterable
|
||||||
)(*primals_unwrapped)
|
)(*primals_unwrapped)
|
||||||
|
|
||||||
subclass_meta.fw_metadata = meta_updated
|
subclass_meta.fw_metadata = meta_updated
|
||||||
|
|
|
||||||
|
|
@ -425,6 +425,7 @@ def collect_fw_donated_buffer_idxs(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
storage_refs = set()
|
storage_refs = set()
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
for t in itertools.chain(fw_ins, user_fw_outs, bw_outs):
|
for t in itertools.chain(fw_ins, user_fw_outs, bw_outs):
|
||||||
# Only access storage if a tensor has storage (not sparse)
|
# Only access storage if a tensor has storage (not sparse)
|
||||||
if t is not None and isinstance(t, FakeTensor) and not is_sparse_any(t):
|
if t is not None and isinstance(t, FakeTensor) and not is_sparse_any(t):
|
||||||
|
|
@ -494,6 +495,7 @@ def collect_bw_donated_buffer_idxs(
|
||||||
fw_ins,
|
fw_ins,
|
||||||
user_fw_outs,
|
user_fw_outs,
|
||||||
bw_outs,
|
bw_outs,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
saved_tensors,
|
saved_tensors,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1762,6 +1764,7 @@ def aot_stage2_autograd(
|
||||||
# (2408448, 1, 21504, 192). The solution mentioned will
|
# (2408448, 1, 21504, 192). The solution mentioned will
|
||||||
# decide a stride of (802816, 1, 7168, 64) for this
|
# decide a stride of (802816, 1, 7168, 64) for this
|
||||||
# tensor which is wrong.
|
# tensor which is wrong.
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
placeholder_list[i] = ph_arg.as_strided(ph_arg.size(), real_stride)
|
placeholder_list[i] = ph_arg.as_strided(ph_arg.size(), real_stride)
|
||||||
|
|
||||||
compiled_bw_func = None
|
compiled_bw_func = None
|
||||||
|
|
|
||||||
|
|
@ -225,6 +225,7 @@ def make_output_handler(info, runtime_metadata, trace_joint):
|
||||||
# not sure why AOTDispatcher needs to manually set this
|
# not sure why AOTDispatcher needs to manually set this
|
||||||
def maybe_mark_dynamic_helper(t: torch.Tensor, dims: set[int]):
|
def maybe_mark_dynamic_helper(t: torch.Tensor, dims: set[int]):
|
||||||
if hasattr(t, "_dynamo_weak_dynamic_indices"):
|
if hasattr(t, "_dynamo_weak_dynamic_indices"):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
t._dynamo_weak_dynamic_indices |= dims
|
t._dynamo_weak_dynamic_indices |= dims
|
||||||
else:
|
else:
|
||||||
t._dynamo_weak_dynamic_indices = dims.copy() # type: ignore[attr-defined]
|
t._dynamo_weak_dynamic_indices = dims.copy() # type: ignore[attr-defined]
|
||||||
|
|
@ -1142,6 +1143,7 @@ class AOTSyntheticBaseWrapper(CompilerWrapper):
|
||||||
|
|
||||||
def _unpack_synthetic_bases(primals: tuple[Any, ...]) -> list[Any]:
|
def _unpack_synthetic_bases(primals: tuple[Any, ...]) -> list[Any]:
|
||||||
f_args_inner = []
|
f_args_inner = []
|
||||||
|
# pyrefly: ignore # not-iterable
|
||||||
for inner_idx_or_tuple in synthetic_base_info:
|
for inner_idx_or_tuple in synthetic_base_info:
|
||||||
if isinstance(inner_idx_or_tuple, int):
|
if isinstance(inner_idx_or_tuple, int):
|
||||||
f_args_inner.append(primals[inner_idx_or_tuple])
|
f_args_inner.append(primals[inner_idx_or_tuple])
|
||||||
|
|
@ -2112,6 +2114,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||||
return (ctx._autograd_function_id, *ctx.symints)
|
return (ctx._autograd_function_id, *ctx.symints)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, *deduped_flat_tensor_args):
|
def forward(ctx, *deduped_flat_tensor_args):
|
||||||
args = deduped_flat_tensor_args
|
args = deduped_flat_tensor_args
|
||||||
if backward_state_indices:
|
if backward_state_indices:
|
||||||
|
|
@ -2148,6 +2151,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||||
# in the fw output order.
|
# in the fw output order.
|
||||||
fw_outs = call_func_at_runtime_with_args(
|
fw_outs = call_func_at_runtime_with_args(
|
||||||
CompiledFunction.compiled_fw,
|
CompiledFunction.compiled_fw,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
args,
|
args,
|
||||||
disable_amp=disable_amp,
|
disable_amp=disable_amp,
|
||||||
)
|
)
|
||||||
|
|
@ -2343,6 +2347,7 @@ To fix this, your tensor subclass must implement the dunder method __force_to_sa
|
||||||
_aot_id = aot_config.aot_id
|
_aot_id = aot_config.aot_id
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(double_ctx, *unused_args):
|
def forward(double_ctx, *unused_args):
|
||||||
return impl_fn(double_ctx)
|
return impl_fn(double_ctx)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1231,7 +1231,9 @@ class SerializableAOTDispatchCompiler(AOTDispatchCompiler):
|
||||||
output_code_ty: type[TOutputCode],
|
output_code_ty: type[TOutputCode],
|
||||||
compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode],
|
compiler_fn: Callable[[torch.fx.GraphModule, Sequence[InputType]], TOutputCode],
|
||||||
):
|
):
|
||||||
|
# pyrefly: ignore # invalid-type-var
|
||||||
self.output_code_ty = output_code_ty
|
self.output_code_ty = output_code_ty
|
||||||
|
# pyrefly: ignore # invalid-type-var
|
||||||
self.compiler_fn = compiler_fn
|
self.compiler_fn = compiler_fn
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
|
|
|
||||||
|
|
@ -90,6 +90,7 @@ def unwrap_tensor_subclass_parameters(module: torch.nn.Module) -> torch.nn.Modul
|
||||||
"""
|
"""
|
||||||
for name, tensor in itertools.chain(
|
for name, tensor in itertools.chain(
|
||||||
list(module.named_parameters(recurse=False)),
|
list(module.named_parameters(recurse=False)),
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
list(module.named_buffers(recurse=False)),
|
list(module.named_buffers(recurse=False)),
|
||||||
):
|
):
|
||||||
if is_traceable_wrapper_subclass(tensor):
|
if is_traceable_wrapper_subclass(tensor):
|
||||||
|
|
|
||||||
|
|
@ -232,11 +232,13 @@ def unwrap_tensor_subclasses(
|
||||||
|
|
||||||
attrs, _ = t.__tensor_flatten__()
|
attrs, _ = t.__tensor_flatten__()
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
for attr in attrs:
|
for attr in attrs:
|
||||||
inner_tensor = getattr(t, attr)
|
inner_tensor = getattr(t, attr)
|
||||||
n_desc: Any = (
|
n_desc: Any = (
|
||||||
SubclassGetAttrAOTInput(desc, attr)
|
SubclassGetAttrAOTInput(desc, attr)
|
||||||
if isinstance(desc, AOTInput)
|
if isinstance(desc, AOTInput)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
else SubclassGetAttrAOTOutput(desc, attr)
|
else SubclassGetAttrAOTOutput(desc, attr)
|
||||||
)
|
)
|
||||||
flatten_subclass(inner_tensor, n_desc, out=out)
|
flatten_subclass(inner_tensor, n_desc, out=out)
|
||||||
|
|
@ -257,6 +259,7 @@ def unwrap_tensor_subclasses(
|
||||||
descs_inner: list[AOTDescriptor] = []
|
descs_inner: list[AOTDescriptor] = []
|
||||||
|
|
||||||
for x, desc in zip(wrapped_args, wrapped_args_descs):
|
for x, desc in zip(wrapped_args, wrapped_args_descs):
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
flatten_subclass(typing.cast(Tensor, x), desc, out=(xs_inner, descs_inner))
|
flatten_subclass(typing.cast(Tensor, x), desc, out=(xs_inner, descs_inner))
|
||||||
|
|
||||||
return xs_inner, descs_inner
|
return xs_inner, descs_inner
|
||||||
|
|
@ -281,6 +284,7 @@ def runtime_unwrap_tensor_subclasses(
|
||||||
|
|
||||||
for attr in attrs:
|
for attr in attrs:
|
||||||
inner_tensor = getattr(x, attr)
|
inner_tensor = getattr(x, attr)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
inner_meta = meta.attrs.get(attr)
|
inner_meta = meta.attrs.get(attr)
|
||||||
flatten_subclass(inner_tensor, inner_meta, out=out)
|
flatten_subclass(inner_tensor, inner_meta, out=out)
|
||||||
|
|
||||||
|
|
@ -310,6 +314,7 @@ def runtime_unwrap_tensor_subclasses(
|
||||||
|
|
||||||
for idx, x in enumerate(wrapped_args):
|
for idx, x in enumerate(wrapped_args):
|
||||||
if not is_traceable_wrapper_subclass(x):
|
if not is_traceable_wrapper_subclass(x):
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
xs_inner.append(x)
|
xs_inner.append(x)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -328,6 +328,7 @@ def unlift_tokens(fw_module, fw_metadata, aot_config, bw_module=None):
|
||||||
and out.args[1] == 0
|
and out.args[1] == 0
|
||||||
and out.args[0] in with_effect_nodes
|
and out.args[0] in with_effect_nodes
|
||||||
):
|
):
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
output_token_nodes.append(out)
|
output_token_nodes.append(out)
|
||||||
else:
|
else:
|
||||||
other_output_nodes.append(out)
|
other_output_nodes.append(out)
|
||||||
|
|
@ -529,8 +530,10 @@ def without_output_descs(f: Callable[_P, tuple[_T, _S]]) -> Callable[_P, _T]:
|
||||||
@wraps(f)
|
@wraps(f)
|
||||||
@simple_wraps(f)
|
@simple_wraps(f)
|
||||||
def inner(*args, **kwargs):
|
def inner(*args, **kwargs):
|
||||||
|
# pyrefly: ignore # invalid-param-spec
|
||||||
return f(*args, **kwargs)[0]
|
return f(*args, **kwargs)[0]
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -753,6 +753,7 @@ class AutogradFunctionApply(HigherOrderOperator):
|
||||||
|
|
||||||
class ApplyTemplate(torch.autograd.Function):
|
class ApplyTemplate(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, *args):
|
def forward(ctx, *args):
|
||||||
nonlocal saved_values
|
nonlocal saved_values
|
||||||
output, saved_values = fwd(None, *fwd_args)
|
output, saved_values = fwd(None, *fwd_args)
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,9 @@ def create_names_map(
|
||||||
This function creates a mapping from the names in named_params to the
|
This function creates a mapping from the names in named_params to the
|
||||||
names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}.
|
names in tied_named_params: {'A': ['A'], 'B': ['B', 'B_tied']}.
|
||||||
"""
|
"""
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
named_params = dict(named_params)
|
named_params = dict(named_params)
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
tied_named_params = dict(tied_named_params)
|
tied_named_params = dict(tied_named_params)
|
||||||
|
|
||||||
tensors_dict_keys = set(named_params.keys())
|
tensors_dict_keys = set(named_params.keys())
|
||||||
|
|
@ -51,9 +53,11 @@ def create_names_map(
|
||||||
|
|
||||||
tensor_to_mapping: dict[Tensor, tuple[str, list[str]]] = {}
|
tensor_to_mapping: dict[Tensor, tuple[str, list[str]]] = {}
|
||||||
for key, tensor in named_params.items():
|
for key, tensor in named_params.items():
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
tensor_to_mapping[tensor] = (key, [])
|
tensor_to_mapping[tensor] = (key, [])
|
||||||
for key, tensor in tied_named_params.items():
|
for key, tensor in tied_named_params.items():
|
||||||
assert tensor in tensor_to_mapping
|
assert tensor in tensor_to_mapping
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
tensor_to_mapping[tensor][1].append(key)
|
tensor_to_mapping[tensor][1].append(key)
|
||||||
return dict(tensor_to_mapping.values())
|
return dict(tensor_to_mapping.values())
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1174,6 +1174,7 @@ def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
|
||||||
# critical path first.
|
# critical path first.
|
||||||
cur_nodes += node.all_input_nodes
|
cur_nodes += node.all_input_nodes
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
insertable_nodes = sorted(insertable_nodes, key=lambda n: order[n])
|
insertable_nodes = sorted(insertable_nodes, key=lambda n: order[n])
|
||||||
for node in insertable_nodes:
|
for node in insertable_nodes:
|
||||||
env[node] = new_graph.node_copy(node, lambda x: env[x])
|
env[node] = new_graph.node_copy(node, lambda x: env[x])
|
||||||
|
|
@ -2849,6 +2850,7 @@ def min_cut_rematerialization_partition(
|
||||||
fw_module, bw_module = _extract_fwd_bwd_modules(
|
fw_module, bw_module = _extract_fwd_bwd_modules(
|
||||||
joint_module,
|
joint_module,
|
||||||
saved_values,
|
saved_values,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
saved_sym_nodes=saved_sym_nodes,
|
saved_sym_nodes=saved_sym_nodes,
|
||||||
num_fwd_outputs=num_fwd_outputs,
|
num_fwd_outputs=num_fwd_outputs,
|
||||||
static_lifetime_input_nodes=node_info.static_lifetime_input_nodes,
|
static_lifetime_input_nodes=node_info.static_lifetime_input_nodes,
|
||||||
|
|
|
||||||
|
|
@ -131,6 +131,7 @@ class VmapInterpreter(FuncTorchInterpreter):
|
||||||
self._cdata = cdata
|
self._cdata = cdata
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def _cptr(self):
|
def _cptr(self):
|
||||||
return CVmapInterpreterPtr(self._cdata)
|
return CVmapInterpreterPtr(self._cdata)
|
||||||
|
|
||||||
|
|
@ -170,6 +171,7 @@ class GradInterpreter(FuncTorchInterpreter):
|
||||||
self._cdata = cdata
|
self._cdata = cdata
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def _cptr(self):
|
def _cptr(self):
|
||||||
return CGradInterpreterPtr(self._cdata)
|
return CGradInterpreterPtr(self._cdata)
|
||||||
|
|
||||||
|
|
@ -207,6 +209,7 @@ class JvpInterpreter(FuncTorchInterpreter):
|
||||||
self._cdata = cdata
|
self._cdata = cdata
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def _cptr(self):
|
def _cptr(self):
|
||||||
return CJvpInterpreterPtr(self._cdata)
|
return CJvpInterpreterPtr(self._cdata)
|
||||||
|
|
||||||
|
|
@ -243,6 +246,7 @@ class FunctionalizeInterpreter(FuncTorchInterpreter):
|
||||||
self._cdata = cdata
|
self._cdata = cdata
|
||||||
|
|
||||||
@cached_property
|
@cached_property
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def _cptr(self):
|
def _cptr(self):
|
||||||
return CFunctionalizeInterpreterPtr(self._cdata)
|
return CFunctionalizeInterpreterPtr(self._cdata)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -279,7 +279,6 @@ def out_wrapper(
|
||||||
TensorLikeType
|
TensorLikeType
|
||||||
if is_tensor
|
if is_tensor
|
||||||
else NamedTuple(
|
else NamedTuple(
|
||||||
# pyrefly: ignore # bad-argument-count
|
|
||||||
f"return_types_{fn.__name__}",
|
f"return_types_{fn.__name__}",
|
||||||
# pyrefly: ignore # bad-argument-count
|
# pyrefly: ignore # bad-argument-count
|
||||||
[(o, TensorLikeType) for o in out_names],
|
[(o, TensorLikeType) for o in out_names],
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,12 @@ class _DeconstructedSymNode:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_node(node: SymNode) -> _DeconstructedSymNode:
|
def from_node(node: SymNode) -> _DeconstructedSymNode:
|
||||||
return _DeconstructedSymNode(
|
return _DeconstructedSymNode(
|
||||||
node._expr, node.pytype, node._hint, node.constant, node.fx_node
|
node._expr,
|
||||||
|
node.pytype,
|
||||||
|
node._hint,
|
||||||
|
node.constant,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
node.fx_node,
|
||||||
)
|
)
|
||||||
|
|
||||||
def extract(self, shape_env: ShapeEnv) -> SymNode:
|
def extract(self, shape_env: ShapeEnv) -> SymNode:
|
||||||
|
|
|
||||||
|
|
@ -404,7 +404,9 @@ class FakeTensorConverter:
|
||||||
with no_dispatch():
|
with no_dispatch():
|
||||||
return FakeTensor(
|
return FakeTensor(
|
||||||
fake_mode,
|
fake_mode,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
make_meta_t(),
|
make_meta_t(),
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
device,
|
device,
|
||||||
# TODO: callback might be used in recursive contexts, in
|
# TODO: callback might be used in recursive contexts, in
|
||||||
# which case using t is wrong! BUG!
|
# which case using t is wrong! BUG!
|
||||||
|
|
@ -679,6 +681,7 @@ class FakeTensor(Tensor):
|
||||||
_mode_key = torch._C._TorchDispatchModeKey.FAKE
|
_mode_key = torch._C._TorchDispatchModeKey.FAKE
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def device(self) -> torch.device:
|
def device(self) -> torch.device:
|
||||||
if self.fake_mode.in_kernel_invocation:
|
if self.fake_mode.in_kernel_invocation:
|
||||||
return torch.device("meta")
|
return torch.device("meta")
|
||||||
|
|
@ -706,6 +709,7 @@ class FakeTensor(Tensor):
|
||||||
|
|
||||||
# We don't support named tensors; graph break
|
# We don't support named tensors; graph break
|
||||||
@property
|
@property
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def names(self) -> list[str]:
|
def names(self) -> list[str]:
|
||||||
raise UnsupportedFakeTensorException(
|
raise UnsupportedFakeTensorException(
|
||||||
"torch.compile doesn't support named tensors"
|
"torch.compile doesn't support named tensors"
|
||||||
|
|
@ -764,6 +768,7 @@ class FakeTensor(Tensor):
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
device = torch.device(f"{device.type}:0")
|
device = torch.device(f"{device.type}:0")
|
||||||
|
# pyrefly: ignore # read-only
|
||||||
self.fake_device = device
|
self.fake_device = device
|
||||||
self.fake_mode = fake_mode
|
self.fake_mode = fake_mode
|
||||||
self.constant = constant
|
self.constant = constant
|
||||||
|
|
@ -1493,6 +1498,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||||
# Do this dispatch outside the above except handler so if it
|
# Do this dispatch outside the above except handler so if it
|
||||||
# generates its own exception there won't be a __context__ caused by
|
# generates its own exception there won't be a __context__ caused by
|
||||||
# the caching mechanism.
|
# the caching mechanism.
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return self._dispatch_impl(func, types, args, kwargs)
|
return self._dispatch_impl(func, types, args, kwargs)
|
||||||
|
|
||||||
assert state is not None
|
assert state is not None
|
||||||
|
|
@ -1510,22 +1516,27 @@ class FakeTensorMode(TorchDispatchMode):
|
||||||
# This represents a negative cache entry - we already saw that the
|
# This represents a negative cache entry - we already saw that the
|
||||||
# output is uncachable. Compute it from first principals.
|
# output is uncachable. Compute it from first principals.
|
||||||
FakeTensorMode.cache_bypasses[entry.reason] += 1
|
FakeTensorMode.cache_bypasses[entry.reason] += 1
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return self._dispatch_impl(func, types, args, kwargs)
|
return self._dispatch_impl(func, types, args, kwargs)
|
||||||
|
|
||||||
# We have a cache entry.
|
# We have a cache entry.
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
output = self._output_from_cache_entry(state, entry, key, func, args)
|
output = self._output_from_cache_entry(state, entry, key, func, args)
|
||||||
FakeTensorMode.cache_hits += 1
|
FakeTensorMode.cache_hits += 1
|
||||||
if self.cache_crosscheck_enabled:
|
if self.cache_crosscheck_enabled:
|
||||||
# For debugging / testing: Validate that the output synthesized
|
# For debugging / testing: Validate that the output synthesized
|
||||||
# from the cache matches the output created by normal dispatch.
|
# from the cache matches the output created by normal dispatch.
|
||||||
with disable_fake_tensor_cache(self):
|
with disable_fake_tensor_cache(self):
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self._crosscheck_cache_output(output, func, types, args, kwargs)
|
self._crosscheck_cache_output(output, func, types, args, kwargs)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
# We don't have a cache entry.
|
# We don't have a cache entry.
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
output = self._dispatch_impl(func, types, args, kwargs)
|
output = self._dispatch_impl(func, types, args, kwargs)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self._validate_cache_key(func, args, kwargs)
|
self._validate_cache_key(func, args, kwargs)
|
||||||
except _BypassDispatchCache as e:
|
except _BypassDispatchCache as e:
|
||||||
# We ran "extra" checks on the cache key and determined that it's no
|
# We ran "extra" checks on the cache key and determined that it's no
|
||||||
|
|
@ -1545,6 +1556,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||||
return output
|
return output
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
|
entry = self._make_cache_entry(state, key, func, args, kwargs, output)
|
||||||
except _BypassDispatchCache as e:
|
except _BypassDispatchCache as e:
|
||||||
# We had trouble making the cache entry. Record the reason and mark
|
# We had trouble making the cache entry. Record the reason and mark
|
||||||
|
|
@ -1587,13 +1599,16 @@ class FakeTensorMode(TorchDispatchMode):
|
||||||
if state.known_symbols:
|
if state.known_symbols:
|
||||||
# If there are symbols then include the epoch - this is really more
|
# If there are symbols then include the epoch - this is really more
|
||||||
# of a Shape env var which lives on the FakeTensorMode.
|
# of a Shape env var which lives on the FakeTensorMode.
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
key_values.append(self.epoch)
|
key_values.append(self.epoch)
|
||||||
# Collect the id_hashed objects to attach a weakref finalize later
|
# Collect the id_hashed objects to attach a weakref finalize later
|
||||||
id_hashed_objects: list[object] = []
|
id_hashed_objects: list[object] = []
|
||||||
# Translate any FakeTensor args to metadata.
|
# Translate any FakeTensor args to metadata.
|
||||||
if args:
|
if args:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self._prep_args_for_hash(key_values, args, state, id_hashed_objects)
|
self._prep_args_for_hash(key_values, args, state, id_hashed_objects)
|
||||||
if kwargs:
|
if kwargs:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self._prep_args_for_hash(key_values, kwargs, state, id_hashed_objects)
|
self._prep_args_for_hash(key_values, kwargs, state, id_hashed_objects)
|
||||||
key = _DispatchCacheKey(tuple(key_values))
|
key = _DispatchCacheKey(tuple(key_values))
|
||||||
|
|
||||||
|
|
@ -1909,27 +1924,53 @@ class FakeTensorMode(TorchDispatchMode):
|
||||||
if isinstance(output, tuple):
|
if isinstance(output, tuple):
|
||||||
for out_element in output:
|
for out_element in output:
|
||||||
self._validate_output_for_cache_entry(
|
self._validate_output_for_cache_entry(
|
||||||
state, key, func, args, kwargs, out_element
|
state,
|
||||||
|
key,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
func,
|
||||||
|
args,
|
||||||
|
kwargs,
|
||||||
|
out_element,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
self._validate_output_for_cache_entry(
|
self._validate_output_for_cache_entry(
|
||||||
state, key, func, args, kwargs, output
|
state,
|
||||||
|
key,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
func,
|
||||||
|
args,
|
||||||
|
kwargs,
|
||||||
|
output,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(output, tuple):
|
if isinstance(output, tuple):
|
||||||
output_infos = [
|
output_infos = [
|
||||||
self._get_output_info_for_cache_entry(
|
self._get_output_info_for_cache_entry(
|
||||||
state, key, func, args, kwargs, out_elem
|
state,
|
||||||
|
key,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
func,
|
||||||
|
args,
|
||||||
|
kwargs,
|
||||||
|
out_elem,
|
||||||
)
|
)
|
||||||
for out_elem in output
|
for out_elem in output
|
||||||
]
|
]
|
||||||
return _DispatchCacheValidEntry(
|
return _DispatchCacheValidEntry(
|
||||||
output_infos=tuple(output_infos), is_output_tuple=True
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
output_infos=tuple(output_infos),
|
||||||
|
is_output_tuple=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
output_info = self._get_output_info_for_cache_entry(
|
output_info = self._get_output_info_for_cache_entry(
|
||||||
state, key, func, args, kwargs, output
|
state,
|
||||||
|
key,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
func,
|
||||||
|
args,
|
||||||
|
kwargs,
|
||||||
|
output,
|
||||||
)
|
)
|
||||||
return _DispatchCacheValidEntry(
|
return _DispatchCacheValidEntry(
|
||||||
output_infos=(output_info,), is_output_tuple=False
|
output_infos=(output_info,), is_output_tuple=False
|
||||||
|
|
@ -2472,6 +2513,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||||
)
|
)
|
||||||
|
|
||||||
with self, maybe_ignore_fresh_unbacked_symbols():
|
with self, maybe_ignore_fresh_unbacked_symbols():
|
||||||
|
# pyrefly: ignore # index-error
|
||||||
return registered_hop_fake_fns[func](*args, **kwargs)
|
return registered_hop_fake_fns[func](*args, **kwargs)
|
||||||
|
|
||||||
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
|
self.invalidate_written_to_constants(func, flat_arg_fake_tensors, args, kwargs)
|
||||||
|
|
@ -2625,6 +2667,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||||
# TODO: Is this really needed?
|
# TODO: Is this really needed?
|
||||||
compute_unbacked_bindings(self.shape_env, fake_out, peek=True)
|
compute_unbacked_bindings(self.shape_env, fake_out, peek=True)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return fake_out
|
return fake_out
|
||||||
|
|
||||||
# Try for fastpath
|
# Try for fastpath
|
||||||
|
|
@ -2906,6 +2949,7 @@ class FakeTensorMode(TorchDispatchMode):
|
||||||
self, e, device or common_device
|
self, e, device or common_device
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return e
|
return e
|
||||||
|
|
||||||
return tree_map(wrap, r)
|
return tree_map(wrap, r)
|
||||||
|
|
|
||||||
|
|
@ -81,6 +81,7 @@ def safe_is_leaf(t: Union[MetaTensorDesc, torch.Tensor]) -> bool:
|
||||||
|
|
||||||
def safe_grad(t: _TensorLikeT) -> Optional[_TensorLikeT]:
|
def safe_grad(t: _TensorLikeT) -> Optional[_TensorLikeT]:
|
||||||
with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter):
|
with torch._logging.hide_warnings(torch._logging._internal.safe_grad_filter):
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return t.grad
|
return t.grad
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -415,6 +416,7 @@ class MetaTensorDescriber:
|
||||||
device=t.device,
|
device=t.device,
|
||||||
size=t.size(),
|
size=t.size(),
|
||||||
stride=stride,
|
stride=stride,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
storage_offset=storage_offset,
|
storage_offset=storage_offset,
|
||||||
dynamo_dynamic_indices=list(getattr(t, "_dynamo_dynamic_indices", set())),
|
dynamo_dynamic_indices=list(getattr(t, "_dynamo_dynamic_indices", set())),
|
||||||
dynamo_hint_overrides=getattr(t, "_dynamo_hint_overrides", {}),
|
dynamo_hint_overrides=getattr(t, "_dynamo_hint_overrides", {}),
|
||||||
|
|
@ -539,7 +541,11 @@ class _FakeTensorViewFunc(ViewFunc["FakeTensor"]):
|
||||||
tensor_visitor_fn: Optional[Callable[[torch.Tensor], FakeTensor]] = None,
|
tensor_visitor_fn: Optional[Callable[[torch.Tensor], FakeTensor]] = None,
|
||||||
) -> FakeTensor:
|
) -> FakeTensor:
|
||||||
return torch._subclasses.fake_tensor.FakeTensor._view_func_unsafe(
|
return torch._subclasses.fake_tensor.FakeTensor._view_func_unsafe(
|
||||||
t, new_base, symint_visitor_fn, tensor_visitor_fn
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
t,
|
||||||
|
new_base,
|
||||||
|
symint_visitor_fn,
|
||||||
|
tensor_visitor_fn,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1013,6 +1019,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||||
# Morally, the code here is same as transform_subclass, but we've
|
# Morally, the code here is same as transform_subclass, but we've
|
||||||
# written it from scratch to read EmptyCreateSubclass
|
# written it from scratch to read EmptyCreateSubclass
|
||||||
outer_size = outer_size if outer_size is not None else t.size
|
outer_size = outer_size if outer_size is not None else t.size
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
outer_stride = outer_stride if outer_stride is not None else t.stride
|
outer_stride = outer_stride if outer_stride is not None else t.stride
|
||||||
|
|
||||||
assert symbolic_context is None or isinstance(
|
assert symbolic_context is None or isinstance(
|
||||||
|
|
@ -1269,6 +1276,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
# It's possible to close over an undefined tensor (e.g. NJT's lengths).
|
# It's possible to close over an undefined tensor (e.g. NJT's lengths).
|
||||||
if visited_t is None:
|
if visited_t is None:
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# NB: visited_t being a Tensor here is very naughty! Should
|
# NB: visited_t being a Tensor here is very naughty! Should
|
||||||
|
|
@ -1399,6 +1407,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||||
if t.requires_grad:
|
if t.requires_grad:
|
||||||
r.requires_grad = True
|
r.requires_grad = True
|
||||||
if t.requires_grad and not is_leaf:
|
if t.requires_grad and not is_leaf:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
r = self._backward_error(r)
|
r = self._backward_error(r)
|
||||||
elif t.is_nested and not t.is_traceable_wrapper_subclass:
|
elif t.is_nested and not t.is_traceable_wrapper_subclass:
|
||||||
# TODO: Handle this better in Dynamo?
|
# TODO: Handle this better in Dynamo?
|
||||||
|
|
@ -1437,6 +1446,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||||
if t.requires_grad:
|
if t.requires_grad:
|
||||||
r.requires_grad = True
|
r.requires_grad = True
|
||||||
if t.requires_grad and not is_leaf:
|
if t.requires_grad and not is_leaf:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
r = self._backward_error(r)
|
r = self._backward_error(r)
|
||||||
elif t.is_functorch_wrapped:
|
elif t.is_functorch_wrapped:
|
||||||
if t.is_view:
|
if t.is_view:
|
||||||
|
|
@ -1533,6 +1543,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||||
)
|
)
|
||||||
assert t.data is not None
|
assert t.data is not None
|
||||||
_safe_copy(r.real_tensor, t.data) # type: ignore[attr-defined]
|
_safe_copy(r.real_tensor, t.data) # type: ignore[attr-defined]
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return r
|
return r
|
||||||
|
|
||||||
r = _to_fake_tensor(t)
|
r = _to_fake_tensor(t)
|
||||||
|
|
@ -1682,6 +1693,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||||
not (t.is_batchedtensor or t.is_gradtrackingtensor)
|
not (t.is_batchedtensor or t.is_gradtrackingtensor)
|
||||||
and t.is_functorch_wrapped
|
and t.is_functorch_wrapped
|
||||||
) or t.is_legacy_batchedtensor:
|
) or t.is_legacy_batchedtensor:
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
|
|
||||||
(
|
(
|
||||||
|
|
@ -1728,6 +1740,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||||
# the metadata of the inner tensor.
|
# the metadata of the inner tensor.
|
||||||
# So instead, we now have a dedicated fn to set autograd history,
|
# So instead, we now have a dedicated fn to set autograd history,
|
||||||
# without inadvertently changing other metadata.
|
# without inadvertently changing other metadata.
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
r = self._backward_error(r)
|
r = self._backward_error(r)
|
||||||
|
|
||||||
s = t.storage
|
s = t.storage
|
||||||
|
|
@ -1839,6 +1852,7 @@ class MetaConverter(Generic[_TensorT]):
|
||||||
nt_tensor_id=t.nested_int
|
nt_tensor_id=t.nested_int
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.set_tensor_memo(t, r)
|
self.set_tensor_memo(t, r)
|
||||||
|
|
||||||
return self._checked_get_tensor_memo(t)
|
return self._checked_get_tensor_memo(t)
|
||||||
|
|
@ -1882,11 +1896,13 @@ class MetaConverter(Generic[_TensorT]):
|
||||||
(t._is_view() and t._base is not None and t._base.is_sparse)
|
(t._is_view() and t._base is not None and t._base.is_sparse)
|
||||||
):
|
):
|
||||||
self.miss += 1
|
self.miss += 1
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
else:
|
else:
|
||||||
self.hit += 1
|
self.hit += 1
|
||||||
elif torch.overrides.is_tensor_like(t):
|
elif torch.overrides.is_tensor_like(t):
|
||||||
self.miss += 1
|
self.miss += 1
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return NotImplemented
|
return NotImplemented
|
||||||
else:
|
else:
|
||||||
# non-Tensor types don't count as hit or miss
|
# non-Tensor types don't count as hit or miss
|
||||||
|
|
|
||||||
|
|
@ -92,6 +92,7 @@ def _make_grads(
|
||||||
is_grads_batched: bool,
|
is_grads_batched: bool,
|
||||||
) -> tuple[_OptionalTensor, ...]:
|
) -> tuple[_OptionalTensor, ...]:
|
||||||
new_grads: list[_OptionalTensor] = []
|
new_grads: list[_OptionalTensor] = []
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
for out, grad in zip(outputs, grads):
|
for out, grad in zip(outputs, grads):
|
||||||
out = cast(Union[torch.Tensor, graph.GradientEdge], out)
|
out = cast(Union[torch.Tensor, graph.GradientEdge], out)
|
||||||
out_size = None
|
out_size = None
|
||||||
|
|
@ -341,6 +342,7 @@ def backward(
|
||||||
Union[tuple[torch.Tensor], tuple[graph.GradientEdge]], (tensors,)
|
Union[tuple[torch.Tensor], tuple[graph.GradientEdge]], (tensors,)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
tensors = tuple(tensors)
|
tensors = tuple(tensors)
|
||||||
|
|
||||||
grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
|
grad_tensors_ = _tensor_or_tensors_to_tuple(grad_tensors, len(tensors))
|
||||||
|
|
@ -440,10 +442,12 @@ def grad(
|
||||||
Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]], (outputs,)
|
Union[Sequence[torch.Tensor], Sequence[graph.GradientEdge]], (outputs,)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
outputs = tuple(outputs)
|
outputs = tuple(outputs)
|
||||||
if is_tensor_like(inputs) or isinstance(inputs, graph.GradientEdge):
|
if is_tensor_like(inputs) or isinstance(inputs, graph.GradientEdge):
|
||||||
inputs = cast(_TensorOrTensorsOrGradEdge, (inputs,))
|
inputs = cast(_TensorOrTensorsOrGradEdge, (inputs,))
|
||||||
else:
|
else:
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
inputs = tuple(inputs)
|
inputs = tuple(inputs)
|
||||||
t_outputs = tuple(i for i in outputs if is_tensor_like(i))
|
t_outputs = tuple(i for i in outputs if is_tensor_like(i))
|
||||||
t_inputs = tuple(i for i in inputs if is_tensor_like(i))
|
t_inputs = tuple(i for i in inputs if is_tensor_like(i))
|
||||||
|
|
|
||||||
|
|
@ -15,12 +15,14 @@ class Type(Function):
|
||||||
"please use `torch.tensor.to(dtype=dtype)` instead.",
|
"please use `torch.tensor.to(dtype=dtype)` instead.",
|
||||||
category=FutureWarning,
|
category=FutureWarning,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, i, dest_type):
|
def forward(ctx, i, dest_type):
|
||||||
ctx.input_type = type(i)
|
ctx.input_type = type(i)
|
||||||
ctx.input_device = -1 if not i.is_cuda else i.get_device()
|
ctx.input_device = -1 if not i.is_cuda else i.get_device()
|
||||||
return i.type(dest_type)
|
return i.type(dest_type)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
if ctx.input_device == -1:
|
if ctx.input_device == -1:
|
||||||
return grad_output.type(ctx.input_type), None
|
return grad_output.type(ctx.input_type), None
|
||||||
|
|
@ -32,6 +34,7 @@ class Type(Function):
|
||||||
# TODO: deprecate this
|
# TODO: deprecate this
|
||||||
class Resize(Function):
|
class Resize(Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, tensor, sizes):
|
def forward(ctx, tensor, sizes):
|
||||||
ctx.sizes = sizes
|
ctx.sizes = sizes
|
||||||
ctx.numel = reduce(operator.mul, sizes, 1)
|
ctx.numel = reduce(operator.mul, sizes, 1)
|
||||||
|
|
@ -60,6 +63,7 @@ class Resize(Function):
|
||||||
return tensor.contiguous().view(*sizes)
|
return tensor.contiguous().view(*sizes)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
assert grad_output.numel() == ctx.numel
|
assert grad_output.numel() == ctx.numel
|
||||||
return grad_output.contiguous().view(ctx.input_sizes), None
|
return grad_output.contiguous().view(ctx.input_sizes), None
|
||||||
|
|
|
||||||
|
|
@ -9,6 +9,8 @@ from typing_extensions import deprecated
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.testing
|
import torch.testing
|
||||||
|
|
||||||
|
# pyrefly: ignore # deprecated
|
||||||
from torch._vmap_internals import _vmap, vmap
|
from torch._vmap_internals import _vmap, vmap
|
||||||
from torch.overrides import is_tensor_like
|
from torch.overrides import is_tensor_like
|
||||||
from torch.types import _TensorOrTensors
|
from torch.types import _TensorOrTensors
|
||||||
|
|
|
||||||
|
|
@ -229,6 +229,7 @@ def get_gradient_edge(tensor: torch.Tensor) -> GradientEdge:
|
||||||
|
|
||||||
# Note that output_nr default to 0 which is the right value
|
# Note that output_nr default to 0 which is the right value
|
||||||
# for the AccumulateGrad node.
|
# for the AccumulateGrad node.
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return GradientEdge(grad_fn, tensor.output_nr, ownership_token=token)
|
return GradientEdge(grad_fn, tensor.output_nr, ownership_token=token)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -531,6 +532,7 @@ def register_multi_grad_hook(
|
||||||
"expected this hook to be called inside a backward call"
|
"expected this hook to be called inside a backward call"
|
||||||
)
|
)
|
||||||
count[id] = count.get(id, 0)
|
count[id] = count.get(id, 0)
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
buffer[id] = buffer.get(id, [None] * len_tensors)
|
buffer[id] = buffer.get(id, [None] * len_tensors)
|
||||||
|
|
||||||
with lock:
|
with lock:
|
||||||
|
|
|
||||||
|
|
@ -731,6 +731,7 @@ class profile:
|
||||||
return all_function_events
|
return all_function_events
|
||||||
|
|
||||||
|
|
||||||
|
# pyrefly: ignore # invalid-inheritance
|
||||||
class record_function(_ContextDecorator):
|
class record_function(_ContextDecorator):
|
||||||
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
|
"""Context manager/function decorator that adds a label to a code block/function when running autograd profiler.
|
||||||
Label will only appear if CPU activity tracing is enabled.
|
Label will only appear if CPU activity tracing is enabled.
|
||||||
|
|
@ -778,7 +779,9 @@ class record_function(_ContextDecorator):
|
||||||
# TODO: TorchScript ignores standard type annotation here
|
# TODO: TorchScript ignores standard type annotation here
|
||||||
# self.record: Optional["torch.classes.profiler._RecordFunction"] = None
|
# self.record: Optional["torch.classes.profiler._RecordFunction"] = None
|
||||||
self.record = torch.jit.annotate(
|
self.record = torch.jit.annotate(
|
||||||
Optional["torch.classes.profiler._RecordFunction"], None
|
# pyrefly: ignore # not-a-type
|
||||||
|
Optional["torch.classes.profiler._RecordFunction"],
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
|
|
|
||||||
|
|
@ -101,12 +101,14 @@ class profile:
|
||||||
|
|
||||||
records = _disable_profiler_legacy()
|
records = _disable_profiler_legacy()
|
||||||
parsed_results = _parse_legacy_records(records)
|
parsed_results = _parse_legacy_records(records)
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.function_events = EventList(
|
self.function_events = EventList(
|
||||||
parsed_results,
|
parsed_results,
|
||||||
use_device="cuda" if self.use_cuda else None,
|
use_device="cuda" if self.use_cuda else None,
|
||||||
profile_memory=self.profile_memory,
|
profile_memory=self.profile_memory,
|
||||||
with_flops=self.with_flops,
|
with_flops=self.with_flops,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
self.function_events._build_tree()
|
self.function_events._build_tree()
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -48,10 +48,14 @@ class EventList(list):
|
||||||
def _remove_dup_nodes(self):
|
def _remove_dup_nodes(self):
|
||||||
while True:
|
while True:
|
||||||
to_delete = set()
|
to_delete = set()
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
for idx in range(len(self)):
|
for idx in range(len(self)):
|
||||||
if (
|
if (
|
||||||
|
# pyrefly: ignore # index-error
|
||||||
self[idx].cpu_parent is not None
|
self[idx].cpu_parent is not None
|
||||||
|
# pyrefly: ignore # index-error
|
||||||
and self[idx].cpu_parent.name == self[idx].name
|
and self[idx].cpu_parent.name == self[idx].name
|
||||||
|
# pyrefly: ignore # index-error
|
||||||
and len(self[idx].cpu_parent.cpu_children) == 1
|
and len(self[idx].cpu_parent.cpu_children) == 1
|
||||||
):
|
):
|
||||||
self[idx].cpu_parent.cpu_children = self[idx].cpu_children
|
self[idx].cpu_parent.cpu_children = self[idx].cpu_children
|
||||||
|
|
@ -61,8 +65,11 @@ class EventList(list):
|
||||||
to_delete.add(idx)
|
to_delete.add(idx)
|
||||||
if len(to_delete) == 0:
|
if len(to_delete) == 0:
|
||||||
break
|
break
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
new_evts = [ev for ind, ev in enumerate(self) if ind not in to_delete]
|
new_evts = [ev for ind, ev in enumerate(self) if ind not in to_delete]
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
self.clear()
|
self.clear()
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
self.extend(new_evts)
|
self.extend(new_evts)
|
||||||
|
|
||||||
def _populate_cpu_children(self):
|
def _populate_cpu_children(self):
|
||||||
|
|
@ -496,7 +503,9 @@ class FunctionEvent(FormattedTimesMixin):
|
||||||
self.id: int = id
|
self.id: int = id
|
||||||
self.node_id: int = node_id
|
self.node_id: int = node_id
|
||||||
self.name: str = name
|
self.name: str = name
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.overload_name: str = overload_name
|
self.overload_name: str = overload_name
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.trace_name: str = trace_name
|
self.trace_name: str = trace_name
|
||||||
self.time_range: Interval = Interval(start_us, end_us)
|
self.time_range: Interval = Interval(start_us, end_us)
|
||||||
self.thread: int = thread
|
self.thread: int = thread
|
||||||
|
|
@ -505,9 +514,13 @@ class FunctionEvent(FormattedTimesMixin):
|
||||||
self.count: int = 1
|
self.count: int = 1
|
||||||
self.cpu_children: list[FunctionEvent] = []
|
self.cpu_children: list[FunctionEvent] = []
|
||||||
self.cpu_parent: Optional[FunctionEvent] = None
|
self.cpu_parent: Optional[FunctionEvent] = None
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.input_shapes: tuple[int, ...] = input_shapes
|
self.input_shapes: tuple[int, ...] = input_shapes
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.concrete_inputs: list[Any] = concrete_inputs
|
self.concrete_inputs: list[Any] = concrete_inputs
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.kwinputs: dict[str, Any] = kwinputs
|
self.kwinputs: dict[str, Any] = kwinputs
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.stack: list = stack
|
self.stack: list = stack
|
||||||
self.scope: int = scope
|
self.scope: int = scope
|
||||||
self.use_device: Optional[str] = use_device
|
self.use_device: Optional[str] = use_device
|
||||||
|
|
@ -732,6 +745,7 @@ class FunctionEventAvg(FormattedTimesMixin):
|
||||||
self.self_device_memory_usage += other.self_device_memory_usage
|
self.self_device_memory_usage += other.self_device_memory_usage
|
||||||
self.count += other.count
|
self.count += other.count
|
||||||
if self.flops is None:
|
if self.flops is None:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.flops = other.flops
|
self.flops = other.flops
|
||||||
elif other.flops is not None:
|
elif other.flops is not None:
|
||||||
self.flops += other.flops
|
self.flops += other.flops
|
||||||
|
|
@ -967,6 +981,7 @@ def _build_table(
|
||||||
"PFLOPs",
|
"PFLOPs",
|
||||||
]
|
]
|
||||||
assert flops > 0
|
assert flops > 0
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1)))
|
log_flops = max(0, min(math.log10(flops) / 3, float(len(flop_headers) - 1)))
|
||||||
assert log_flops >= 0 and log_flops < len(flop_headers)
|
assert log_flops >= 0 and log_flops < len(flop_headers)
|
||||||
return (pow(10, (math.floor(log_flops) * -3.0)), flop_headers[int(log_flops)])
|
return (pow(10, (math.floor(log_flops) * -3.0)), flop_headers[int(log_flops)])
|
||||||
|
|
|
||||||
|
|
@ -496,12 +496,14 @@ class cudaStatus:
|
||||||
|
|
||||||
class CudaError(RuntimeError):
|
class CudaError(RuntimeError):
|
||||||
def __init__(self, code: int) -> None:
|
def __init__(self, code: int) -> None:
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
msg = _cudart.cudaGetErrorString(_cudart.cudaError(code))
|
msg = _cudart.cudaGetErrorString(_cudart.cudaError(code))
|
||||||
super().__init__(f"{msg} ({code})")
|
super().__init__(f"{msg} ({code})")
|
||||||
|
|
||||||
|
|
||||||
def check_error(res: int) -> None:
|
def check_error(res: int) -> None:
|
||||||
r"""Raise an error if the result of a CUDA runtime API call is not success."""
|
r"""Raise an error if the result of a CUDA runtime API call is not success."""
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if res != _cudart.cudaError.success:
|
if res != _cudart.cudaError.success:
|
||||||
raise CudaError(res)
|
raise CudaError(res)
|
||||||
|
|
||||||
|
|
@ -601,6 +603,7 @@ def get_device_capability(device: "Device" = None) -> tuple[int, int]:
|
||||||
return prop.major, prop.minor
|
return prop.major, prop.minor
|
||||||
|
|
||||||
|
|
||||||
|
# pyrefly: ignore # not-a-type
|
||||||
def get_device_properties(device: "Device" = None) -> _CudaDeviceProperties:
|
def get_device_properties(device: "Device" = None) -> _CudaDeviceProperties:
|
||||||
r"""Get the properties of a device.
|
r"""Get the properties of a device.
|
||||||
|
|
||||||
|
|
@ -651,6 +654,7 @@ class StreamContext:
|
||||||
self.idx = _get_device_index(None, True)
|
self.idx = _get_device_index(None, True)
|
||||||
if not torch.jit.is_scripting():
|
if not torch.jit.is_scripting():
|
||||||
if self.idx is None:
|
if self.idx is None:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
self.idx = -1
|
self.idx = -1
|
||||||
|
|
||||||
self.src_prev_stream = (
|
self.src_prev_stream = (
|
||||||
|
|
@ -953,7 +957,9 @@ def _device_count_amdsmi() -> int:
|
||||||
if raw_cnt <= 0:
|
if raw_cnt <= 0:
|
||||||
return raw_cnt
|
return raw_cnt
|
||||||
# Trim the list up to a maximum available device
|
# Trim the list up to a maximum available device
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
for idx, val in enumerate(visible_devices):
|
for idx, val in enumerate(visible_devices):
|
||||||
|
# pyrefly: ignore # redundant-cast
|
||||||
if cast(int, val) >= raw_cnt:
|
if cast(int, val) >= raw_cnt:
|
||||||
return idx
|
return idx
|
||||||
except OSError:
|
except OSError:
|
||||||
|
|
@ -987,7 +993,9 @@ def _device_count_nvml() -> int:
|
||||||
if raw_cnt <= 0:
|
if raw_cnt <= 0:
|
||||||
return raw_cnt
|
return raw_cnt
|
||||||
# Trim the list up to a maximum available device
|
# Trim the list up to a maximum available device
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
for idx, val in enumerate(visible_devices):
|
for idx, val in enumerate(visible_devices):
|
||||||
|
# pyrefly: ignore # redundant-cast
|
||||||
if cast(int, val) >= raw_cnt:
|
if cast(int, val) >= raw_cnt:
|
||||||
return idx
|
return idx
|
||||||
except OSError:
|
except OSError:
|
||||||
|
|
@ -1203,7 +1211,9 @@ def _get_pynvml_handler(device: "Device" = None):
|
||||||
if not _HAS_PYNVML:
|
if not _HAS_PYNVML:
|
||||||
raise ModuleNotFoundError(
|
raise ModuleNotFoundError(
|
||||||
"pynvml does not seem to be installed or it can't be imported."
|
"pynvml does not seem to be installed or it can't be imported."
|
||||||
|
# pyrefly: ignore # invalid-inheritance
|
||||||
) from _PYNVML_ERR
|
) from _PYNVML_ERR
|
||||||
|
# pyrefly: ignore # import-error
|
||||||
from pynvml import NVMLError_DriverNotLoaded
|
from pynvml import NVMLError_DriverNotLoaded
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -1220,6 +1230,7 @@ def _get_amdsmi_handler(device: "Device" = None):
|
||||||
if not _HAS_PYNVML:
|
if not _HAS_PYNVML:
|
||||||
raise ModuleNotFoundError(
|
raise ModuleNotFoundError(
|
||||||
"amdsmi does not seem to be installed or it can't be imported."
|
"amdsmi does not seem to be installed or it can't be imported."
|
||||||
|
# pyrefly: ignore # invalid-inheritance
|
||||||
) from _PYNVML_ERR
|
) from _PYNVML_ERR
|
||||||
try:
|
try:
|
||||||
amdsmi.amdsmi_init()
|
amdsmi.amdsmi_init()
|
||||||
|
|
@ -1483,6 +1494,7 @@ def _get_rng_state_offset(device: Union[int, str, torch.device] = "cuda") -> int
|
||||||
return default_generator.get_offset()
|
return default_generator.get_offset()
|
||||||
|
|
||||||
|
|
||||||
|
# pyrefly: ignore # deprecated
|
||||||
from .memory import * # noqa: F403
|
from .memory import * # noqa: F403
|
||||||
from .random import * # noqa: F403
|
from .random import * # noqa: F403
|
||||||
|
|
||||||
|
|
@ -1699,6 +1711,7 @@ def _register_triton_kernels():
|
||||||
def kernel_impl(*args, **kwargs):
|
def kernel_impl(*args, **kwargs):
|
||||||
from torch.sparse._triton_ops import bsr_dense_mm
|
from torch.sparse._triton_ops import bsr_dense_mm
|
||||||
|
|
||||||
|
# pyrefly: ignore # not-callable
|
||||||
return bsr_dense_mm(*args, skip_checks=True, **kwargs)
|
return bsr_dense_mm(*args, skip_checks=True, **kwargs)
|
||||||
|
|
||||||
@_WrappedTritonKernel
|
@_WrappedTritonKernel
|
||||||
|
|
|
||||||
|
|
@ -279,6 +279,7 @@ class _CudaModule:
|
||||||
return self._kernels[name]
|
return self._kernels[name]
|
||||||
|
|
||||||
# Import the CUDA library inside the method
|
# Import the CUDA library inside the method
|
||||||
|
# pyrefly: ignore # missing-module-attribute
|
||||||
from torch.cuda._utils import _get_gpu_runtime_library
|
from torch.cuda._utils import _get_gpu_runtime_library
|
||||||
|
|
||||||
libcuda = _get_gpu_runtime_library()
|
libcuda = _get_gpu_runtime_library()
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
# pyrefly: ignore # deprecated
|
||||||
from .autocast_mode import autocast, custom_bwd, custom_fwd
|
from .autocast_mode import autocast, custom_bwd, custom_fwd
|
||||||
from .common import amp_definitely_not_available
|
from .common import amp_definitely_not_available
|
||||||
from .grad_scaler import GradScaler
|
from .grad_scaler import GradScaler
|
||||||
|
|
|
||||||
|
|
@ -259,6 +259,7 @@ class graph:
|
||||||
self.cuda_graph.capture_begin(
|
self.cuda_graph.capture_begin(
|
||||||
# type: ignore[misc]
|
# type: ignore[misc]
|
||||||
*self.pool,
|
*self.pool,
|
||||||
|
# pyrefly: ignore # bad-keyword-argument
|
||||||
capture_error_mode=self.capture_error_mode,
|
capture_error_mode=self.capture_error_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -524,6 +525,7 @@ def make_graphed_callables(
|
||||||
) -> Callable[..., object]:
|
) -> Callable[..., object]:
|
||||||
class Graphed(torch.autograd.Function):
|
class Graphed(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx: object, *inputs: Tensor) -> tuple[Tensor, ...]:
|
def forward(ctx: object, *inputs: Tensor) -> tuple[Tensor, ...]:
|
||||||
# At this stage, only the user args may (potentially) be new tensors.
|
# At this stage, only the user args may (potentially) be new tensors.
|
||||||
for i in range(len_user_args):
|
for i in range(len_user_args):
|
||||||
|
|
@ -535,6 +537,7 @@ def make_graphed_callables(
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@torch.autograd.function.once_differentiable
|
@torch.autograd.function.once_differentiable
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def backward(ctx: object, *grads: Tensor) -> tuple[Tensor, ...]:
|
def backward(ctx: object, *grads: Tensor) -> tuple[Tensor, ...]:
|
||||||
assert len(grads) == len(static_grad_outputs)
|
assert len(grads) == len(static_grad_outputs)
|
||||||
for g, grad in zip(static_grad_outputs, grads):
|
for g, grad in zip(static_grad_outputs, grads):
|
||||||
|
|
@ -548,7 +551,9 @@ def make_graphed_callables(
|
||||||
# Input args that didn't require grad expect a None gradient.
|
# Input args that didn't require grad expect a None gradient.
|
||||||
assert isinstance(static_grad_inputs, tuple)
|
assert isinstance(static_grad_inputs, tuple)
|
||||||
return tuple(
|
return tuple(
|
||||||
b.detach() if b is not None else b for b in static_grad_inputs
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
b.detach() if b is not None else b
|
||||||
|
for b in static_grad_inputs
|
||||||
)
|
)
|
||||||
|
|
||||||
def functionalized(*user_args: object) -> object:
|
def functionalized(*user_args: object) -> object:
|
||||||
|
|
|
||||||
|
|
@ -770,6 +770,7 @@ def list_gpu_processes(device: "Device" = None) -> str:
|
||||||
import pynvml # type: ignore[import]
|
import pynvml # type: ignore[import]
|
||||||
except ModuleNotFoundError:
|
except ModuleNotFoundError:
|
||||||
return "pynvml module not found, please install pynvml"
|
return "pynvml module not found, please install pynvml"
|
||||||
|
# pyrefly: ignore # import-error
|
||||||
from pynvml import NVMLError_DriverNotLoaded
|
from pynvml import NVMLError_DriverNotLoaded
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
@ -852,6 +853,7 @@ def _record_memory_history_legacy(
|
||||||
_C._cuda_record_memory_history_legacy( # type: ignore[call-arg]
|
_C._cuda_record_memory_history_legacy( # type: ignore[call-arg]
|
||||||
enabled,
|
enabled,
|
||||||
record_context,
|
record_context,
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
trace_alloc_max_entries,
|
trace_alloc_max_entries,
|
||||||
trace_alloc_record_context,
|
trace_alloc_record_context,
|
||||||
record_context_cpp,
|
record_context_cpp,
|
||||||
|
|
|
||||||
|
|
@ -53,6 +53,7 @@ def range_start(msg) -> int:
|
||||||
Args:
|
Args:
|
||||||
msg (str): ASCII message to associate with the range.
|
msg (str): ASCII message to associate with the range.
|
||||||
"""
|
"""
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
return _nvtx.rangeStartA(msg)
|
return _nvtx.rangeStartA(msg)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -63,6 +64,7 @@ def range_end(range_id) -> None:
|
||||||
Args:
|
Args:
|
||||||
range_id (int): an unique handle for the start range.
|
range_id (int): an unique handle for the start range.
|
||||||
"""
|
"""
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
_nvtx.rangeEnd(range_id)
|
_nvtx.rangeEnd(range_id)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -83,6 +85,7 @@ def _device_range_start(msg: str, stream: int = 0) -> object:
|
||||||
msg (str): ASCII message to associate with the range.
|
msg (str): ASCII message to associate with the range.
|
||||||
stream (int): CUDA stream id.
|
stream (int): CUDA stream id.
|
||||||
"""
|
"""
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
return _nvtx.deviceRangeStart(msg, stream)
|
return _nvtx.deviceRangeStart(msg, stream)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -95,6 +98,7 @@ def _device_range_end(range_handle: object, stream: int = 0) -> None:
|
||||||
range_handle: an unique handle for the start range.
|
range_handle: an unique handle for the start range.
|
||||||
stream (int): CUDA stream id.
|
stream (int): CUDA stream id.
|
||||||
"""
|
"""
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
_nvtx.deviceRangeEnd(range_handle, stream)
|
_nvtx.deviceRangeEnd(range_handle, stream)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -436,6 +436,7 @@ def load(
|
||||||
print(ep(torch.randn(5)))
|
print(ep(torch.randn(5)))
|
||||||
"""
|
"""
|
||||||
if isinstance(f, (str, os.PathLike)):
|
if isinstance(f, (str, os.PathLike)):
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
f = os.fspath(f)
|
f = os.fspath(f)
|
||||||
|
|
||||||
extra_files = extra_files or {}
|
extra_files = extra_files or {}
|
||||||
|
|
|
||||||
|
|
@ -295,6 +295,7 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
|
||||||
|
|
||||||
self.logger.addHandler(self)
|
self.logger.addHandler(self)
|
||||||
self.prev_get_dtrace = torch._logging._internal.GET_DTRACE_STRUCTURED
|
self.prev_get_dtrace = torch._logging._internal.GET_DTRACE_STRUCTURED
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
torch._logging._internal.GET_DTRACE_STRUCTURED = True
|
torch._logging._internal.GET_DTRACE_STRUCTURED = True
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -302,6 +303,7 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler):
|
||||||
self.log_record = LogRecord()
|
self.log_record = LogRecord()
|
||||||
self.expression_created_logs = {}
|
self.expression_created_logs = {}
|
||||||
self.logger.removeHandler(self)
|
self.logger.removeHandler(self)
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
torch._logging._internal.GET_DTRACE_STRUCTURED = self.prev_get_dtrace
|
torch._logging._internal.GET_DTRACE_STRUCTURED = self.prev_get_dtrace
|
||||||
self.prev_get_dtrace = False
|
self.prev_get_dtrace = False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -107,8 +107,11 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None:
|
||||||
return
|
return
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
arg.op == "call_function"
|
arg.op == "call_function"
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
and arg.target == operator.getitem
|
and arg.target == operator.getitem
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
and arg.args[1] == i
|
and arg.args[1] == i
|
||||||
):
|
):
|
||||||
log.debug(
|
log.debug(
|
||||||
|
|
|
||||||
|
|
@ -185,6 +185,7 @@ def _ignore_backend_decomps():
|
||||||
def _disable_custom_triton_op_functional_decomposition():
|
def _disable_custom_triton_op_functional_decomposition():
|
||||||
old = torch._functorch.config.decompose_custom_triton_ops
|
old = torch._functorch.config.decompose_custom_triton_ops
|
||||||
try:
|
try:
|
||||||
|
# pyrefly: ignore # bad-assignment
|
||||||
torch._functorch.config.decompose_custom_triton_ops = False
|
torch._functorch.config.decompose_custom_triton_ops = False
|
||||||
yield torch._functorch.config.decompose_custom_triton_ops
|
yield torch._functorch.config.decompose_custom_triton_ops
|
||||||
finally:
|
finally:
|
||||||
|
|
@ -365,6 +366,7 @@ def _normalize_nn_module_stack(gm_torch_level, root_cls):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def __getitem__(self, idx):
|
def __getitem__(self, idx):
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
parts.append(str(idx))
|
parts.append(str(idx))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -660,6 +662,7 @@ def _rename_constants_nodes(
|
||||||
if spec.kind == InputKind.CONSTANT_TENSOR and not spec.arg.name.startswith(
|
if spec.kind == InputKind.CONSTANT_TENSOR and not spec.arg.name.startswith(
|
||||||
const_prefix
|
const_prefix
|
||||||
):
|
):
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
if spec.arg.name.startswith(buffer_prefix): # map from buffer to constants
|
if spec.arg.name.startswith(buffer_prefix): # map from buffer to constants
|
||||||
c_name = rename_constant(
|
c_name = rename_constant(
|
||||||
const_prefix + spec.arg.name[len(buffer_prefix) :]
|
const_prefix + spec.arg.name[len(buffer_prefix) :]
|
||||||
|
|
|
||||||
|
|
@ -293,6 +293,7 @@ class _ExportPackage:
|
||||||
if isinstance(fn, torch.nn.Module):
|
if isinstance(fn, torch.nn.Module):
|
||||||
dynamic_shapes = v(fn, *args, **kwargs) # type: ignore[arg-type]
|
dynamic_shapes = v(fn, *args, **kwargs) # type: ignore[arg-type]
|
||||||
else:
|
else:
|
||||||
|
# pyrefly: ignore # invalid-param-spec
|
||||||
dynamic_shapes = v(*args, **kwargs)
|
dynamic_shapes = v(*args, **kwargs)
|
||||||
except AssertionError:
|
except AssertionError:
|
||||||
continue
|
continue
|
||||||
|
|
@ -340,6 +341,7 @@ class _ExportPackage:
|
||||||
assert not hasattr(fn, "_define_overload")
|
assert not hasattr(fn, "_define_overload")
|
||||||
_exporter_context._define_overload = _define_overload # type: ignore[attr-defined]
|
_exporter_context._define_overload = _define_overload # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return _exporter_context
|
return _exporter_context
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -376,6 +378,7 @@ class _ExportPackage:
|
||||||
kwargs=ep.example_inputs[1],
|
kwargs=ep.example_inputs[1],
|
||||||
options=options,
|
options=options,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
aoti_files_map[name] = aoti_files
|
aoti_files_map[name] = aoti_files
|
||||||
|
|
||||||
from torch._inductor.package import package
|
from torch._inductor.package import package
|
||||||
|
|
|
||||||
|
|
@ -1500,6 +1500,7 @@ class ExportedProgram:
|
||||||
transformed_gm = res.graph_module if res is not None else self.graph_module
|
transformed_gm = res.graph_module if res is not None else self.graph_module
|
||||||
assert transformed_gm is not None
|
assert transformed_gm is not None
|
||||||
|
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
if transformed_gm is self.graph_module and not res.modified:
|
if transformed_gm is self.graph_module and not res.modified:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
@ -1578,6 +1579,7 @@ class ExportedProgram:
|
||||||
verifiers=self.verifiers,
|
verifiers=self.verifiers,
|
||||||
)
|
)
|
||||||
transformed_ep.graph_module.meta.update(self.graph_module.meta)
|
transformed_ep.graph_module.meta.update(self.graph_module.meta)
|
||||||
|
# pyrefly: ignore # missing-attribute
|
||||||
transformed_ep.graph_module.meta.update(res.graph_module.meta)
|
transformed_ep.graph_module.meta.update(res.graph_module.meta)
|
||||||
return transformed_ep
|
return transformed_ep
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -81,6 +81,7 @@ def move_to_device_pass(
|
||||||
and node.target == torch.ops.aten.to.device
|
and node.target == torch.ops.aten.to.device
|
||||||
):
|
):
|
||||||
args = list(node.args)
|
args = list(node.args)
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
args[1] = _get_new_device(args[1], location)
|
args[1] = _get_new_device(args[1], location)
|
||||||
node.args = tuple(args)
|
node.args = tuple(args)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -172,8 +172,10 @@ class PT2ArchiveWriter:
|
||||||
os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True)
|
os.path.isfile, glob.glob(f"{folder_dir}/**", recursive=True)
|
||||||
)
|
)
|
||||||
for file_path in file_paths:
|
for file_path in file_paths:
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
filename = os.path.relpath(file_path, folder_dir)
|
filename = os.path.relpath(file_path, folder_dir)
|
||||||
archive_path = os.path.join(archive_dir, filename)
|
archive_path = os.path.join(archive_dir, filename)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
self.write_file(archive_path, file_path)
|
self.write_file(archive_path, file_path)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
|
|
@ -593,6 +595,7 @@ def package_pt2(
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
(isinstance(f, (io.IOBase, IO)) and f.writable() and f.seekable())
|
(isinstance(f, (io.IOBase, IO)) and f.writable() and f.seekable())
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2"))
|
or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2"))
|
||||||
or (isinstance(f, tempfile._TemporaryFileWrapper) and f.name.endswith(".pt2"))
|
or (isinstance(f, tempfile._TemporaryFileWrapper) and f.name.endswith(".pt2"))
|
||||||
):
|
):
|
||||||
|
|
@ -604,8 +607,10 @@ def package_pt2(
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(f, (str, os.PathLike)):
|
if isinstance(f, (str, os.PathLike)):
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
f = os.fspath(f)
|
f = os.fspath(f)
|
||||||
|
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
with PT2ArchiveWriter(f) as archive_writer:
|
with PT2ArchiveWriter(f) as archive_writer:
|
||||||
_package_exported_programs(
|
_package_exported_programs(
|
||||||
archive_writer, exported_programs, pickle_protocol=pickle_protocol
|
archive_writer, exported_programs, pickle_protocol=pickle_protocol
|
||||||
|
|
@ -620,6 +625,7 @@ def package_pt2(
|
||||||
|
|
||||||
if isinstance(f, (io.IOBase, IO)):
|
if isinstance(f, (io.IOBase, IO)):
|
||||||
f.seek(0)
|
f.seek(0)
|
||||||
|
# pyrefly: ignore # bad-return
|
||||||
return f
|
return f
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -992,6 +998,7 @@ def load_pt2(
|
||||||
|
|
||||||
if not (
|
if not (
|
||||||
(isinstance(f, (io.IOBase, IO)) and f.readable() and f.seekable())
|
(isinstance(f, (io.IOBase, IO)) and f.readable() and f.seekable())
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2"))
|
or (isinstance(f, (str, os.PathLike)) and os.fspath(f).endswith(".pt2"))
|
||||||
):
|
):
|
||||||
# TODO: turn this into an error in 2.9
|
# TODO: turn this into an error in 2.9
|
||||||
|
|
@ -1002,10 +1009,12 @@ def load_pt2(
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(f, (str, os.PathLike)):
|
if isinstance(f, (str, os.PathLike)):
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
f = os.fspath(f)
|
f = os.fspath(f)
|
||||||
|
|
||||||
weights = {}
|
weights = {}
|
||||||
weight_maps = {}
|
weight_maps = {}
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
with PT2ArchiveReader(f) as archive_reader:
|
with PT2ArchiveReader(f) as archive_reader:
|
||||||
version = archive_reader.read_string(ARCHIVE_VERSION_PATH)
|
version = archive_reader.read_string(ARCHIVE_VERSION_PATH)
|
||||||
if version != ARCHIVE_VERSION_VALUE:
|
if version != ARCHIVE_VERSION_VALUE:
|
||||||
|
|
@ -1070,7 +1079,12 @@ def load_pt2(
|
||||||
else:
|
else:
|
||||||
aoti_runners = {
|
aoti_runners = {
|
||||||
model_name: _load_aoti(
|
model_name: _load_aoti(
|
||||||
f, model_name, run_single_threaded, num_runners, device_index
|
# pyrefly: ignore # bad-argument-type
|
||||||
|
f,
|
||||||
|
model_name,
|
||||||
|
run_single_threaded,
|
||||||
|
num_runners,
|
||||||
|
device_index,
|
||||||
)
|
)
|
||||||
for model_name in aoti_model_names
|
for model_name in aoti_model_names
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -937,6 +937,7 @@ def _check_graph_equivalence(x: torch.nn.Module, y: torch.nn.Module):
|
||||||
for key, value in pytree.tree_map(arg_dump, node.kwargs).items()
|
for key, value in pytree.tree_map(arg_dump, node.kwargs).items()
|
||||||
]
|
]
|
||||||
target = node.target if node.op in ("call_function", "get_attr") else ""
|
target = node.target if node.op in ("call_function", "get_attr") else ""
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})")
|
ret.append(f"{i}: {node.op}[{target}]({', '.join(args_dump)})")
|
||||||
nodes_idx[id(node)] = i
|
nodes_idx[id(node)] = i
|
||||||
return "\n".join(ret)
|
return "\n".join(ret)
|
||||||
|
|
@ -1473,6 +1474,7 @@ class _ModuleFrame:
|
||||||
self.seen_attrs[self.child_fqn].add(node.target)
|
self.seen_attrs[self.child_fqn].add(node.target)
|
||||||
|
|
||||||
self.copy_node(node)
|
self.copy_node(node)
|
||||||
|
# pyrefly: ignore # unsupported-operation
|
||||||
node_idx += 1
|
node_idx += 1
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -483,6 +483,7 @@ def _canonical_dim(dim: DimOrDims, ndim: int) -> tuple[int, ...]:
|
||||||
raise IndexError(
|
raise IndexError(
|
||||||
f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})"
|
f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})"
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
dims.append(d % ndim)
|
dims.append(d % ndim)
|
||||||
return tuple(sorted(dims))
|
return tuple(sorted(dims))
|
||||||
|
|
||||||
|
|
@ -641,6 +642,7 @@ def _sparse_coo_scatter_reduction_helper(
|
||||||
|
|
||||||
# promote dtype if specified
|
# promote dtype if specified
|
||||||
if values.dtype != output_dtype:
|
if values.dtype != output_dtype:
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
values = values.to(output_dtype)
|
values = values.to(output_dtype)
|
||||||
|
|
||||||
if keepdim:
|
if keepdim:
|
||||||
|
|
@ -765,6 +767,7 @@ def _sparse_csr_segment_reduction_helper(
|
||||||
|
|
||||||
# promote dtype if specified
|
# promote dtype if specified
|
||||||
if values.dtype != output_dtype:
|
if values.dtype != output_dtype:
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
values = values.to(output_dtype)
|
values = values.to(output_dtype)
|
||||||
|
|
||||||
if len(dims) == 0:
|
if len(dims) == 0:
|
||||||
|
|
@ -1015,6 +1018,7 @@ def _combine_input_and_mask(
|
||||||
|
|
||||||
class Combine(torch.autograd.Function):
|
class Combine(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, input, mask):
|
def forward(ctx, input, mask):
|
||||||
"""Return input with masked-out elements eliminated for the given operations."""
|
"""Return input with masked-out elements eliminated for the given operations."""
|
||||||
ctx.save_for_backward(mask)
|
ctx.save_for_backward(mask)
|
||||||
|
|
@ -1025,6 +1029,7 @@ def _combine_input_and_mask(
|
||||||
return helper(input, mask)
|
return helper(input, mask)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
(mask,) = ctx.saved_tensors
|
(mask,) = ctx.saved_tensors
|
||||||
grad_data = (
|
grad_data = (
|
||||||
|
|
@ -1399,15 +1404,18 @@ elements, have ``nan`` values.
|
||||||
if input.layout == torch.strided:
|
if input.layout == torch.strided:
|
||||||
if mask is None:
|
if mask is None:
|
||||||
# TODO: compute count analytically
|
# TODO: compute count analytically
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
count = sum(
|
count = sum(
|
||||||
torch.ones(input.shape, dtype=torch.int64, device=input.device),
|
torch.ones(input.shape, dtype=torch.int64, device=input.device),
|
||||||
dim,
|
dim,
|
||||||
keepdim=keepdim,
|
keepdim=keepdim,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
total = sum(input, dim, keepdim=keepdim, dtype=dtype)
|
total = sum(input, dim, keepdim=keepdim, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
inmask = _input_mask(input, mask=mask)
|
inmask = _input_mask(input, mask=mask)
|
||||||
count = inmask.sum(dim=dim, keepdim=bool(keepdim))
|
count = inmask.sum(dim=dim, keepdim=bool(keepdim))
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask)
|
total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask)
|
||||||
return total / count
|
return total / count
|
||||||
elif input.layout == torch.sparse_csr:
|
elif input.layout == torch.sparse_csr:
|
||||||
|
|
@ -1618,15 +1626,18 @@ def _std_var(
|
||||||
if input.layout == torch.strided:
|
if input.layout == torch.strided:
|
||||||
if mask is None:
|
if mask is None:
|
||||||
# TODO: compute count analytically
|
# TODO: compute count analytically
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
count = sum(
|
count = sum(
|
||||||
torch.ones(input.shape, dtype=torch.int64, device=input.device),
|
torch.ones(input.shape, dtype=torch.int64, device=input.device),
|
||||||
dim,
|
dim,
|
||||||
keepdim=True,
|
keepdim=True,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
sample_total = sum(input, dim, keepdim=True, dtype=dtype)
|
sample_total = sum(input, dim, keepdim=True, dtype=dtype)
|
||||||
else:
|
else:
|
||||||
inmask = _input_mask(input, mask=mask)
|
inmask = _input_mask(input, mask=mask)
|
||||||
count = inmask.sum(dim=dim, keepdim=True)
|
count = inmask.sum(dim=dim, keepdim=True)
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask)
|
sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask)
|
||||||
# TODO: replace torch.subtract/divide/square/maximum with
|
# TODO: replace torch.subtract/divide/square/maximum with
|
||||||
# masked subtract/divide/square/maximum when these will be
|
# masked subtract/divide/square/maximum when these will be
|
||||||
|
|
@ -1634,6 +1645,7 @@ def _std_var(
|
||||||
sample_mean = torch.divide(sample_total, count)
|
sample_mean = torch.divide(sample_total, count)
|
||||||
x = torch.subtract(input, sample_mean)
|
x = torch.subtract(input, sample_mean)
|
||||||
if mask is None:
|
if mask is None:
|
||||||
|
# pyrefly: ignore # no-matching-overload
|
||||||
total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
|
total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
|
||||||
else:
|
else:
|
||||||
total = sum(
|
total = sum(
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,7 @@ def _check_args_kwargs_length(
|
||||||
|
|
||||||
class _MaskedContiguous(torch.autograd.Function):
|
class _MaskedContiguous(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, input):
|
def forward(ctx, input):
|
||||||
if not is_masked_tensor(input):
|
if not is_masked_tensor(input):
|
||||||
raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.")
|
raise ValueError("MaskedContiguous forward: input must be a MaskedTensor.")
|
||||||
|
|
@ -60,12 +61,14 @@ class _MaskedContiguous(torch.autograd.Function):
|
||||||
return MaskedTensor(data.contiguous(), mask.contiguous())
|
return MaskedTensor(data.contiguous(), mask.contiguous())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
return grad_output
|
return grad_output
|
||||||
|
|
||||||
|
|
||||||
class _MaskedToDense(torch.autograd.Function):
|
class _MaskedToDense(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, input):
|
def forward(ctx, input):
|
||||||
if not is_masked_tensor(input):
|
if not is_masked_tensor(input):
|
||||||
raise ValueError("MaskedToDense forward: input must be a MaskedTensor.")
|
raise ValueError("MaskedToDense forward: input must be a MaskedTensor.")
|
||||||
|
|
@ -80,6 +83,7 @@ class _MaskedToDense(torch.autograd.Function):
|
||||||
return MaskedTensor(data.to_dense(), mask.to_dense())
|
return MaskedTensor(data.to_dense(), mask.to_dense())
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
layout = ctx.layout
|
layout = ctx.layout
|
||||||
|
|
||||||
|
|
@ -94,6 +98,7 @@ class _MaskedToDense(torch.autograd.Function):
|
||||||
|
|
||||||
class _MaskedToSparse(torch.autograd.Function):
|
class _MaskedToSparse(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, input):
|
def forward(ctx, input):
|
||||||
if not is_masked_tensor(input):
|
if not is_masked_tensor(input):
|
||||||
raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.")
|
raise ValueError("MaskedToSparse forward: input must be a MaskedTensor.")
|
||||||
|
|
@ -110,12 +115,14 @@ class _MaskedToSparse(torch.autograd.Function):
|
||||||
return MaskedTensor(sparse_data, sparse_mask)
|
return MaskedTensor(sparse_data, sparse_mask)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
return grad_output.to_dense()
|
return grad_output.to_dense()
|
||||||
|
|
||||||
|
|
||||||
class _MaskedToSparseCsr(torch.autograd.Function):
|
class _MaskedToSparseCsr(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, input):
|
def forward(ctx, input):
|
||||||
if not is_masked_tensor(input):
|
if not is_masked_tensor(input):
|
||||||
raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.")
|
raise ValueError("MaskedToSparseCsr forward: input must be a MaskedTensor.")
|
||||||
|
|
@ -136,18 +143,21 @@ class _MaskedToSparseCsr(torch.autograd.Function):
|
||||||
return MaskedTensor(sparse_data, sparse_mask)
|
return MaskedTensor(sparse_data, sparse_mask)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
return grad_output.to_dense()
|
return grad_output.to_dense()
|
||||||
|
|
||||||
|
|
||||||
class _MaskedWhere(torch.autograd.Function):
|
class _MaskedWhere(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, cond, self, other):
|
def forward(ctx, cond, self, other):
|
||||||
ctx.mark_non_differentiable(cond)
|
ctx.mark_non_differentiable(cond)
|
||||||
ctx.save_for_backward(cond)
|
ctx.save_for_backward(cond)
|
||||||
return torch.ops.aten.where(cond, self, other)
|
return torch.ops.aten.where(cond, self, other)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
(cond,) = ctx.saved_tensors
|
(cond,) = ctx.saved_tensors
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -174,6 +174,7 @@ class MaskedTensor(torch.Tensor):
|
||||||
UserWarning,
|
UserWarning,
|
||||||
stacklevel=2,
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
|
# pyrefly: ignore # bad-argument-type
|
||||||
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs)
|
return torch.Tensor._make_wrapper_subclass(cls, data.size(), **kwargs)
|
||||||
|
|
||||||
def _preprocess_data(self, data, mask):
|
def _preprocess_data(self, data, mask):
|
||||||
|
|
@ -243,10 +244,12 @@ class MaskedTensor(torch.Tensor):
|
||||||
|
|
||||||
class Constructor(torch.autograd.Function):
|
class Constructor(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, data, mask):
|
def forward(ctx, data, mask):
|
||||||
return MaskedTensor(data, mask)
|
return MaskedTensor(data, mask)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
return grad_output, None
|
return grad_output, None
|
||||||
|
|
||||||
|
|
@ -333,10 +336,12 @@ class MaskedTensor(torch.Tensor):
|
||||||
def get_data(self):
|
def get_data(self):
|
||||||
class GetData(torch.autograd.Function):
|
class GetData(torch.autograd.Function):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def forward(ctx, self):
|
def forward(ctx, self):
|
||||||
return self._masked_data.detach()
|
return self._masked_data.detach()
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
# pyrefly: ignore # bad-override
|
||||||
def backward(ctx, grad_output):
|
def backward(ctx, grad_output):
|
||||||
if is_masked_tensor(grad_output):
|
if is_masked_tensor(grad_output):
|
||||||
return grad_output
|
return grad_output
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user