[BE][PYFMT] migrate PYFMT for torch._inductor to ruff format (#144550)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144550
Approved by: https://github.com/jansel
This commit is contained in:
Xuehai Pan 2025-02-28 15:35:13 +08:00 committed by PyTorch MergeBot
parent 34d726011f
commit 1cb4e2df65
88 changed files with 1157 additions and 954 deletions

View File

@ -53,7 +53,6 @@ USE_BLACK_FILELIST = re.compile(
# torch/_[e-h]*/**
"torch/_[e-h]*/**",
# torch/_i*/**
"torch/_i*/**",
# torch/_[j-z]*/**
"torch/_[j-z]*/**",
# torch/[a-c]*/**

View File

@ -66,7 +66,9 @@ def aoti_compile_and_package(
.. code-block:: python
ep = torch.export.export(M(), ...)
aoti_file = torch._inductor.aoti_compile_and_package(ep, package_path="my_package.pt2")
aoti_file = torch._inductor.aoti_compile_and_package(
ep, package_path="my_package.pt2"
)
compiled_model = torch._inductor.aoti_load_package("my_package.pt2")
To compile and save multiple models into a single ``.pt2`` artifact, you can do
@ -75,11 +77,16 @@ def aoti_compile_and_package(
.. code-block:: python
ep1 = torch.export.export(M1(), ...)
aoti_file1 = torch._inductor.aot_compile(ep1, ..., options={"aot_inductor.package": True})
aoti_file1 = torch._inductor.aot_compile(
ep1, ..., options={"aot_inductor.package": True}
)
ep2 = torch.export.export(M2(), ...)
aoti_file2 = torch._inductor.aot_compile(ep2, ..., options={"aot_inductor.package": True})
aoti_file2 = torch._inductor.aot_compile(
ep2, ..., options={"aot_inductor.package": True}
)
from torch._inductor.package import package_aoti, load_package
package_aoti("my_package.pt2", {"model1": aoti_file1, "model2": aoti_file2})
compiled_model1 = load_package("my_package.pt2", "model1")
@ -123,7 +130,9 @@ def aoti_compile_and_package(
isinstance(package_path, (str, os.PathLike))
and os.fspath(package_path).endswith(".pt2")
)
), f"Expect package path to be a file ending in .pt2, is None, or is a buffer. Instead got {package_path}"
), (
f"Expect package path to be a file ending in .pt2, is None, or is a buffer. Instead got {package_path}"
)
inductor_configs = inductor_configs or {}
inductor_configs["aot_inductor.package"] = True
@ -168,9 +177,9 @@ def _aoti_compile_and_package_inner(
"""
if check_accuracy:
assert (
kwargs is None or len(kwargs) == 0
), "when checking for accuracy, the inputs must have been flattened and kwargs is None"
assert kwargs is None or len(kwargs) == 0, (
"when checking for accuracy, the inputs must have been flattened and kwargs is None"
)
from .package import package_aoti

View File

@ -156,8 +156,9 @@ def can_codegen_without_upcasts(
low_prec_analysis = RecordLowPrecisionOps(disallow_fp32_ops)
# Need to turn off upcasting to do analysis of whether we can turn it off
with config.patch("triton.codegen_upcast_to_fp32", False), V.set_ops_handler(
low_prec_analysis
with (
config.patch("triton.codegen_upcast_to_fp32", False),
V.set_ops_handler(low_prec_analysis),
):
prologue._body(*prologue.get_ranges())

View File

@ -245,8 +245,7 @@ class AsyncCompile:
def use_process_pool(self):
return (
get_compile_threads() > 1
and self.process_pool().ready_future.done() # type: ignore[union-attr]
get_compile_threads() > 1 and self.process_pool().ready_future.done() # type: ignore[union-attr]
)
def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):

View File

@ -24,8 +24,7 @@ if TYPE_CHECKING:
class Sortable(typing.Protocol):
"""Anything that can be used as a list.sort() key (int/tuple/etc)"""
def __lt__(self, other: typing.Self) -> bool:
...
def __lt__(self, other: typing.Self) -> bool: ...
class InductorChoices:
@ -100,7 +99,9 @@ class InductorChoices:
# to pick the faster one.
if config.triton.multi_kernel:
threshold *= 16
return V.graph.sizevars.statically_known_leq(features.reduction_numel, threshold) # type: ignore[arg-types]
return V.graph.sizevars.statically_known_leq(
features.reduction_numel, threshold
) # type: ignore[arg-types]
@staticmethod
def want_no_x_dim(features: SIMDKernelFeatures) -> bool:

View File

@ -417,9 +417,9 @@ def write_atomic(
) -> None:
# Write into temporary file first to avoid conflicts between threads
# Avoid using a named temporary file, as those have restricted permissions
assert isinstance(
content, (str, bytes)
), "Only strings and byte arrays can be saved in the cache"
assert isinstance(content, (str, bytes)), (
"Only strings and byte arrays can be saved in the cache"
)
path = Path(path_)
if make_dirs:
path.parent.mkdir(parents=True, exist_ok=True)
@ -975,9 +975,9 @@ class FxGraphCache:
symints = FxGraphCache._filter_backed_symints(example_inputs)
hints = [hint_int(s) for s in symints]
def iterate_over_candidates() -> (
Generator[tuple[CompiledFxGraph, bytes], None, None]
):
def iterate_over_candidates() -> Generator[
tuple[CompiledFxGraph, bytes], None, None
]:
if local:
subdir = FxGraphCache._get_tmp_dir_for_key(key)
if os.path.exists(subdir):
@ -1123,9 +1123,9 @@ class FxGraphCache:
"""
from .compile_fx import CompiledFxGraph
assert isinstance(
compiled_graph, CompiledFxGraph
), f"serialization for {type(compiled_graph)} NYI"
assert isinstance(compiled_graph, CompiledFxGraph), (
f"serialization for {type(compiled_graph)} NYI"
)
disk_compiled_graph = copy(compiled_graph)
disk_compiled_graph.prepare_for_serialization()
@ -1315,9 +1315,8 @@ class FxGraphCache:
"distributed_ephemeral_timeout_us", time_saved_ns // 1000
)
if (
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(
time_saved_ns
)
ephemeral_increase
:= add_ephemeral_timeout_increase_for_distributed(time_saved_ns)
) != 0:
cache_info["ephemeral_timeout_increase"] = ephemeral_increase
else:
@ -1556,9 +1555,9 @@ class AotCodeCompiler:
cpp_path_operator.with_name(f"{cpp_path_operator.stem}_metadata.json")
)
for k, v in config.aot_inductor.metadata.items():
assert isinstance(k, str) and isinstance(
v, (str)
), "Metadata must only contain strings"
assert isinstance(k, str) and isinstance(v, (str)), (
"Metadata must only contain strings"
)
with open(meta_json, "w") as f:
f.write(json.dumps(config.aot_inductor.metadata))

View File

@ -341,7 +341,7 @@ class BackendFeature(Enum):
def get_backend_features(
device: Union[torch.device, str, None]
device: Union[torch.device, str, None],
) -> OrderedSet[BackendFeature]:
if device is None:
return OrderedSet()
@ -986,9 +986,9 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
if cls._is_unimplemented(funcname):
setattr(cls, funcname, cls._unimplemented(funcname))
else:
assert (
funcname not in cls.__dict__
), f"multiple definitions of {funcname} on {cls.__name__}"
assert funcname not in cls.__dict__, (
f"multiple definitions of {funcname} on {cls.__name__}"
)
impl.__name__ = funcname
setattr(cls, funcname, staticmethod(impl))
@ -2229,7 +2229,7 @@ class KernelTemplate:
@staticmethod
def _fake_get_dtype(
fake_outs: Union[list[Buffer], Buffer]
fake_outs: Union[list[Buffer], Buffer],
) -> Callable[[str], torch.dtype]:
_get_dtype_real = V.graph.get_dtype
if isinstance(fake_outs, (list, tuple)):

View File

@ -483,9 +483,9 @@ class OuterLoopFusedSchedulerNode(FusedSchedulerNode):
outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]],
outer_loop_fusion_depth,
):
self.outer_fused_nodes: list[
Union[FusedSchedulerNode, SchedulerNode]
] = outer_fused_nodes
self.outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]] = (
outer_fused_nodes
)
self.outer_loop_fusion_depth = outer_loop_fusion_depth
flatten_snodes = []
for _node in self.outer_fused_nodes:
@ -1361,9 +1361,9 @@ class CppVecOverrides(CppOverrides):
@staticmethod
def remainder(a, b):
assert (
a.dtype == b.dtype
), "remainder vec implementation expect the same inputs' dtype."
assert a.dtype == b.dtype, (
"remainder vec implementation expect the same inputs' dtype."
)
return f"{a} - ({CppVecOverrides.floordiv(a, b)}) * {b}"
@staticmethod
@ -1468,9 +1468,9 @@ class CppVecOverrides(CppOverrides):
@staticmethod
def floordiv(a, b):
if is_float_dtype(a.dtype):
assert (
a.dtype == b.dtype
), "div_floor_floating_vec implementation expect the same inputs' dtype."
assert a.dtype == b.dtype, (
"div_floor_floating_vec implementation expect the same inputs' dtype."
)
return f"div_floor_floating_vec({a}, {b})"
else:
assert all(is_integer_dtype(item.dtype) for item in [a, b])
@ -1629,9 +1629,9 @@ class CppVecOverrides(CppOverrides):
assert isinstance(other_vec_var, CppCSEVariable), other_vec_var
body_vec_var.dtype = dtype
other_vec_var.dtype = dtype
overrides: type[
Union[CppOverrides, CppVecOverrides]
] = V.kernel.overrides # type: ignore[has-type]
overrides: type[Union[CppOverrides, CppVecOverrides]] = (
V.kernel.overrides
) # type: ignore[has-type]
code.writeline(
f"return {overrides.where(new_mask, body_vec_var, other_vec_var)};"
)
@ -2108,9 +2108,9 @@ class CppKernel(Kernel):
def set_ranges(self, lengths, reduction_lengths):
if self.call_ranges:
assert self.call_ranges == tuple(lengths) + tuple(
reduction_lengths
), f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}"
assert self.call_ranges == tuple(lengths) + tuple(reduction_lengths), (
f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}"
)
assert self.reduction_depth == len(lengths)
else:
self.call_ranges = tuple(lengths) + tuple(reduction_lengths)
@ -2787,9 +2787,9 @@ class CppVecKernel(CppKernel):
self.weight_recps_val = self.weight_recps_cse.generate(
self.compute, f"reduction {self.weight_recp_vec_range}", write=False
)
self.weight_recps_cse.reduction_cache[
self.weight_recp_vec_range
] = self.weight_recps_val
self.weight_recps_cse.reduction_cache[self.weight_recp_vec_range] = (
self.weight_recps_val
)
self.non_parallel_reduction_prefix.writeline(
self.welford_weight_reciprocal_vec(dtype)
)
@ -4969,9 +4969,9 @@ class CppScheduling(BaseScheduling):
]
counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes)
assert self.is_cpp_template(
template_node
), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer"
assert self.is_cpp_template(template_node), (
"Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer"
)
template_node = cast(SchedulerNode, template_node)
_, (_, rnumel) = template_node.group
assert rnumel == ()
@ -4979,9 +4979,9 @@ class CppScheduling(BaseScheduling):
epilogue_ir_nodes: list[Optional[ir.Operation]] = [
n.node for n in epilogue_nodes
]
assert all(
isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes
), "Epilogue nodes must all be instances of ir.ComputedBuffer"
assert all(isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes), (
"Epilogue nodes must all be instances of ir.ComputedBuffer"
)
def template_buffer_has_other_users(
template_buffer, outputs_by_name, epilogue_nodes
@ -5019,16 +5019,16 @@ class CppScheduling(BaseScheduling):
if is_multi_outputs_template(template_node.node):
# For multi outputs template, allocate buffers for each output after the epilogue
# codegen to which determines if the buffer has been removed.
assert (
len(template_node.outputs) == 1
), "Multi outputs template should be with 1 output template buffer of MultiOutputLayout"
assert len(template_node.outputs) == 1, (
"Multi outputs template should be with 1 output template buffer of MultiOutputLayout"
)
for user in template_node.outputs[0].users:
assert isinstance(
user.node, ExternKernelSchedulerNode
), "Multi outputs template should be with ExternKernelSchedulerNode"
assert isinstance(
user.node.node, ir.MultiOutput
), "Multi outputs template has multi users with MultiOutput"
assert isinstance(user.node, ExternKernelSchedulerNode), (
"Multi outputs template should be with ExternKernelSchedulerNode"
)
assert isinstance(user.node.node, ir.MultiOutput), (
"Multi outputs template has multi users with MultiOutput"
)
user.node.mark_run()
kernel.call_kernel(kernel_name, ctb)
@ -5347,9 +5347,9 @@ class LoopNest:
return self.loops is not None and self.loops[0].is_reduction
def mark_parallel(self, par_depth):
assert (
par_depth <= self.max_parallel_depth()
), "Parallel depth cannot exceed the maximal allowed parallel depth"
assert par_depth <= self.max_parallel_depth(), (
"Parallel depth cannot exceed the maximal allowed parallel depth"
)
assert self.loops is not None
assert len(self.loops) >= par_depth
loop = self.loops[0]

View File

@ -862,7 +862,9 @@ class CppFlexAttentionTemplate(CppTemplate):
assert all(
mem.buffer_name in kernel_group.args.input_buffers
for mem in body.memory_usage[MemoryUsageType.LOAD]
), "All the buffers in the score and mask subgraph should be in kernel_group.args.input_buffers"
), (
"All the buffers in the score and mask subgraph should be in kernel_group.args.input_buffers"
)
bodies.append(body)
var_sizes_list.append((var_sizes, ()))

View File

@ -557,9 +557,9 @@ class CppGemmTemplate(CppTemplate):
thread_block_m = math.ceil(m_blocks / m_factor)
return GemmBlocking(thread_block_m, thread_block_n, thread_block_k)
assert (
not self.is_dynamic_M
), "Unable to determine thread blocking for dynamic M."
assert not self.is_dynamic_M, (
"Unable to determine thread blocking for dynamic M."
)
register_blocking = self.register_blocking
m_blocks = math.ceil(self.m / register_blocking.block_m)
n_blocks = math.ceil(self.n / register_blocking.block_n)
@ -673,17 +673,17 @@ class CppGemmTemplate(CppTemplate):
L1_cache_size = (
torch._C._cpu._L1d_cache_size()
) # per core cache size in Bytes
assert (
L1_cache_size > 0
), f"Expect L1_cache_size > 0 but got {L1_cache_size}"
assert L1_cache_size > 0, (
f"Expect L1_cache_size > 0 but got {L1_cache_size}"
)
L1 = L1_cache_size * L1_limit_factor
L2_cache_size = (
torch._C._cpu._L2_cache_size()
) # per core cache size in Bytes
assert (
L2_cache_size > 0
), f"Expect L2_cache_size > 0 but got {L2_cache_size}"
assert L2_cache_size > 0, (
f"Expect L2_cache_size > 0 but got {L2_cache_size}"
)
L2 = L2_cache_size * L2_limit_factor
def get_num_byte(dtype):
@ -744,9 +744,9 @@ class CppGemmTemplate(CppTemplate):
return Mc_blocks, Nc_blocks, Kc_blocks
assert (
not self.is_dynamic_M
), "Unable to determine cache blocking for dynamic M."
assert not self.is_dynamic_M, (
"Unable to determine cache blocking for dynamic M."
)
register_blocking = self.register_blocking
thread_blocking = self.thread_blocking(num_threads)
@ -1114,9 +1114,9 @@ class CppGemmTemplate(CppTemplate):
LayoutType.VNNI4,
], f"We only support {layout_str} for now"
vnni_size = 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2
assert (
k % vnni_size == 0
), f"k should be divisible by vnni_size for {layout_str} layout"
assert k % vnni_size == 0, (
f"k should be divisible by vnni_size for {layout_str} layout"
)
vnni_view_size = list(new_size)
vnni_view_size[-2] = k // vnni_size
vnni_view_size.insert(-1, vnni_size)

View File

@ -309,9 +309,9 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
for W_node in W_nodes:
assert W_node.get_name() in V.graph.constants
W_tensor.append(V.graph.constants[W_node.get_name()])
new_input_nodes[
wgt_start_idx : wgt_start_idx + gemm_grouped_num
] = W_tensor # type: ignore[assignment]
new_input_nodes[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = (
W_tensor # type: ignore[assignment]
)
new_input_nodes, _ = pack_weight(
*normalize_shapes(*maybe_to_dense(new_input_nodes, layout))
)
@ -321,9 +321,9 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
W_packed = new_input_nodes[idx]
assert isinstance(W_packed, torch.Tensor)
W_packed_constant = V.graph.add_tensor_constant(W_packed)
template_buffer.inputs[
idx
] = ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
template_buffer.inputs[idx] = (
ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
)
return output
template = DataProcessorTemplateWrapper(
@ -419,9 +419,9 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
ir.Buffer(name=gemm_output_name, layout=template_buffer.layout)
)
assert (
not self.epilogue_creator
), "epilogue_creator is not supported yet in Grouped GEMM Template"
assert not self.epilogue_creator, (
"epilogue_creator is not supported yet in Grouped GEMM Template"
)
kernel_args: dict[str, Optional[ir.IRNode]] = {}
for x_idx in range(wgt_start_idx):

View File

@ -231,9 +231,9 @@ micro_gemm_configs: dict[type[CppMicroGemm], list[CppMicroGemmConfig]] = {}
def register_micro_gemm(*configs):
def inner(cls):
assert (
cls not in micro_gemm_configs
), f"Duplicate micro_gemm registration for {cls}"
assert cls not in micro_gemm_configs, (
f"Duplicate micro_gemm registration for {cls}"
)
assert len(configs) > 0, f"No micro_gemm configs provided for {cls}"
micro_gemm_configs[cls] = list(configs)
return cls

View File

@ -44,11 +44,13 @@ class CppTemplate(KernelTemplate):
def generate(self, **kwargs):
kernel_name = f"cpp_{self.name}"
with patch.object(
V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
), patch.object(ir.FlexibleLayout, "allow_indexing", True), CppTemplateKernel(
with (
patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)),
patch.object(ir.FlexibleLayout, "allow_indexing", True),
CppTemplateKernel(
kernel_name=kernel_name, num_threads=self.num_threads
) as kernel:
) as kernel,
):
code = kernel.render(self, **kwargs)
_, call_args, _, _ = kernel.args.python_argdefs()
log.debug("Generated Code:\n%s", code)

View File

@ -377,7 +377,10 @@ class CppTemplateKernel(CppKernel):
)
epilogue_nodes = scope.localize_nodes(epilogue_nodes)
return self.store_pointwise_nodes(
dst, epilogue_nodes, offsets, reindexers # type: ignore[arg-type]
dst,
epilogue_nodes, # type: ignore[arg-type]
offsets,
reindexers,
)
else:
if dst.get_name() != src.get_name():

View File

@ -110,9 +110,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
Only valid when cuda == True.
"""
assert not gpu, "CppWrapperCpu.generate_kernel_call does not support GPU"
assert arg_types is not None and len(call_args) == len(
arg_types
), "Mismatch call_args and arg_types in generate_kernel_call"
assert arg_types is not None and len(call_args) == len(arg_types), (
"Mismatch call_args and arg_types in generate_kernel_call"
)
new_args = []
for idx, arg in enumerate(call_args):
if "*" in arg_types[idx]:
@ -506,9 +506,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
dtype = may_get_constant_buffer_dtype(
V.graph.graph_inputs[input_key] # type: ignore[arg-type]
)
assert (
dtype is not None
), "Fails to get the dtype of the sympy.Expr"
assert dtype is not None, (
"Fails to get the dtype of the sympy.Expr"
)
self.codegen_tensor_item(
dtype, f"inputs[{idx}]", input_key, self.prefix
)
@ -555,8 +555,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
def codegen_tensor_dtype_var_decl(self, code: IndentedBuffer, name):
code.writeline(f"int32_t {name}_dtype;")
code.writeline(
"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype"
f"({name}, &{name}_dtype));"
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype({name}, &{name}_dtype));"
)
def codegen_input_size_var_decl(self, code: IndentedBuffer, name):
@ -570,9 +569,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
# Tell compiler we need to link with the non-mangled symbols
for kernel in self.initialized_kernels.values():
assert hasattr(
kernel, "get_signature"
), f"{kernel} must have get_signature implemented"
assert hasattr(kernel, "get_signature"), (
f"{kernel} must have get_signature implemented"
)
signature = kernel.get_signature()
self.prefix.writeline(f'extern "C" {signature};')
@ -597,9 +596,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
)
)
for name, kernel in self.initialized_kernels.items():
assert hasattr(
kernel, "get_signature"
), f"{kernel} must have get_signature implemented"
assert hasattr(kernel, "get_signature"), (
f"{kernel} must have get_signature implemented"
)
kernel_ptr = f"(*{name})"
signature = kernel.get_signature().replace(name, kernel_ptr)
self.prefix.writeline(f" {signature} = torch::aot_inductor::{name};")
@ -645,9 +644,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
with self.prefix.indent():
for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()):
assert not isinstance(
inp, sympy.Expr
), f"input {name=} cannot be symbolic"
assert not isinstance(inp, sympy.Expr), (
f"input {name=} cannot be symbolic"
)
self.write_input_output_info("inputs_info_", idx, name)
all_cuda = all(
@ -718,9 +717,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
opaque_metadata_tensor = torch.ops.mkldnn._get_mkldnn_serialized_md(
tensor
)
assert (
opaque_metadata_tensor.dim() == 1
), "Expect opaque_metadata_tensor to be 1-D"
assert opaque_metadata_tensor.dim() == 1, (
"Expect opaque_metadata_tensor to be 1-D"
)
opaque_metadata_list = opaque_metadata_tensor.tolist()
opaque_metadata_str = self.codegen_shape_tuple(opaque_metadata_list)
@ -757,9 +756,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
)
for idx, output in enumerate(V.graph.graph_outputs):
assert not isinstance(
output, sympy.Expr
), f"output {name=} cannot be symbolic"
assert not isinstance(output, sympy.Expr), (
f"output {name=} cannot be symbolic"
)
name = f"output{idx}"
self.write_input_output_info("outputs_info_", idx, name)
@ -816,9 +815,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
for idx, (name, _) in enumerate(V.graph.constants.items()):
if name in V.graph.const_output_index:
const_index_mapping[V.graph.const_output_index[name]] = (idx, name) # type: ignore[call-overload]
assert (
None not in const_index_mapping
), "Not all constant gets mapped for constant folding graph."
assert None not in const_index_mapping, (
"Not all constant gets mapped for constant folding graph."
)
self.prefix.writeline(
f"""
@ -1117,9 +1116,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
name = f"{output.get_name()}"
output_handle_name = f"{name}_handle"
if output.indices:
assert (
output.indices[0][1] == idx
), f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}"
assert output.indices[0][1] == idx, (
f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}"
)
self.writeline(f"AtenTensorHandle {output_handle_name};")
output_args.append(f"&{output_handle_name}")
output_raii_handles.append(
@ -1140,7 +1139,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
args = args + output_args
device = d.type if (d := fallback_kernel.get_device()) else self.device
self.generate_c_shim_extern_kernel_call(
fallback_kernel.cpp_kernel_name, args, device # type: ignore[arg-type]
fallback_kernel.cpp_kernel_name, # type: ignore[arg-type]
args,
device,
)
for raii_handle in output_raii_handles:
self.writeline(raii_handle)
@ -1189,9 +1190,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
if reduce:
line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
else:
assert (
reduce is None
), "Expect reduce to be None for aten.scatter_ with scalar src"
assert reduce is None, (
"Expect reduce to be None for aten.scatter_ with scalar src"
)
line += ");"
self.writeline(line)
@ -1841,18 +1842,24 @@ class CppWrapperCpu(PythonWrapperCodegen):
# Only treat int Scalar as dynamic
is_int_type = [isinstance(a, int) for a in arg]
if any(is_int_type):
assert all(
is_int_type
), "AOTInductor only supports int scalars of the same type"
assert all(is_int_type), (
"AOTInductor only supports int scalars of the same type"
)
new_int_args.extend([str(a) for a in arg])
else:
assert isinstance(
arg_type.getElementType(), static_arg_types # type: ignore[arg-type]
), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
arg_type.getElementType(),
static_arg_types, # type: ignore[arg-type]
), (
f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
)
else:
assert isinstance(
arg_type, static_arg_types # type: ignore[arg-type]
), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
arg_type,
static_arg_types, # type: ignore[arg-type]
), (
f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
)
for arg, arg_type in zip(raw_args, arg_types):
if arg is not None:
@ -2378,9 +2385,9 @@ if (custom_op_wrapper.get() == NULL) {
return f"&{var_name}"
if isinstance(type_, torch.ListType):
assert isinstance(
val, (list, tuple)
), f"{val} does not match with arg type {type_}"
assert isinstance(val, (list, tuple)), (
f"{val} does not match with arg type {type_}"
)
element_type = type_.getElementType()
var_name = f"var_array_{next(self.var_array_id)}"
if len(val) == 0:

View File

@ -56,9 +56,9 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
self.cached_output_id = count()
self.scalar_to_tensor_id = count()
self.custom_op_wrapper_loaded = False
self.allow_stack_allocation: Optional[
bool
] = config.aot_inductor.allow_stack_allocation
self.allow_stack_allocation: Optional[bool] = (
config.aot_inductor.allow_stack_allocation
)
self.stack_allocated_buffers: dict[BufferName, BufferLike] = {}
@staticmethod
@ -126,12 +126,12 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
Otherwise it uses the CUDA language for codegen.
Only valid when cuda == True.
"""
assert (
not gpu
), "CppWrapperCpuArrayRef.generate_kernel_call does not support GPU"
assert arg_types is not None and len(call_args) == len(
arg_types
), "Mismatch call_args and arg_types in generate_kernel_call"
assert not gpu, (
"CppWrapperCpuArrayRef.generate_kernel_call does not support GPU"
)
assert arg_types is not None and len(call_args) == len(arg_types), (
"Mismatch call_args and arg_types in generate_kernel_call"
)
new_args = []
for idx, arg in enumerate(call_args):
if "*" in arg_types[idx]:
@ -328,9 +328,9 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
dtype = may_get_constant_buffer_dtype(
V.graph.graph_inputs[input_key] # type: ignore[arg-type]
)
assert (
dtype is not None
), "Fails to get the dtype of the sympy.Expr"
assert dtype is not None, (
"Fails to get the dtype of the sympy.Expr"
)
self.codegen_tensor_item(
dtype, f"inputs[{idx}]", input_key, self.prefix
)
@ -724,9 +724,9 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
if reduce:
line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
else:
assert (
reduce is None
), "Expect reduce to be None for aten.scatter_ with scalar src"
assert reduce is None, (
"Expect reduce to be None for aten.scatter_ with scalar src"
)
line += ");"
self.writeline(line)

View File

@ -60,13 +60,13 @@ class DeferredGpuKernelLine(DeferredLineBase):
# MultiKernel will select one kernel after running the autotune block
self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
params = CudaKernelParamCache.get(self.kernel_name)
assert (
params is not None
), f"{self.kernel_name} not found in CudaKernelParamCache"
assert params is not None, (
f"{self.kernel_name} not found in CudaKernelParamCache"
)
for key in self.keys:
assert (
key in params
), f"{key} not found in CudaKernelParamCache[{self.kernel_name}]"
assert key in params, (
f"{key} not found in CudaKernelParamCache[{self.kernel_name}]"
)
if key == get_cpp_wrapper_cubin_path_name():
assert os.path.exists(params[key]), f"{params[key]} does not exist"
self.additional_files.append(params[key])
@ -122,9 +122,9 @@ class DeferredGpuDefaultGrid:
grid_fn = self.grid_callable(*grid, **self.grid_extra_kwargs)
params = CudaKernelParamCache.get(self.kernel_name)
assert (
params is not None
), f"{self.kernel_name} not found in CudaKernelParamCache"
assert params is not None, (
f"{self.kernel_name} not found in CudaKernelParamCache"
)
return grid_fn(params["meta"])
@ -153,9 +153,9 @@ class DeferredGpuGridLine(DeferredLineBase):
self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
params = CudaKernelParamCache.get(self.kernel_name)
assert (
params is not None
), f"{self.kernel_name} not found in CudaKernelParamCache"
assert params is not None, (
f"{self.kernel_name} not found in CudaKernelParamCache"
)
if self.autotune_configs is not None:
# This indicates the Triton kernel is a user-defined one.
@ -248,13 +248,13 @@ class CppWrapperGpu(CppWrapperCpu):
if V.graph.aot_mode and V.graph.inputs_to_check:
for idx in V.graph.inputs_to_check:
input_name = V.graph.graph_input_names[idx]
assert (
input_name in V.graph.graph_inputs
), f"{input_name} not found in graph inputs"
assert input_name in V.graph.graph_inputs, (
f"{input_name} not found in graph inputs"
)
value = V.graph.graph_inputs[input_name]
assert isinstance(
value, TensorBox
), f"{input_name} is expected to be tensor but found as {type(value)}"
assert isinstance(value, TensorBox), (
f"{input_name} is expected to be tensor but found as {type(value)}"
)
warn_msg = (
f"Input {idx} was compiled as {GPU_ALIGN_BYTES}-bytes aligned, "
"but it is not aligned at run time. Copying to an aligned tensor "

View File

@ -87,9 +87,9 @@ class CUDACPPScheduling(BaseScheduling):
Codegen a CUDA template, possibly with fused epilogues
"""
counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes)
assert self.is_cuda_cpp_template(
template_node
), "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer"
assert self.is_cuda_cpp_template(template_node), (
"Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer"
)
template_node = cast(SchedulerNode, template_node)
_, (_numel, rnumel) = template_node.group
assert rnumel == 1

View File

@ -496,7 +496,9 @@ class CUDATemplateCaller(ChoiceCaller):
make_kernel_render: Callable[[CUDATemplateBuffer, Optional[list[IRNode]]], str],
bmreq: CUDABenchmarkRequest,
template: "CUDATemplate", # type: ignore[name-defined]
info_kwargs: Optional[dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]], # type: ignore[type-arg]
info_kwargs: Optional[
dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]
], # type: ignore[type-arg]
description: str,
) -> None:
super().__init__(name, input_nodes, layout, description)

View File

@ -71,13 +71,14 @@ class CUDATemplate(KernelTemplate):
A CUDATemplateCaller object representing the generated CUDA template caller.
"""
kernel_name = f"cuda_{self.name}"
with patch.object(
V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
), CUDATemplateKernel(
with (
patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)),
CUDATemplateKernel(
kernel_name=kernel_name,
runtime_arg_info=self.get_runtime_arg_info(),
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
) as kernel:
) as kernel,
):
code = self.render(kernel=kernel, **kwargs)
_, call_args, _, _ = kernel.args.python_argdefs()
autotuning_log.debug("Generated Code:\n%s", code)

View File

@ -147,7 +147,9 @@ if try_import_cutlass():
"element_d": DataTypeTag[operation.D.element], # type: ignore[name-defined]
"layout_d": LayoutTag[instance_layout_D], # type: ignore[name-defined]
"element_accumulator": DataTypeTag[operation.accumulator_type()], # type: ignore[name-defined]
"opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], # type: ignore[name-defined] # noqa: B950
"opcode_class": OpcodeClassTag[ # type: ignore[name-defined]
operation.tile_description.math_instruction.opcode_class
],
"arch": f"cutlass::arch::Sm{operation.arch:d}",
"tile_shape_m": str(operation.tile_description.tile_shape[0]),
"tile_shape_n": str(operation.tile_description.tile_shape[1]),
@ -168,7 +170,9 @@ if try_import_cutlass():
operation.tile_description.math_instruction.instruction_shape[2]
),
"kernel_schedule": str(KernelScheduleTag[operation.kernel_schedule]), # type: ignore[name-defined]
"epilogue_schedule": str(EpilogueScheduleTag[operation.epilogue_schedule]), # type: ignore[name-defined]
"epilogue_schedule": str(
EpilogueScheduleTag[operation.epilogue_schedule] # type: ignore[name-defined]
),
"epilogue_functor": epilogue_functor,
"stages": stage_count_string,
"align_a": str(operation.A.alignment),

View File

@ -56,9 +56,9 @@ def try_import_cutlass() -> bool:
"Found cutlass_library in python search path, overriding config.cuda.cutlass_dir"
)
cutlass_library_dir = os.path.dirname(cutlass_library.__file__)
assert os.path.isdir(
cutlass_library_dir
), f"{cutlass_library_dir} is not a directory"
assert os.path.isdir(cutlass_library_dir), (
f"{cutlass_library_dir} is not a directory"
)
config.cuda.cutlass_dir = os.path.abspath(
os.path.join(
cutlass_library_dir,
@ -86,9 +86,9 @@ def try_import_cutlass() -> bool:
if os.path.isdir(cutlass_py_full_path):
if tmp_cutlass_py_full_path not in sys.path:
if os.path.exists(dst_link):
assert os.path.islink(
dst_link
), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
assert os.path.islink(dst_link), (
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
)
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
cutlass_py_full_path
), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"

View File

@ -949,9 +949,9 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
import cutlass_library.gemm_operation as cutlass_gemm_op
import cutlass_library.library as cutlass_lib
assert isinstance(
op, cutlass_gemm_op.GemmOperation
), "op argument is required and has to be an instance of GemmOperation"
assert isinstance(op, cutlass_gemm_op.GemmOperation), (
"op argument is required and has to be an instance of GemmOperation"
)
assert len(self.input_nodes) >= 2 and self.output_node is not None
X, W = self.input_nodes[0], self.input_nodes[1]
@ -977,7 +977,10 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
else:
input_reorder = None
kernel_call_signature = kernel.def_kernel(
inputs=inputs, outputs=[Y], names_str=names_str, input_reorder=input_reorder # type: ignore[arg-type]
inputs=inputs, # type: ignore[arg-type]
outputs=[Y],
names_str=names_str,
input_reorder=input_reorder,
)
test_call_statement = self.test_call_statement(kernel, inputs, names_str)
# The layouts might have changed between autotuning and this call if they were FlexibleLayout

View File

@ -198,7 +198,7 @@ class HalidePrinter(PythonPrinter):
val, n = expr.args
val = self._print(val)
n = int(n)
return f"hl.f32({10.**(-n)!r})*hl.round(({val})*hl.f32({10.**n!r}))"
return f"hl.f32({10.0 ** (-n)!r})*hl.round(({val})*hl.f32({10.0**n!r}))"
texpr = HalidePrinter().doprint
@ -856,12 +856,12 @@ class HalideKernel(SIMDKernel):
for sym, size in added_sym_size:
full_index += stride * sym
stride *= size
self.index_replacements[
node.symbol()
] = V.graph.sizevars.simplify_with_ranges(
self.index_replacements[node.symbol()] = (
V.graph.sizevars.simplify_with_ranges(
ModularIndexing(full_index, node.divisor, node.length),
self.halide_vars, # type: ignore[arg-type]
)
)
# codegen the variable definitions
for sym in self.halide_vars:
@ -1183,9 +1183,9 @@ class HalideKernel(SIMDKernel):
if isinstance(value, tuple):
assert reduction_type == "welford_combine"
self.cse.reduction_cache[
cache_key
] = result_tuple = self.welford_combine_impl(*value)
self.cse.reduction_cache[cache_key] = result_tuple = (
self.welford_combine_impl(*value)
)
return result_tuple
assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
@ -1304,9 +1304,9 @@ class HalideKernel(SIMDKernel):
scan = f"{scan_dom}.x"
self.body.writeline(f"{scan_dom} = hl.RDom([hl.Range(1, {length})])")
assert (
len(self.reduction_renames) == 1
), "multi-dimensional scan not implemented"
assert len(self.reduction_renames) == 1, (
"multi-dimensional scan not implemented"
)
(scan_var,) = [*self.reduction_renames] # type: ignore[misc]
scan_renames_cur = {scan_var: sympy_index_symbol(scan)}
scan_renames_pri = {scan_var: sympy_index_symbol(scan) - 1}

View File

@ -214,8 +214,7 @@ class MemorySplitProtocol(Protocol):
get_size_hint: CachedMethod[[], int]
get_symbolic_size: CachedMethod[[], sympy.Expr]
def _allocate(self, block: Allocation, is_last: bool) -> bool:
...
def _allocate(self, block: Allocation, is_last: bool) -> bool: ...
class ClearCacheOnAllocateMixin(MemorySplitProtocol):

View File

@ -560,7 +560,10 @@ class MetalKernel(SIMDKernel):
threads = [self.pexpr(v.numel) for v in self.active_range_trees()] # type: ignore[misc]
args += [f"threads=[{', '.join(threads)}]"]
if self.inside_reduction:
threads = [self.pexpr(v.numel) if v.is_reduction else "1" for v in self.active_range_trees()] # type: ignore[misc]
threads = [
self.pexpr(v.numel) if v.is_reduction else "1" # type: ignore[misc]
for v in self.active_range_trees()
]
args += [f"group_size=[{', '.join(threads)}]"]
wrapper.generate_kernel_call(

View File

@ -33,9 +33,9 @@ def _get_all_args(args_list, arg_types_list=None):
all_args = max(args_list, key=len)[:]
arg_types = max(arg_types_list, key=len)[:] if arg_types_list is not None else None
for args in args_list:
assert OrderedSet(args).issubset(
OrderedSet(all_args)
), f"{args} v.s. {all_args}"
assert OrderedSet(args).issubset(OrderedSet(all_args)), (
f"{args} v.s. {all_args}"
)
return all_args, arg_types
@ -149,7 +149,9 @@ class MultiKernel:
Assume we do codegen for a MultiKernel encapsulating kernel1 and kernel2.
The generated definition for the multi-kernel will looks like:
```
multi_kernel_kernel1 = MultiKernelCall([kernel1, kernel2], multi_kernel_definition_code)
multi_kernel_kernel1 = MultiKernelCall(
[kernel1, kernel2], multi_kernel_definition_code
)
```
Here is an concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39

View File

@ -516,7 +516,12 @@ class CKGroupedConvFwdTemplate(CKTemplate):
template_params=(",\n" + 12 * " ").join(template_params),
), self._template_from_string(template_type).render(operation_name=op.name())
def render(self, kernel: ROCmTemplateKernel, op: "CKGroupedConvFwdOp", **kwargs) -> str: # type: ignore[override, name-defined]
def render( # type: ignore[override]
self,
kernel: ROCmTemplateKernel,
op: "CKGroupedConvFwdOp", # type: ignore[name-defined]
**kwargs,
) -> str:
template_buffer_node = kwargs.get("template_buffer_node", None)
if template_buffer_node is not None:
self.output_node = template_buffer_node

View File

@ -602,7 +602,12 @@ class CKGemmTemplate(CKTemplate):
operation_name=operation_name
)
def render(self, kernel: ROCmTemplateKernel, op: "CKGemmOperation", **kwargs) -> str: # type: ignore[override]
def render( # type: ignore[override]
self,
kernel: ROCmTemplateKernel,
op: "CKGemmOperation",
**kwargs,
) -> str:
"""
The primary entry point for the code rendering process used in this template.
"""
@ -706,7 +711,7 @@ class CKGemmTemplate(CKTemplate):
* Template instance {op}
*
* {torch.__version__=}
* torch.version.git_version={getattr(torch.version, 'git_version', 'None')}
* torch.version.git_version={getattr(torch.version, "git_version", "None")}
*/
"""
epilogue = None

View File

@ -79,9 +79,9 @@ class ROCmCPPScheduling(BaseScheduling):
"""
Codegen a ROCm template, possibly with fused epilogues
"""
assert self.is_rocm_cpp_template(
template_node
), "Template node passed to ROCmScheduler.codegen_template must be a SchedulerNode that wraps a ROCmTemplateBuffer"
assert self.is_rocm_cpp_template(template_node), (
"Template node passed to ROCmScheduler.codegen_template must be a SchedulerNode that wraps a ROCmTemplateBuffer"
)
template_node = cast(SchedulerNode, template_node)
_, (_numel, rnumel) = template_node.group
assert rnumel == 1

View File

@ -232,7 +232,9 @@ class ROCmTemplateCaller(ChoiceCaller):
],
bmreq: ROCmBenchmarkRequest,
template: "ROCmTemplate", # type: ignore[name-defined]
info_kwargs: Optional[dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]], # type: ignore[type-arg]
info_kwargs: Optional[
dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]
], # type: ignore[type-arg]
) -> None:
super().__init__(name, input_nodes, layout, description="")
self.category = category

View File

@ -70,13 +70,14 @@ class ROCmTemplate(KernelTemplate):
"""
kernel_name = f"rocm_{self.name}"
kernel_hash_name = f"rocm_{self.name}_{next(self.index_counter)}"
with patch.object(
V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
), ROCmTemplateKernel(
with (
patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)),
ROCmTemplateKernel(
kernel_name=kernel_name,
runtime_arg_info=self.get_runtime_arg_info(),
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
) as kernel:
) as kernel,
):
code = self.render(kernel=kernel, **kwargs)
_, call_args, _, _ = kernel.args.python_argdefs()
log.debug("Autotune key: %s, Generated Code:\n%s", kernel_hash_name, code)

View File

@ -638,7 +638,8 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
continue
while current_group < len(remaining) and sv.statically_known_equals(
remaining[current_group], 1 # type: ignore[arg-type]
remaining[current_group],
1, # type: ignore[arg-type]
):
# scroll to next group with remaining elements
current_group += 1
@ -666,9 +667,9 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
)
return_getters_groups.append(return_getters)
assert all(
V.graph.sizevars.size_hint(s) == 1 for s in remaining
), f"failed to set ranges {remaining} {lengths}"
assert all(V.graph.sizevars.size_hint(s) == 1 for s in remaining), (
f"failed to set ranges {remaining} {lengths}"
)
return new_ranges, return_getters_groups
@ -836,7 +837,8 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps)
if len(replacements) > 0:
self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index]
self.range_tree_nodes[sym].expr, replacements # type: ignore[index]
self.range_tree_nodes[sym].expr,
replacements, # type: ignore[index]
)
self.range_tree_nodes[sym].codegen() # type: ignore[index]
return expr
@ -2071,9 +2073,10 @@ class SIMDScheduling(BaseScheduling):
features=SIMDKernelFeatures(node_schedule, numel, rnumel),
)
self.codegen_node_schedule_with_kernel(node_schedule, kernel)
with config.patch(
"benchmark_kernel", benchmark_kernel
), V.set_kernel_handler(kernel):
with (
config.patch("benchmark_kernel", benchmark_kernel),
V.set_kernel_handler(kernel),
):
src_code = kernel.codegen_kernel()
else:
prologue, template, epilogue = nodes[0].get_prologue_template_epilogue(

View File

@ -1579,9 +1579,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
self.block_ptr_id = itertools.count()
self.block_ptr_to_buffer = dict[str, str]()
self.helper_functions = HelperFunctions()
self.pointer_advancements: dict[
SymT, dict[str, list[sympy.Expr]]
] = collections.defaultdict(dict)
self.pointer_advancements: dict[SymT, dict[str, list[sympy.Expr]]] = (
collections.defaultdict(dict)
)
self._load_counts: collections.Counter[str] = collections.Counter()
# A set of autotuning hints to pass as part of triton_meta
@ -2053,9 +2053,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
continue
advancements = self.pointer_advancements[symt]
assert (
block_ptr not in advancements
), "duplicate advancement for pointer '{block_ptr}' at type '{symt}'"
assert block_ptr not in advancements, (
"duplicate advancement for pointer '{block_ptr}' at type '{symt}'"
)
advancements[block_ptr] = advance_offsets
else:
block_ptr = indexing.format(var)
@ -2476,7 +2476,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
buffer.splice(
f"""\
{result_var}_val, {result_var}_idx = triton_helpers.{root_op}_with_index({value}, {index}, {dim})
{result_var} = {self.reduction_resize(f'{result_var}_idx')}
{result_var} = {self.reduction_resize(f"{result_var}_idx")}
"""
)
@ -2576,8 +2576,8 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
{accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index(
{accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index
)
{accumulator} = {where_cond(f'{accumulator}_next', accumulator)}
{accumulator_index} = {where_cond(f'{accumulator_index}_next', accumulator_index)}
{accumulator} = {where_cond(f"{accumulator}_next", accumulator)}
{accumulator_index} = {where_cond(f"{accumulator_index}_next", accumulator_index)}
"""
)
final_argreduce(
@ -2751,9 +2751,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
)
self.compute.splice(
f"""\
{accumulator} = {where_cond(f'{accumulator}_next', accumulator)}
{accumulator_m2} = {where_cond(f'{accumulator_m2}_next', accumulator_m2)}
{accumulator_weight} = {where_cond(f'{accumulator_weight}_next', accumulator_weight)}
{accumulator} = {where_cond(f"{accumulator}_next", accumulator)}
{accumulator_m2} = {where_cond(f"{accumulator_m2}_next", accumulator_m2)}
{accumulator_weight} = {where_cond(f"{accumulator_weight}_next", accumulator_weight)}
"""
)
result_mean = result_var
@ -3040,9 +3040,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
self.filter_masks(masks)
masks = sorted(masks)
assert not self._load_mask, "ops.sort not supported inside ops.masked"
assert (
self.persistent_reduction
), "ops.sort is only supported in persistent reductions"
assert self.persistent_reduction, (
"ops.sort is only supported in persistent reductions"
)
cse_compute = functools.partial(self.cse.generate, self.compute)
dim = self.triton_tensor_ndim() - self.num_reduction_dims
@ -3302,9 +3302,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
{}
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
""".format(
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
)
""".format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream"))
)
def _get_heuristic(self):
@ -3344,19 +3342,19 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
inductor_meta["profile_bandwidth"] = config.profile_bandwidth
inductor_meta["profile_bandwidth_regex"] = config.profile_bandwidth_regex
inductor_meta["profile_bandwidth_output"] = config.profile_bandwidth_output
inductor_meta[
"profile_bandwidth_with_do_bench_using_profiling"
] = config.profile_bandwidth_with_do_bench_using_profiling
inductor_meta["profile_bandwidth_with_do_bench_using_profiling"] = (
config.profile_bandwidth_with_do_bench_using_profiling
)
if config.coordinate_descent_tuning:
inductor_meta[
"coordinate_descent_tuning"
] = config.coordinate_descent_tuning
inductor_meta[
"coordinate_descent_search_radius"
] = config.coordinate_descent_search_radius
inductor_meta[
"coordinate_descent_check_all_directions"
] = config.coordinate_descent_check_all_directions
inductor_meta["coordinate_descent_tuning"] = (
config.coordinate_descent_tuning
)
inductor_meta["coordinate_descent_search_radius"] = (
config.coordinate_descent_search_radius
)
inductor_meta["coordinate_descent_check_all_directions"] = (
config.coordinate_descent_check_all_directions
)
return inductor_meta
def codegen_kernel(self, name=None):
@ -4046,9 +4044,10 @@ class TritonScheduling(SIMDScheduling):
) -> tuple[float, str]:
"""Benchmark an already compiled module"""
device_interface = get_interface_for_device(V.graph.device_type)
with preserve_rng_state(), device_interface.device(
V.graph.get_current_device_or_throw()
): # type: ignore[attr-defined]
with (
preserve_rng_state(),
device_interface.device(V.graph.get_current_device_or_throw()), # type: ignore[attr-defined]
):
ms = None
def cache_file_path():
@ -4322,9 +4321,9 @@ def debug_triton_code(node: BaseSchedulerNode) -> list[str]:
device = node.get_device()
assert device is not None
backend = node.scheduler.get_backend(device)
assert isinstance(
backend, (SIMDScheduling, CUDACombinedScheduling)
), f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}"
assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling)), (
f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}"
)
with V.graph.set_current_device(device):
# Don't increment kernel count when generating debug string.

View File

@ -86,7 +86,9 @@ def _default_custom_combo_kernel_horizontal_partition(
# rnumel > 2048 usually has long execution time
# BaseSchedulerNode.group[-1][-1] is rnumel for reduction nodes
long_reduction = [
n for n in reduction if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 # type: ignore[arg-type]
n
for n in reduction
if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 # type: ignore[arg-type]
]
short_reduction = [n for n in reduction if n not in long_reduction]
if long_reduction:
@ -138,7 +140,7 @@ def set_custom_combo_kernel_horizontal_partition(
dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]],
],
list[list[BaseSchedulerNode]],
]
],
) -> None:
"""Sets the algorithm used to partition nodes into horizontal partitions. Nodes in different partitions
are implemented in different combo kernels. Nodes in the same partition are likely to be implemented
@ -593,9 +595,9 @@ class ComboKernel(Kernel):
num_persistent_reduction = len(
[e for e in heuristics_list if e == "persistent_reduction"]
)
assert (
num_reduction == 0
), "combining pointwise and reduction are not supported yet."
assert num_reduction == 0, (
"combining pointwise and reduction are not supported yet."
)
heuristics = (
"pointwise_with_reduction"
if num_persistent_reduction > 0
@ -784,13 +786,13 @@ class ComboKernel(Kernel):
name, tree, suffix=str(num)
)
if not tree.is_reduction:
assert isinstance(
grid[i][num], str
), f"Grid {grid[i][num]} should be a dynamic shape."
assert isinstance(grid[i][num], str), (
f"Grid {grid[i][num]} should be a dynamic shape."
)
numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else ""
assert (
grid[i][num] == numel_sign + numel_name
), f"numel args mismatch: {grid[i][num]} vs {numel_name}"
assert grid[i][num] == numel_sign + numel_name, (
f"numel args mismatch: {grid[i][num]} vs {numel_name}"
)
grid[i][num] = -expr if numel_sign == "-" else expr
if not tree.is_reduction or sub_kernel.inside_reduction:
@ -807,13 +809,13 @@ class ComboKernel(Kernel):
continue
expr = V.graph.sizevars.size_hint(tree.numel)
if not tree.is_reduction:
assert isinstance(
grid[i][num], str
), f"Grid {grid[i][num]} should be a dynamic shape."
assert isinstance(grid[i][num], str), (
f"Grid {grid[i][num]} should be a dynamic shape."
)
numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else ""
assert (
grid[i][num] == numel_sign + numel_name
), f"grid mismatch: {grid[i][num]} vs {numel_name}"
assert grid[i][num] == numel_sign + numel_name, (
f"grid mismatch: {grid[i][num]} vs {numel_name}"
)
grid[i][num] = -expr if numel_sign == "-" else expr
if not tree.is_reduction or sub_kernel.inside_reduction:
extra_args.append(expr)
@ -1015,9 +1017,7 @@ class ComboKernel(Kernel):
{}
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels
""".format(
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
)
""".format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream"))
)
def uniquify_block_sizes(

View File

@ -57,9 +57,9 @@ class TritonSplitScanKernel(TritonKernel):
def initialize_range_tree(self, pid_cache):
prefixes = ["y", "x", "r0_"]
assert len(self.numels) <= len(
prefixes
), "z dimension not supported for split scan"
assert len(self.numels) <= len(prefixes), (
"z dimension not supported for split scan"
)
active_prefixes = prefixes[len(prefixes) - len(self.numels) :]
grid_dims = {"r0_": 0, "x": 1, "y": 2}

View File

@ -184,7 +184,8 @@ def config_of(
if isinstance(x, TensorArg):
if include_tensor:
offset_aligned = V.graph.sizevars.statically_known_multiple_of(
x.offset * x.dtype.itemsize, alignment # type: ignore[arg-type]
x.offset * x.dtype.itemsize,
alignment, # type: ignore[arg-type]
)
return offset_aligned and not is_unaligned_buffer(x)
else:

View File

@ -104,8 +104,7 @@ def can_match_buffer_size(input_buf: BufferLike, output_buf: BufferLike):
# NB: this is symbolic so that we don't try to reuse a buffer
# for s0 for s1, just because they happen to share the same
# size hint
sympy_str(input_size)
== sympy_str(output_size)
sympy_str(input_size) == sympy_str(output_size)
) or (
# statically known that 0.95 * input_size <= output_size <= input_size
V.graph.sizevars.statically_known_geq(output_size, 0.95 * input_size)
@ -138,9 +137,9 @@ def convert_arg_type(arg: torch.Argument) -> str:
container_match = re.findall(py_container + r"\[([a-zA-Z_]+)]", python_type)
if len(container_match) == 1:
contained_type = container_match[0]
assert (
contained_type in PYTHON_TO_CPP
), f"unsupported {py_container} type in convert_arg_type: {contained_type}"
assert contained_type in PYTHON_TO_CPP, (
f"unsupported {py_container} type in convert_arg_type: {contained_type}"
)
cpp_contained_type = PYTHON_TO_CPP[contained_type]
return f"{cpp_container}<{cpp_contained_type}>"
@ -367,9 +366,9 @@ class SymbolicCallArg:
class MemoryPlanningState:
def __init__(self):
super().__init__()
self.reuse_pool: dict[
ReuseKey, list[FreeIfNotReusedLine]
] = collections.defaultdict(list)
self.reuse_pool: dict[ReuseKey, list[FreeIfNotReusedLine]] = (
collections.defaultdict(list)
)
self.total_allocated_buffer_size: int = 0
def __contains__(self, key: ReuseKey) -> bool:
@ -431,9 +430,9 @@ class EnterDeviceContextManagerLine(WrapperLine):
f"{V.graph.device_ops.cpp_aoti_stream_guard()} stream_guard(stream, this->device_idx_);"
)
else:
assert (
self.last_seen_device_guard_index == self.device_idx
), "AOTInductor only supports running on one CUDA device"
assert self.last_seen_device_guard_index == self.device_idx, (
"AOTInductor only supports running on one CUDA device"
)
else:
if self.last_seen_device_guard_index is None:
code.writeline(
@ -1794,7 +1793,8 @@ class PythonWrapperCodegen(CodeGen):
equals_1 = isinstance(
arg, (int, sympy.Integer)
) and V.graph.sizevars.statically_known_equals(
arg, 1 # type: ignore[arg-type]
arg,
1, # type: ignore[arg-type]
)
add_arg(idx, SizeArg(key, arg), equals_1=equals_1)
@ -2052,9 +2052,9 @@ class PythonWrapperCodegen(CodeGen):
buf_name = arg
buf = V.graph.get_buffer(arg)
else:
assert (
raw_arg is not None
), "V.graph.get_buffer(arg) and raw_arg can't be None at the same time"
assert raw_arg is not None, (
"V.graph.get_buffer(arg) and raw_arg can't be None at the same time"
)
buf_name = f"tmp_arg_{index}"
buf = raw_arg
@ -2181,9 +2181,9 @@ class PythonWrapperCodegen(CodeGen):
and kernel_name not in self.kernel_autotune_names
):
# Create example args for autotune in a separate epilogue
assert arg_types is not None and len(call_args) == len(
arg_types
), "call_args and arg_types do not match"
assert arg_types is not None and len(call_args) == len(arg_types), (
"call_args and arg_types do not match"
)
tensor_args = {}
all_args = []
@ -2191,9 +2191,9 @@ class PythonWrapperCodegen(CodeGen):
# create a dummy raw_args for uniform behavior in the following loop
raw_args = [None] * len(call_args)
else:
assert len(raw_args) == len(
call_args
), "call_args and raw_args do not match"
assert len(raw_args) == len(call_args), (
"call_args and raw_args do not match"
)
for i, (arg, arg_type, raw_arg) in enumerate(
zip(call_args, arg_types, raw_args)
@ -2411,9 +2411,9 @@ class PythonWrapperCodegen(CodeGen):
if isinstance(layout, ir.NoneLayout):
return
if isinstance(layout, ir.NonOwningLayout):
assert isinstance(
layout.view, ir.ReinterpretView
), f"unexpected {type(layout.view)}: {layout.view}"
assert isinstance(layout.view, ir.ReinterpretView), (
f"unexpected {type(layout.view)}: {layout.view}"
)
assert isinstance(layout.view.data, ir.StorageBox), type(layout.view.data)
assert isinstance(layout.view.data.data, ir.Buffer), type(layout.view.data)
self.codegen_allocation(layout.view.data.data)
@ -2535,9 +2535,9 @@ class PythonWrapperCodegen(CodeGen):
def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs):
# All inputs of hops must be explicitly passed in.
# Free tensors and basic symbols should have been explicitly lifted as inputs in dynamo.
assert len(outer_inputs) == len(
subgraph.graph.graph_input_names
), f"graph_input_names:{subgraph.graph.graph_input_names}, outer_inputs: {outer_inputs}"
assert len(outer_inputs) == len(subgraph.graph.graph_input_names), (
f"graph_input_names:{subgraph.graph.graph_input_names}, outer_inputs: {outer_inputs}"
)
for inner_input, outer_input in zip(
subgraph.graph.graph_input_names, outer_inputs
):

View File

@ -219,8 +219,7 @@ def _schedule_for_comm(
for snode, deps in unmet_deps.items():
assert len(deps) == 0, (
"Detected unscheduled nodes. "
f"Nodes with unmet dependencies: {unmet_deps}"
f"Detected unscheduled nodes. Nodes with unmet dependencies: {unmet_deps}"
)
return scheduled
@ -354,9 +353,7 @@ def remove_fsdp2_unsharded_param_graph_input_usage(graph: torch.fx.Graph):
node.op == "call_function"
and node.target == torch.ops.inductor.resize_storage_bytes_.default
):
assert (
node.args[0].op == "placeholder"
), f"""\
assert node.args[0].op == "placeholder", f"""\
Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]}
"""
graph_input = node.args[0]
@ -408,9 +405,7 @@ Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass for that
if node.op == "call_function" and node.target == torch.ops.fsdp.copy_.default:
fsdp_copy_node = node
unsharded_param = node.args[0]
assert (
unsharded_param.op == "placeholder"
), f"""
assert unsharded_param.op == "placeholder", f"""
Assumed all FSDP2 `unsharded_param`s to be graph input, but it's not true!
Offending node: {unsharded_param}. Graph: {graph}
"""

View File

@ -281,9 +281,9 @@ def _unlift_graph(
elif node_name in graph_signature.inputs_to_buffers:
buffer_name = graph_signature.inputs_to_buffers[node_name]
lifted_inputs.append(buffer_name)
gm.meta[
get_cloned_parameter_buffer_name(buffer_name)
] = clone_preserve_strides(state_dict[buffer_name])
gm.meta[get_cloned_parameter_buffer_name(buffer_name)] = (
clone_preserve_strides(state_dict[buffer_name])
)
else:
assert node_name in graph_signature.user_inputs
lifted_inputs.append(None)
@ -542,7 +542,7 @@ def fake_tensor_prop(
# pass config dict back to user
def get_patched_config_dict(
config_patches: Optional[Union[str, dict[str, Any]]] = None
config_patches: Optional[Union[str, dict[str, Any]]] = None,
) -> dict[str, Any]:
with config.patch(config_patches):
return config.get_config_copy()
@ -579,8 +579,7 @@ class _CompileFxCallable(Protocol):
gm: GraphModule,
example_inputs: Sequence[InputType],
**kwargs: Unpack[_CompileFxKwargs],
) -> OutputCode:
...
) -> OutputCode: ...
def compile_fx_inner(
@ -662,9 +661,9 @@ def _compile_fx_inner(
static_inputs_log.debug("static input idxs compile_fx_inner: %s", static_input_idxs)
inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs)
assert isinstance(
next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)
), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}"
assert isinstance(next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)), (
f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}"
)
if (cudagraphs := graph_kwargs.get("cudagraphs")) is None:
graph_kwargs["cudagraphs"] = cudagraphs = BoxedBool(config.triton.cudagraphs)
@ -679,9 +678,10 @@ def _compile_fx_inner(
fx_graph_remote_cache = should_use_remote_fx_graph_cache()
with _WaitCounter(
"pytorch.wait_counter.fx_codegen_and_compile"
).guard() as _, _WaitCounter("pytorch.wait_counter.all_compilation_types").guard():
with (
_WaitCounter("pytorch.wait_counter.fx_codegen_and_compile").guard() as _,
_WaitCounter("pytorch.wait_counter.all_compilation_types").guard(),
):
use_cache = (
not config.force_disable_caches
and (config.fx_graph_cache or fx_graph_remote_cache)
@ -865,8 +865,7 @@ class FxCompile(ABC):
example_inputs: Sequence[InputType],
inputs_to_check: Sequence[int],
graph_kwargs: _CompileFxKwargs,
) -> OutputCode:
...
) -> OutputCode: ...
class _InProcessFxCompile(FxCompile):
@ -890,16 +889,17 @@ class _InProcessFxCompile(FxCompile):
cpp_wrapper: bool = graph_kwargs.get("cpp_wrapper", False)
aot_mode: bool = V.aot_compilation
is_inference: bool = graph_kwargs.get("is_inference", False)
extern_node_serializer: Optional[
Callable[[list[ExternKernelNode]], Any]
] = graph_kwargs.get("extern_node_serializer", None)
extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]] = (
graph_kwargs.get("extern_node_serializer", None)
)
boxed_forward_device_index: Optional[BoxedDeviceIndex] = graph_kwargs.get(
"boxed_forward_device_index", None
)
with _WaitCounter(
"pytorch.wait_counter.actual_codegen_and_compile"
).guard(), dynamo_utils.preserve_rng_state():
with (
_WaitCounter("pytorch.wait_counter.actual_codegen_and_compile").guard(),
dynamo_utils.preserve_rng_state(),
):
if (sleep_sec := config.sleep_sec_TESTING_ONLY) is not None:
import time
@ -1038,9 +1038,11 @@ class _InProcessFxCompile(FxCompile):
# See details in vllm/compilation/pass_manager.py.
log.warning("failed to log pt2_configs")
with V.set_fake_mode(fake_mode), maybe_disable_comprehensive_padding(
example_inputs
), maybe_disable_graph_partition(cpp_wrapper, aot_mode):
with (
V.set_fake_mode(fake_mode),
maybe_disable_comprehensive_padding(example_inputs),
maybe_disable_graph_partition(cpp_wrapper, aot_mode),
):
const_output_index = None
const_graph = None
const_code = None
@ -1123,9 +1125,9 @@ class _InProcessFxCompile(FxCompile):
if graph.aot_mode:
from .codecache import AotCodeCompiler
assert (
graph.cpp_wrapper
), "AOT mode only supports C++ wrapper"
assert graph.cpp_wrapper, (
"AOT mode only supports C++ wrapper"
)
code, linemap = graph.codegen_with_cpp_wrapper()
output_code_log.debug("Output code: \n%s", code)
@ -1509,10 +1511,13 @@ def cudagraphify(
def run(new_inputs: Sequence[InputType]) -> Any:
nonlocal compiled_fn
if compiled_fn is None:
with dynamo_utils.dynamo_timed(
with (
dynamo_utils.dynamo_timed(
"cudagraphify",
log_pt2_compile_event=True,
), dynamo_utils.preserve_rng_state():
),
dynamo_utils.preserve_rng_state(),
):
compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs)
return compiled_fn(new_inputs)
@ -1669,13 +1674,16 @@ def compile_fx_aot(
extern_node_serializer = config_patches.pop("extern_node_serializer", None)
saved_compile_id = model_.meta.get("dynamo_compile_id", None)
saved_compile_context = torch._guards.CompileContext(saved_compile_id)
with V.set_aot_compilation(True), torch._guards.compile_context(
saved_compile_context
), chromium_event_timed(
with (
V.set_aot_compilation(True),
torch._guards.compile_context(saved_compile_context),
chromium_event_timed(
"compile_fx_aot",
log_pt2_compile_event=True,
reset_event_log_on_exit=True,
), get_metrics_context():
),
get_metrics_context(),
):
compiled_artifacts = compile_fx(
model_,
example_inputs_,
@ -1875,12 +1883,15 @@ def compile_fx(
# TODO: This probably shouldn't be a recursive call
if config.cpp_wrapper:
with config.patch(
with (
config.patch(
{
"cpp_wrapper": False, # reset to break recursive call to compile_fx
**get_cpp_wrapper_config(),
}
), V.set_real_inputs(example_inputs_):
),
V.set_real_inputs(example_inputs_),
):
inputs_: Sequence[InputType] = example_inputs_
if isinstance(model_, GraphModule):
@ -1940,10 +1951,10 @@ def compile_fx(
# Do the actual work
with _use_lazy_graph_module(
dynamo_config.use_lazy_graph_module
), enable_python_dispatcher(), torch.fx.traceback.preserve_node_meta(
config.trace.enabled
with (
_use_lazy_graph_module(dynamo_config.use_lazy_graph_module),
enable_python_dispatcher(),
torch.fx.traceback.preserve_node_meta(config.trace.enabled),
):
# Pre-grad passes cannot be run if we weren't given a GraphModule.
# Dynamo will always produce a GraphModule, but this handles cases
@ -2085,9 +2096,9 @@ def compile_fx(
boxed_forward_device_index=forward_device,
)
fw_compiler: Callable[
[GraphModule, Sequence[InputType]], OutputCode
] = functools.partial(fw_compiler_base, is_inference=False)
fw_compiler: Callable[[GraphModule, Sequence[InputType]], OutputCode] = (
functools.partial(fw_compiler_base, is_inference=False)
)
fw_compiler = SerializableAOTDispatchCompiler(OutputCode, fw_compiler)
if config.freezing and not torch.is_grad_enabled():
@ -2124,9 +2135,10 @@ def compile_fx(
) -> OutputCode:
from torch._dynamo.convert_frame import compile_lock
with dynamo_utils.dynamo_timed(
"compile_fx.<locals>.bw_compiler"
), compile_lock:
with (
dynamo_utils.dynamo_timed("compile_fx.<locals>.bw_compiler"),
compile_lock,
):
model_outputs_node = output_node(gm)
if config.bw_outputs_user_visible:
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
@ -2194,10 +2206,11 @@ def compile_fx(
with V.set_fake_mode(fake_mode), compiled_autograd._disable(), context():
return inference_compiler(unlifted_gm, example_inputs_)
with V.set_fake_mode(fake_mode), torch._guards.tracing(
tracing_context
), compiled_autograd._disable(), functorch_config.patch(
unlift_effect_tokens=True
with (
V.set_fake_mode(fake_mode),
torch._guards.tracing(tracing_context),
compiled_autograd._disable(),
functorch_config.patch(unlift_effect_tokens=True),
):
try:
return aot_autograd(

View File

@ -530,7 +530,8 @@ class CompilerBisector:
)
if result:
curr_subsystem = cls.get_subsystem_object(
curr_backend, cls.get_subsystem() # type: ignore[arg-type]
curr_backend,
cls.get_subsystem(), # type: ignore[arg-type]
)
if isinstance(curr_subsystem, BinarySubsystem):

View File

@ -80,9 +80,9 @@ fx_graph_cache: bool = Config(
fx_graph_remote_cache: Optional[bool] = fx_graph_remote_cache_default()
# should we bundle triton caching into fx graph cache
bundle_triton_into_fx_graph_cache: Optional[
bool
] = bundle_triton_into_fx_graph_cache_default()
bundle_triton_into_fx_graph_cache: Optional[bool] = (
bundle_triton_into_fx_graph_cache_default()
)
# Enable autotune local cache.
#
@ -1390,12 +1390,12 @@ class halide:
# Halide autoscheduler to use, choices are:
# "Anderson2021" (gpu-only), "Li2018", "Adams2019" (cpu-only), or "Mullapudi2016" (cpu-only)
scheduler_cuda: Literal[
"Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"
] = "Anderson2021"
scheduler_cpu: Literal[
"Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"
] = "Adams2019"
scheduler_cuda: Literal["Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"] = (
"Anderson2021"
)
scheduler_cpu: Literal["Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"] = (
"Adams2019"
)
# Controls `no_asserts` flag passed to Halide target (warning: can false positive)
asserts = False

View File

@ -125,7 +125,8 @@ class ConstantFolder(torch.fx.Interpreter):
and is_woq_int8_pattern(next(iter(node.users)))
)
) and is_const_source(
node.args[0], self.lifted_constant_names # type: ignore[arg-type]
node.args[0], # type: ignore[arg-type]
self.lifted_constant_names,
):
# Case 1: int8_weight -> dq -> bf16_weight
# Case 2: int8_weight -> permute -> dq -> bf16_weight

View File

@ -1633,8 +1633,8 @@ class CppBuilder:
"""
)
assert os.path.exists(
cmake_path
), f"save_link_cmd_to_cmakefile expects {cmake_path} to already exist"
assert os.path.exists(cmake_path), (
f"save_link_cmd_to_cmakefile expects {cmake_path} to already exist"
)
with open(cmake_path, "a") as f:
f.write(contents)

View File

@ -119,6 +119,7 @@ from . import config
@dataclasses.dataclass(frozen=True)
class GraphID:
"Unique counter of a cuda graph recording"
id: int
@ -622,11 +623,15 @@ class CUDAWarmupNode:
refs = list(self.path_live_weakrefs())
check_memory_pool(self.device_index, self.cuda_graphs_pool, refs)
with torch.cuda.device(
self.device_index
), disable_conv_cache_emptying(), clear_cublas_manager(), _use_cuda_memory_pool_manager(
with (
torch.cuda.device(self.device_index),
disable_conv_cache_emptying(),
clear_cublas_manager(),
_use_cuda_memory_pool_manager(
self.device_index, self.cuda_graphs_pool, self.stream
), get_history_recording():
),
get_history_recording(),
):
out = self.wrapped_function.model(new_inputs)
# We need to know which outputs are allocated within the cudagraph pool
@ -713,6 +718,7 @@ UnaliasedStorage = _UnaliasedStorage()
class AliasesPriorGraphOutput(OutputAliasInfo):
"Marks that the graph output aliases an output of a prior graph"
__slots__ = ["index"]
index: PathOutputIndex
@ -1200,14 +1206,18 @@ class CUDAGraphNode:
]
check_memory_pool(self.device, self.cuda_graphs_pool, memory)
with preserve_rng_state(), torch.cuda.device(
self.device
), clear_cublas_manager(), torch.cuda.graph(
with (
preserve_rng_state(),
torch.cuda.device(self.device),
clear_cublas_manager(),
torch.cuda.graph(
self.graph,
stream=self.stream,
pool=self.cuda_graphs_pool,
capture_error_mode="thread_local",
), get_history_recording():
),
get_history_recording(),
):
static_outputs = model(inputs)
# running model should reclaim memory
@ -1247,12 +1257,14 @@ class CUDAGraphNode:
self.output_storage_alias.append(UnaliasedStorage)
continue
(
torch._check(
o.is_cuda or o.untyped_storage().data_ptr() == 0,
lambda: (
"Expected all cuda outputs in cuda graph recording. Non cuda output "
f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}"
),
),
)
ref = static_input_persistent_storage_ptrs.get(
@ -1291,9 +1303,9 @@ class CUDAGraphNode:
if self.stack_traces is None:
self.stack_traces = [None for _ in range(len(outputs))]
else:
assert len(self.stack_traces) == len(
outputs
), "Wrong number of stack traces passed in"
assert len(self.stack_traces) == len(outputs), (
"Wrong number of stack traces passed in"
)
assert not self.outputs_weakrefs
for out, static_output_tensor in zip(outputs, self.static_output_tensors):
@ -1599,12 +1611,14 @@ class CUDAGraphNode:
self.stream.wait_stream(torch.cuda.current_stream())
recording_inputs: list[InputType] = []
with warnings.catch_warnings(record=True), torch.cuda.device(
self.device
), _use_cuda_memory_pool_manager(
with (
warnings.catch_warnings(record=True),
torch.cuda.device(self.device),
_use_cuda_memory_pool_manager(
self.device,
mem_pool=self.cuda_graphs_pool,
stream=self.stream,
),
):
for i, inp in enumerate(inputs):
if not isinstance(inp, torch.Tensor):
@ -1736,12 +1750,8 @@ def check_memory_pool(
pool_id: tuple[int, int],
live_storages_ptrs: list[StorageWeakRefWrapper],
) -> None:
assert all(
isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs
) # noqa: C419
unique_storages = {
stor.data_ptr() for stor in live_storages_ptrs if stor()
} # noqa: set_linter
assert all(isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs) # noqa: C419
unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()} # noqa: set_linter
# check if there is a divergence first, then do the expensive snapshot call after
# we know it will error
@ -1864,11 +1874,14 @@ class CUDAGraphTreeManager:
self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph()
self.cuda_graphs_thread_pool = torch.cuda.graph_pool_handle()
with warnings.catch_warnings(record=True), torch.cuda.graph(
with (
warnings.catch_warnings(record=True),
torch.cuda.graph(
self.graph,
pool=self.cuda_graphs_thread_pool,
stream=self.stream,
capture_error_mode="thread_local",
),
):
pass
@ -2230,7 +2243,10 @@ class CUDAGraphTreeManager:
constants: tuple[torch.Tensor, ...],
placeholders: tuple[PlaceholderInfo, ...],
mutated_input_idxs: tuple[int, ...],
) -> tuple[ModelType, OutputType,]:
) -> tuple[
ModelType,
OutputType,
]:
id = self.new_func_id()
self.ids_to_stack_traces[id] = stack_traces
self.ids_to_funcs[id] = WrappedFunction(

View File

@ -28,6 +28,7 @@ ModelType = Callable[[list[InputType]], OutputType]
@dataclasses.dataclass(frozen=True)
class FunctionID:
"Unique counter of a function wrapped in cudagraphify_impl"
id: int
@ -164,7 +165,7 @@ def _get_use_stack_trace(node: torch.fx.Node) -> Optional[str]:
def check_multiple_devices_or_any_cpu_nodes(
device_node_mapping: dict[torch.device, torch.fx.Node]
device_node_mapping: dict[torch.device, torch.fx.Node],
) -> Optional[str]:
if cpu_node := device_node_mapping.get(torch.device("cpu")):
msg = f"cpu device ({cpu_node.name})"
@ -184,7 +185,7 @@ def check_multiple_devices_or_any_cpu_nodes(
def check_lowering_disable_cudagraph(
device_node_mapping: dict[torch.device, torch.fx.Node]
device_node_mapping: dict[torch.device, torch.fx.Node],
) -> Optional[str]:
return check_multiple_devices_or_any_cpu_nodes(device_node_mapping)
@ -276,9 +277,9 @@ def log_data_ptr_mismatch(
Logs the mismatch between input data pointers and recorded data pointers.
This checks only idxs in target_idxs.
"""
assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len(
placeholders
), "length mismatch between inputs, recorded_data_ptr, and placeholders"
assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len(placeholders), (
"length mismatch between inputs, recorded_data_ptr, and placeholders"
)
t_tensors = [inputs[i] for i in target_idxs]
t_data_ptrs = [recorded_data_ptr[i] for i in target_idxs]

View File

@ -240,7 +240,7 @@ def update_orig_fx_node_name_to_buf_name(
def get_node_name_to_buf_meta(
node_name_to_buf_name: dict[str, str]
node_name_to_buf_name: dict[str, str],
) -> dict[str, BufMeta]:
buf_name_to_n_node = {}
for node_name, buf_name in node_name_to_buf_name.items():

View File

@ -123,7 +123,7 @@ remove_decompositions(decompositions, decomps_to_exclude)
def register_decomposition(
ops: list[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]]
ops: list[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]],
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined]
if op in decompositions:

View File

@ -194,7 +194,9 @@ class MemoryDep(Dep):
)
new_index = sympy_subs(sympy.expand(self.index), replacement) # type: ignore[arg-type] # next PR
out = MemoryDep(self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values())) # type: ignore[arg-type]
out = MemoryDep(
self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values())
) # type: ignore[arg-type]
return out
@property
@ -649,11 +651,16 @@ def extract_loop_body_with_args(
inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) # type: ignore[arg-type]
for entry in fn.memory_usage[MemoryUsageType.STORE]:
inner.store(
entry.buffer_name, name_to_index[entry.index_name], None, entry.mode # type: ignore[arg-type]
entry.buffer_name,
name_to_index[entry.index_name],
None, # type: ignore[arg-type]
entry.mode,
)
for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]:
inner.store_reduction(
entry.buffer_name, name_to_index[entry.index_name], None # type: ignore[arg-type]
entry.buffer_name,
name_to_index[entry.index_name],
None, # type: ignore[arg-type]
)
for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]:
inner.index_expr(name_to_index[entry.index_name], None)
@ -661,7 +668,11 @@ def extract_loop_body_with_args(
# All that matters is that we record the buffer name, so place it in the
# "boundaries" name position to ensure that it's recorded.
inner.bucketize(
None, (entry.buffer_name, None, None, None), None, None, None # type: ignore[arg-type]
None,
(entry.buffer_name, None, None, None),
None,
None, # type: ignore[arg-type]
None, # type: ignore[arg-type]
)
# fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped
return inner
@ -801,8 +812,9 @@ def extract_free_unbacked_symbols(
handler = FreeUnbackedSymbolsOpsHandler()
# NB: I cargo culted the allow_indexing patch here, I don't understand why
# people do this all over
with V.set_ops_handler(handler), patch.object(
FlexibleLayout, "allow_indexing", True
with (
V.set_ops_handler(handler),
patch.object(FlexibleLayout, "allow_indexing", True),
):
fn(*args)
return handler.symbols

View File

@ -19,8 +19,7 @@ T = TypeVar("T")
class DTypeVar(Protocol):
@property
def dtype(self) -> torch.dtype:
...
def dtype(self) -> torch.dtype: ...
DTypeArg = Union[DTypeVar, torch.types.Number, str, OpsValue]

View File

@ -526,6 +526,7 @@ class ConfigFuzzer:
```python
import torch._inductor.config as cfg
def create_simple_test_model_gpu() -> FactoryOutputType:
batch_size = 32
seq_length = 50
@ -539,6 +540,8 @@ class ConfigFuzzer:
return True
return test_fn
fuzzer = ConfigFuzzer(cfg, create_simple_test_model_gpu, seed=2)
# Test every pair of configs:
@ -550,7 +553,9 @@ class ConfigFuzzer:
ret = fuzzer.bisect(num_attempts=10)
# reproduce a failing config
fuzzer.reproduce([{"triton.autotune_pointwise": ..., "coordinate_descent_tuning": ...}])
fuzzer.reproduce(
[{"triton.autotune_pointwise": ..., "coordinate_descent_tuning": ...}]
)
```
The list of known failures on inductor config are:

View File

@ -531,7 +531,11 @@ def tuned_b2b_gemm(
A.realize()
B.realize()
C.realize()
layout = FixedLayout(A.get_device_or_error(), A.get_dtype(), [A.shape[0], C.shape[1]]) # type: ignore[index]
layout = FixedLayout(
A.get_device_or_error(),
A.get_dtype(),
[A.shape[0], C.shape[1]], # type: ignore[index]
)
subgraph_buffer = build_subgraph_buffer(
[create_placeholder("inner_mm", A.get_dtype(), A.get_device_or_error())],
subgraph,

View File

@ -545,9 +545,9 @@ def schedule_comm_wait(graph: fx.Graph) -> None:
node_indices = {node: i for i, node in enumerate(graph.nodes)}
for allreduce in comm_blocks:
# Find the earliest/first user -- target_node.
assert (
len(allreduce.outputs) >= 1
), f"Found a allreduce that has zero outputs/users -- {allreduce}."
assert len(allreduce.outputs) >= 1, (
f"Found a allreduce that has zero outputs/users -- {allreduce}."
)
# Initialize the target node to avoid typing issues.
target_node = next(iter(next(iter(allreduce.outputs)).users))
target_node_index = 2**31

View File

@ -380,7 +380,9 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
# argument. `graph.get_attr` and
# `graph.call_function` does not allow the `name` argument.
conv_get_node = graph.create_node(
op="get_attr", target=conv_node.target, name="get_conv" # type: ignore[union-attr]
op="get_attr",
target=conv_node.target, # type: ignore[union-attr]
name="get_conv",
)
bn_get_node = graph.create_node(
op="get_attr", target=bn_node.target, name="get_bn"

View File

@ -866,7 +866,9 @@ def _get_sfdp_patterns():
name += "_bs1"
training_name = name + "_training"
yield training_name, {
yield (
training_name,
{
"search_fn": pattern,
"replace_fn": replacement,
"example_inputs": args,
@ -874,7 +876,8 @@ def _get_sfdp_patterns():
"pass_dicts": patterns,
"extra_check": extra_check,
"scalar_workaround": workaround,
}
},
)
if workaround:
assert len(workaround) == 1 and "dropout_p" in workaround
@ -886,7 +889,9 @@ def _get_sfdp_patterns():
workaround = {}
inference_name = name + "_inference"
yield inference_name, {
yield (
inference_name,
{
"search_fn": pattern,
"replace_fn": replacement,
"example_inputs": args,
@ -897,7 +902,8 @@ def _get_sfdp_patterns():
# with dropout turned into clone, we end up with a number of
# semantically identical graphs
"skip_duplicates": True,
}
},
)
@functools.lru_cache(None)

View File

@ -271,7 +271,9 @@ class PostGradBatchLinearFusion(BatchFusion):
args=(batch_biases[i],),
kwargs={"size": broadcast_shape},
)
broadcast_bias.meta["val"] = aten.broadcast_to(batch_biases_meta[i]["val"], broadcast_shape) # type: ignore[assignment]
broadcast_bias.meta["val"] = aten.broadcast_to(
batch_biases_meta[i]["val"], broadcast_shape
) # type: ignore[assignment]
new_bias_add = graph.call_function( # type: ignore[operator]
aten.add.Tensor, args=((broadcast_bias, new_mm))
)
@ -803,9 +805,9 @@ class BatchLayernormFusion(BatchFusion):
group_biases = None # type: ignore[assignment]
if all(weight is None for weight in group_weights):
group_weights = None # type: ignore[assignment]
assert all(
eps == group_epss[0] for eps in group_epss
), "all epsilon values must be equal"
assert all(eps == group_epss[0] for eps in group_epss), (
"all epsilon values must be equal"
)
with graph.inserting_before(subset[0]): # type: ignore[operator]
stack_input = graph.call_function( # type: ignore[operator]
@ -996,7 +998,11 @@ class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
# for relu op, we also use the inplace to construct the key
# we batch the ops with same parent to enable followup split cat
parent = node.args[0]
parent = parent.target if self.graph_search_options.get("fuse_nodes_with_same_parent", False) else "" # type: ignore[union-attr]
parent = (
parent.target # type: ignore[union-attr]
if self.graph_search_options.get("fuse_nodes_with_same_parent", False)
else ""
)
group_key = (
"batch_aten_" + self.op.__name__.lower().split(".")[0],
str(input.meta["val"].shape),
@ -1293,9 +1299,9 @@ def get_fusion_candidates(
"""
q: collections.deque[tuple[int, torch.fx.Node]] = collections.deque()
candidate_dict: collections.defaultdict[
Any, list[torch.fx.Node]
] = collections.defaultdict(list)
candidate_dict: collections.defaultdict[Any, list[torch.fx.Node]] = (
collections.defaultdict(list)
)
if root_node.target in SEARCH_EXCLUSIONS:
return candidate_dict

View File

@ -763,9 +763,7 @@ def _get_node_to_ancestors(
"""
Compute the ancestors for all nodes in a graph.
"""
node_to_ancestors = defaultdict(
OrderedSet[torch.fx.Node]
) # type: ignore[var-annotated]
node_to_ancestors = defaultdict(OrderedSet[torch.fx.Node]) # type: ignore[var-annotated]
for node in graph.nodes:
node_to_ancestors[node] = OrderedSet(node.all_input_nodes)
for dep in node.all_input_nodes:

View File

@ -558,9 +558,9 @@ if torch._C._has_mkldnn:
binary_nodes = filter_nodes(match.nodes, binary_op)
def _get_compute_node(_binary_node, _other_index):
assert (
len(_binary_node.all_input_nodes) == 2
), "Binary node should have 2 input nodes."
assert len(_binary_node.all_input_nodes) == 2, (
"Binary node should have 2 input nodes."
)
_compute_index = 1 if (_other_index == 0) else 0
return _binary_node.args[_compute_index]
@ -614,9 +614,9 @@ if torch._C._has_mkldnn:
else:
computation_args += [1.0, None, [], None]
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1
counters["inductor"][
"mkldnn_conv_binary_unary_fusion_matcher_nodes"
] += len(match.nodes)
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"] += (
len(match.nodes)
)
return L[fusion_op](*computation_args)
return fn
@ -659,9 +659,9 @@ if torch._C._has_mkldnn:
else:
computation_args += [1.0, None, [], None]
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1
counters["inductor"][
"mkldnn_conv_binary_unary_fusion_matcher_nodes"
] += len(match.nodes)
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"] += (
len(match.nodes)
)
# Make sure the other is not an alias or mutation(fx side doesn't has such info).
other.realize()
if not _can_be_inplace(other) or other.data.shape != list(
@ -1310,9 +1310,9 @@ if torch._C._has_mkldnn:
)
batch_size = input.meta.get("val").shape[0]
if has_free_symbols(batch_size):
assert (
is_lp_weight or mkldnn._is_mkldnn_acl_supported()
), f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
assert is_lp_weight or mkldnn._is_mkldnn_acl_supported(), (
f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
)
# For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance.
packed_weight_inputs = (
transpose_weight_node,

View File

@ -437,7 +437,7 @@ def _should_pad_bench(
return False
def realize_symbols(
ds: Union[torch.Size, tuple[torch.SymInt, ...]]
ds: Union[torch.Size, tuple[torch.SymInt, ...]],
) -> list[int]:
return [d if isinstance(d, int) else d.node.hint for d in ds]

View File

@ -137,9 +137,9 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
pattern_matcher_pass.apply
)
if not is_same_dict(counters["inductor"], inductor_before_change):
optimus_scuba_log[
f"{pattern_matcher_pass.pass_name}_post_grad"
] = upload_graph(gm.graph)
optimus_scuba_log[f"{pattern_matcher_pass.pass_name}_post_grad"] = (
upload_graph(gm.graph)
)
if config.b2b_gemm_pass:
B2B_GEMM_PASS.apply(gm.graph) # type: ignore[arg-type]

View File

@ -277,9 +277,9 @@ def pre_grad_passes(
for _ in range(counter):
pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
if not is_same_dict(counters["inductor"], inductor_before_change):
optimus_scuba_log[
f"{pattern_matcher_pass.pass_name}_pre_grad"
] = upload_graph(gm.graph)
optimus_scuba_log[f"{pattern_matcher_pass.pass_name}_pre_grad"] = (
upload_graph(gm.graph)
)
# TODO: move efficient_conv_bn_eval_pass to the fusions dict too.
efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type]

View File

@ -763,9 +763,9 @@ def _register_quantized_conv_binary_lowering(
accum.realize()
from .mkldnn_fusion import _can_be_inplace
assert _can_be_inplace(
accum
), "QConv Binary Inplace Fusion requires accum is not an alias or mutation."
assert _can_be_inplace(accum), (
"QConv Binary Inplace Fusion requires accum is not an alias or mutation."
)
computation_args = (
x,
@ -1307,9 +1307,9 @@ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
def clone_to_new_node(graph, source_node, user_node):
# Clone the source_node to a new node
# Replace user_node's input from source_node to new_node
assert (
source_node.op == "call_function"
), "clone_to_new_node only support node.op call_function"
assert source_node.op == "call_function", (
"clone_to_new_node only support node.op call_function"
)
with graph.inserting_before(user_node):
new_node = graph.call_function(
source_node.target,
@ -1343,9 +1343,9 @@ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
# For a dequant pattern, we expect the start node is a dequantize_per_tensor node
return _node
else:
assert (
len(_node.args) >= 1
), "In in dequant pattern, each node should have more than 1 arg."
assert len(_node.args) >= 1, (
"In in dequant pattern, each node should have more than 1 arg."
)
return _find_first_node_in_dequant_pattern(_node.args[0])
dequant_pattern_start_node = _find_first_node_in_dequant_pattern(

View File

@ -616,7 +616,8 @@ def merge_splits(
dim=first_split_dim,
)
first_split_num_to_user = {
user.args[1]: user for user in first_split.users.keys() # type: ignore[union-attr]
user.args[1]: user
for user in first_split.users.keys() # type: ignore[union-attr]
}
new_split_num = 0
@ -706,7 +707,11 @@ class SplitCatSimplifier:
graph, split_node, split_sections, user_inputs_list, simplified_split_ranges
)
self.replace_cat(
graph, split_node, next_users, user_inputs_list_new, transform_params_list # type: ignore[arg-type]
graph,
split_node,
next_users,
user_inputs_list_new,
transform_params_list, # type: ignore[arg-type]
)
self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type]
counters["inductor"]["unbind_stack_pass"] += 1
@ -913,7 +918,9 @@ class SplitCatSimplifier:
)
if is_node_meta_valid(split_input): # type: ignore[arg-type, union-attr]
new_split.meta["example_value"] = torch.split(
split_input.meta["example_value"], [r[1] - r[0] for r in split_ranges], dim=split_dim # type: ignore[union-attr]
split_input.meta["example_value"], # type: ignore[union-attr]
[r[1] - r[0] for r in split_ranges],
dim=split_dim,
)
counters["inductor"]["scmerge_split_added"] += 1
split_items = []
@ -1005,7 +1012,10 @@ class SplitCatSimplifier:
stacked_input = graph.call_function(
torch.stack, args=(to_stack,), kwargs={"dim": stack_dim}
)
stacked_input.meta["example_value"] = torch.stack(to_stack_meta, dim=stack_dim) # type: ignore[arg-type, union-attr]
stacked_input.meta["example_value"] = torch.stack( # type: ignore[arg-type]
to_stack_meta,
dim=stack_dim, # type: ignore[arg-type]
)
to_stack, to_stack_meta = [], []
stack_dim = None
user_inputs_new_transformed.append(stacked_input)
@ -1023,19 +1033,28 @@ class SplitCatSimplifier:
user_input_new = graph.call_function(
torch.unflatten, args=(user_input_new, *unflatten_params)
)
user_input_new.meta["example_value"] = torch.unflatten(user_input_new_meta, *unflatten_params) # type: ignore[arg-type, possibly-undefined, union-attr]
user_input_new.meta["example_value"] = torch.unflatten( # type: ignore[arg-type]
user_input_new_meta, # type: ignore[arg-type]
*unflatten_params, # type: ignore[arg-type]
)
if movedim_params:
user_input_new_meta = user_input_new.meta["example_value"]
user_input_new = graph.call_function(
torch.movedim, args=(user_input_new, *movedim_params)
)
user_input_new.meta["example_value"] = torch.movedim(user_input_new_meta, *movedim_params) # type: ignore[arg-type, possibly-undefined, union-attr]
user_input_new.meta["example_value"] = torch.movedim( # type: ignore[arg-type]
user_input_new_meta, # type: ignore[arg-type]
*movedim_params, # type: ignore[arg-type]
)
if flatten_params:
user_input_new_meta = user_input_new.meta["example_value"]
user_input_new = graph.call_function(
torch.flatten, args=(user_input_new, *flatten_params)
)
user_input_new.meta["example_value"] = torch.flatten(user_input_new_meta, *flatten_params) # type: ignore[arg-type, possibly-undefined, union-attr]
user_input_new.meta["example_value"] = torch.flatten( # type: ignore[arg-type]
user_input_new_meta,
*flatten_params, # type: ignore[arg-type]
)
user_inputs_new_transformed.append(user_input_new)
user_inputs_new_transformed_meta.append(
user_input_new.meta["example_value"]
@ -1044,7 +1063,10 @@ class SplitCatSimplifier:
stacked_input = graph.call_function(
torch.stack, args=(to_stack,), kwargs={"dim": stack_dim}
)
stacked_input.meta["example_value"] = torch.stack(to_stack_meta, dim=stack_dim) # type: ignore[arg-type, union-attr]
stacked_input.meta["example_value"] = torch.stack( # type: ignore[arg-type]
to_stack_meta,
dim=stack_dim, # type: ignore[arg-type]
)
user_inputs_new_transformed.append(stacked_input)
user_inputs_new_transformed_meta.append(
stacked_input.meta["example_value"]
@ -1058,14 +1080,15 @@ class SplitCatSimplifier:
kwargs={"dim": cat_dim},
)
new_cat_node.meta["example_value"] = torch.cat(
user_inputs_new_transformed_meta, dim=cat_dim
user_inputs_new_transformed_meta,
dim=cat_dim,
)
counters["inductor"]["scmerge_cat_added"] += 1
else:
new_cat_node = user_inputs_new_transformed[-1]
new_cat_node.meta[
"example_value"
] = user_inputs_new_transformed_meta[-1]
new_cat_node.meta["example_value"] = (
user_inputs_new_transformed_meta[-1]
)
if (
user_node.target == torch.cat
@ -1077,7 +1100,11 @@ class SplitCatSimplifier:
new_cat_node = graph.call_function(
torch.flatten, args=(new_cat_node, cat_dim, cat_dim + 1)
)
new_cat_node.meta["example_value"] = torch.flatten(new_cat_node_meta, cat_dim, cat_dim + 1) # type: ignore[possibly-undefined, union-attr]
new_cat_node.meta["example_value"] = torch.flatten(
new_cat_node_meta,
cat_dim,
cat_dim + 1,
)
user_node.replace_all_uses_with(new_cat_node)
new_cats.append(new_cat_node)
@ -1123,9 +1150,7 @@ class UnbindCatRemover(SplitCatSimplifier):
]
if not is_sorted_and_consecutive(getitem_indices) or len( # type: ignore[arg-type]
getitem_indices
) != len(
unbind_node.meta["example_value"]
):
) != len(unbind_node.meta["example_value"]):
return
num_unbind = len(getitem_indices)
split_sections = [1 for _ in range(num_unbind)] # type: ignore[operator, arg-type]
@ -1510,7 +1535,8 @@ def merge_getitem_cat(match: Match, split_sections: list[int], dim: int):
fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index]
# update the split sections
split_sections[indices[0]] = calculate_fused_tensor_size( # type: ignore[index]
split_node, indices # type: ignore[arg-type]
split_node,
indices, # type: ignore[arg-type]
)
# padding others with zeros to keep the same dict size
for i in indices[1:]:
@ -1613,10 +1639,12 @@ def mutate_cat_node(match: Match, split_sections: list[int], dim: int):
elif is_node_meta_valid(split_node.args[0]): # type: ignore[arg-type]
# check the split dim, and construct the slice tuple
start_fused_size = calculate_fused_tensor_size(
split_node, list(range(indices[0])) # type: ignore[arg-type]
split_node,
list(range(indices[0])), # type: ignore[arg-type]
)
end_fused_size = start_fused_size + calculate_fused_tensor_size(
split_node, indices # type: ignore[arg-type]
split_node,
indices, # type: ignore[arg-type]
)
slice_list = []
for i in range(len(split_node.args[0].meta["example_value"].shape)): # type: ignore[union-attr]
@ -1714,7 +1742,10 @@ def merge_split_cat_aten(match: Match, *args, **kwargs):
continue
# check the cat node has consecutive indices
indices = [arg.args[1] for arg in cat_node.args[0]] # type: ignore[union-attr]
if not is_sorted_and_consecutive(indices) and len(getitem_nodes) != len(cat_inputs): # type: ignore[arg-type]
if (
not is_sorted_and_consecutive(indices) # type: ignore[arg-type]
and len(getitem_nodes) != len(cat_inputs)
):
continue
# replace the users of the cat node to be the input of the split node
cat_node.replace_all_uses_with(split_input)
@ -1764,7 +1795,10 @@ def merge_select_cat_aten(match: Match, *args, **kwargs):
continue
# check the cat node has consecutive indices
indices = [select.args[2] for select in cat_node.args[0]] # type: ignore[union-attr]
if not is_sorted_and_consecutive(indices) or len(select_nodes) != len(cat_inputs): # type: ignore[arg-type]
if (
not is_sorted_and_consecutive(indices) # type: ignore[arg-type]
or len(select_nodes) != len(cat_inputs)
):
continue
# check all the select nodes can be merged to the cat node input
if len(indices) != select_nodes[0].args[0].meta["val"].shape[cat_dim]: # type: ignore[union-attr]
@ -2318,7 +2352,9 @@ def unbind_cat_to_view(match: Match, unbind_input: torch.fx.Node, dim: int):
args=(new_cat_args,),
kwargs={"dim": cat_dim},
)
new_cat_node.meta["example_value"] = torch.cat(new_cat_args_meta, dim=cat_dim) # type: ignore[arg-type]
new_cat_node.meta["example_value"] = torch.cat(
new_cat_args_meta, dim=cat_dim
) # type: ignore[arg-type]
cat_node.replace_all_uses_with(new_cat_node)
new_cat_node.meta.update(cat_node.meta)
# remove inputs of cat_node if they have no users
@ -2411,7 +2447,8 @@ def convert_reshape_cat_arg_to_stack(
args=(permute_node, tuple(stack_node_shape)), # type: ignore[arg-type]
)
reshape_node.meta["example_value"] = torch.Tensor.view(
permute_node.meta["example_value"], tuple(stack_node_shape) # type: ignore[arg-type]
permute_node.meta["example_value"],
tuple(stack_node_shape), # type: ignore[arg-type]
)
return reshape_node
@ -2687,7 +2724,9 @@ def move_reshape_out_of_split_stack(match: Match, *args, **kwargs):
cat_inputs.append(decomposed_stack_node)
# cat_arg must be the split input
view_shape_list = get_view_shape_list(cat_arg, stack_dim)
stack_node_shape = torch.reshape(cat_arg.meta["example_value"], tuple(view_shape_list)).shape # type: ignore[union-attr]
stack_node_shape = torch.reshape(
cat_arg.meta["example_value"], tuple(view_shape_list)
).shape # type: ignore[union-attr]
cat_inputs.append(
convert_reshape_cat_arg_to_stack(
graph,

View File

@ -105,9 +105,9 @@ class FakeTensorUpdater:
if new is None:
return old is None
if not isinstance(new, torch.Tensor):
assert isinstance(
new, (torch.SymInt, torch.SymBool, torch.SymFloat)
), f"Unknown type {type(new)} in {self.graph}"
assert isinstance(new, (torch.SymInt, torch.SymBool, torch.SymFloat)), (
f"Unknown type {type(new)} in {self.graph}"
)
return (
new.node.shape_env._maybe_evaluate_static(
sympy.Eq(new.node.expr, old.node.expr)

View File

@ -136,7 +136,9 @@ else:
def may_get_constant_buffer_dtype(constant_buffer: sympy.Expr) -> Optional[torch.dtype]:
assert isinstance(
constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer"
), (
"get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer"
)
if isinstance(constant_buffer, sympy.core.numbers.Integer):
return torch.int64
@ -308,9 +310,9 @@ class GraphLowering(torch.fx.Interpreter):
self.reuse_shape_env = True
self._shape_env = shape_env
# We're going to mutate ras_by_symbol as we finish generating them
self.ras_by_symbol: dict[
Optional[sympy.Symbol], list[RuntimeAssert]
] = shape_env.deferred_runtime_asserts.copy()
self.ras_by_symbol: dict[Optional[sympy.Symbol], list[RuntimeAssert]] = (
shape_env.deferred_runtime_asserts.copy()
)
self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]()
self.sizevars = SizeVarAllocator(shape_env)
self.graph_input_names: list[str] = []
@ -400,9 +402,7 @@ class GraphLowering(torch.fx.Interpreter):
self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
self.cache_linemap: list[
tuple[int, str]
] = (
[]
) # This is the linemap used by the profiler to mark custom compiled kernels getting run
] = [] # This is the linemap used by the profiler to mark custom compiled kernels getting run
# Used if lowering encounters cases where cudagraphs are not supported
self.disable_cudagraphs_reason: Optional[str] = None
@ -1012,7 +1012,10 @@ class GraphLowering(torch.fx.Interpreter):
)
def placeholder(
self, target: str, args: tuple[object], kwargs: dict[str, object] # type: ignore[override]
self,
target: str, # type: ignore[override]
args: tuple[object], # type: ignore[override]
kwargs: dict[str, object],
) -> Union[Expr, TensorBox, None]:
self.placeholder_idx += 1
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
@ -1118,9 +1121,9 @@ class GraphLowering(torch.fx.Interpreter):
return target(*args, **kwargs)
if target not in lowerings:
assert isinstance(
target, torch._ops.OpOverload
), f"{target} is not an OpOverload"
assert isinstance(target, torch._ops.OpOverload), (
f"{target} is not an OpOverload"
)
base_name = target.name().split(".")[0]
if base_name in FALLBACK_ALLOW_LIST:
make_fallback(target, warn=False, override_decomp=True)
@ -1189,7 +1192,10 @@ class GraphLowering(torch.fx.Interpreter):
return len(t.shape) == 1 and t.shape[0] <= 8
def get_attr(
self, target: str, args: tuple[()], kwargs: dict[str, object] # type: ignore[override]
self,
target: str, # type: ignore[override]
args: tuple[()], # type: ignore[override]
kwargs: dict[str, object],
) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
# this is a constant
value = getattr_recursive(self.module, target) # type: ignore[arg-type]
@ -1241,7 +1247,10 @@ class GraphLowering(torch.fx.Interpreter):
raise AssertionError
def output(
self, target: str, args: tuple[object], kwargs: dict[str, object] # type: ignore[override]
self,
target: str, # type: ignore[override]
args: tuple[object], # type: ignore[override]
kwargs: dict[str, object],
) -> None:
result = super().output(target, args, kwargs) # type: ignore[arg-type]
if not isinstance(result, (tuple, list)):
@ -1439,9 +1448,11 @@ class GraphLowering(torch.fx.Interpreter):
if is_call_function:
args, kwargs = self.fetch_args_kwargs_from_env(n)
origins |= gather_origins(args, kwargs)
with ir.IRNode.current_origins(origins), self.set_current_node(
n
), V.set_current_node(n):
with (
ir.IRNode.current_origins(origins),
self.set_current_node(n),
V.set_current_node(n),
):
if (
n.op == "call_function"
and n.target is not operator.getitem
@ -1454,7 +1465,8 @@ class GraphLowering(torch.fx.Interpreter):
):
debug("fallback_handler")
result = fallback_handler(n.target, add_to_fallback_set=False)(
*args, **kwargs # type: ignore[possibly-undefined]
*args, # type: ignore[possibly-undefined]
**kwargs, # type: ignore[possibly-undefined]
)
elif (
n.op == "call_function"
@ -1833,9 +1845,9 @@ class GraphLowering(torch.fx.Interpreter):
wrapper_code_gen_cls = get_wrapper_codegen_for_device(
self.device_type, self.cpp_wrapper
)
assert (
wrapper_code_gen_cls is not None
), f"Device {self.device_type} not supported"
assert wrapper_code_gen_cls is not None, (
f"Device {self.device_type} not supported"
)
self.wrapper_code = wrapper_code_gen_cls.create(
is_subgraph,
subgraph_name,
@ -1866,7 +1878,7 @@ class GraphLowering(torch.fx.Interpreter):
compiled = self.compile_to_module().call
def materialize(
x: Union[torch.SymInt, torch.SymFloat, torch.Tensor]
x: Union[torch.SymInt, torch.SymFloat, torch.Tensor],
) -> Union[int, float, torch.Tensor]:
if x is None:
return None
@ -1876,9 +1888,9 @@ class GraphLowering(torch.fx.Interpreter):
elif isinstance(x, FakeTensor):
return defake(x)
else:
assert isinstance(
x, torch.Tensor
), "Unknown type when creating real inputs" + str(type(x))
assert isinstance(x, torch.Tensor), (
"Unknown type when creating real inputs" + str(type(x))
)
return x
tracing_context = torch._guards.TracingContext.try_get()

View File

@ -20,6 +20,7 @@ printers. So simple operations like minimum and maximum cannot be translated to
SymPy expressions yet, despite sympy.Min and sympy.Max existing.
"""
import itertools
from collections.abc import Sequence
from dataclasses import dataclass
@ -179,9 +180,9 @@ class IndexPropVar:
return IndexPropVar(expr, is_symbolic=True)
def __post_init__(self):
assert not self.is_symbolic or isinstance(
self.value, TypedExpr
), "Symbolic IndexPropVar must contain a TypedExpr"
assert not self.is_symbolic or isinstance(self.value, TypedExpr), (
"Symbolic IndexPropVar must contain a TypedExpr"
)
IndexPropResult: TypeAlias = Union[IndexPropVar, tuple["IndexPropResult", ...]]
@ -251,14 +252,12 @@ class IndexPropagation(DefaultHandler):
name: Literal["indirect_indexing"],
args: Sequence[Any],
kwargs: dict[str, Any],
) -> IndexPropVar:
...
) -> IndexPropVar: ...
@overload
def fallback(
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
) -> IndexPropResult:
...
) -> IndexPropResult: ...
def fallback(
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
@ -283,8 +282,7 @@ class IndexPropagation(DefaultHandler):
is_valid_expr = new_expr is not NotImplemented and (
# Inductor doesn't expect floating point in sympy expressions, but
# allow floating point constants to be propagated
new_expr.is_constant()
or new_expr.expr.is_integer
new_expr.is_constant() or new_expr.expr.is_integer
)
if not is_valid_expr:
return self.fallback(name, args, kwargs)

View File

@ -211,7 +211,9 @@ def validate_ir(node_or_nodes: Optional[_NodeOrNodes]) -> None:
int,
EffectfulKernel,
),
), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"
), (
f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"
)
# Be picky about the accepted data structure (don't use pytree here)
_check_tensorbox(node_or_nodes)
@ -298,13 +300,11 @@ def get_stride_order(
@overload
def ir_node_to_tensor(x: Literal[None], guard_shape: bool = True) -> None:
...
def ir_node_to_tensor(x: Literal[None], guard_shape: bool = True) -> None: ...
@overload
def ir_node_to_tensor(x: IRNode, guard_shape: bool = True) -> torch.Tensor:
...
def ir_node_to_tensor(x: IRNode, guard_shape: bool = True) -> torch.Tensor: ...
def ir_node_to_tensor(
@ -346,7 +346,7 @@ def may_convert_to_optional(
def get_device_type(
x: Union[IRNode, OutputSpec, torch.device, None, str]
x: Union[IRNode, OutputSpec, torch.device, None, str],
) -> Optional[str]:
if isinstance(x, str) or x is None:
return x
@ -698,8 +698,7 @@ class IRNode:
if TYPE_CHECKING:
@property
def dtype(self) -> torch.dtype:
...
def dtype(self) -> torch.dtype: ...
@ir_dataclass(frozen=False)
@ -839,8 +838,9 @@ class Loops(IRNode):
@cache_on_self
def inner_fn_opcount(self) -> OpCountResult:
opcounter = OpCounterCSE(V.MockHandler())
with V.set_ops_handler(opcounter), patch.object(
FlexibleLayout, "allow_indexing", True
with (
V.set_ops_handler(opcounter),
patch.object(FlexibleLayout, "allow_indexing", True),
):
self.inner_fn(*self.inner_fn_args())
return opcounter.getvalue()
@ -1364,9 +1364,9 @@ class Reduction(Loops):
# "all" is desugared to `!any(!val)`
}
assert (
reduction_type in rtypes_to_inits.keys()
), f"{reduction_type} not supported for zero-dimension tensors!"
assert reduction_type in rtypes_to_inits.keys(), (
f"{reduction_type} not supported for zero-dimension tensors!"
)
def const_fn(index: int) -> OpsValue:
return ops.constant(rtypes_to_inits[reduction_type], dst_dtype)
@ -1575,9 +1575,9 @@ class Reduction(Loops):
new_ranges: Sequence[Integer],
new_reduction_ranges: Sequence[Integer],
) -> Callable[[Sequence[sympy.Expr], Sequence[sympy.Expr]], OpsValue]:
assert all(
r == 1 for r in original_ranges
), f"Only enabled for numel_hint == 1, found {original_ranges=}"
assert all(r == 1 for r in original_ranges), (
f"Only enabled for numel_hint == 1, found {original_ranges=}"
)
reindex = View.dynamic_reshape_indexer(
original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges)
)
@ -1828,7 +1828,7 @@ class WelfordReduction(Reduction):
if reduction_numel == 1:
def copy(
loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue]
loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue],
) -> TensorBox:
def inner_fn(idx: Sequence[Expr]) -> OpsValue:
reduction_index = [sympy.S.Zero for _ in reduction_ranges]
@ -2571,9 +2571,9 @@ class ExpandView(BaseView):
# NB: new_size[i] == old_size[i] is expected to already be
# guarded because the meta formula was expected to have taught
# us this equality.
assert (
sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0
), "Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}"
assert sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0, (
"Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}"
)
return new_size
@classmethod
@ -3382,9 +3382,9 @@ class Layout(OutputSpec):
)
def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
assert (
FlexibleLayout.allow_indexing
), f"convert {type(self).__name__} to FixedLayout first"
assert FlexibleLayout.allow_indexing, (
f"convert {type(self).__name__} to FixedLayout first"
)
return self.as_fixed().make_indexer()
def __eq__(self, other) -> bool: # type: ignore[no-untyped-def]
@ -3684,9 +3684,9 @@ class MutationLayoutSHOULDREMOVE(Layout):
return target
result = unwrap_views(self.target)
assert isinstance(
result, Buffer
), "MutationLayoutSHOULDREMOVE must refer to a buffer"
assert isinstance(result, Buffer), (
"MutationLayoutSHOULDREMOVE must refer to a buffer"
)
return result
def real_layout(self): # type: ignore[no-untyped-def]
@ -3803,7 +3803,9 @@ class Buffer(IRNode):
assert isinstance(self.layout, FlexibleLayout)
self.layout = self.layout.as_same_order(stride)
def freeze_layout_with_exact_strides(self, exact_strides, allow_padding=False) -> None: # type: ignore[no-untyped-def]
def freeze_layout_with_exact_strides( # type: ignore[no-untyped-def]
self, exact_strides, allow_padding=False
) -> None:
assert isinstance(self.layout, FlexibleLayout)
self.layout = self.layout.as_exact_strides(
exact_strides, allow_padding=allow_padding
@ -4365,9 +4367,9 @@ class TritonTemplateBuffer(TemplateBuffer):
torch.ops.higher_order.flex_attention_backward,
)
current_node = V.graph.current_node.target
assert (
current_node in allowed_set
), f"Mutated inputs are only allowed for {allowed_set} but got {current_node}"
assert current_node in allowed_set, (
f"Mutated inputs are only allowed for {allowed_set} but got {current_node}"
)
device = self.inputs[0].get_device()
self.outputs += [
MutationOutput(NoneLayout(device=device), buf, self)
@ -5106,7 +5108,8 @@ class ExternKernel(InputsKernel):
x_unwrap_view.freeze_layout()
index_args, var_ranges = dependencies.index_vars_squeeze(
x.get_size(), prefix="r" # type: ignore[arg-type]
x.get_size(),
prefix="r", # type: ignore[arg-type]
)
range_vars = index_args[0]
index = x.make_indexer()(range_vars)
@ -5404,9 +5407,9 @@ class ExternKernel(InputsKernel):
# pass in a list of const arg names for arg_properties lookup.
name_to_arg_properties = None
if names and self.arg_properties:
assert len(self.constant_args) == len(
names
), "names passed to codegen_const_args does not match self.constant_args"
assert len(self.constant_args) == len(names), (
"names passed to codegen_const_args does not match self.constant_args"
)
name_to_arg_properties = {
arg.get("name"): arg for arg in self.arg_properties
}
@ -5442,9 +5445,9 @@ class ExternKernel(InputsKernel):
args = []
for i, x in enumerate(inputs):
if V.graph.cpp_wrapper:
assert self.arg_properties and i < len(
self.arg_properties
), "Invalid access to ExternKernel.arg_properties"
assert self.arg_properties and i < len(self.arg_properties), (
"Invalid access to ExternKernel.arg_properties"
)
type_ = self.arg_properties[i].get("type")
args.append(V.graph.wrapper_code.val_to_arg_str(x, type_))
else:
@ -5914,7 +5917,9 @@ class UserDefinedTritonKernel(ExternKernel):
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet()
def __init__(self, *, kernel_idx, grid, tma_descriptor_metadata, kernel_args) -> None: # type: ignore[no-untyped-def]
def __init__( # type: ignore[no-untyped-def]
self, *, kernel_idx, grid, tma_descriptor_metadata, kernel_args
) -> None:
inputs = []
kwargs = {}
constant_args = []
@ -6835,9 +6840,9 @@ class FallbackKernel(ExternKernelAlloc):
elif isinstance(output, torch.SymInt):
return output.node.expr
else:
assert (
output is None
), f"FallbackKernel output type {type(output)} is not supported"
assert output is None, (
f"FallbackKernel output type {type(output)} is not supported"
)
return None
outputs = generate_output(example_output, [])
@ -6919,7 +6924,12 @@ class MultiOutput(ExternKernel):
)
self.codegen_size_asserts(wrapper)
def __init__(self, layout: OutputSpec, input, indices: list[tuple[Any, ...]]) -> None: # type: ignore[no-untyped-def]
def __init__( # type: ignore[no-untyped-def]
self,
layout: OutputSpec,
input,
indices: list[tuple[Any, ...]],
) -> None:
super().__init__(None, layout, [input], ())
self.name = V.graph.register_buffer(self)
V.graph.register_operation(self)
@ -7496,9 +7506,9 @@ class WhileLoop(ExternKernel):
assert p.get_dtype() == torch.bool, p
assert len(p.get_size()) == 0, p
assert (
len(all_inputs) > 0
), "torch.while_loop is assumed to have at least one operand."
assert len(all_inputs) > 0, (
"torch.while_loop is assumed to have at least one operand."
)
device = all_inputs[0].get_device()
@ -7669,9 +7679,9 @@ class _CollectiveKernel(FallbackKernel):
# This is identical to FallbackKernel.set_cpp_kernel(), minus the
# part that checks against input aliasing and mutation.
def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None:
assert (
type(self.op_overload) is torch._ops.OpOverload
), "Setting cpp kernel needs a valid op_overload"
assert type(self.op_overload) is torch._ops.OpOverload, (
"Setting cpp kernel needs a valid op_overload"
)
kernel = self.op_overload
self.cpp_kernel_name = kernel._schema.name

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
""" Triton Implementation of the flex_attention Kernel"""
"""Triton Implementation of the flex_attention Kernel"""
import copy
import logging
@ -60,9 +60,9 @@ def construct_strides(
) -> Sequence[int]:
"""From a list of sizes and a fill order, construct the strides of the permuted tensor."""
# Initialize strides
assert len(sizes) == len(
fill_order
), "Length of sizes must match the length of the fill order"
assert len(sizes) == len(fill_order), (
"Length of sizes must match the length of the fill order"
)
strides = [0] * len(sizes)
# Start with stride 1 for the innermost dimension
@ -1151,10 +1151,14 @@ def lower_cpu(
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
assert V.graph.sizevars.evaluate_expr(
sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE))
), "Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask."
), (
"Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask."
)
assert V.graph.sizevars.evaluate_expr(
sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE))
), "KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask."
), (
"KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask."
)
CppFlexAttentionTemplate.add_choices(
choices=_choices,
input_nodes=input_nodes,
@ -1364,15 +1368,15 @@ def flex_attention(
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
assert V.graph.sizevars.evaluate_expr(
sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)
), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
assert V.graph.sizevars.evaluate_expr(
sympy.Gt(seq_len_q, 0)
), "Query length must be greater than 0"
assert V.graph.sizevars.evaluate_expr(
sympy.Gt(seq_len_kv, 0)
), "Key length must be greater than 0"
assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), (
f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
)
assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_q, 0)), (
"Query length must be greater than 0"
)
assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_kv, 0)), (
"Key length must be greater than 0"
)
B = Bq
@ -2291,9 +2295,9 @@ def process_joint_outputs(
JointOutputResult containing processed buffers and gradients
"""
assert isinstance(all_joint_outputs, list)
assert (
all_joint_outputs[0] is not None
), "joint_subgraph_buffer is None - this is a bug!"
assert all_joint_outputs[0] is not None, (
"joint_subgraph_buffer is None - this is a bug!"
)
joint_buffer = all_joint_outputs[0]
other_grads = all_joint_outputs[num_placeholders - 1 :]
@ -2392,9 +2396,9 @@ def flex_attention_backward(*args, **kwargs):
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
assert V.graph.sizevars.evaluate_expr(
sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)
), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), (
f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
)
kernel_options = dict(kernel_options)
# Mark symbols in custom kernel options as static shapes and add guards.
@ -2639,9 +2643,11 @@ def flex_attention_backward(*args, **kwargs):
grad_key = broadcasted_grad_key
grad_value = broadcasted_grad_value
else:
assert V.graph.sizevars.evaluate_expr(
sympy.Gt(Bq, 1) & sympy.Eq(Bkv, 1)
), f"Bq and Bkv must broadcastable. Got Bq={V.graph.sizevars.evaluate_expr(Bq)} and Bkv={V.graph.sizevars.evaluate_expr(Bkv)}" # noqa: B950
assert V.graph.sizevars.evaluate_expr(sympy.Gt(Bq, 1) & sympy.Eq(Bkv, 1)), (
f"Bq and Bkv must broadcastable. "
f"Got Bq={V.graph.sizevars.evaluate_expr(Bq)} "
f"and Bkv={V.graph.sizevars.evaluate_expr(Bkv)}"
)
grad_key = lowerings[aten.sum](broadcasted_grad_key, axis=0, keepdims=True)
grad_value = lowerings[aten.sum](broadcasted_grad_value, axis=0, keepdims=True)

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
""" Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)"""
"""Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)"""
from typing import Any
import sympy
@ -367,9 +368,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
assert V.graph.sizevars.evaluate_expr(
sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)
), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), (
f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
)
B = Bq
kernel_options = dict(kernel_options)
@ -481,7 +482,8 @@ def create_flex_decoding_kernel(*args, **kwargs):
max(
next_power_of_2(
V.graph.sizevars.size_hint(
seq_len_q, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
seq_len_q,
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
)
* gqa_shared_heads
),

View File

@ -65,7 +65,8 @@ def filtered_configs(
m = max(
next_power_of_2(
V.graph.sizevars.size_hint(
m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
m,
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size,
@ -73,7 +74,8 @@ def filtered_configs(
n = max(
next_power_of_2(
V.graph.sizevars.size_hint(
n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
n,
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size,
@ -81,7 +83,8 @@ def filtered_configs(
k = max(
next_power_of_2(
V.graph.sizevars.size_hint(
k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
k,
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size_k,
@ -467,8 +470,7 @@ def mm_options(config, sym_m, sym_n, sym_k, layout):
"""
even_k_symbolic = (
# it isn't worth guarding on this
sympy.gcd(sym_k, config.kwargs["BLOCK_K"])
== config.kwargs["BLOCK_K"]
sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"]
)
allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
not inductor_config.force_same_precision

View File

@ -193,12 +193,13 @@ class LoopBody:
#
# There is indeed an issue due to symbol name conflicting.
# y0 maybe reused for the y dimension later.
(
(
iter_vars,
reduce_vars,
), var_ranges = dependencies.index_vars_no_squeeze(
iter_sizes, reduce_sizes, prefix="t"
)
),
var_ranges,
) = dependencies.index_vars_no_squeeze(iter_sizes, reduce_sizes, prefix="t")
new_body = LoopBody(
old_body,
[iter_reindex(iter_vars), reduce_reindex(reduce_vars)],
@ -234,7 +235,8 @@ class LoopBody:
new_sizes = (new_iter_size, reduce_size)
(iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
*new_sizes, prefix="t" # type: ignore[arg-type]
*new_sizes,
prefix="t", # type: ignore[arg-type]
)
inverse_order = {b: a for a, b in enumerate(new_order)}
@ -254,7 +256,8 @@ class LoopBody:
# use the original symbol prefix so we can do multiple round of reordering
(iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
*new_sizes, prefix="p" # type: ignore[arg-type]
*new_sizes,
prefix="p", # type: ignore[arg-type]
)
new_body = LoopBody(
loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2
@ -385,9 +388,9 @@ class LoopBody:
def indexing_from_args(self, indices):
index = [*itertools.chain.from_iterable(indices)]
assert len(index) == len(self.var_ranges), (index, self.var_ranges)
assert all(
v not in self.var_ranges for v in index
), f"{self.var_ranges=}, {indices=}"
assert all(v not in self.var_ranges for v in index), (
f"{self.var_ranges=}, {indices=}"
)
replacements = dict(zip(self.var_ranges.keys(), index))
return {
name: sympy_subs(expr, replacements)

View File

@ -346,7 +346,8 @@ def transform_args(
# only consider tensor kwargs for promotion, for now
promoting_args.extend(a for a in kwargs.values() if hasattr(a, "dtype"))
dtype = get_promoted_dtype(
*promoting_args, type_promotion_kind=type_promotion_kind # type: ignore[arg-type]
*promoting_args,
type_promotion_kind=type_promotion_kind, # type: ignore[arg-type]
)
device = (
@ -448,9 +449,9 @@ def _register_lowering(
(fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn
):
# explicitly assert for "out=" ops for better error messages
assert not any(
x == "out" for x in kwargs.keys()
), "out= ops aren't yet supported"
assert not any(x == "out" for x in kwargs.keys()), (
"out= ops aren't yet supported"
)
args, kwargs = transform_args(
args, kwargs, broadcast, type_promotion_kind, convert_input_to_bool
@ -517,9 +518,9 @@ def broadcast_symbolic_shapes(a, b):
def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None):
assert (
override_return_dtype is None or type_promotion_kind is None
), "only one of override_return_dtype or type_promotion_kind may be given"
assert override_return_dtype is None or type_promotion_kind is None, (
"only one of override_return_dtype or type_promotion_kind may be given"
)
if override_return_dtype is None and type_promotion_kind is None:
type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
@ -674,9 +675,9 @@ def make_foreach_pointwise(pw_fn, allow_alpha=False):
if isinstance(input, (list, tuple)):
a_list_input = input
break
assert (
a_list_input is not None
), "at least one input must be a list to a foreach op"
assert a_list_input is not None, (
"at least one input must be a list to a foreach op"
)
# broadcast scalar inputs to match length of list inputs
broadcast_inputs = []
@ -1321,12 +1322,12 @@ def quantized_decomposed_quantize_per_channel(
if input.get_dtype() == torch.bfloat16:
input = to_dtype(input, torch.float32)
assert (
input.get_dtype() == torch.float32
), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
assert axis < len(
input.get_size()
), f"Expecting axis to be < {len(input.get_size())}"
assert input.get_dtype() == torch.float32, (
f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
)
assert axis < len(input.get_size()), (
f"Expecting axis to be < {len(input.get_size())}"
)
input_loader = input.make_loader()
scales_loader = scales.make_loader()
@ -1373,12 +1374,12 @@ def quantized_decomposed_dequantize_per_channel(
) -> TensorBox:
assert len(scales.get_size()) == 1, "expect scales 1 dim"
assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim"
assert (
input.get_dtype() == dtype
), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
assert axis < len(
input.get_size()
), f"Expecting axis to be < {len(input.get_size())}"
assert input.get_dtype() == dtype, (
f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
)
assert axis < len(input.get_size()), (
f"Expecting axis to be < {len(input.get_size())}"
)
if out_dtype is None:
out_dtype = torch.float32
@ -1423,9 +1424,9 @@ def quantized_decomposed_quantize_per_tensor_default(
) -> TensorBox:
if input.get_dtype() == torch.bfloat16:
input = to_dtype(input, torch.float32)
assert (
input.get_dtype() == torch.float32
), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
assert input.get_dtype() == torch.float32, (
f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
)
input_loader = input.make_loader()
@ -1462,9 +1463,9 @@ def quantized_decomposed_dequantize_per_tensor_default(
*,
out_dtype: Optional[torch.dtype] = None,
) -> TensorBox:
assert (
input.get_dtype() == dtype
), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
assert input.get_dtype() == dtype, (
f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
)
if out_dtype is None:
out_dtype = torch.float32
@ -1501,9 +1502,9 @@ def quantized_decomposed_quantize_per_tensor_tensor(
) -> TensorBox:
if input.get_dtype() == torch.bfloat16:
input = to_dtype(input, torch.float32)
assert (
input.get_dtype() == torch.float32
), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
assert input.get_dtype() == torch.float32, (
f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
)
assert len(scale.get_size()) == 0 or (
len(scale.get_size()) == 1 and scale.get_size()[0] == 1
), "expect scale as scalar tensor"
@ -1555,9 +1556,9 @@ def quantized_decomposed_dequantize_per_tensor_tensor(
assert len(zero_point.get_size()) == 0 or (
len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1
), "expect zero_point as scalar tensor"
assert (
input.get_dtype() == dtype
), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
assert input.get_dtype() == dtype, (
f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
)
if out_dtype is None:
out_dtype = torch.float32
@ -1973,9 +1974,9 @@ def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=
def make_fallback(op, layout_constraint=None, warn=True, override_decomp=False):
assert (
op not in decompositions or override_decomp
), f"both a fallback and a decomp for same op: {op}"
assert op not in decompositions or override_decomp, (
f"both a fallback and a decomp for same op: {op}"
)
if (
warn
and bool(os.getenv("CI"))
@ -2086,9 +2087,9 @@ def native_dropout(x, p, train):
@register_lowering(aten.bernoulli_, type_promotion_kind=None)
def bernoulli_(x, *args):
assert config.fallback_random or x.get_device() == torch.device(
"cpu"
), "this should be handled in decomps unless config.fallback_random or the device is CPU"
assert config.fallback_random or x.get_device() == torch.device("cpu"), (
"this should be handled in decomps unless config.fallback_random or the device is CPU"
)
x.realize()
op_overload = (
aten.bernoulli_.float
@ -2101,9 +2102,9 @@ def bernoulli_(x, *args):
@register_lowering(aten.bernoulli.p, type_promotion_kind=None)
def bernoulli_p(x, *args):
assert config.fallback_random or x.get_device() == torch.device(
"cpu"
), "this should be handled in decomps unless config.fallback_random or the device is CPU"
assert config.fallback_random or x.get_device() == torch.device("cpu"), (
"this should be handled in decomps unless config.fallback_random or the device is CPU"
)
return bernoulli_(clone(x), *args)
@ -3376,7 +3377,9 @@ def check_and_broadcast_indices(indices, device):
i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8)
for i in indices
if i is not None
), f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}"
), (
f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}"
)
if any(
i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None
):
@ -5668,7 +5671,8 @@ def make_reduction(reduction_type: ReductionType, override_return_dtype=None):
)
result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs)
if isinstance(
result.data.data, Reduction # type: ignore[attr-defined]
result.data.data, # type: ignore[attr-defined]
Reduction,
): # Only realize if reduction isn't unrolled
result.realize()
return result
@ -6008,8 +6012,9 @@ def get_constant_value(x: ir.IRNode) -> Optional[ir.Constant]:
return None
handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device())
with V.set_ops_handler(handler), patch.object(
ir.FlexibleLayout, "allow_indexing", True
with (
V.set_ops_handler(handler),
patch.object(ir.FlexibleLayout, "allow_indexing", True),
):
out = x.inner_fn(*x.inner_fn_args())
@ -6898,9 +6903,9 @@ def force_fallback(op: torch._ops.OpOverload):
A context manager to force fallback an op. Used in unit test
for FallbackKernel.
"""
assert isinstance(
op, torch._ops.OpOverload
), "Only OpOverload to make the clean up easier"
assert isinstance(op, torch._ops.OpOverload), (
"Only OpOverload to make the clean up easier"
)
old_handler = lowerings.get(op)
try:
register_lowering(op)(fallback_handler(op))

View File

@ -35,9 +35,9 @@ class MemoryPlanningInfoForBuffer:
class MemoryPlanningInfoForNode:
index: int = 0
size: int = 0
pred_buffers: OrderedSet[
Union[SchedulerBuffer, FreeableInputBuffer]
] = dataclasses.field(default_factory=OrderedSet)
pred_buffers: OrderedSet[Union[SchedulerBuffer, FreeableInputBuffer]] = (
dataclasses.field(default_factory=OrderedSet)
)
pred_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field(
default_factory=OrderedSet
)
@ -87,9 +87,9 @@ def get_freeable_input_buf(
# get freeable input buffers' successor nodes and their sizes
# note that different deps can have the same name, so we use name as keys
dep_name_to_succ_nodes: dict[
str, OrderedSet[BaseSchedulerNode]
] = collections.defaultdict(OrderedSet)
dep_name_to_succ_nodes: dict[str, OrderedSet[BaseSchedulerNode]] = (
collections.defaultdict(OrderedSet)
)
dep_name_to_size: dict[str, int] = dict()
for node in nodes:
for dep in node.read_writes.reads:
@ -112,7 +112,7 @@ def get_freeable_input_buf(
def compute_size_for_scheduler_buffer(
name_to_buf: dict[str, SchedulerBuffer]
name_to_buf: dict[str, SchedulerBuffer],
) -> dict[str, tuple[int, int]]:
"""
Compute the size of each scheduler buffer, including (1) memory allocated when
@ -187,9 +187,9 @@ def assign_memory_planning_info_for_scheduler_buffers(
# get buffer's successor nodes
# note that different deps can have the same name, so we use name as keys
dep_name_to_succ_nodes: dict[
str, OrderedSet[BaseSchedulerNode]
] = collections.defaultdict(OrderedSet)
dep_name_to_succ_nodes: dict[str, OrderedSet[BaseSchedulerNode]] = (
collections.defaultdict(OrderedSet)
)
for node in nodes:
for dep in node.unmet_dependencies:
dep_name_to_succ_nodes[dep.name].add(node)

View File

@ -138,12 +138,12 @@ class MetricTable:
return
row_dict = row_fn()
assert len(self.column_names) == len(
row_dict
), f"{len(self.column_names)} v.s. {len(row_dict)}"
assert OrderedSet(self.column_names) == OrderedSet(
row_dict.keys()
), f"{OrderedSet(self.column_names)} v.s. {OrderedSet(row_dict.keys())}"
assert len(self.column_names) == len(row_dict), (
f"{len(self.column_names)} v.s. {len(row_dict)}"
)
assert OrderedSet(self.column_names) == OrderedSet(row_dict.keys()), (
f"{OrderedSet(self.column_names)} v.s. {OrderedSet(row_dict.keys())}"
)
bn = get_benchmark_name()
# assert bn is not None
@ -433,9 +433,9 @@ def enabled_metric_tables_impl(config_str: str) -> OrderedSet[str]:
name = name.strip()
if not name:
continue
assert (
name in REGISTERED_METRIC_TABLES
), f"Metric table name {name} is not registered"
assert name in REGISTERED_METRIC_TABLES, (
f"Metric table name {name} is not registered"
)
enabled.add(name)
return enabled

View File

@ -751,9 +751,9 @@ class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
unary_algorithm,
]
assert (
binary_attr == "sum"
), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E."
assert binary_attr == "sum", (
"For now, only post op sum is supported in QConvPointWiseBinaryPT2E."
)
V.graph.mark_buffer_mutated(qaccum.get_name())
packed = QConvPointWiseBinaryPT2E(

View File

@ -575,9 +575,9 @@ def register_onednn_fusion_ops():
algorithm,
layout=None,
):
assert (
packed_weight.get_dtype() is torch.int8
), "Only int8 weights are supported by oneDNN qlinear."
assert packed_weight.get_dtype() is torch.int8, (
"Only int8 weights are supported by oneDNN qlinear."
)
x_size = x.get_size()
if len(x_size) > 2:
# GEMM template needs 2D input, normalize input shape here
@ -928,9 +928,9 @@ def register_onednn_fusion_ops():
# we will do accum dtype convertion here.
x2 = to_dtype(x2, output_dtype)
else:
assert (
x2.get_dtype() == output_dtype
), "dtype of accum for qlinear post op sum should be the same as output"
assert x2.get_dtype() == output_dtype, (
"dtype of accum for qlinear post op sum should be the same as output"
)
x2_dtype = x2.get_dtype()
bias_dtype = bias.get_dtype() if bias is not None else None
choices: list[ChoiceCaller] = []

View File

@ -806,8 +806,8 @@ class DefaultHandler(OpsHandler[Any]):
assert self_arg == "self"
code.write(
f"""
def {target}(self, {', '.join(args)}):
return self._default({target!r}, ({', '.join(args)}, ), {{}})
def {target}(self, {", ".join(args)}):
return self._default({target!r}, ({", ".join(args)}, ), {{}})
""".strip()
)
code.write("\n\n")
@ -994,8 +994,9 @@ class KernelFormatterHandler(DefaultHandler):
)
formatter._output.writeline(f"{lhs} = {name}")
with V.set_ops_handler(formatter), patch.object(
FlexibleLayout, "allow_indexing", True
with (
V.set_ops_handler(formatter),
patch.object(FlexibleLayout, "allow_indexing", True),
):
result = ir_fn(*args)
return formatter.getvalue(result)

View File

@ -188,7 +188,9 @@ def package_aoti(
) or (
isinstance(archive_file, (str, os.PathLike))
and os.fspath(archive_file).endswith(".pt2")
), f"Expect archive file to be a file ending in .pt2, or is a buffer. Instead got {archive_file}"
), (
f"Expect archive file to be a file ending in .pt2, or is a buffer. Instead got {archive_file}"
)
# Save using the PT2 packaging format
# (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a)
@ -285,9 +287,9 @@ class AOTICompiledModel:
def load_package(path: FileLike, model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg]
assert (
isinstance(path, (io.IOBase, IO)) and path.readable() and path.seekable()
) or (
isinstance(path, (str, os.PathLike)) and os.fspath(path).endswith(".pt2")
), f"Unable to load package. Path must be a buffer or a file ending in .pt2. Instead got {path}"
) or (isinstance(path, (str, os.PathLike)) and os.fspath(path).endswith(".pt2")), (
f"Unable to load package. Path must be a buffer or a file ending in .pt2. Instead got {path}"
)
if isinstance(path, (io.IOBase, IO)):
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:

View File

@ -90,20 +90,17 @@ NodeOrConstant = Union[Constant, torch.fx.Node]
class SearchFn(Protocol):
__name__: str
def __call__(self, *args: Any, **kwargs: Any) -> Any:
...
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
class ReplaceFn(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
...
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
class TraceFn(Protocol):
def __call__(
self, fn: Union[SearchFn, ReplaceFn], *args: Any, **kwargs: Any
) -> torch.fx.GraphModule:
...
) -> torch.fx.GraphModule: ...
T = TypeVar("T")
@ -365,8 +362,7 @@ class PatternExpr(ABC):
"""
@abstractmethod
def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
...
def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: ...
def match(self, node: torch.fx.Node) -> MatchResult:
try:
@ -489,8 +485,7 @@ class _TargetExpr(PatternExpr):
@property
@abstractmethod
def op(self) -> str:
...
def op(self) -> str: ...
def fns_repr(self) -> str:
first_repr = self.fns[0]
@ -997,8 +992,9 @@ class PatternPrettyPrinter:
class _PassDictsType(Protocol):
def __getitem__(self, k: tuple[str, torch.fx.node.Target]) -> list[PatternEntry]:
...
def __getitem__(
self, k: tuple[str, torch.fx.node.Target]
) -> list[PatternEntry]: ...
@dataclasses.dataclass
@ -1925,7 +1921,10 @@ def fx_to_pattern(
get_attr = _not_implemented
def placeholder(
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
self,
target: str, # type: ignore[override]
args: Sequence[Any],
kwargs: Mapping[str, Any],
) -> Union[ExclusiveKeywordArg, KeywordArg]:
n = next(argnum)
if n < len(argnames):
@ -1942,7 +1941,10 @@ def fx_to_pattern(
return KeywordArg(name)
def call_function(
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
self,
target: str, # type: ignore[override]
args: Sequence[Any],
kwargs: Mapping[str, Any],
) -> PatternExpr:
process_arg_fn = process_arg
# Indexing is critical for matching getitem nodes, so we can't ignore int args here

View File

@ -24,7 +24,7 @@ T = TypeVar("T")
def time_and_count(
fn: Callable[Concatenate[Any, P], T]
fn: Callable[Concatenate[Any, P], T],
) -> Callable[Concatenate[Any, P], T]:
"""Wraps `fn` with `dynamo_timed` context, and increments the appropriate dynamo
counters. It is expected that `fn` is a method of `Benchmarker` or one of its

View File

@ -77,9 +77,9 @@ def validate_triton_config(cfg: Config) -> None:
# right now, if a pre-hook is attached to the config, it will not be saved;
# and then it won't be used when the config is loaded from cache.
# So we assert - if we do get a pre_hook, it might get ignored after caching.
assert (
getattr(cfg, "pre_hook", None) is None
), "triton configs with pre_hooks not supported"
assert getattr(cfg, "pre_hook", None) is None, (
"triton configs with pre_hooks not supported"
)
def create_bandwidth_info_str(

View File

@ -450,9 +450,9 @@ class CachingAutotuner(KernelInterface):
self.launchers = []
def __getstate__(self) -> dict[str, Any]:
assert (
not self.launchers
), "pickle should not be called with after make_launchers()"
assert not self.launchers, (
"pickle should not be called with after make_launchers()"
)
return {
**self.__dict__,
"lock": None,
@ -678,7 +678,9 @@ class CachingAutotuner(KernelInterface):
assert isinstance(
arg,
torch.Tensor,
), "self.reset_to_zero_arg_names should only contain valid argument names"
), (
"self.reset_to_zero_arg_names should only contain valid argument names"
)
arg.zero_()
for name, arg in kwargs.items():
@ -686,7 +688,9 @@ class CachingAutotuner(KernelInterface):
assert isinstance(
arg,
torch.Tensor,
), "self.reset_to_zero_arg_names should only contain valid argument names"
), (
"self.reset_to_zero_arg_names should only contain valid argument names"
)
arg.zero_()
def maybe_clone_args(
@ -866,7 +870,9 @@ class CachingAutotuner(KernelInterface):
assert not (
self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION
and "R0_BLOCK" in launcher.config.kwargs
), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK"
), (
"Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK"
)
start_time = time.time_ns()
best_config = self.coordesc_tuner.autotune(
benchmark_one_config, launcher.config, None
@ -882,9 +888,7 @@ class CachingAutotuner(KernelInterface):
)
return config2launcher.get(best_config)
def run(
self, *args, grid, stream, benchmark_run=False, **kwargs
): # type:ignore[override]
def run(self, *args, grid, stream, benchmark_run=False, **kwargs): # type:ignore[override]
if self.triton_interpret:
return self.fn[grid](
*args,
@ -1192,12 +1196,12 @@ class TritonCompileResult:
exec(
f"""
def launcher({', '.join(def_args)}, grid, stream):
def launcher({", ".join(def_args)}, grid, stream):
if callable(grid):
grid_0, grid_1, grid_2 = grid(grid_meta)
else:
grid_0, grid_1, grid_2 = grid
runner({', '.join(runner_args)})
runner({", ".join(runner_args)})
return bin
""".lstrip(),
scope,
@ -1503,9 +1507,9 @@ def check_max_block(cfg: dict[str, int]):
if block_suffix in var:
prefix = var.removesuffix(block_suffix)
max_block = TRITON_MAX_BLOCK[prefix]
assert (
val <= max_block
), f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
assert val <= max_block, (
f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
)
def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False):
@ -1657,20 +1661,20 @@ def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, in
prefix = f"r{idx}_"
max_size = min(size_hints[prefix], TRITON_MAX_BLOCK[prefix.upper()])
dim = min(max_size, remaining)
assert (
remaining % dim == 0
), f"Expected dimension '{dim}' to divide remaining size '{remaining}'"
assert remaining % dim == 0, (
f"Expected dimension '{dim}' to divide remaining size '{remaining}'"
)
rnumels[prefix] = dim
remaining //= dim
# Sanity check the results.
final_numel = conditional_product(*rnumels.values())
assert (
r == final_numel
), f"Expected ND reduction size ({rnumels}) to have {r} elements."
assert all(
rnumels[prefix] <= size_hints[prefix] for prefix in rnumels
), f"rnumels exceed size_hints. {rnumels} > {size_hints}"
assert r == final_numel, (
f"Expected ND reduction size ({rnumels}) to have {r} elements."
)
assert all(rnumels[prefix] <= size_hints[prefix] for prefix in rnumels), (
f"rnumels exceed size_hints. {rnumels} > {size_hints}"
)
return rnumels
@ -1967,9 +1971,9 @@ def cooperative_reduction(
size_hints["x"] = 1
# Cooperative reductions currently only support a single reduction dimension.
assert (
len(size_hints) == 2
), "Cooperative reductions don't support tiling reduction dims"
assert len(size_hints) == 2, (
"Cooperative reductions don't support tiling reduction dims"
)
xnumel, rnumel = size_hints["x"], size_hints["r0_"]
# TODO(jansel): we should base target on the SM count of the local GPU
@ -2274,9 +2278,9 @@ def grid_combo_kernels(
assert min_blocks_d is not None
min_blocks = min_blocks_d
else:
assert (
min_blocks_d is None or min_blocks == min_blocks_d
), f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}"
assert min_blocks_d is None or min_blocks == min_blocks_d, (
f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}"
)
else:
# sequential dispatch
seq_numels = list(numels)

View File

@ -200,9 +200,9 @@ class BaseSchedulerNode:
def __init__(self, scheduler: Scheduler) -> None:
self.scheduler: Scheduler = scheduler
self.debug_device_str: Callable[
[BaseSchedulerNode], list[str]
] = lambda *args, **kwargs: []
self.debug_device_str: Callable[[BaseSchedulerNode], list[str]] = (
lambda *args, **kwargs: []
)
def _init_from_node(self, node: ir.Operation) -> None:
self.node: Optional[ir.Operation] = node
@ -232,7 +232,7 @@ class BaseSchedulerNode:
buf = IndentedBuffer()
buf.splice(
f"""\
{name}: {type(self).__name__}({type(getattr(self, 'node', None)).__name__})
{name}: {type(self).__name__}({type(getattr(self, "node", None)).__name__})
{name}.writes = {pformat(self.read_writes.writes)}
{name}.unmet_dependencies = {pformat(self.unmet_dependencies)}
{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}
@ -525,9 +525,9 @@ class BaseSchedulerNode:
V.kernel.mutations.add(input_buf.get_name())
V.kernel.mutations.add(buf.get_name())
V.kernel.inplace_update_buffers[
buf.get_name()
] = input_buf.get_name()
V.kernel.inplace_update_buffers[buf.get_name()] = (
input_buf.get_name()
)
break
def codegen_originating_info(
@ -693,7 +693,7 @@ class BaseSchedulerNode:
continue
def get_buf_bytes(
buf: Optional[Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject]]
buf: Optional[Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject]],
) -> int:
if not buf:
return 0
@ -794,12 +794,11 @@ class BaseSchedulerNode:
# runtime for that today
return 0
with FakeTensorMode() as fake_mode, FlopCounterMode(
display=False
) as flop_counter_mode, V.set_current_node(
self.node.fx_node
), V.set_fake_mode(
fake_mode
with (
FakeTensorMode() as fake_mode,
FlopCounterMode(display=False) as flop_counter_mode,
V.set_current_node(self.node.fx_node),
V.set_fake_mode(fake_mode),
):
from .ir import ir_node_to_tensor
@ -1123,15 +1122,15 @@ class SchedulerNode(BaseSchedulerNode):
return self._sizes
def is_reduction(self) -> bool:
assert isinstance(
self.node, (ir.ComputedBuffer, ir.TemplateBuffer)
), f"{type(self.node)=}"
assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)), (
f"{type(self.node)=}"
)
return bool(self.node.get_reduction_type())
def is_split_scan(self) -> bool:
assert isinstance(
self.node, (ir.ComputedBuffer, ir.TemplateBuffer)
), f"{type(self.node)=}"
assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)), (
f"{type(self.node)=}"
)
return isinstance(self.node, ir.ComputedBuffer) and isinstance(
self.node.data, ir.SplitScan
)
@ -1163,9 +1162,10 @@ class SchedulerNode(BaseSchedulerNode):
def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None:
var_ranges = self.ranges_from_index_vars(index_vars)
try:
with V.set_ops_handler(
SimplifyIndexing(V.get_ops_handler(), var_ranges)
), V.kernel.set_current_node(self):
with (
V.set_ops_handler(SimplifyIndexing(V.get_ops_handler(), var_ranges)),
V.kernel.set_current_node(self),
):
self._body(*index_vars)
except Exception:
log.fatal("Error in codegen for %s", self.node)
@ -1231,7 +1231,7 @@ class SchedulerNode(BaseSchedulerNode):
def refresh_group_node_dependencies(
group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode]
group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode],
) -> None:
snodes = group_snode.snodes
group_snode.set_read_writes(
@ -1754,7 +1754,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
@staticmethod
def set_group_algorithm_for_combo_kernels(
custom_group_algorithm: Callable[[Scheduler], list[list[BaseSchedulerNode]]]
custom_group_algorithm: Callable[[Scheduler], list[list[BaseSchedulerNode]]],
) -> None:
ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels = (
custom_group_algorithm
@ -1975,9 +1975,9 @@ class Scheduler:
for node in self.nodes:
node.prune_deps()
self.name_to_donated_buffer: dict[
str, SchedulerDonatedBuffer
] = self.get_donated_buffers()
self.name_to_donated_buffer: dict[str, SchedulerDonatedBuffer] = (
self.get_donated_buffers()
)
self.name_to_node: dict[str, BaseSchedulerNode] = {
n.get_name(): n for n in self.nodes
}
@ -2099,9 +2099,9 @@ class Scheduler:
node.log_details()
def create_scheduler_node(self, node: ir.Operation) -> BaseSchedulerNode:
assert (
node.get_origins() is not None
), "All nodes passed to scheduling must have an origin"
assert node.get_origins() is not None, (
"All nodes passed to scheduling must have an origin"
)
if node.is_no_op():
return NopKernelSchedulerNode(self, node)
elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)):
@ -2260,9 +2260,9 @@ class Scheduler:
)
# if a kernel takes unbacked symints, register dependencies
for s in unbacked_symbol_uses:
assert (
s in unbacked_symbol_to_origin_node
), f"{s} not in {unbacked_symbol_to_origin_node}"
assert s in unbacked_symbol_to_origin_node, (
f"{s} not in {unbacked_symbol_to_origin_node}"
)
if (r := unbacked_symbol_to_origin_node[s]) is not None:
for buf in self.name_to_node[r].get_outputs():
node.add_fake_dep(StarDep(buf.get_name()))
@ -2310,9 +2310,9 @@ class Scheduler:
for alt_name in buf.get_mutations():
self.mutation_renames[rename(alt_name)] = buf.get_name()
self.mutation_renames[alt_name] = buf.get_name()
self.mutation_real_name[
buf.get_name()
] = self.mutation_real_name.get(alt_name, alt_name)
self.mutation_real_name[buf.get_name()] = (
self.mutation_real_name.get(alt_name, alt_name)
)
# make sure outputs aren't dead-code-eliminated
for buf_name in V.graph.get_output_names():
@ -2322,9 +2322,9 @@ class Scheduler:
# make sure unbacked symints aren't dead-code-eliminated
for out in V.graph.graph_outputs:
for s in out.get_unbacked_symbol_uses():
assert (
s in unbacked_symbol_to_origin_node
), f"{s} not in {unbacked_symbol_to_origin_node.keys()}"
assert s in unbacked_symbol_to_origin_node, (
f"{s} not in {unbacked_symbol_to_origin_node.keys()}"
)
if r := unbacked_symbol_to_origin_node[s]:
for buf_name in self.name_to_node[r].get_buffer_names():
log.debug(
@ -3304,15 +3304,15 @@ class Scheduler:
rhs_dep = node2_name2dep[buf_name]
if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep):
reasons[
buf_name
] = f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}"
reasons[buf_name] = (
f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}"
)
continue
if lhs_dep.get_numel() != rhs_dep.get_numel():
reasons[
buf_name
] = f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}"
reasons[buf_name] = (
f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}"
)
continue
# same numel but different MemoryDep.size. Should be broadcasting
@ -3340,9 +3340,9 @@ class Scheduler:
layout_str = ""
if not isinstance(buf, ir.TorchBindObject):
layout_str = f"Layout: {buf.layout}"
reasons[
buf_name
] = f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. {layout_str}"
reasons[buf_name] = (
f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. {layout_str}"
)
return str(reasons)
@ -3903,9 +3903,9 @@ class Scheduler:
self.free_buffers()
def create_backend(self, device: torch.device) -> BaseScheduling:
assert (
not is_gpu(device.type) or device.index is not None
), f"{device} should have been normalized in lowering"
assert not is_gpu(device.type) or device.index is not None, (
f"{device} should have been normalized in lowering"
)
V.graph.add_device_info(device)
device_scheduling = get_scheduling_for_device(device.type)
@ -4135,9 +4135,9 @@ class Scheduler:
partitions, signatures = self.graph_partition()
for partition, signature in zip(partitions, signatures):
assert (
len(partition) >= 1
), f"Each partition must have at least one node but found {len(partition)}"
assert len(partition) >= 1, (
f"Each partition must have at least one node but found {len(partition)}"
)
if signature.skip_cudagraph:
self._codegen(partition)

View File

@ -168,9 +168,9 @@ class PartialRender:
)
else:
return
assert (
self.replacement_hooks[hook_key] is not None
), "hook_key can only be called once"
assert self.replacement_hooks[hook_key] is not None, (
"hook_key can only be called once"
)
self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]())
self.replacement_hooks[hook_key] = None
@ -257,9 +257,9 @@ class ModificationWrapper(V.WrapperHandler): # type: ignore[name-defined]
This is used by flex_attention's backwards grad for captured buffers, see
zeros_and_scatter lowering
"""
assert (
self.mask is not None
), "Mask is required for inner stores in modifications"
assert self.mask is not None, (
"Mask is required for inner stores in modifications"
)
assert mode == "atomic_add", "Only atomic_add is supported for inner stores"
buf_name = self._add_kernel_input(name)
@ -573,12 +573,12 @@ class TritonTemplateKernel(TritonKernel):
def _get_subgraph(self, subgraph_number: int):
assert isinstance(subgraph_number, int)
assert isinstance(self.subgraphs, list)
assert subgraph_number < len(
self.subgraphs
), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}"
assert (
self.body.getvalue() == ""
), "Body should be clear before adding a modification"
assert subgraph_number < len(self.subgraphs), (
f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}"
)
assert self.body.getvalue() == "", (
"Body should be clear before adding a modification"
)
return self.subgraphs[subgraph_number]
def _handle_scatter_graph(self, scatter_graph):
@ -587,9 +587,9 @@ class TritonTemplateKernel(TritonKernel):
Args:
scatter_graph: The scatter graph to process
"""
assert isinstance(
scatter_graph, ir.ComputedBuffer
), f"scatter_graph must be an instance of ComputeBuffer but got {type(scatter_graph)}"
assert isinstance(scatter_graph, ir.ComputedBuffer), (
f"scatter_graph must be an instance of ComputeBuffer but got {type(scatter_graph)}"
)
def contiguous_strides(x):
# We always create a fresh contiguous grad for scattering into
@ -597,7 +597,9 @@ class TritonTemplateKernel(TritonKernel):
x_i * stride for x_i, stride in zip(x, scatter_graph.get_stride())
)
return scatter_graph.data.store_output(scatter_graph.name, contiguous_strides, []) # type: ignore[attr-defined]
return scatter_graph.data.store_output( # type: ignore[attr-defined]
scatter_graph.name, contiguous_strides, []
)
def modification(
self,
@ -626,9 +628,9 @@ class TritonTemplateKernel(TritonKernel):
self, subgraph_number, fixed_inputs, mask
)
with V.set_ops_handler(modification_handler):
assert isinstance(
subgraph, (ir.ComputedBuffer, list)
), f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}"
assert isinstance(subgraph, (ir.ComputedBuffer, list)), (
f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}"
)
# Handle scatter stores
if isinstance(subgraph, list):
for scatter_graph in subgraph:
@ -1123,15 +1125,17 @@ class TritonTemplate(KernelTemplate):
"subgraphs": subgraphs,
}
with patch.object(
V.graph, "get_dtype", self._fake_get_dtype(fake_out)
), V.graph.set_current_device(layout.device), TritonTemplateKernel(
with (
patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_out)),
V.graph.set_current_device(layout.device),
TritonTemplateKernel(
kernel_name=kernel_name,
output_node=fake_out,
workspace_arg=workspace_arg,
use_jit=False,
**kernel_options,
) as kernel:
) as kernel,
):
try:
template = kernel.render(self.template, kwargs)
with kernel.set_subgraph_body("<STORE_OUTPUT>"):
@ -1442,9 +1446,9 @@ class ExternKernelCaller(ChoiceCaller):
def output_node(self):
if self.choice.use_fallback_kernel:
assert (
self.choice.op_overload is not None
), "Please provide an op_overload to use ir.FallbackKernel"
assert self.choice.op_overload is not None, (
"Please provide an op_overload to use ir.FallbackKernel"
)
inner = ir.FallbackKernel.create(
self.choice.op_overload, *self.input_nodes, **self.kwargs
)
@ -1979,7 +1983,7 @@ class AlgorithmSelectorCache(PersistentCache):
input_gen_fns = {}
def get_inputs(
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]]
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]],
) -> AutotuneArgs:
# de-duplicate args
unique_example_inputs = {
@ -2099,7 +2103,7 @@ class AlgorithmSelectorCache(PersistentCache):
return timings
def benchmark_in_sub_process(
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]]
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]],
):
from . import autotune_process
@ -2139,7 +2143,8 @@ class AlgorithmSelectorCache(PersistentCache):
map(
str,
V.graph.sizevars.size_hints(
n.get_size(), fallback=config.unbacked_symint_fallback # type: ignore[arg-type]
n.get_size(),
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
),
)
)
@ -2313,15 +2318,15 @@ def autotune_select_algorithm(*args, **kwargs):
_ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache()
if "return_multi_template" not in kwargs:
kwargs[
"return_multi_template"
] = torch._inductor.config.benchmark_epilogue_fusion
kwargs["return_multi_template"] = (
torch._inductor.config.benchmark_epilogue_fusion
)
return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
def add_feedback_saver(
fn: Callable[[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None]
fn: Callable[[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None],
):
global _ALGORITHM_SELECTOR_CACHE
if _ALGORITHM_SELECTOR_CACHE is None:

View File

@ -905,9 +905,9 @@ class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined]
def __init__(self, inner, var_ranges: VarRanges) -> None:
super().__init__(inner)
self.name = "SimplifyIndexing"
self._simplify: Callable[
[Expr], Expr
] = lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges)
self._simplify: Callable[[Expr], Expr] = (
lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges)
)
def load(self, name: str, index: sympy.Expr):
return self._inner.load(name, self._simplify(index))

View File

@ -283,9 +283,9 @@ def ceildiv(
# TODO: There is a bug in a call to this function, to repro:
# python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy
# --amp --only YituTechConvBert --dynamic-shapes
assert isinstance(numer, int) and isinstance(
denom, int
), f"{numer}: {type(numer)}, {denom}: {type(denom)}"
assert isinstance(numer, int) and isinstance(denom, int), (
f"{numer}: {type(numer)}, {denom}: {type(denom)}"
)
return runtime_ceildiv(numer, denom)
@ -325,7 +325,7 @@ def _type_of(key: Optional[torch.dtype]) -> str:
def convert_shape_to_inductor(
lst: Iterable[Union[int, torch.SymInt]]
lst: Iterable[Union[int, torch.SymInt]],
) -> list[sympy.Expr]:
"""
Gets the shape and stride of a tensor. For non-symbolic tensors, this is
@ -502,11 +502,9 @@ RV = TypeVar("RV", covariant=True)
class CachedMethod(Protocol, Generic[P, RV]):
@staticmethod
def clear_cache(cache: Any) -> None:
...
def clear_cache(cache: Any) -> None: ...
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV:
...
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: ...
# See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature
@ -1359,9 +1357,9 @@ def _rocm_native_device_arch_name(device: str) -> str:
@functools.lru_cache(None)
def try_import_ck_lib() -> (
tuple[Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any]]
):
def try_import_ck_lib() -> tuple[
Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any]
]:
try:
import ck4inductor # type: ignore[import]
from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import]
@ -1610,9 +1608,12 @@ def get_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> list[str]:
return DummyModule()
with mock.patch.object(
with (
mock.patch.object(
GraphLowering, "compile_to_module", patched_compile_to_module
), mock.patch.object(GraphLowering, "save_output_code", save_output_code):
),
mock.patch.object(GraphLowering, "save_output_code", save_output_code),
):
torch._dynamo.reset()
# Note the return here is None
_ = fn(*args, **kwargs)
@ -1623,18 +1624,18 @@ def get_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> list[str]:
def get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> str:
source_codes = get_code(fn, *args, **kwargs)
# Can have two outputs if backwards was eagerly compiled
assert (
1 <= len(source_codes) <= 2
), f"expected one or two code outputs got {len(source_codes)}"
assert 1 <= len(source_codes) <= 2, (
f"expected one or two code outputs got {len(source_codes)}"
)
return source_codes[0]
def run_and_get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> str:
_, source_codes = run_and_get_code(fn, *args, **kwargs)
# Can have two outputs if backwards was eagerly compiled
assert (
1 <= len(source_codes) <= 2
), f"expected one or two code outputs got {len(source_codes)}"
assert 1 <= len(source_codes) <= 2, (
f"expected one or two code outputs got {len(source_codes)}"
)
return source_codes[0]
@ -1760,9 +1761,9 @@ def is_cpu_device(inputs: Sequence[torch.Tensor]) -> bool:
def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype:
assert isinstance(
val, sympy.Expr
), "only support sympy.Expr as input to get_sympy_Expr_dtype"
assert isinstance(val, sympy.Expr), (
"only support sympy.Expr as input to get_sympy_Expr_dtype"
)
if val.is_integer: # type: ignore[attr-defined]
return torch.int64
else:
@ -1932,7 +1933,7 @@ def is_multi_outputs_template(input_buf: Optional[Union[Buffer, Operation]]) ->
def is_output_of_multi_outputs_template(
input_buf: Optional[Union[Buffer, Operation]]
input_buf: Optional[Union[Buffer, Operation]],
) -> bool:
"""
Check if input buffer is a output of multi-outputs template buffer
@ -2633,7 +2634,8 @@ def set_kernel_post_grad_provenance_tracing(
if node not in (EnableReduction, DisableReduction):
if node.node is not None:
V.debug._inductor_triton_kernel_to_post_grad_node_info[kernel_name] = [
origin.name for origin in node.node.origins # type: ignore[attr-defined]
origin.name
for origin in node.node.origins # type: ignore[attr-defined]
]

View File

@ -314,9 +314,9 @@ class _V:
KernelFormatterHandler = KernelFormatterHandler
WrapperHandler = WrapperHandler
set_ops_handler: Callable[
[OpsHandler[Any]], AbstractContextManager[None]
] = _ops._set_handler
set_ops_handler: Callable[[OpsHandler[Any]], AbstractContextManager[None]] = (
_ops._set_handler
)
get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler
set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler
set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler

View File

@ -14,8 +14,7 @@ from .runtime.runtime_utils import create_bandwidth_info_str, get_num_bytes
class BenchmarkCallableType(Protocol):
def __call__(self, times: int, repeat: int) -> float:
...
def __call__(self, times: int, repeat: int) -> float: ...
_kernel_category_choices = [
@ -138,9 +137,9 @@ def benchmark_all_kernels(
)
else:
ms = benchmarker.benchmark_gpu(lambda: kernel_mod.call(args), rep=40)
assert (
len(triton_kernel.launchers) == 1
), "Autotuner should have selected the best config"
assert len(triton_kernel.launchers) == 1, (
"Autotuner should have selected the best config"
)
launcher = triton_kernel.launchers[0]
print(
get_info_str(
@ -256,9 +255,9 @@ def parse_profile_event_list(
"triton_unknown",
"unknown",
]
assert OrderedSet(all_events.keys()).issubset(
OrderedSet(category_list)
), f"{list(all_events.keys())}"
assert OrderedSet(all_events.keys()).issubset(OrderedSet(category_list)), (
f"{list(all_events.keys())}"
)
per_category_wall_time = {}
total_device_ms = 0.0