[BE][PYFMT] migrate PYFMT for torch._inductor to ruff format (#144550)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144550
Approved by: https://github.com/jansel
This commit is contained in:
Xuehai Pan 2025-02-28 15:35:13 +08:00 committed by PyTorch MergeBot
parent 34d726011f
commit 1cb4e2df65
88 changed files with 1157 additions and 954 deletions

View File

@ -53,7 +53,6 @@ USE_BLACK_FILELIST = re.compile(
# torch/_[e-h]*/**
"torch/_[e-h]*/**",
# torch/_i*/**
"torch/_i*/**",
# torch/_[j-z]*/**
"torch/_[j-z]*/**",
# torch/[a-c]*/**

View File

@ -66,7 +66,9 @@ def aoti_compile_and_package(
.. code-block:: python
ep = torch.export.export(M(), ...)
aoti_file = torch._inductor.aoti_compile_and_package(ep, package_path="my_package.pt2")
aoti_file = torch._inductor.aoti_compile_and_package(
ep, package_path="my_package.pt2"
)
compiled_model = torch._inductor.aoti_load_package("my_package.pt2")
To compile and save multiple models into a single ``.pt2`` artifact, you can do
@ -75,11 +77,16 @@ def aoti_compile_and_package(
.. code-block:: python
ep1 = torch.export.export(M1(), ...)
aoti_file1 = torch._inductor.aot_compile(ep1, ..., options={"aot_inductor.package": True})
aoti_file1 = torch._inductor.aot_compile(
ep1, ..., options={"aot_inductor.package": True}
)
ep2 = torch.export.export(M2(), ...)
aoti_file2 = torch._inductor.aot_compile(ep2, ..., options={"aot_inductor.package": True})
aoti_file2 = torch._inductor.aot_compile(
ep2, ..., options={"aot_inductor.package": True}
)
from torch._inductor.package import package_aoti, load_package
package_aoti("my_package.pt2", {"model1": aoti_file1, "model2": aoti_file2})
compiled_model1 = load_package("my_package.pt2", "model1")
@ -123,7 +130,9 @@ def aoti_compile_and_package(
isinstance(package_path, (str, os.PathLike))
and os.fspath(package_path).endswith(".pt2")
)
), f"Expect package path to be a file ending in .pt2, is None, or is a buffer. Instead got {package_path}"
), (
f"Expect package path to be a file ending in .pt2, is None, or is a buffer. Instead got {package_path}"
)
inductor_configs = inductor_configs or {}
inductor_configs["aot_inductor.package"] = True
@ -168,9 +177,9 @@ def _aoti_compile_and_package_inner(
"""
if check_accuracy:
assert (
kwargs is None or len(kwargs) == 0
), "when checking for accuracy, the inputs must have been flattened and kwargs is None"
assert kwargs is None or len(kwargs) == 0, (
"when checking for accuracy, the inputs must have been flattened and kwargs is None"
)
from .package import package_aoti

View File

@ -156,8 +156,9 @@ def can_codegen_without_upcasts(
low_prec_analysis = RecordLowPrecisionOps(disallow_fp32_ops)
# Need to turn off upcasting to do analysis of whether we can turn it off
with config.patch("triton.codegen_upcast_to_fp32", False), V.set_ops_handler(
low_prec_analysis
with (
config.patch("triton.codegen_upcast_to_fp32", False),
V.set_ops_handler(low_prec_analysis),
):
prologue._body(*prologue.get_ranges())

View File

@ -245,8 +245,7 @@ class AsyncCompile:
def use_process_pool(self):
return (
get_compile_threads() > 1
and self.process_pool().ready_future.done() # type: ignore[union-attr]
get_compile_threads() > 1 and self.process_pool().ready_future.done() # type: ignore[union-attr]
)
def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):

View File

@ -24,8 +24,7 @@ if TYPE_CHECKING:
class Sortable(typing.Protocol):
"""Anything that can be used as a list.sort() key (int/tuple/etc)"""
def __lt__(self, other: typing.Self) -> bool:
...
def __lt__(self, other: typing.Self) -> bool: ...
class InductorChoices:
@ -100,7 +99,9 @@ class InductorChoices:
# to pick the faster one.
if config.triton.multi_kernel:
threshold *= 16
return V.graph.sizevars.statically_known_leq(features.reduction_numel, threshold) # type: ignore[arg-types]
return V.graph.sizevars.statically_known_leq(
features.reduction_numel, threshold
) # type: ignore[arg-types]
@staticmethod
def want_no_x_dim(features: SIMDKernelFeatures) -> bool:

View File

@ -417,9 +417,9 @@ def write_atomic(
) -> None:
# Write into temporary file first to avoid conflicts between threads
# Avoid using a named temporary file, as those have restricted permissions
assert isinstance(
content, (str, bytes)
), "Only strings and byte arrays can be saved in the cache"
assert isinstance(content, (str, bytes)), (
"Only strings and byte arrays can be saved in the cache"
)
path = Path(path_)
if make_dirs:
path.parent.mkdir(parents=True, exist_ok=True)
@ -975,9 +975,9 @@ class FxGraphCache:
symints = FxGraphCache._filter_backed_symints(example_inputs)
hints = [hint_int(s) for s in symints]
def iterate_over_candidates() -> (
Generator[tuple[CompiledFxGraph, bytes], None, None]
):
def iterate_over_candidates() -> Generator[
tuple[CompiledFxGraph, bytes], None, None
]:
if local:
subdir = FxGraphCache._get_tmp_dir_for_key(key)
if os.path.exists(subdir):
@ -1123,9 +1123,9 @@ class FxGraphCache:
"""
from .compile_fx import CompiledFxGraph
assert isinstance(
compiled_graph, CompiledFxGraph
), f"serialization for {type(compiled_graph)} NYI"
assert isinstance(compiled_graph, CompiledFxGraph), (
f"serialization for {type(compiled_graph)} NYI"
)
disk_compiled_graph = copy(compiled_graph)
disk_compiled_graph.prepare_for_serialization()
@ -1315,9 +1315,8 @@ class FxGraphCache:
"distributed_ephemeral_timeout_us", time_saved_ns // 1000
)
if (
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(
time_saved_ns
)
ephemeral_increase
:= add_ephemeral_timeout_increase_for_distributed(time_saved_ns)
) != 0:
cache_info["ephemeral_timeout_increase"] = ephemeral_increase
else:
@ -1556,9 +1555,9 @@ class AotCodeCompiler:
cpp_path_operator.with_name(f"{cpp_path_operator.stem}_metadata.json")
)
for k, v in config.aot_inductor.metadata.items():
assert isinstance(k, str) and isinstance(
v, (str)
), "Metadata must only contain strings"
assert isinstance(k, str) and isinstance(v, (str)), (
"Metadata must only contain strings"
)
with open(meta_json, "w") as f:
f.write(json.dumps(config.aot_inductor.metadata))

View File

@ -341,7 +341,7 @@ class BackendFeature(Enum):
def get_backend_features(
device: Union[torch.device, str, None]
device: Union[torch.device, str, None],
) -> OrderedSet[BackendFeature]:
if device is None:
return OrderedSet()
@ -986,9 +986,9 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
if cls._is_unimplemented(funcname):
setattr(cls, funcname, cls._unimplemented(funcname))
else:
assert (
funcname not in cls.__dict__
), f"multiple definitions of {funcname} on {cls.__name__}"
assert funcname not in cls.__dict__, (
f"multiple definitions of {funcname} on {cls.__name__}"
)
impl.__name__ = funcname
setattr(cls, funcname, staticmethod(impl))
@ -2229,7 +2229,7 @@ class KernelTemplate:
@staticmethod
def _fake_get_dtype(
fake_outs: Union[list[Buffer], Buffer]
fake_outs: Union[list[Buffer], Buffer],
) -> Callable[[str], torch.dtype]:
_get_dtype_real = V.graph.get_dtype
if isinstance(fake_outs, (list, tuple)):

View File

@ -483,9 +483,9 @@ class OuterLoopFusedSchedulerNode(FusedSchedulerNode):
outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]],
outer_loop_fusion_depth,
):
self.outer_fused_nodes: list[
Union[FusedSchedulerNode, SchedulerNode]
] = outer_fused_nodes
self.outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]] = (
outer_fused_nodes
)
self.outer_loop_fusion_depth = outer_loop_fusion_depth
flatten_snodes = []
for _node in self.outer_fused_nodes:
@ -1361,9 +1361,9 @@ class CppVecOverrides(CppOverrides):
@staticmethod
def remainder(a, b):
assert (
a.dtype == b.dtype
), "remainder vec implementation expect the same inputs' dtype."
assert a.dtype == b.dtype, (
"remainder vec implementation expect the same inputs' dtype."
)
return f"{a} - ({CppVecOverrides.floordiv(a, b)}) * {b}"
@staticmethod
@ -1468,9 +1468,9 @@ class CppVecOverrides(CppOverrides):
@staticmethod
def floordiv(a, b):
if is_float_dtype(a.dtype):
assert (
a.dtype == b.dtype
), "div_floor_floating_vec implementation expect the same inputs' dtype."
assert a.dtype == b.dtype, (
"div_floor_floating_vec implementation expect the same inputs' dtype."
)
return f"div_floor_floating_vec({a}, {b})"
else:
assert all(is_integer_dtype(item.dtype) for item in [a, b])
@ -1629,9 +1629,9 @@ class CppVecOverrides(CppOverrides):
assert isinstance(other_vec_var, CppCSEVariable), other_vec_var
body_vec_var.dtype = dtype
other_vec_var.dtype = dtype
overrides: type[
Union[CppOverrides, CppVecOverrides]
] = V.kernel.overrides # type: ignore[has-type]
overrides: type[Union[CppOverrides, CppVecOverrides]] = (
V.kernel.overrides
) # type: ignore[has-type]
code.writeline(
f"return {overrides.where(new_mask, body_vec_var, other_vec_var)};"
)
@ -2108,9 +2108,9 @@ class CppKernel(Kernel):
def set_ranges(self, lengths, reduction_lengths):
if self.call_ranges:
assert self.call_ranges == tuple(lengths) + tuple(
reduction_lengths
), f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}"
assert self.call_ranges == tuple(lengths) + tuple(reduction_lengths), (
f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}"
)
assert self.reduction_depth == len(lengths)
else:
self.call_ranges = tuple(lengths) + tuple(reduction_lengths)
@ -2787,9 +2787,9 @@ class CppVecKernel(CppKernel):
self.weight_recps_val = self.weight_recps_cse.generate(
self.compute, f"reduction {self.weight_recp_vec_range}", write=False
)
self.weight_recps_cse.reduction_cache[
self.weight_recp_vec_range
] = self.weight_recps_val
self.weight_recps_cse.reduction_cache[self.weight_recp_vec_range] = (
self.weight_recps_val
)
self.non_parallel_reduction_prefix.writeline(
self.welford_weight_reciprocal_vec(dtype)
)
@ -4969,9 +4969,9 @@ class CppScheduling(BaseScheduling):
]
counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes)
assert self.is_cpp_template(
template_node
), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer"
assert self.is_cpp_template(template_node), (
"Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer"
)
template_node = cast(SchedulerNode, template_node)
_, (_, rnumel) = template_node.group
assert rnumel == ()
@ -4979,9 +4979,9 @@ class CppScheduling(BaseScheduling):
epilogue_ir_nodes: list[Optional[ir.Operation]] = [
n.node for n in epilogue_nodes
]
assert all(
isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes
), "Epilogue nodes must all be instances of ir.ComputedBuffer"
assert all(isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes), (
"Epilogue nodes must all be instances of ir.ComputedBuffer"
)
def template_buffer_has_other_users(
template_buffer, outputs_by_name, epilogue_nodes
@ -5019,16 +5019,16 @@ class CppScheduling(BaseScheduling):
if is_multi_outputs_template(template_node.node):
# For multi outputs template, allocate buffers for each output after the epilogue
# codegen to which determines if the buffer has been removed.
assert (
len(template_node.outputs) == 1
), "Multi outputs template should be with 1 output template buffer of MultiOutputLayout"
assert len(template_node.outputs) == 1, (
"Multi outputs template should be with 1 output template buffer of MultiOutputLayout"
)
for user in template_node.outputs[0].users:
assert isinstance(
user.node, ExternKernelSchedulerNode
), "Multi outputs template should be with ExternKernelSchedulerNode"
assert isinstance(
user.node.node, ir.MultiOutput
), "Multi outputs template has multi users with MultiOutput"
assert isinstance(user.node, ExternKernelSchedulerNode), (
"Multi outputs template should be with ExternKernelSchedulerNode"
)
assert isinstance(user.node.node, ir.MultiOutput), (
"Multi outputs template has multi users with MultiOutput"
)
user.node.mark_run()
kernel.call_kernel(kernel_name, ctb)
@ -5347,9 +5347,9 @@ class LoopNest:
return self.loops is not None and self.loops[0].is_reduction
def mark_parallel(self, par_depth):
assert (
par_depth <= self.max_parallel_depth()
), "Parallel depth cannot exceed the maximal allowed parallel depth"
assert par_depth <= self.max_parallel_depth(), (
"Parallel depth cannot exceed the maximal allowed parallel depth"
)
assert self.loops is not None
assert len(self.loops) >= par_depth
loop = self.loops[0]

View File

@ -862,7 +862,9 @@ class CppFlexAttentionTemplate(CppTemplate):
assert all(
mem.buffer_name in kernel_group.args.input_buffers
for mem in body.memory_usage[MemoryUsageType.LOAD]
), "All the buffers in the score and mask subgraph should be in kernel_group.args.input_buffers"
), (
"All the buffers in the score and mask subgraph should be in kernel_group.args.input_buffers"
)
bodies.append(body)
var_sizes_list.append((var_sizes, ()))

View File

@ -557,9 +557,9 @@ class CppGemmTemplate(CppTemplate):
thread_block_m = math.ceil(m_blocks / m_factor)
return GemmBlocking(thread_block_m, thread_block_n, thread_block_k)
assert (
not self.is_dynamic_M
), "Unable to determine thread blocking for dynamic M."
assert not self.is_dynamic_M, (
"Unable to determine thread blocking for dynamic M."
)
register_blocking = self.register_blocking
m_blocks = math.ceil(self.m / register_blocking.block_m)
n_blocks = math.ceil(self.n / register_blocking.block_n)
@ -673,17 +673,17 @@ class CppGemmTemplate(CppTemplate):
L1_cache_size = (
torch._C._cpu._L1d_cache_size()
) # per core cache size in Bytes
assert (
L1_cache_size > 0
), f"Expect L1_cache_size > 0 but got {L1_cache_size}"
assert L1_cache_size > 0, (
f"Expect L1_cache_size > 0 but got {L1_cache_size}"
)
L1 = L1_cache_size * L1_limit_factor
L2_cache_size = (
torch._C._cpu._L2_cache_size()
) # per core cache size in Bytes
assert (
L2_cache_size > 0
), f"Expect L2_cache_size > 0 but got {L2_cache_size}"
assert L2_cache_size > 0, (
f"Expect L2_cache_size > 0 but got {L2_cache_size}"
)
L2 = L2_cache_size * L2_limit_factor
def get_num_byte(dtype):
@ -744,9 +744,9 @@ class CppGemmTemplate(CppTemplate):
return Mc_blocks, Nc_blocks, Kc_blocks
assert (
not self.is_dynamic_M
), "Unable to determine cache blocking for dynamic M."
assert not self.is_dynamic_M, (
"Unable to determine cache blocking for dynamic M."
)
register_blocking = self.register_blocking
thread_blocking = self.thread_blocking(num_threads)
@ -1114,9 +1114,9 @@ class CppGemmTemplate(CppTemplate):
LayoutType.VNNI4,
], f"We only support {layout_str} for now"
vnni_size = 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2
assert (
k % vnni_size == 0
), f"k should be divisible by vnni_size for {layout_str} layout"
assert k % vnni_size == 0, (
f"k should be divisible by vnni_size for {layout_str} layout"
)
vnni_view_size = list(new_size)
vnni_view_size[-2] = k // vnni_size
vnni_view_size.insert(-1, vnni_size)

View File

@ -309,9 +309,9 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
for W_node in W_nodes:
assert W_node.get_name() in V.graph.constants
W_tensor.append(V.graph.constants[W_node.get_name()])
new_input_nodes[
wgt_start_idx : wgt_start_idx + gemm_grouped_num
] = W_tensor # type: ignore[assignment]
new_input_nodes[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = (
W_tensor # type: ignore[assignment]
)
new_input_nodes, _ = pack_weight(
*normalize_shapes(*maybe_to_dense(new_input_nodes, layout))
)
@ -321,9 +321,9 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
W_packed = new_input_nodes[idx]
assert isinstance(W_packed, torch.Tensor)
W_packed_constant = V.graph.add_tensor_constant(W_packed)
template_buffer.inputs[
idx
] = ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
template_buffer.inputs[idx] = (
ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
)
return output
template = DataProcessorTemplateWrapper(
@ -419,9 +419,9 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
ir.Buffer(name=gemm_output_name, layout=template_buffer.layout)
)
assert (
not self.epilogue_creator
), "epilogue_creator is not supported yet in Grouped GEMM Template"
assert not self.epilogue_creator, (
"epilogue_creator is not supported yet in Grouped GEMM Template"
)
kernel_args: dict[str, Optional[ir.IRNode]] = {}
for x_idx in range(wgt_start_idx):

View File

@ -231,9 +231,9 @@ micro_gemm_configs: dict[type[CppMicroGemm], list[CppMicroGemmConfig]] = {}
def register_micro_gemm(*configs):
def inner(cls):
assert (
cls not in micro_gemm_configs
), f"Duplicate micro_gemm registration for {cls}"
assert cls not in micro_gemm_configs, (
f"Duplicate micro_gemm registration for {cls}"
)
assert len(configs) > 0, f"No micro_gemm configs provided for {cls}"
micro_gemm_configs[cls] = list(configs)
return cls

View File

@ -44,11 +44,13 @@ class CppTemplate(KernelTemplate):
def generate(self, **kwargs):
kernel_name = f"cpp_{self.name}"
with patch.object(
V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
), patch.object(ir.FlexibleLayout, "allow_indexing", True), CppTemplateKernel(
kernel_name=kernel_name, num_threads=self.num_threads
) as kernel:
with (
patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)),
patch.object(ir.FlexibleLayout, "allow_indexing", True),
CppTemplateKernel(
kernel_name=kernel_name, num_threads=self.num_threads
) as kernel,
):
code = kernel.render(self, **kwargs)
_, call_args, _, _ = kernel.args.python_argdefs()
log.debug("Generated Code:\n%s", code)

View File

@ -377,7 +377,10 @@ class CppTemplateKernel(CppKernel):
)
epilogue_nodes = scope.localize_nodes(epilogue_nodes)
return self.store_pointwise_nodes(
dst, epilogue_nodes, offsets, reindexers # type: ignore[arg-type]
dst,
epilogue_nodes, # type: ignore[arg-type]
offsets,
reindexers,
)
else:
if dst.get_name() != src.get_name():

View File

@ -110,9 +110,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
Only valid when cuda == True.
"""
assert not gpu, "CppWrapperCpu.generate_kernel_call does not support GPU"
assert arg_types is not None and len(call_args) == len(
arg_types
), "Mismatch call_args and arg_types in generate_kernel_call"
assert arg_types is not None and len(call_args) == len(arg_types), (
"Mismatch call_args and arg_types in generate_kernel_call"
)
new_args = []
for idx, arg in enumerate(call_args):
if "*" in arg_types[idx]:
@ -506,9 +506,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
dtype = may_get_constant_buffer_dtype(
V.graph.graph_inputs[input_key] # type: ignore[arg-type]
)
assert (
dtype is not None
), "Fails to get the dtype of the sympy.Expr"
assert dtype is not None, (
"Fails to get the dtype of the sympy.Expr"
)
self.codegen_tensor_item(
dtype, f"inputs[{idx}]", input_key, self.prefix
)
@ -555,8 +555,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
def codegen_tensor_dtype_var_decl(self, code: IndentedBuffer, name):
code.writeline(f"int32_t {name}_dtype;")
code.writeline(
"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype"
f"({name}, &{name}_dtype));"
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype({name}, &{name}_dtype));"
)
def codegen_input_size_var_decl(self, code: IndentedBuffer, name):
@ -570,9 +569,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
# Tell compiler we need to link with the non-mangled symbols
for kernel in self.initialized_kernels.values():
assert hasattr(
kernel, "get_signature"
), f"{kernel} must have get_signature implemented"
assert hasattr(kernel, "get_signature"), (
f"{kernel} must have get_signature implemented"
)
signature = kernel.get_signature()
self.prefix.writeline(f'extern "C" {signature};')
@ -597,9 +596,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
)
)
for name, kernel in self.initialized_kernels.items():
assert hasattr(
kernel, "get_signature"
), f"{kernel} must have get_signature implemented"
assert hasattr(kernel, "get_signature"), (
f"{kernel} must have get_signature implemented"
)
kernel_ptr = f"(*{name})"
signature = kernel.get_signature().replace(name, kernel_ptr)
self.prefix.writeline(f" {signature} = torch::aot_inductor::{name};")
@ -645,9 +644,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
with self.prefix.indent():
for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()):
assert not isinstance(
inp, sympy.Expr
), f"input {name=} cannot be symbolic"
assert not isinstance(inp, sympy.Expr), (
f"input {name=} cannot be symbolic"
)
self.write_input_output_info("inputs_info_", idx, name)
all_cuda = all(
@ -718,9 +717,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
opaque_metadata_tensor = torch.ops.mkldnn._get_mkldnn_serialized_md(
tensor
)
assert (
opaque_metadata_tensor.dim() == 1
), "Expect opaque_metadata_tensor to be 1-D"
assert opaque_metadata_tensor.dim() == 1, (
"Expect opaque_metadata_tensor to be 1-D"
)
opaque_metadata_list = opaque_metadata_tensor.tolist()
opaque_metadata_str = self.codegen_shape_tuple(opaque_metadata_list)
@ -757,9 +756,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
)
for idx, output in enumerate(V.graph.graph_outputs):
assert not isinstance(
output, sympy.Expr
), f"output {name=} cannot be symbolic"
assert not isinstance(output, sympy.Expr), (
f"output {name=} cannot be symbolic"
)
name = f"output{idx}"
self.write_input_output_info("outputs_info_", idx, name)
@ -816,9 +815,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
for idx, (name, _) in enumerate(V.graph.constants.items()):
if name in V.graph.const_output_index:
const_index_mapping[V.graph.const_output_index[name]] = (idx, name) # type: ignore[call-overload]
assert (
None not in const_index_mapping
), "Not all constant gets mapped for constant folding graph."
assert None not in const_index_mapping, (
"Not all constant gets mapped for constant folding graph."
)
self.prefix.writeline(
f"""
@ -1117,9 +1116,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
name = f"{output.get_name()}"
output_handle_name = f"{name}_handle"
if output.indices:
assert (
output.indices[0][1] == idx
), f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}"
assert output.indices[0][1] == idx, (
f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}"
)
self.writeline(f"AtenTensorHandle {output_handle_name};")
output_args.append(f"&{output_handle_name}")
output_raii_handles.append(
@ -1140,7 +1139,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
args = args + output_args
device = d.type if (d := fallback_kernel.get_device()) else self.device
self.generate_c_shim_extern_kernel_call(
fallback_kernel.cpp_kernel_name, args, device # type: ignore[arg-type]
fallback_kernel.cpp_kernel_name, # type: ignore[arg-type]
args,
device,
)
for raii_handle in output_raii_handles:
self.writeline(raii_handle)
@ -1189,9 +1190,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
if reduce:
line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
else:
assert (
reduce is None
), "Expect reduce to be None for aten.scatter_ with scalar src"
assert reduce is None, (
"Expect reduce to be None for aten.scatter_ with scalar src"
)
line += ");"
self.writeline(line)
@ -1841,18 +1842,24 @@ class CppWrapperCpu(PythonWrapperCodegen):
# Only treat int Scalar as dynamic
is_int_type = [isinstance(a, int) for a in arg]
if any(is_int_type):
assert all(
is_int_type
), "AOTInductor only supports int scalars of the same type"
assert all(is_int_type), (
"AOTInductor only supports int scalars of the same type"
)
new_int_args.extend([str(a) for a in arg])
else:
assert isinstance(
arg_type.getElementType(), static_arg_types # type: ignore[arg-type]
), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
arg_type.getElementType(),
static_arg_types, # type: ignore[arg-type]
), (
f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
)
else:
assert isinstance(
arg_type, static_arg_types # type: ignore[arg-type]
), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
arg_type,
static_arg_types, # type: ignore[arg-type]
), (
f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
)
for arg, arg_type in zip(raw_args, arg_types):
if arg is not None:
@ -2378,9 +2385,9 @@ if (custom_op_wrapper.get() == NULL) {
return f"&{var_name}"
if isinstance(type_, torch.ListType):
assert isinstance(
val, (list, tuple)
), f"{val} does not match with arg type {type_}"
assert isinstance(val, (list, tuple)), (
f"{val} does not match with arg type {type_}"
)
element_type = type_.getElementType()
var_name = f"var_array_{next(self.var_array_id)}"
if len(val) == 0:

View File

@ -56,9 +56,9 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
self.cached_output_id = count()
self.scalar_to_tensor_id = count()
self.custom_op_wrapper_loaded = False
self.allow_stack_allocation: Optional[
bool
] = config.aot_inductor.allow_stack_allocation
self.allow_stack_allocation: Optional[bool] = (
config.aot_inductor.allow_stack_allocation
)
self.stack_allocated_buffers: dict[BufferName, BufferLike] = {}
@staticmethod
@ -126,12 +126,12 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
Otherwise it uses the CUDA language for codegen.
Only valid when cuda == True.
"""
assert (
not gpu
), "CppWrapperCpuArrayRef.generate_kernel_call does not support GPU"
assert arg_types is not None and len(call_args) == len(
arg_types
), "Mismatch call_args and arg_types in generate_kernel_call"
assert not gpu, (
"CppWrapperCpuArrayRef.generate_kernel_call does not support GPU"
)
assert arg_types is not None and len(call_args) == len(arg_types), (
"Mismatch call_args and arg_types in generate_kernel_call"
)
new_args = []
for idx, arg in enumerate(call_args):
if "*" in arg_types[idx]:
@ -328,9 +328,9 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
dtype = may_get_constant_buffer_dtype(
V.graph.graph_inputs[input_key] # type: ignore[arg-type]
)
assert (
dtype is not None
), "Fails to get the dtype of the sympy.Expr"
assert dtype is not None, (
"Fails to get the dtype of the sympy.Expr"
)
self.codegen_tensor_item(
dtype, f"inputs[{idx}]", input_key, self.prefix
)
@ -724,9 +724,9 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
if reduce:
line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
else:
assert (
reduce is None
), "Expect reduce to be None for aten.scatter_ with scalar src"
assert reduce is None, (
"Expect reduce to be None for aten.scatter_ with scalar src"
)
line += ");"
self.writeline(line)

View File

@ -60,13 +60,13 @@ class DeferredGpuKernelLine(DeferredLineBase):
# MultiKernel will select one kernel after running the autotune block
self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
params = CudaKernelParamCache.get(self.kernel_name)
assert (
params is not None
), f"{self.kernel_name} not found in CudaKernelParamCache"
assert params is not None, (
f"{self.kernel_name} not found in CudaKernelParamCache"
)
for key in self.keys:
assert (
key in params
), f"{key} not found in CudaKernelParamCache[{self.kernel_name}]"
assert key in params, (
f"{key} not found in CudaKernelParamCache[{self.kernel_name}]"
)
if key == get_cpp_wrapper_cubin_path_name():
assert os.path.exists(params[key]), f"{params[key]} does not exist"
self.additional_files.append(params[key])
@ -122,9 +122,9 @@ class DeferredGpuDefaultGrid:
grid_fn = self.grid_callable(*grid, **self.grid_extra_kwargs)
params = CudaKernelParamCache.get(self.kernel_name)
assert (
params is not None
), f"{self.kernel_name} not found in CudaKernelParamCache"
assert params is not None, (
f"{self.kernel_name} not found in CudaKernelParamCache"
)
return grid_fn(params["meta"])
@ -153,9 +153,9 @@ class DeferredGpuGridLine(DeferredLineBase):
self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
params = CudaKernelParamCache.get(self.kernel_name)
assert (
params is not None
), f"{self.kernel_name} not found in CudaKernelParamCache"
assert params is not None, (
f"{self.kernel_name} not found in CudaKernelParamCache"
)
if self.autotune_configs is not None:
# This indicates the Triton kernel is a user-defined one.
@ -248,13 +248,13 @@ class CppWrapperGpu(CppWrapperCpu):
if V.graph.aot_mode and V.graph.inputs_to_check:
for idx in V.graph.inputs_to_check:
input_name = V.graph.graph_input_names[idx]
assert (
input_name in V.graph.graph_inputs
), f"{input_name} not found in graph inputs"
assert input_name in V.graph.graph_inputs, (
f"{input_name} not found in graph inputs"
)
value = V.graph.graph_inputs[input_name]
assert isinstance(
value, TensorBox
), f"{input_name} is expected to be tensor but found as {type(value)}"
assert isinstance(value, TensorBox), (
f"{input_name} is expected to be tensor but found as {type(value)}"
)
warn_msg = (
f"Input {idx} was compiled as {GPU_ALIGN_BYTES}-bytes aligned, "
"but it is not aligned at run time. Copying to an aligned tensor "

View File

@ -87,9 +87,9 @@ class CUDACPPScheduling(BaseScheduling):
Codegen a CUDA template, possibly with fused epilogues
"""
counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes)
assert self.is_cuda_cpp_template(
template_node
), "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer"
assert self.is_cuda_cpp_template(template_node), (
"Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer"
)
template_node = cast(SchedulerNode, template_node)
_, (_numel, rnumel) = template_node.group
assert rnumel == 1

View File

@ -496,7 +496,9 @@ class CUDATemplateCaller(ChoiceCaller):
make_kernel_render: Callable[[CUDATemplateBuffer, Optional[list[IRNode]]], str],
bmreq: CUDABenchmarkRequest,
template: "CUDATemplate", # type: ignore[name-defined]
info_kwargs: Optional[dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]], # type: ignore[type-arg]
info_kwargs: Optional[
dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]
], # type: ignore[type-arg]
description: str,
) -> None:
super().__init__(name, input_nodes, layout, description)

View File

@ -71,13 +71,14 @@ class CUDATemplate(KernelTemplate):
A CUDATemplateCaller object representing the generated CUDA template caller.
"""
kernel_name = f"cuda_{self.name}"
with patch.object(
V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
), CUDATemplateKernel(
kernel_name=kernel_name,
runtime_arg_info=self.get_runtime_arg_info(),
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
) as kernel:
with (
patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)),
CUDATemplateKernel(
kernel_name=kernel_name,
runtime_arg_info=self.get_runtime_arg_info(),
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
) as kernel,
):
code = self.render(kernel=kernel, **kwargs)
_, call_args, _, _ = kernel.args.python_argdefs()
autotuning_log.debug("Generated Code:\n%s", code)

View File

@ -147,7 +147,9 @@ if try_import_cutlass():
"element_d": DataTypeTag[operation.D.element], # type: ignore[name-defined]
"layout_d": LayoutTag[instance_layout_D], # type: ignore[name-defined]
"element_accumulator": DataTypeTag[operation.accumulator_type()], # type: ignore[name-defined]
"opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], # type: ignore[name-defined] # noqa: B950
"opcode_class": OpcodeClassTag[ # type: ignore[name-defined]
operation.tile_description.math_instruction.opcode_class
],
"arch": f"cutlass::arch::Sm{operation.arch:d}",
"tile_shape_m": str(operation.tile_description.tile_shape[0]),
"tile_shape_n": str(operation.tile_description.tile_shape[1]),
@ -168,7 +170,9 @@ if try_import_cutlass():
operation.tile_description.math_instruction.instruction_shape[2]
),
"kernel_schedule": str(KernelScheduleTag[operation.kernel_schedule]), # type: ignore[name-defined]
"epilogue_schedule": str(EpilogueScheduleTag[operation.epilogue_schedule]), # type: ignore[name-defined]
"epilogue_schedule": str(
EpilogueScheduleTag[operation.epilogue_schedule] # type: ignore[name-defined]
),
"epilogue_functor": epilogue_functor,
"stages": stage_count_string,
"align_a": str(operation.A.alignment),

View File

@ -56,9 +56,9 @@ def try_import_cutlass() -> bool:
"Found cutlass_library in python search path, overriding config.cuda.cutlass_dir"
)
cutlass_library_dir = os.path.dirname(cutlass_library.__file__)
assert os.path.isdir(
cutlass_library_dir
), f"{cutlass_library_dir} is not a directory"
assert os.path.isdir(cutlass_library_dir), (
f"{cutlass_library_dir} is not a directory"
)
config.cuda.cutlass_dir = os.path.abspath(
os.path.join(
cutlass_library_dir,
@ -86,9 +86,9 @@ def try_import_cutlass() -> bool:
if os.path.isdir(cutlass_py_full_path):
if tmp_cutlass_py_full_path not in sys.path:
if os.path.exists(dst_link):
assert os.path.islink(
dst_link
), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
assert os.path.islink(dst_link), (
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
)
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
cutlass_py_full_path
), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"

View File

@ -949,9 +949,9 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
import cutlass_library.gemm_operation as cutlass_gemm_op
import cutlass_library.library as cutlass_lib
assert isinstance(
op, cutlass_gemm_op.GemmOperation
), "op argument is required and has to be an instance of GemmOperation"
assert isinstance(op, cutlass_gemm_op.GemmOperation), (
"op argument is required and has to be an instance of GemmOperation"
)
assert len(self.input_nodes) >= 2 and self.output_node is not None
X, W = self.input_nodes[0], self.input_nodes[1]
@ -977,7 +977,10 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
else:
input_reorder = None
kernel_call_signature = kernel.def_kernel(
inputs=inputs, outputs=[Y], names_str=names_str, input_reorder=input_reorder # type: ignore[arg-type]
inputs=inputs, # type: ignore[arg-type]
outputs=[Y],
names_str=names_str,
input_reorder=input_reorder,
)
test_call_statement = self.test_call_statement(kernel, inputs, names_str)
# The layouts might have changed between autotuning and this call if they were FlexibleLayout

View File

@ -198,7 +198,7 @@ class HalidePrinter(PythonPrinter):
val, n = expr.args
val = self._print(val)
n = int(n)
return f"hl.f32({10.**(-n)!r})*hl.round(({val})*hl.f32({10.**n!r}))"
return f"hl.f32({10.0 ** (-n)!r})*hl.round(({val})*hl.f32({10.0**n!r}))"
texpr = HalidePrinter().doprint
@ -856,11 +856,11 @@ class HalideKernel(SIMDKernel):
for sym, size in added_sym_size:
full_index += stride * sym
stride *= size
self.index_replacements[
node.symbol()
] = V.graph.sizevars.simplify_with_ranges(
ModularIndexing(full_index, node.divisor, node.length),
self.halide_vars, # type: ignore[arg-type]
self.index_replacements[node.symbol()] = (
V.graph.sizevars.simplify_with_ranges(
ModularIndexing(full_index, node.divisor, node.length),
self.halide_vars, # type: ignore[arg-type]
)
)
# codegen the variable definitions
@ -1183,9 +1183,9 @@ class HalideKernel(SIMDKernel):
if isinstance(value, tuple):
assert reduction_type == "welford_combine"
self.cse.reduction_cache[
cache_key
] = result_tuple = self.welford_combine_impl(*value)
self.cse.reduction_cache[cache_key] = result_tuple = (
self.welford_combine_impl(*value)
)
return result_tuple
assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
@ -1304,9 +1304,9 @@ class HalideKernel(SIMDKernel):
scan = f"{scan_dom}.x"
self.body.writeline(f"{scan_dom} = hl.RDom([hl.Range(1, {length})])")
assert (
len(self.reduction_renames) == 1
), "multi-dimensional scan not implemented"
assert len(self.reduction_renames) == 1, (
"multi-dimensional scan not implemented"
)
(scan_var,) = [*self.reduction_renames] # type: ignore[misc]
scan_renames_cur = {scan_var: sympy_index_symbol(scan)}
scan_renames_pri = {scan_var: sympy_index_symbol(scan) - 1}

View File

@ -214,8 +214,7 @@ class MemorySplitProtocol(Protocol):
get_size_hint: CachedMethod[[], int]
get_symbolic_size: CachedMethod[[], sympy.Expr]
def _allocate(self, block: Allocation, is_last: bool) -> bool:
...
def _allocate(self, block: Allocation, is_last: bool) -> bool: ...
class ClearCacheOnAllocateMixin(MemorySplitProtocol):

View File

@ -560,7 +560,10 @@ class MetalKernel(SIMDKernel):
threads = [self.pexpr(v.numel) for v in self.active_range_trees()] # type: ignore[misc]
args += [f"threads=[{', '.join(threads)}]"]
if self.inside_reduction:
threads = [self.pexpr(v.numel) if v.is_reduction else "1" for v in self.active_range_trees()] # type: ignore[misc]
threads = [
self.pexpr(v.numel) if v.is_reduction else "1" # type: ignore[misc]
for v in self.active_range_trees()
]
args += [f"group_size=[{', '.join(threads)}]"]
wrapper.generate_kernel_call(

View File

@ -33,9 +33,9 @@ def _get_all_args(args_list, arg_types_list=None):
all_args = max(args_list, key=len)[:]
arg_types = max(arg_types_list, key=len)[:] if arg_types_list is not None else None
for args in args_list:
assert OrderedSet(args).issubset(
OrderedSet(all_args)
), f"{args} v.s. {all_args}"
assert OrderedSet(args).issubset(OrderedSet(all_args)), (
f"{args} v.s. {all_args}"
)
return all_args, arg_types
@ -149,7 +149,9 @@ class MultiKernel:
Assume we do codegen for a MultiKernel encapsulating kernel1 and kernel2.
The generated definition for the multi-kernel will looks like:
```
multi_kernel_kernel1 = MultiKernelCall([kernel1, kernel2], multi_kernel_definition_code)
multi_kernel_kernel1 = MultiKernelCall(
[kernel1, kernel2], multi_kernel_definition_code
)
```
Here is an concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39

View File

@ -516,7 +516,12 @@ class CKGroupedConvFwdTemplate(CKTemplate):
template_params=(",\n" + 12 * " ").join(template_params),
), self._template_from_string(template_type).render(operation_name=op.name())
def render(self, kernel: ROCmTemplateKernel, op: "CKGroupedConvFwdOp", **kwargs) -> str: # type: ignore[override, name-defined]
def render( # type: ignore[override]
self,
kernel: ROCmTemplateKernel,
op: "CKGroupedConvFwdOp", # type: ignore[name-defined]
**kwargs,
) -> str:
template_buffer_node = kwargs.get("template_buffer_node", None)
if template_buffer_node is not None:
self.output_node = template_buffer_node

View File

@ -602,7 +602,12 @@ class CKGemmTemplate(CKTemplate):
operation_name=operation_name
)
def render(self, kernel: ROCmTemplateKernel, op: "CKGemmOperation", **kwargs) -> str: # type: ignore[override]
def render( # type: ignore[override]
self,
kernel: ROCmTemplateKernel,
op: "CKGemmOperation",
**kwargs,
) -> str:
"""
The primary entry point for the code rendering process used in this template.
"""
@ -706,7 +711,7 @@ class CKGemmTemplate(CKTemplate):
* Template instance {op}
*
* {torch.__version__=}
* torch.version.git_version={getattr(torch.version, 'git_version', 'None')}
* torch.version.git_version={getattr(torch.version, "git_version", "None")}
*/
"""
epilogue = None

View File

@ -79,9 +79,9 @@ class ROCmCPPScheduling(BaseScheduling):
"""
Codegen a ROCm template, possibly with fused epilogues
"""
assert self.is_rocm_cpp_template(
template_node
), "Template node passed to ROCmScheduler.codegen_template must be a SchedulerNode that wraps a ROCmTemplateBuffer"
assert self.is_rocm_cpp_template(template_node), (
"Template node passed to ROCmScheduler.codegen_template must be a SchedulerNode that wraps a ROCmTemplateBuffer"
)
template_node = cast(SchedulerNode, template_node)
_, (_numel, rnumel) = template_node.group
assert rnumel == 1

View File

@ -232,7 +232,9 @@ class ROCmTemplateCaller(ChoiceCaller):
],
bmreq: ROCmBenchmarkRequest,
template: "ROCmTemplate", # type: ignore[name-defined]
info_kwargs: Optional[dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]], # type: ignore[type-arg]
info_kwargs: Optional[
dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]
], # type: ignore[type-arg]
) -> None:
super().__init__(name, input_nodes, layout, description="")
self.category = category

View File

@ -70,13 +70,14 @@ class ROCmTemplate(KernelTemplate):
"""
kernel_name = f"rocm_{self.name}"
kernel_hash_name = f"rocm_{self.name}_{next(self.index_counter)}"
with patch.object(
V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
), ROCmTemplateKernel(
kernel_name=kernel_name,
runtime_arg_info=self.get_runtime_arg_info(),
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
) as kernel:
with (
patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)),
ROCmTemplateKernel(
kernel_name=kernel_name,
runtime_arg_info=self.get_runtime_arg_info(),
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
) as kernel,
):
code = self.render(kernel=kernel, **kwargs)
_, call_args, _, _ = kernel.args.python_argdefs()
log.debug("Autotune key: %s, Generated Code:\n%s", kernel_hash_name, code)

View File

@ -638,7 +638,8 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
continue
while current_group < len(remaining) and sv.statically_known_equals(
remaining[current_group], 1 # type: ignore[arg-type]
remaining[current_group],
1, # type: ignore[arg-type]
):
# scroll to next group with remaining elements
current_group += 1
@ -666,9 +667,9 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
)
return_getters_groups.append(return_getters)
assert all(
V.graph.sizevars.size_hint(s) == 1 for s in remaining
), f"failed to set ranges {remaining} {lengths}"
assert all(V.graph.sizevars.size_hint(s) == 1 for s in remaining), (
f"failed to set ranges {remaining} {lengths}"
)
return new_ranges, return_getters_groups
@ -836,7 +837,8 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps)
if len(replacements) > 0:
self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index]
self.range_tree_nodes[sym].expr, replacements # type: ignore[index]
self.range_tree_nodes[sym].expr,
replacements, # type: ignore[index]
)
self.range_tree_nodes[sym].codegen() # type: ignore[index]
return expr
@ -2071,9 +2073,10 @@ class SIMDScheduling(BaseScheduling):
features=SIMDKernelFeatures(node_schedule, numel, rnumel),
)
self.codegen_node_schedule_with_kernel(node_schedule, kernel)
with config.patch(
"benchmark_kernel", benchmark_kernel
), V.set_kernel_handler(kernel):
with (
config.patch("benchmark_kernel", benchmark_kernel),
V.set_kernel_handler(kernel),
):
src_code = kernel.codegen_kernel()
else:
prologue, template, epilogue = nodes[0].get_prologue_template_epilogue(

View File

@ -1579,9 +1579,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
self.block_ptr_id = itertools.count()
self.block_ptr_to_buffer = dict[str, str]()
self.helper_functions = HelperFunctions()
self.pointer_advancements: dict[
SymT, dict[str, list[sympy.Expr]]
] = collections.defaultdict(dict)
self.pointer_advancements: dict[SymT, dict[str, list[sympy.Expr]]] = (
collections.defaultdict(dict)
)
self._load_counts: collections.Counter[str] = collections.Counter()
# A set of autotuning hints to pass as part of triton_meta
@ -2053,9 +2053,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
continue
advancements = self.pointer_advancements[symt]
assert (
block_ptr not in advancements
), "duplicate advancement for pointer '{block_ptr}' at type '{symt}'"
assert block_ptr not in advancements, (
"duplicate advancement for pointer '{block_ptr}' at type '{symt}'"
)
advancements[block_ptr] = advance_offsets
else:
block_ptr = indexing.format(var)
@ -2476,7 +2476,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
buffer.splice(
f"""\
{result_var}_val, {result_var}_idx = triton_helpers.{root_op}_with_index({value}, {index}, {dim})
{result_var} = {self.reduction_resize(f'{result_var}_idx')}
{result_var} = {self.reduction_resize(f"{result_var}_idx")}
"""
)
@ -2576,8 +2576,8 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
{accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index(
{accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index
)
{accumulator} = {where_cond(f'{accumulator}_next', accumulator)}
{accumulator_index} = {where_cond(f'{accumulator_index}_next', accumulator_index)}
{accumulator} = {where_cond(f"{accumulator}_next", accumulator)}
{accumulator_index} = {where_cond(f"{accumulator_index}_next", accumulator_index)}
"""
)
final_argreduce(
@ -2751,9 +2751,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
)
self.compute.splice(
f"""\
{accumulator} = {where_cond(f'{accumulator}_next', accumulator)}
{accumulator_m2} = {where_cond(f'{accumulator_m2}_next', accumulator_m2)}
{accumulator_weight} = {where_cond(f'{accumulator_weight}_next', accumulator_weight)}
{accumulator} = {where_cond(f"{accumulator}_next", accumulator)}
{accumulator_m2} = {where_cond(f"{accumulator_m2}_next", accumulator_m2)}
{accumulator_weight} = {where_cond(f"{accumulator_weight}_next", accumulator_weight)}
"""
)
result_mean = result_var
@ -3040,9 +3040,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
self.filter_masks(masks)
masks = sorted(masks)
assert not self._load_mask, "ops.sort not supported inside ops.masked"
assert (
self.persistent_reduction
), "ops.sort is only supported in persistent reductions"
assert self.persistent_reduction, (
"ops.sort is only supported in persistent reductions"
)
cse_compute = functools.partial(self.cse.generate, self.compute)
dim = self.triton_tensor_ndim() - self.num_reduction_dims
@ -3302,9 +3302,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
{}
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
""".format(
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
)
""".format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream"))
)
def _get_heuristic(self):
@ -3344,19 +3342,19 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
inductor_meta["profile_bandwidth"] = config.profile_bandwidth
inductor_meta["profile_bandwidth_regex"] = config.profile_bandwidth_regex
inductor_meta["profile_bandwidth_output"] = config.profile_bandwidth_output
inductor_meta[
"profile_bandwidth_with_do_bench_using_profiling"
] = config.profile_bandwidth_with_do_bench_using_profiling
inductor_meta["profile_bandwidth_with_do_bench_using_profiling"] = (
config.profile_bandwidth_with_do_bench_using_profiling
)
if config.coordinate_descent_tuning:
inductor_meta[
"coordinate_descent_tuning"
] = config.coordinate_descent_tuning
inductor_meta[
"coordinate_descent_search_radius"
] = config.coordinate_descent_search_radius
inductor_meta[
"coordinate_descent_check_all_directions"
] = config.coordinate_descent_check_all_directions
inductor_meta["coordinate_descent_tuning"] = (
config.coordinate_descent_tuning
)
inductor_meta["coordinate_descent_search_radius"] = (
config.coordinate_descent_search_radius
)
inductor_meta["coordinate_descent_check_all_directions"] = (
config.coordinate_descent_check_all_directions
)
return inductor_meta
def codegen_kernel(self, name=None):
@ -4046,9 +4044,10 @@ class TritonScheduling(SIMDScheduling):
) -> tuple[float, str]:
"""Benchmark an already compiled module"""
device_interface = get_interface_for_device(V.graph.device_type)
with preserve_rng_state(), device_interface.device(
V.graph.get_current_device_or_throw()
): # type: ignore[attr-defined]
with (
preserve_rng_state(),
device_interface.device(V.graph.get_current_device_or_throw()), # type: ignore[attr-defined]
):
ms = None
def cache_file_path():
@ -4322,9 +4321,9 @@ def debug_triton_code(node: BaseSchedulerNode) -> list[str]:
device = node.get_device()
assert device is not None
backend = node.scheduler.get_backend(device)
assert isinstance(
backend, (SIMDScheduling, CUDACombinedScheduling)
), f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}"
assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling)), (
f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}"
)
with V.graph.set_current_device(device):
# Don't increment kernel count when generating debug string.

View File

@ -86,7 +86,9 @@ def _default_custom_combo_kernel_horizontal_partition(
# rnumel > 2048 usually has long execution time
# BaseSchedulerNode.group[-1][-1] is rnumel for reduction nodes
long_reduction = [
n for n in reduction if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 # type: ignore[arg-type]
n
for n in reduction
if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 # type: ignore[arg-type]
]
short_reduction = [n for n in reduction if n not in long_reduction]
if long_reduction:
@ -138,7 +140,7 @@ def set_custom_combo_kernel_horizontal_partition(
dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]],
],
list[list[BaseSchedulerNode]],
]
],
) -> None:
"""Sets the algorithm used to partition nodes into horizontal partitions. Nodes in different partitions
are implemented in different combo kernels. Nodes in the same partition are likely to be implemented
@ -593,9 +595,9 @@ class ComboKernel(Kernel):
num_persistent_reduction = len(
[e for e in heuristics_list if e == "persistent_reduction"]
)
assert (
num_reduction == 0
), "combining pointwise and reduction are not supported yet."
assert num_reduction == 0, (
"combining pointwise and reduction are not supported yet."
)
heuristics = (
"pointwise_with_reduction"
if num_persistent_reduction > 0
@ -784,13 +786,13 @@ class ComboKernel(Kernel):
name, tree, suffix=str(num)
)
if not tree.is_reduction:
assert isinstance(
grid[i][num], str
), f"Grid {grid[i][num]} should be a dynamic shape."
assert isinstance(grid[i][num], str), (
f"Grid {grid[i][num]} should be a dynamic shape."
)
numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else ""
assert (
grid[i][num] == numel_sign + numel_name
), f"numel args mismatch: {grid[i][num]} vs {numel_name}"
assert grid[i][num] == numel_sign + numel_name, (
f"numel args mismatch: {grid[i][num]} vs {numel_name}"
)
grid[i][num] = -expr if numel_sign == "-" else expr
if not tree.is_reduction or sub_kernel.inside_reduction:
@ -807,13 +809,13 @@ class ComboKernel(Kernel):
continue
expr = V.graph.sizevars.size_hint(tree.numel)
if not tree.is_reduction:
assert isinstance(
grid[i][num], str
), f"Grid {grid[i][num]} should be a dynamic shape."
assert isinstance(grid[i][num], str), (
f"Grid {grid[i][num]} should be a dynamic shape."
)
numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else ""
assert (
grid[i][num] == numel_sign + numel_name
), f"grid mismatch: {grid[i][num]} vs {numel_name}"
assert grid[i][num] == numel_sign + numel_name, (
f"grid mismatch: {grid[i][num]} vs {numel_name}"
)
grid[i][num] = -expr if numel_sign == "-" else expr
if not tree.is_reduction or sub_kernel.inside_reduction:
extra_args.append(expr)
@ -1015,9 +1017,7 @@ class ComboKernel(Kernel):
{}
import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels
""".format(
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
)
""".format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream"))
)
def uniquify_block_sizes(

View File

@ -57,9 +57,9 @@ class TritonSplitScanKernel(TritonKernel):
def initialize_range_tree(self, pid_cache):
prefixes = ["y", "x", "r0_"]
assert len(self.numels) <= len(
prefixes
), "z dimension not supported for split scan"
assert len(self.numels) <= len(prefixes), (
"z dimension not supported for split scan"
)
active_prefixes = prefixes[len(prefixes) - len(self.numels) :]
grid_dims = {"r0_": 0, "x": 1, "y": 2}

View File

@ -184,7 +184,8 @@ def config_of(
if isinstance(x, TensorArg):
if include_tensor:
offset_aligned = V.graph.sizevars.statically_known_multiple_of(
x.offset * x.dtype.itemsize, alignment # type: ignore[arg-type]
x.offset * x.dtype.itemsize,
alignment, # type: ignore[arg-type]
)
return offset_aligned and not is_unaligned_buffer(x)
else:

View File

@ -104,8 +104,7 @@ def can_match_buffer_size(input_buf: BufferLike, output_buf: BufferLike):
# NB: this is symbolic so that we don't try to reuse a buffer
# for s0 for s1, just because they happen to share the same
# size hint
sympy_str(input_size)
== sympy_str(output_size)
sympy_str(input_size) == sympy_str(output_size)
) or (
# statically known that 0.95 * input_size <= output_size <= input_size
V.graph.sizevars.statically_known_geq(output_size, 0.95 * input_size)
@ -138,9 +137,9 @@ def convert_arg_type(arg: torch.Argument) -> str:
container_match = re.findall(py_container + r"\[([a-zA-Z_]+)]", python_type)
if len(container_match) == 1:
contained_type = container_match[0]
assert (
contained_type in PYTHON_TO_CPP
), f"unsupported {py_container} type in convert_arg_type: {contained_type}"
assert contained_type in PYTHON_TO_CPP, (
f"unsupported {py_container} type in convert_arg_type: {contained_type}"
)
cpp_contained_type = PYTHON_TO_CPP[contained_type]
return f"{cpp_container}<{cpp_contained_type}>"
@ -367,9 +366,9 @@ class SymbolicCallArg:
class MemoryPlanningState:
def __init__(self):
super().__init__()
self.reuse_pool: dict[
ReuseKey, list[FreeIfNotReusedLine]
] = collections.defaultdict(list)
self.reuse_pool: dict[ReuseKey, list[FreeIfNotReusedLine]] = (
collections.defaultdict(list)
)
self.total_allocated_buffer_size: int = 0
def __contains__(self, key: ReuseKey) -> bool:
@ -431,9 +430,9 @@ class EnterDeviceContextManagerLine(WrapperLine):
f"{V.graph.device_ops.cpp_aoti_stream_guard()} stream_guard(stream, this->device_idx_);"
)
else:
assert (
self.last_seen_device_guard_index == self.device_idx
), "AOTInductor only supports running on one CUDA device"
assert self.last_seen_device_guard_index == self.device_idx, (
"AOTInductor only supports running on one CUDA device"
)
else:
if self.last_seen_device_guard_index is None:
code.writeline(
@ -1794,7 +1793,8 @@ class PythonWrapperCodegen(CodeGen):
equals_1 = isinstance(
arg, (int, sympy.Integer)
) and V.graph.sizevars.statically_known_equals(
arg, 1 # type: ignore[arg-type]
arg,
1, # type: ignore[arg-type]
)
add_arg(idx, SizeArg(key, arg), equals_1=equals_1)
@ -2052,9 +2052,9 @@ class PythonWrapperCodegen(CodeGen):
buf_name = arg
buf = V.graph.get_buffer(arg)
else:
assert (
raw_arg is not None
), "V.graph.get_buffer(arg) and raw_arg can't be None at the same time"
assert raw_arg is not None, (
"V.graph.get_buffer(arg) and raw_arg can't be None at the same time"
)
buf_name = f"tmp_arg_{index}"
buf = raw_arg
@ -2181,9 +2181,9 @@ class PythonWrapperCodegen(CodeGen):
and kernel_name not in self.kernel_autotune_names
):
# Create example args for autotune in a separate epilogue
assert arg_types is not None and len(call_args) == len(
arg_types
), "call_args and arg_types do not match"
assert arg_types is not None and len(call_args) == len(arg_types), (
"call_args and arg_types do not match"
)
tensor_args = {}
all_args = []
@ -2191,9 +2191,9 @@ class PythonWrapperCodegen(CodeGen):
# create a dummy raw_args for uniform behavior in the following loop
raw_args = [None] * len(call_args)
else:
assert len(raw_args) == len(
call_args
), "call_args and raw_args do not match"
assert len(raw_args) == len(call_args), (
"call_args and raw_args do not match"
)
for i, (arg, arg_type, raw_arg) in enumerate(
zip(call_args, arg_types, raw_args)
@ -2411,9 +2411,9 @@ class PythonWrapperCodegen(CodeGen):
if isinstance(layout, ir.NoneLayout):
return
if isinstance(layout, ir.NonOwningLayout):
assert isinstance(
layout.view, ir.ReinterpretView
), f"unexpected {type(layout.view)}: {layout.view}"
assert isinstance(layout.view, ir.ReinterpretView), (
f"unexpected {type(layout.view)}: {layout.view}"
)
assert isinstance(layout.view.data, ir.StorageBox), type(layout.view.data)
assert isinstance(layout.view.data.data, ir.Buffer), type(layout.view.data)
self.codegen_allocation(layout.view.data.data)
@ -2535,9 +2535,9 @@ class PythonWrapperCodegen(CodeGen):
def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs):
# All inputs of hops must be explicitly passed in.
# Free tensors and basic symbols should have been explicitly lifted as inputs in dynamo.
assert len(outer_inputs) == len(
subgraph.graph.graph_input_names
), f"graph_input_names:{subgraph.graph.graph_input_names}, outer_inputs: {outer_inputs}"
assert len(outer_inputs) == len(subgraph.graph.graph_input_names), (
f"graph_input_names:{subgraph.graph.graph_input_names}, outer_inputs: {outer_inputs}"
)
for inner_input, outer_input in zip(
subgraph.graph.graph_input_names, outer_inputs
):

View File

@ -219,8 +219,7 @@ def _schedule_for_comm(
for snode, deps in unmet_deps.items():
assert len(deps) == 0, (
"Detected unscheduled nodes. "
f"Nodes with unmet dependencies: {unmet_deps}"
f"Detected unscheduled nodes. Nodes with unmet dependencies: {unmet_deps}"
)
return scheduled
@ -354,9 +353,7 @@ def remove_fsdp2_unsharded_param_graph_input_usage(graph: torch.fx.Graph):
node.op == "call_function"
and node.target == torch.ops.inductor.resize_storage_bytes_.default
):
assert (
node.args[0].op == "placeholder"
), f"""\
assert node.args[0].op == "placeholder", f"""\
Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]}
"""
graph_input = node.args[0]
@ -408,9 +405,7 @@ Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass for that
if node.op == "call_function" and node.target == torch.ops.fsdp.copy_.default:
fsdp_copy_node = node
unsharded_param = node.args[0]
assert (
unsharded_param.op == "placeholder"
), f"""
assert unsharded_param.op == "placeholder", f"""
Assumed all FSDP2 `unsharded_param`s to be graph input, but it's not true!
Offending node: {unsharded_param}. Graph: {graph}
"""

View File

@ -281,9 +281,9 @@ def _unlift_graph(
elif node_name in graph_signature.inputs_to_buffers:
buffer_name = graph_signature.inputs_to_buffers[node_name]
lifted_inputs.append(buffer_name)
gm.meta[
get_cloned_parameter_buffer_name(buffer_name)
] = clone_preserve_strides(state_dict[buffer_name])
gm.meta[get_cloned_parameter_buffer_name(buffer_name)] = (
clone_preserve_strides(state_dict[buffer_name])
)
else:
assert node_name in graph_signature.user_inputs
lifted_inputs.append(None)
@ -542,7 +542,7 @@ def fake_tensor_prop(
# pass config dict back to user
def get_patched_config_dict(
config_patches: Optional[Union[str, dict[str, Any]]] = None
config_patches: Optional[Union[str, dict[str, Any]]] = None,
) -> dict[str, Any]:
with config.patch(config_patches):
return config.get_config_copy()
@ -579,8 +579,7 @@ class _CompileFxCallable(Protocol):
gm: GraphModule,
example_inputs: Sequence[InputType],
**kwargs: Unpack[_CompileFxKwargs],
) -> OutputCode:
...
) -> OutputCode: ...
def compile_fx_inner(
@ -662,9 +661,9 @@ def _compile_fx_inner(
static_inputs_log.debug("static input idxs compile_fx_inner: %s", static_input_idxs)
inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs)
assert isinstance(
next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)
), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}"
assert isinstance(next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)), (
f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}"
)
if (cudagraphs := graph_kwargs.get("cudagraphs")) is None:
graph_kwargs["cudagraphs"] = cudagraphs = BoxedBool(config.triton.cudagraphs)
@ -679,9 +678,10 @@ def _compile_fx_inner(
fx_graph_remote_cache = should_use_remote_fx_graph_cache()
with _WaitCounter(
"pytorch.wait_counter.fx_codegen_and_compile"
).guard() as _, _WaitCounter("pytorch.wait_counter.all_compilation_types").guard():
with (
_WaitCounter("pytorch.wait_counter.fx_codegen_and_compile").guard() as _,
_WaitCounter("pytorch.wait_counter.all_compilation_types").guard(),
):
use_cache = (
not config.force_disable_caches
and (config.fx_graph_cache or fx_graph_remote_cache)
@ -865,8 +865,7 @@ class FxCompile(ABC):
example_inputs: Sequence[InputType],
inputs_to_check: Sequence[int],
graph_kwargs: _CompileFxKwargs,
) -> OutputCode:
...
) -> OutputCode: ...
class _InProcessFxCompile(FxCompile):
@ -890,16 +889,17 @@ class _InProcessFxCompile(FxCompile):
cpp_wrapper: bool = graph_kwargs.get("cpp_wrapper", False)
aot_mode: bool = V.aot_compilation
is_inference: bool = graph_kwargs.get("is_inference", False)
extern_node_serializer: Optional[
Callable[[list[ExternKernelNode]], Any]
] = graph_kwargs.get("extern_node_serializer", None)
extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]] = (
graph_kwargs.get("extern_node_serializer", None)
)
boxed_forward_device_index: Optional[BoxedDeviceIndex] = graph_kwargs.get(
"boxed_forward_device_index", None
)
with _WaitCounter(
"pytorch.wait_counter.actual_codegen_and_compile"
).guard(), dynamo_utils.preserve_rng_state():
with (
_WaitCounter("pytorch.wait_counter.actual_codegen_and_compile").guard(),
dynamo_utils.preserve_rng_state(),
):
if (sleep_sec := config.sleep_sec_TESTING_ONLY) is not None:
import time
@ -1038,9 +1038,11 @@ class _InProcessFxCompile(FxCompile):
# See details in vllm/compilation/pass_manager.py.
log.warning("failed to log pt2_configs")
with V.set_fake_mode(fake_mode), maybe_disable_comprehensive_padding(
example_inputs
), maybe_disable_graph_partition(cpp_wrapper, aot_mode):
with (
V.set_fake_mode(fake_mode),
maybe_disable_comprehensive_padding(example_inputs),
maybe_disable_graph_partition(cpp_wrapper, aot_mode),
):
const_output_index = None
const_graph = None
const_code = None
@ -1123,9 +1125,9 @@ class _InProcessFxCompile(FxCompile):
if graph.aot_mode:
from .codecache import AotCodeCompiler
assert (
graph.cpp_wrapper
), "AOT mode only supports C++ wrapper"
assert graph.cpp_wrapper, (
"AOT mode only supports C++ wrapper"
)
code, linemap = graph.codegen_with_cpp_wrapper()
output_code_log.debug("Output code: \n%s", code)
@ -1509,10 +1511,13 @@ def cudagraphify(
def run(new_inputs: Sequence[InputType]) -> Any:
nonlocal compiled_fn
if compiled_fn is None:
with dynamo_utils.dynamo_timed(
"cudagraphify",
log_pt2_compile_event=True,
), dynamo_utils.preserve_rng_state():
with (
dynamo_utils.dynamo_timed(
"cudagraphify",
log_pt2_compile_event=True,
),
dynamo_utils.preserve_rng_state(),
):
compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs)
return compiled_fn(new_inputs)
@ -1669,13 +1674,16 @@ def compile_fx_aot(
extern_node_serializer = config_patches.pop("extern_node_serializer", None)
saved_compile_id = model_.meta.get("dynamo_compile_id", None)
saved_compile_context = torch._guards.CompileContext(saved_compile_id)
with V.set_aot_compilation(True), torch._guards.compile_context(
saved_compile_context
), chromium_event_timed(
"compile_fx_aot",
log_pt2_compile_event=True,
reset_event_log_on_exit=True,
), get_metrics_context():
with (
V.set_aot_compilation(True),
torch._guards.compile_context(saved_compile_context),
chromium_event_timed(
"compile_fx_aot",
log_pt2_compile_event=True,
reset_event_log_on_exit=True,
),
get_metrics_context(),
):
compiled_artifacts = compile_fx(
model_,
example_inputs_,
@ -1875,12 +1883,15 @@ def compile_fx(
# TODO: This probably shouldn't be a recursive call
if config.cpp_wrapper:
with config.patch(
{
"cpp_wrapper": False, # reset to break recursive call to compile_fx
**get_cpp_wrapper_config(),
}
), V.set_real_inputs(example_inputs_):
with (
config.patch(
{
"cpp_wrapper": False, # reset to break recursive call to compile_fx
**get_cpp_wrapper_config(),
}
),
V.set_real_inputs(example_inputs_),
):
inputs_: Sequence[InputType] = example_inputs_
if isinstance(model_, GraphModule):
@ -1940,10 +1951,10 @@ def compile_fx(
# Do the actual work
with _use_lazy_graph_module(
dynamo_config.use_lazy_graph_module
), enable_python_dispatcher(), torch.fx.traceback.preserve_node_meta(
config.trace.enabled
with (
_use_lazy_graph_module(dynamo_config.use_lazy_graph_module),
enable_python_dispatcher(),
torch.fx.traceback.preserve_node_meta(config.trace.enabled),
):
# Pre-grad passes cannot be run if we weren't given a GraphModule.
# Dynamo will always produce a GraphModule, but this handles cases
@ -2085,9 +2096,9 @@ def compile_fx(
boxed_forward_device_index=forward_device,
)
fw_compiler: Callable[
[GraphModule, Sequence[InputType]], OutputCode
] = functools.partial(fw_compiler_base, is_inference=False)
fw_compiler: Callable[[GraphModule, Sequence[InputType]], OutputCode] = (
functools.partial(fw_compiler_base, is_inference=False)
)
fw_compiler = SerializableAOTDispatchCompiler(OutputCode, fw_compiler)
if config.freezing and not torch.is_grad_enabled():
@ -2124,9 +2135,10 @@ def compile_fx(
) -> OutputCode:
from torch._dynamo.convert_frame import compile_lock
with dynamo_utils.dynamo_timed(
"compile_fx.<locals>.bw_compiler"
), compile_lock:
with (
dynamo_utils.dynamo_timed("compile_fx.<locals>.bw_compiler"),
compile_lock,
):
model_outputs_node = output_node(gm)
if config.bw_outputs_user_visible:
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
@ -2194,10 +2206,11 @@ def compile_fx(
with V.set_fake_mode(fake_mode), compiled_autograd._disable(), context():
return inference_compiler(unlifted_gm, example_inputs_)
with V.set_fake_mode(fake_mode), torch._guards.tracing(
tracing_context
), compiled_autograd._disable(), functorch_config.patch(
unlift_effect_tokens=True
with (
V.set_fake_mode(fake_mode),
torch._guards.tracing(tracing_context),
compiled_autograd._disable(),
functorch_config.patch(unlift_effect_tokens=True),
):
try:
return aot_autograd(

View File

@ -530,7 +530,8 @@ class CompilerBisector:
)
if result:
curr_subsystem = cls.get_subsystem_object(
curr_backend, cls.get_subsystem() # type: ignore[arg-type]
curr_backend,
cls.get_subsystem(), # type: ignore[arg-type]
)
if isinstance(curr_subsystem, BinarySubsystem):

View File

@ -80,9 +80,9 @@ fx_graph_cache: bool = Config(
fx_graph_remote_cache: Optional[bool] = fx_graph_remote_cache_default()
# should we bundle triton caching into fx graph cache
bundle_triton_into_fx_graph_cache: Optional[
bool
] = bundle_triton_into_fx_graph_cache_default()
bundle_triton_into_fx_graph_cache: Optional[bool] = (
bundle_triton_into_fx_graph_cache_default()
)
# Enable autotune local cache.
#
@ -1390,12 +1390,12 @@ class halide:
# Halide autoscheduler to use, choices are:
# "Anderson2021" (gpu-only), "Li2018", "Adams2019" (cpu-only), or "Mullapudi2016" (cpu-only)
scheduler_cuda: Literal[
"Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"
] = "Anderson2021"
scheduler_cpu: Literal[
"Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"
] = "Adams2019"
scheduler_cuda: Literal["Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"] = (
"Anderson2021"
)
scheduler_cpu: Literal["Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"] = (
"Adams2019"
)
# Controls `no_asserts` flag passed to Halide target (warning: can false positive)
asserts = False

View File

@ -125,7 +125,8 @@ class ConstantFolder(torch.fx.Interpreter):
and is_woq_int8_pattern(next(iter(node.users)))
)
) and is_const_source(
node.args[0], self.lifted_constant_names # type: ignore[arg-type]
node.args[0], # type: ignore[arg-type]
self.lifted_constant_names,
):
# Case 1: int8_weight -> dq -> bf16_weight
# Case 2: int8_weight -> permute -> dq -> bf16_weight

View File

@ -1633,8 +1633,8 @@ class CppBuilder:
"""
)
assert os.path.exists(
cmake_path
), f"save_link_cmd_to_cmakefile expects {cmake_path} to already exist"
assert os.path.exists(cmake_path), (
f"save_link_cmd_to_cmakefile expects {cmake_path} to already exist"
)
with open(cmake_path, "a") as f:
f.write(contents)

View File

@ -119,6 +119,7 @@ from . import config
@dataclasses.dataclass(frozen=True)
class GraphID:
"Unique counter of a cuda graph recording"
id: int
@ -622,11 +623,15 @@ class CUDAWarmupNode:
refs = list(self.path_live_weakrefs())
check_memory_pool(self.device_index, self.cuda_graphs_pool, refs)
with torch.cuda.device(
self.device_index
), disable_conv_cache_emptying(), clear_cublas_manager(), _use_cuda_memory_pool_manager(
self.device_index, self.cuda_graphs_pool, self.stream
), get_history_recording():
with (
torch.cuda.device(self.device_index),
disable_conv_cache_emptying(),
clear_cublas_manager(),
_use_cuda_memory_pool_manager(
self.device_index, self.cuda_graphs_pool, self.stream
),
get_history_recording(),
):
out = self.wrapped_function.model(new_inputs)
# We need to know which outputs are allocated within the cudagraph pool
@ -713,6 +718,7 @@ UnaliasedStorage = _UnaliasedStorage()
class AliasesPriorGraphOutput(OutputAliasInfo):
"Marks that the graph output aliases an output of a prior graph"
__slots__ = ["index"]
index: PathOutputIndex
@ -1200,14 +1206,18 @@ class CUDAGraphNode:
]
check_memory_pool(self.device, self.cuda_graphs_pool, memory)
with preserve_rng_state(), torch.cuda.device(
self.device
), clear_cublas_manager(), torch.cuda.graph(
self.graph,
stream=self.stream,
pool=self.cuda_graphs_pool,
capture_error_mode="thread_local",
), get_history_recording():
with (
preserve_rng_state(),
torch.cuda.device(self.device),
clear_cublas_manager(),
torch.cuda.graph(
self.graph,
stream=self.stream,
pool=self.cuda_graphs_pool,
capture_error_mode="thread_local",
),
get_history_recording(),
):
static_outputs = model(inputs)
# running model should reclaim memory
@ -1247,11 +1257,13 @@ class CUDAGraphNode:
self.output_storage_alias.append(UnaliasedStorage)
continue
torch._check(
o.is_cuda or o.untyped_storage().data_ptr() == 0,
lambda: (
"Expected all cuda outputs in cuda graph recording. Non cuda output "
f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}"
(
torch._check(
o.is_cuda or o.untyped_storage().data_ptr() == 0,
lambda: (
"Expected all cuda outputs in cuda graph recording. Non cuda output "
f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}"
),
),
)
@ -1291,9 +1303,9 @@ class CUDAGraphNode:
if self.stack_traces is None:
self.stack_traces = [None for _ in range(len(outputs))]
else:
assert len(self.stack_traces) == len(
outputs
), "Wrong number of stack traces passed in"
assert len(self.stack_traces) == len(outputs), (
"Wrong number of stack traces passed in"
)
assert not self.outputs_weakrefs
for out, static_output_tensor in zip(outputs, self.static_output_tensors):
@ -1599,12 +1611,14 @@ class CUDAGraphNode:
self.stream.wait_stream(torch.cuda.current_stream())
recording_inputs: list[InputType] = []
with warnings.catch_warnings(record=True), torch.cuda.device(
self.device
), _use_cuda_memory_pool_manager(
self.device,
mem_pool=self.cuda_graphs_pool,
stream=self.stream,
with (
warnings.catch_warnings(record=True),
torch.cuda.device(self.device),
_use_cuda_memory_pool_manager(
self.device,
mem_pool=self.cuda_graphs_pool,
stream=self.stream,
),
):
for i, inp in enumerate(inputs):
if not isinstance(inp, torch.Tensor):
@ -1736,12 +1750,8 @@ def check_memory_pool(
pool_id: tuple[int, int],
live_storages_ptrs: list[StorageWeakRefWrapper],
) -> None:
assert all(
isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs
) # noqa: C419
unique_storages = {
stor.data_ptr() for stor in live_storages_ptrs if stor()
} # noqa: set_linter
assert all(isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs) # noqa: C419
unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()} # noqa: set_linter
# check if there is a divergence first, then do the expensive snapshot call after
# we know it will error
@ -1864,11 +1874,14 @@ class CUDAGraphTreeManager:
self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph()
self.cuda_graphs_thread_pool = torch.cuda.graph_pool_handle()
with warnings.catch_warnings(record=True), torch.cuda.graph(
self.graph,
pool=self.cuda_graphs_thread_pool,
stream=self.stream,
capture_error_mode="thread_local",
with (
warnings.catch_warnings(record=True),
torch.cuda.graph(
self.graph,
pool=self.cuda_graphs_thread_pool,
stream=self.stream,
capture_error_mode="thread_local",
),
):
pass
@ -2230,7 +2243,10 @@ class CUDAGraphTreeManager:
constants: tuple[torch.Tensor, ...],
placeholders: tuple[PlaceholderInfo, ...],
mutated_input_idxs: tuple[int, ...],
) -> tuple[ModelType, OutputType,]:
) -> tuple[
ModelType,
OutputType,
]:
id = self.new_func_id()
self.ids_to_stack_traces[id] = stack_traces
self.ids_to_funcs[id] = WrappedFunction(

View File

@ -28,6 +28,7 @@ ModelType = Callable[[list[InputType]], OutputType]
@dataclasses.dataclass(frozen=True)
class FunctionID:
"Unique counter of a function wrapped in cudagraphify_impl"
id: int
@ -164,7 +165,7 @@ def _get_use_stack_trace(node: torch.fx.Node) -> Optional[str]:
def check_multiple_devices_or_any_cpu_nodes(
device_node_mapping: dict[torch.device, torch.fx.Node]
device_node_mapping: dict[torch.device, torch.fx.Node],
) -> Optional[str]:
if cpu_node := device_node_mapping.get(torch.device("cpu")):
msg = f"cpu device ({cpu_node.name})"
@ -184,7 +185,7 @@ def check_multiple_devices_or_any_cpu_nodes(
def check_lowering_disable_cudagraph(
device_node_mapping: dict[torch.device, torch.fx.Node]
device_node_mapping: dict[torch.device, torch.fx.Node],
) -> Optional[str]:
return check_multiple_devices_or_any_cpu_nodes(device_node_mapping)
@ -276,9 +277,9 @@ def log_data_ptr_mismatch(
Logs the mismatch between input data pointers and recorded data pointers.
This checks only idxs in target_idxs.
"""
assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len(
placeholders
), "length mismatch between inputs, recorded_data_ptr, and placeholders"
assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len(placeholders), (
"length mismatch between inputs, recorded_data_ptr, and placeholders"
)
t_tensors = [inputs[i] for i in target_idxs]
t_data_ptrs = [recorded_data_ptr[i] for i in target_idxs]

View File

@ -240,7 +240,7 @@ def update_orig_fx_node_name_to_buf_name(
def get_node_name_to_buf_meta(
node_name_to_buf_name: dict[str, str]
node_name_to_buf_name: dict[str, str],
) -> dict[str, BufMeta]:
buf_name_to_n_node = {}
for node_name, buf_name in node_name_to_buf_name.items():

View File

@ -123,7 +123,7 @@ remove_decompositions(decompositions, decomps_to_exclude)
def register_decomposition(
ops: list[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]]
ops: list[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]],
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined]
if op in decompositions:

View File

@ -194,7 +194,9 @@ class MemoryDep(Dep):
)
new_index = sympy_subs(sympy.expand(self.index), replacement) # type: ignore[arg-type] # next PR
out = MemoryDep(self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values())) # type: ignore[arg-type]
out = MemoryDep(
self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values())
) # type: ignore[arg-type]
return out
@property
@ -649,11 +651,16 @@ def extract_loop_body_with_args(
inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) # type: ignore[arg-type]
for entry in fn.memory_usage[MemoryUsageType.STORE]:
inner.store(
entry.buffer_name, name_to_index[entry.index_name], None, entry.mode # type: ignore[arg-type]
entry.buffer_name,
name_to_index[entry.index_name],
None, # type: ignore[arg-type]
entry.mode,
)
for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]:
inner.store_reduction(
entry.buffer_name, name_to_index[entry.index_name], None # type: ignore[arg-type]
entry.buffer_name,
name_to_index[entry.index_name],
None, # type: ignore[arg-type]
)
for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]:
inner.index_expr(name_to_index[entry.index_name], None)
@ -661,7 +668,11 @@ def extract_loop_body_with_args(
# All that matters is that we record the buffer name, so place it in the
# "boundaries" name position to ensure that it's recorded.
inner.bucketize(
None, (entry.buffer_name, None, None, None), None, None, None # type: ignore[arg-type]
None,
(entry.buffer_name, None, None, None),
None,
None, # type: ignore[arg-type]
None, # type: ignore[arg-type]
)
# fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped
return inner
@ -801,8 +812,9 @@ def extract_free_unbacked_symbols(
handler = FreeUnbackedSymbolsOpsHandler()
# NB: I cargo culted the allow_indexing patch here, I don't understand why
# people do this all over
with V.set_ops_handler(handler), patch.object(
FlexibleLayout, "allow_indexing", True
with (
V.set_ops_handler(handler),
patch.object(FlexibleLayout, "allow_indexing", True),
):
fn(*args)
return handler.symbols

View File

@ -19,8 +19,7 @@ T = TypeVar("T")
class DTypeVar(Protocol):
@property
def dtype(self) -> torch.dtype:
...
def dtype(self) -> torch.dtype: ...
DTypeArg = Union[DTypeVar, torch.types.Number, str, OpsValue]

View File

@ -526,6 +526,7 @@ class ConfigFuzzer:
```python
import torch._inductor.config as cfg
def create_simple_test_model_gpu() -> FactoryOutputType:
batch_size = 32
seq_length = 50
@ -539,6 +540,8 @@ class ConfigFuzzer:
return True
return test_fn
fuzzer = ConfigFuzzer(cfg, create_simple_test_model_gpu, seed=2)
# Test every pair of configs:
@ -550,7 +553,9 @@ class ConfigFuzzer:
ret = fuzzer.bisect(num_attempts=10)
# reproduce a failing config
fuzzer.reproduce([{"triton.autotune_pointwise": ..., "coordinate_descent_tuning": ...}])
fuzzer.reproduce(
[{"triton.autotune_pointwise": ..., "coordinate_descent_tuning": ...}]
)
```
The list of known failures on inductor config are:

View File

@ -531,7 +531,11 @@ def tuned_b2b_gemm(
A.realize()
B.realize()
C.realize()
layout = FixedLayout(A.get_device_or_error(), A.get_dtype(), [A.shape[0], C.shape[1]]) # type: ignore[index]
layout = FixedLayout(
A.get_device_or_error(),
A.get_dtype(),
[A.shape[0], C.shape[1]], # type: ignore[index]
)
subgraph_buffer = build_subgraph_buffer(
[create_placeholder("inner_mm", A.get_dtype(), A.get_device_or_error())],
subgraph,

View File

@ -545,9 +545,9 @@ def schedule_comm_wait(graph: fx.Graph) -> None:
node_indices = {node: i for i, node in enumerate(graph.nodes)}
for allreduce in comm_blocks:
# Find the earliest/first user -- target_node.
assert (
len(allreduce.outputs) >= 1
), f"Found a allreduce that has zero outputs/users -- {allreduce}."
assert len(allreduce.outputs) >= 1, (
f"Found a allreduce that has zero outputs/users -- {allreduce}."
)
# Initialize the target node to avoid typing issues.
target_node = next(iter(next(iter(allreduce.outputs)).users))
target_node_index = 2**31

View File

@ -380,7 +380,9 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
# argument. `graph.get_attr` and
# `graph.call_function` does not allow the `name` argument.
conv_get_node = graph.create_node(
op="get_attr", target=conv_node.target, name="get_conv" # type: ignore[union-attr]
op="get_attr",
target=conv_node.target, # type: ignore[union-attr]
name="get_conv",
)
bn_get_node = graph.create_node(
op="get_attr", target=bn_node.target, name="get_bn"

View File

@ -866,15 +866,18 @@ def _get_sfdp_patterns():
name += "_bs1"
training_name = name + "_training"
yield training_name, {
"search_fn": pattern,
"replace_fn": replacement,
"example_inputs": args,
"trace_fn": joint_fwd_bwd,
"pass_dicts": patterns,
"extra_check": extra_check,
"scalar_workaround": workaround,
}
yield (
training_name,
{
"search_fn": pattern,
"replace_fn": replacement,
"example_inputs": args,
"trace_fn": joint_fwd_bwd,
"pass_dicts": patterns,
"extra_check": extra_check,
"scalar_workaround": workaround,
},
)
if workaround:
assert len(workaround) == 1 and "dropout_p" in workaround
@ -886,18 +889,21 @@ def _get_sfdp_patterns():
workaround = {}
inference_name = name + "_inference"
yield inference_name, {
"search_fn": pattern,
"replace_fn": replacement,
"example_inputs": args,
"trace_fn": fwd_only,
"pass_dicts": patterns,
"extra_check": extra_check,
"scalar_workaround": workaround,
# with dropout turned into clone, we end up with a number of
# semantically identical graphs
"skip_duplicates": True,
}
yield (
inference_name,
{
"search_fn": pattern,
"replace_fn": replacement,
"example_inputs": args,
"trace_fn": fwd_only,
"pass_dicts": patterns,
"extra_check": extra_check,
"scalar_workaround": workaround,
# with dropout turned into clone, we end up with a number of
# semantically identical graphs
"skip_duplicates": True,
},
)
@functools.lru_cache(None)

View File

@ -271,7 +271,9 @@ class PostGradBatchLinearFusion(BatchFusion):
args=(batch_biases[i],),
kwargs={"size": broadcast_shape},
)
broadcast_bias.meta["val"] = aten.broadcast_to(batch_biases_meta[i]["val"], broadcast_shape) # type: ignore[assignment]
broadcast_bias.meta["val"] = aten.broadcast_to(
batch_biases_meta[i]["val"], broadcast_shape
) # type: ignore[assignment]
new_bias_add = graph.call_function( # type: ignore[operator]
aten.add.Tensor, args=((broadcast_bias, new_mm))
)
@ -803,9 +805,9 @@ class BatchLayernormFusion(BatchFusion):
group_biases = None # type: ignore[assignment]
if all(weight is None for weight in group_weights):
group_weights = None # type: ignore[assignment]
assert all(
eps == group_epss[0] for eps in group_epss
), "all epsilon values must be equal"
assert all(eps == group_epss[0] for eps in group_epss), (
"all epsilon values must be equal"
)
with graph.inserting_before(subset[0]): # type: ignore[operator]
stack_input = graph.call_function( # type: ignore[operator]
@ -996,7 +998,11 @@ class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
# for relu op, we also use the inplace to construct the key
# we batch the ops with same parent to enable followup split cat
parent = node.args[0]
parent = parent.target if self.graph_search_options.get("fuse_nodes_with_same_parent", False) else "" # type: ignore[union-attr]
parent = (
parent.target # type: ignore[union-attr]
if self.graph_search_options.get("fuse_nodes_with_same_parent", False)
else ""
)
group_key = (
"batch_aten_" + self.op.__name__.lower().split(".")[0],
str(input.meta["val"].shape),
@ -1293,9 +1299,9 @@ def get_fusion_candidates(
"""
q: collections.deque[tuple[int, torch.fx.Node]] = collections.deque()
candidate_dict: collections.defaultdict[
Any, list[torch.fx.Node]
] = collections.defaultdict(list)
candidate_dict: collections.defaultdict[Any, list[torch.fx.Node]] = (
collections.defaultdict(list)
)
if root_node.target in SEARCH_EXCLUSIONS:
return candidate_dict

View File

@ -763,9 +763,7 @@ def _get_node_to_ancestors(
"""
Compute the ancestors for all nodes in a graph.
"""
node_to_ancestors = defaultdict(
OrderedSet[torch.fx.Node]
) # type: ignore[var-annotated]
node_to_ancestors = defaultdict(OrderedSet[torch.fx.Node]) # type: ignore[var-annotated]
for node in graph.nodes:
node_to_ancestors[node] = OrderedSet(node.all_input_nodes)
for dep in node.all_input_nodes:

View File

@ -558,9 +558,9 @@ if torch._C._has_mkldnn:
binary_nodes = filter_nodes(match.nodes, binary_op)
def _get_compute_node(_binary_node, _other_index):
assert (
len(_binary_node.all_input_nodes) == 2
), "Binary node should have 2 input nodes."
assert len(_binary_node.all_input_nodes) == 2, (
"Binary node should have 2 input nodes."
)
_compute_index = 1 if (_other_index == 0) else 0
return _binary_node.args[_compute_index]
@ -614,9 +614,9 @@ if torch._C._has_mkldnn:
else:
computation_args += [1.0, None, [], None]
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1
counters["inductor"][
"mkldnn_conv_binary_unary_fusion_matcher_nodes"
] += len(match.nodes)
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"] += (
len(match.nodes)
)
return L[fusion_op](*computation_args)
return fn
@ -659,9 +659,9 @@ if torch._C._has_mkldnn:
else:
computation_args += [1.0, None, [], None]
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1
counters["inductor"][
"mkldnn_conv_binary_unary_fusion_matcher_nodes"
] += len(match.nodes)
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"] += (
len(match.nodes)
)
# Make sure the other is not an alias or mutation(fx side doesn't has such info).
other.realize()
if not _can_be_inplace(other) or other.data.shape != list(
@ -1310,9 +1310,9 @@ if torch._C._has_mkldnn:
)
batch_size = input.meta.get("val").shape[0]
if has_free_symbols(batch_size):
assert (
is_lp_weight or mkldnn._is_mkldnn_acl_supported()
), f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
assert is_lp_weight or mkldnn._is_mkldnn_acl_supported(), (
f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
)
# For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance.
packed_weight_inputs = (
transpose_weight_node,

View File

@ -437,7 +437,7 @@ def _should_pad_bench(
return False
def realize_symbols(
ds: Union[torch.Size, tuple[torch.SymInt, ...]]
ds: Union[torch.Size, tuple[torch.SymInt, ...]],
) -> list[int]:
return [d if isinstance(d, int) else d.node.hint for d in ds]

View File

@ -137,9 +137,9 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
pattern_matcher_pass.apply
)
if not is_same_dict(counters["inductor"], inductor_before_change):
optimus_scuba_log[
f"{pattern_matcher_pass.pass_name}_post_grad"
] = upload_graph(gm.graph)
optimus_scuba_log[f"{pattern_matcher_pass.pass_name}_post_grad"] = (
upload_graph(gm.graph)
)
if config.b2b_gemm_pass:
B2B_GEMM_PASS.apply(gm.graph) # type: ignore[arg-type]

View File

@ -277,9 +277,9 @@ def pre_grad_passes(
for _ in range(counter):
pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
if not is_same_dict(counters["inductor"], inductor_before_change):
optimus_scuba_log[
f"{pattern_matcher_pass.pass_name}_pre_grad"
] = upload_graph(gm.graph)
optimus_scuba_log[f"{pattern_matcher_pass.pass_name}_pre_grad"] = (
upload_graph(gm.graph)
)
# TODO: move efficient_conv_bn_eval_pass to the fusions dict too.
efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type]

View File

@ -763,9 +763,9 @@ def _register_quantized_conv_binary_lowering(
accum.realize()
from .mkldnn_fusion import _can_be_inplace
assert _can_be_inplace(
accum
), "QConv Binary Inplace Fusion requires accum is not an alias or mutation."
assert _can_be_inplace(accum), (
"QConv Binary Inplace Fusion requires accum is not an alias or mutation."
)
computation_args = (
x,
@ -1307,9 +1307,9 @@ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
def clone_to_new_node(graph, source_node, user_node):
# Clone the source_node to a new node
# Replace user_node's input from source_node to new_node
assert (
source_node.op == "call_function"
), "clone_to_new_node only support node.op call_function"
assert source_node.op == "call_function", (
"clone_to_new_node only support node.op call_function"
)
with graph.inserting_before(user_node):
new_node = graph.call_function(
source_node.target,
@ -1343,9 +1343,9 @@ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
# For a dequant pattern, we expect the start node is a dequantize_per_tensor node
return _node
else:
assert (
len(_node.args) >= 1
), "In in dequant pattern, each node should have more than 1 arg."
assert len(_node.args) >= 1, (
"In in dequant pattern, each node should have more than 1 arg."
)
return _find_first_node_in_dequant_pattern(_node.args[0])
dequant_pattern_start_node = _find_first_node_in_dequant_pattern(

View File

@ -616,7 +616,8 @@ def merge_splits(
dim=first_split_dim,
)
first_split_num_to_user = {
user.args[1]: user for user in first_split.users.keys() # type: ignore[union-attr]
user.args[1]: user
for user in first_split.users.keys() # type: ignore[union-attr]
}
new_split_num = 0
@ -706,7 +707,11 @@ class SplitCatSimplifier:
graph, split_node, split_sections, user_inputs_list, simplified_split_ranges
)
self.replace_cat(
graph, split_node, next_users, user_inputs_list_new, transform_params_list # type: ignore[arg-type]
graph,
split_node,
next_users,
user_inputs_list_new,
transform_params_list, # type: ignore[arg-type]
)
self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type]
counters["inductor"]["unbind_stack_pass"] += 1
@ -913,7 +918,9 @@ class SplitCatSimplifier:
)
if is_node_meta_valid(split_input): # type: ignore[arg-type, union-attr]
new_split.meta["example_value"] = torch.split(
split_input.meta["example_value"], [r[1] - r[0] for r in split_ranges], dim=split_dim # type: ignore[union-attr]
split_input.meta["example_value"], # type: ignore[union-attr]
[r[1] - r[0] for r in split_ranges],
dim=split_dim,
)
counters["inductor"]["scmerge_split_added"] += 1
split_items = []
@ -1005,7 +1012,10 @@ class SplitCatSimplifier:
stacked_input = graph.call_function(
torch.stack, args=(to_stack,), kwargs={"dim": stack_dim}
)
stacked_input.meta["example_value"] = torch.stack(to_stack_meta, dim=stack_dim) # type: ignore[arg-type, union-attr]
stacked_input.meta["example_value"] = torch.stack( # type: ignore[arg-type]
to_stack_meta,
dim=stack_dim, # type: ignore[arg-type]
)
to_stack, to_stack_meta = [], []
stack_dim = None
user_inputs_new_transformed.append(stacked_input)
@ -1023,19 +1033,28 @@ class SplitCatSimplifier:
user_input_new = graph.call_function(
torch.unflatten, args=(user_input_new, *unflatten_params)
)
user_input_new.meta["example_value"] = torch.unflatten(user_input_new_meta, *unflatten_params) # type: ignore[arg-type, possibly-undefined, union-attr]
user_input_new.meta["example_value"] = torch.unflatten( # type: ignore[arg-type]
user_input_new_meta, # type: ignore[arg-type]
*unflatten_params, # type: ignore[arg-type]
)
if movedim_params:
user_input_new_meta = user_input_new.meta["example_value"]
user_input_new = graph.call_function(
torch.movedim, args=(user_input_new, *movedim_params)
)
user_input_new.meta["example_value"] = torch.movedim(user_input_new_meta, *movedim_params) # type: ignore[arg-type, possibly-undefined, union-attr]
user_input_new.meta["example_value"] = torch.movedim( # type: ignore[arg-type]
user_input_new_meta, # type: ignore[arg-type]
*movedim_params, # type: ignore[arg-type]
)
if flatten_params:
user_input_new_meta = user_input_new.meta["example_value"]
user_input_new = graph.call_function(
torch.flatten, args=(user_input_new, *flatten_params)
)
user_input_new.meta["example_value"] = torch.flatten(user_input_new_meta, *flatten_params) # type: ignore[arg-type, possibly-undefined, union-attr]
user_input_new.meta["example_value"] = torch.flatten( # type: ignore[arg-type]
user_input_new_meta,
*flatten_params, # type: ignore[arg-type]
)
user_inputs_new_transformed.append(user_input_new)
user_inputs_new_transformed_meta.append(
user_input_new.meta["example_value"]
@ -1044,7 +1063,10 @@ class SplitCatSimplifier:
stacked_input = graph.call_function(
torch.stack, args=(to_stack,), kwargs={"dim": stack_dim}
)
stacked_input.meta["example_value"] = torch.stack(to_stack_meta, dim=stack_dim) # type: ignore[arg-type, union-attr]
stacked_input.meta["example_value"] = torch.stack( # type: ignore[arg-type]
to_stack_meta,
dim=stack_dim, # type: ignore[arg-type]
)
user_inputs_new_transformed.append(stacked_input)
user_inputs_new_transformed_meta.append(
stacked_input.meta["example_value"]
@ -1058,14 +1080,15 @@ class SplitCatSimplifier:
kwargs={"dim": cat_dim},
)
new_cat_node.meta["example_value"] = torch.cat(
user_inputs_new_transformed_meta, dim=cat_dim
user_inputs_new_transformed_meta,
dim=cat_dim,
)
counters["inductor"]["scmerge_cat_added"] += 1
else:
new_cat_node = user_inputs_new_transformed[-1]
new_cat_node.meta[
"example_value"
] = user_inputs_new_transformed_meta[-1]
new_cat_node.meta["example_value"] = (
user_inputs_new_transformed_meta[-1]
)
if (
user_node.target == torch.cat
@ -1077,7 +1100,11 @@ class SplitCatSimplifier:
new_cat_node = graph.call_function(
torch.flatten, args=(new_cat_node, cat_dim, cat_dim + 1)
)
new_cat_node.meta["example_value"] = torch.flatten(new_cat_node_meta, cat_dim, cat_dim + 1) # type: ignore[possibly-undefined, union-attr]
new_cat_node.meta["example_value"] = torch.flatten(
new_cat_node_meta,
cat_dim,
cat_dim + 1,
)
user_node.replace_all_uses_with(new_cat_node)
new_cats.append(new_cat_node)
@ -1123,9 +1150,7 @@ class UnbindCatRemover(SplitCatSimplifier):
]
if not is_sorted_and_consecutive(getitem_indices) or len( # type: ignore[arg-type]
getitem_indices
) != len(
unbind_node.meta["example_value"]
):
) != len(unbind_node.meta["example_value"]):
return
num_unbind = len(getitem_indices)
split_sections = [1 for _ in range(num_unbind)] # type: ignore[operator, arg-type]
@ -1510,7 +1535,8 @@ def merge_getitem_cat(match: Match, split_sections: list[int], dim: int):
fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index]
# update the split sections
split_sections[indices[0]] = calculate_fused_tensor_size( # type: ignore[index]
split_node, indices # type: ignore[arg-type]
split_node,
indices, # type: ignore[arg-type]
)
# padding others with zeros to keep the same dict size
for i in indices[1:]:
@ -1613,10 +1639,12 @@ def mutate_cat_node(match: Match, split_sections: list[int], dim: int):
elif is_node_meta_valid(split_node.args[0]): # type: ignore[arg-type]
# check the split dim, and construct the slice tuple
start_fused_size = calculate_fused_tensor_size(
split_node, list(range(indices[0])) # type: ignore[arg-type]
split_node,
list(range(indices[0])), # type: ignore[arg-type]
)
end_fused_size = start_fused_size + calculate_fused_tensor_size(
split_node, indices # type: ignore[arg-type]
split_node,
indices, # type: ignore[arg-type]
)
slice_list = []
for i in range(len(split_node.args[0].meta["example_value"].shape)): # type: ignore[union-attr]
@ -1714,7 +1742,10 @@ def merge_split_cat_aten(match: Match, *args, **kwargs):
continue
# check the cat node has consecutive indices
indices = [arg.args[1] for arg in cat_node.args[0]] # type: ignore[union-attr]
if not is_sorted_and_consecutive(indices) and len(getitem_nodes) != len(cat_inputs): # type: ignore[arg-type]
if (
not is_sorted_and_consecutive(indices) # type: ignore[arg-type]
and len(getitem_nodes) != len(cat_inputs)
):
continue
# replace the users of the cat node to be the input of the split node
cat_node.replace_all_uses_with(split_input)
@ -1764,7 +1795,10 @@ def merge_select_cat_aten(match: Match, *args, **kwargs):
continue
# check the cat node has consecutive indices
indices = [select.args[2] for select in cat_node.args[0]] # type: ignore[union-attr]
if not is_sorted_and_consecutive(indices) or len(select_nodes) != len(cat_inputs): # type: ignore[arg-type]
if (
not is_sorted_and_consecutive(indices) # type: ignore[arg-type]
or len(select_nodes) != len(cat_inputs)
):
continue
# check all the select nodes can be merged to the cat node input
if len(indices) != select_nodes[0].args[0].meta["val"].shape[cat_dim]: # type: ignore[union-attr]
@ -2318,7 +2352,9 @@ def unbind_cat_to_view(match: Match, unbind_input: torch.fx.Node, dim: int):
args=(new_cat_args,),
kwargs={"dim": cat_dim},
)
new_cat_node.meta["example_value"] = torch.cat(new_cat_args_meta, dim=cat_dim) # type: ignore[arg-type]
new_cat_node.meta["example_value"] = torch.cat(
new_cat_args_meta, dim=cat_dim
) # type: ignore[arg-type]
cat_node.replace_all_uses_with(new_cat_node)
new_cat_node.meta.update(cat_node.meta)
# remove inputs of cat_node if they have no users
@ -2411,7 +2447,8 @@ def convert_reshape_cat_arg_to_stack(
args=(permute_node, tuple(stack_node_shape)), # type: ignore[arg-type]
)
reshape_node.meta["example_value"] = torch.Tensor.view(
permute_node.meta["example_value"], tuple(stack_node_shape) # type: ignore[arg-type]
permute_node.meta["example_value"],
tuple(stack_node_shape), # type: ignore[arg-type]
)
return reshape_node
@ -2687,7 +2724,9 @@ def move_reshape_out_of_split_stack(match: Match, *args, **kwargs):
cat_inputs.append(decomposed_stack_node)
# cat_arg must be the split input
view_shape_list = get_view_shape_list(cat_arg, stack_dim)
stack_node_shape = torch.reshape(cat_arg.meta["example_value"], tuple(view_shape_list)).shape # type: ignore[union-attr]
stack_node_shape = torch.reshape(
cat_arg.meta["example_value"], tuple(view_shape_list)
).shape # type: ignore[union-attr]
cat_inputs.append(
convert_reshape_cat_arg_to_stack(
graph,

View File

@ -105,9 +105,9 @@ class FakeTensorUpdater:
if new is None:
return old is None
if not isinstance(new, torch.Tensor):
assert isinstance(
new, (torch.SymInt, torch.SymBool, torch.SymFloat)
), f"Unknown type {type(new)} in {self.graph}"
assert isinstance(new, (torch.SymInt, torch.SymBool, torch.SymFloat)), (
f"Unknown type {type(new)} in {self.graph}"
)
return (
new.node.shape_env._maybe_evaluate_static(
sympy.Eq(new.node.expr, old.node.expr)

View File

@ -136,7 +136,9 @@ else:
def may_get_constant_buffer_dtype(constant_buffer: sympy.Expr) -> Optional[torch.dtype]:
assert isinstance(
constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer"
), (
"get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer"
)
if isinstance(constant_buffer, sympy.core.numbers.Integer):
return torch.int64
@ -308,9 +310,9 @@ class GraphLowering(torch.fx.Interpreter):
self.reuse_shape_env = True
self._shape_env = shape_env
# We're going to mutate ras_by_symbol as we finish generating them
self.ras_by_symbol: dict[
Optional[sympy.Symbol], list[RuntimeAssert]
] = shape_env.deferred_runtime_asserts.copy()
self.ras_by_symbol: dict[Optional[sympy.Symbol], list[RuntimeAssert]] = (
shape_env.deferred_runtime_asserts.copy()
)
self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]()
self.sizevars = SizeVarAllocator(shape_env)
self.graph_input_names: list[str] = []
@ -400,9 +402,7 @@ class GraphLowering(torch.fx.Interpreter):
self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
self.cache_linemap: list[
tuple[int, str]
] = (
[]
) # This is the linemap used by the profiler to mark custom compiled kernels getting run
] = [] # This is the linemap used by the profiler to mark custom compiled kernels getting run
# Used if lowering encounters cases where cudagraphs are not supported
self.disable_cudagraphs_reason: Optional[str] = None
@ -1012,7 +1012,10 @@ class GraphLowering(torch.fx.Interpreter):
)
def placeholder(
self, target: str, args: tuple[object], kwargs: dict[str, object] # type: ignore[override]
self,
target: str, # type: ignore[override]
args: tuple[object], # type: ignore[override]
kwargs: dict[str, object],
) -> Union[Expr, TensorBox, None]:
self.placeholder_idx += 1
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
@ -1118,9 +1121,9 @@ class GraphLowering(torch.fx.Interpreter):
return target(*args, **kwargs)
if target not in lowerings:
assert isinstance(
target, torch._ops.OpOverload
), f"{target} is not an OpOverload"
assert isinstance(target, torch._ops.OpOverload), (
f"{target} is not an OpOverload"
)
base_name = target.name().split(".")[0]
if base_name in FALLBACK_ALLOW_LIST:
make_fallback(target, warn=False, override_decomp=True)
@ -1189,7 +1192,10 @@ class GraphLowering(torch.fx.Interpreter):
return len(t.shape) == 1 and t.shape[0] <= 8
def get_attr(
self, target: str, args: tuple[()], kwargs: dict[str, object] # type: ignore[override]
self,
target: str, # type: ignore[override]
args: tuple[()], # type: ignore[override]
kwargs: dict[str, object],
) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
# this is a constant
value = getattr_recursive(self.module, target) # type: ignore[arg-type]
@ -1241,7 +1247,10 @@ class GraphLowering(torch.fx.Interpreter):
raise AssertionError
def output(
self, target: str, args: tuple[object], kwargs: dict[str, object] # type: ignore[override]
self,
target: str, # type: ignore[override]
args: tuple[object], # type: ignore[override]
kwargs: dict[str, object],
) -> None:
result = super().output(target, args, kwargs) # type: ignore[arg-type]
if not isinstance(result, (tuple, list)):
@ -1439,9 +1448,11 @@ class GraphLowering(torch.fx.Interpreter):
if is_call_function:
args, kwargs = self.fetch_args_kwargs_from_env(n)
origins |= gather_origins(args, kwargs)
with ir.IRNode.current_origins(origins), self.set_current_node(
n
), V.set_current_node(n):
with (
ir.IRNode.current_origins(origins),
self.set_current_node(n),
V.set_current_node(n),
):
if (
n.op == "call_function"
and n.target is not operator.getitem
@ -1454,7 +1465,8 @@ class GraphLowering(torch.fx.Interpreter):
):
debug("fallback_handler")
result = fallback_handler(n.target, add_to_fallback_set=False)(
*args, **kwargs # type: ignore[possibly-undefined]
*args, # type: ignore[possibly-undefined]
**kwargs, # type: ignore[possibly-undefined]
)
elif (
n.op == "call_function"
@ -1833,9 +1845,9 @@ class GraphLowering(torch.fx.Interpreter):
wrapper_code_gen_cls = get_wrapper_codegen_for_device(
self.device_type, self.cpp_wrapper
)
assert (
wrapper_code_gen_cls is not None
), f"Device {self.device_type} not supported"
assert wrapper_code_gen_cls is not None, (
f"Device {self.device_type} not supported"
)
self.wrapper_code = wrapper_code_gen_cls.create(
is_subgraph,
subgraph_name,
@ -1866,7 +1878,7 @@ class GraphLowering(torch.fx.Interpreter):
compiled = self.compile_to_module().call
def materialize(
x: Union[torch.SymInt, torch.SymFloat, torch.Tensor]
x: Union[torch.SymInt, torch.SymFloat, torch.Tensor],
) -> Union[int, float, torch.Tensor]:
if x is None:
return None
@ -1876,9 +1888,9 @@ class GraphLowering(torch.fx.Interpreter):
elif isinstance(x, FakeTensor):
return defake(x)
else:
assert isinstance(
x, torch.Tensor
), "Unknown type when creating real inputs" + str(type(x))
assert isinstance(x, torch.Tensor), (
"Unknown type when creating real inputs" + str(type(x))
)
return x
tracing_context = torch._guards.TracingContext.try_get()

View File

@ -5,21 +5,22 @@ propagation of sympy expressions downstream of ops.index_expr calls.
For example, say we have the IR:
tmp0 = ops.index_expr(x, torch.int32)
tmp1 = ops.constant(2, torch.int32)
tmp2 = ops.mul(tmp0, tmp1)
tmp3 = ops.indirect_indexing(tmp2, x_size)
tmp4 = ops.load("buf0", tmp3)
tmp0 = ops.index_expr(x, torch.int32)
tmp1 = ops.constant(2, torch.int32)
tmp2 = ops.mul(tmp0, tmp1)
tmp3 = ops.indirect_indexing(tmp2, x_size)
tmp4 = ops.load("buf0", tmp3)
The underlying handler would just see:
ops.load("buf0", x * 2)
ops.load("buf0", x * 2)
This is limited by the set of operators handled in the sympy expression
printers. So simple operations like minimum and maximum cannot be translated to
SymPy expressions yet, despite sympy.Min and sympy.Max existing.
"""
import itertools
from collections.abc import Sequence
from dataclasses import dataclass
@ -179,9 +180,9 @@ class IndexPropVar:
return IndexPropVar(expr, is_symbolic=True)
def __post_init__(self):
assert not self.is_symbolic or isinstance(
self.value, TypedExpr
), "Symbolic IndexPropVar must contain a TypedExpr"
assert not self.is_symbolic or isinstance(self.value, TypedExpr), (
"Symbolic IndexPropVar must contain a TypedExpr"
)
IndexPropResult: TypeAlias = Union[IndexPropVar, tuple["IndexPropResult", ...]]
@ -251,14 +252,12 @@ class IndexPropagation(DefaultHandler):
name: Literal["indirect_indexing"],
args: Sequence[Any],
kwargs: dict[str, Any],
) -> IndexPropVar:
...
) -> IndexPropVar: ...
@overload
def fallback(
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
) -> IndexPropResult:
...
) -> IndexPropResult: ...
def fallback(
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
@ -283,8 +282,7 @@ class IndexPropagation(DefaultHandler):
is_valid_expr = new_expr is not NotImplemented and (
# Inductor doesn't expect floating point in sympy expressions, but
# allow floating point constants to be propagated
new_expr.is_constant()
or new_expr.expr.is_integer
new_expr.is_constant() or new_expr.expr.is_integer
)
if not is_valid_expr:
return self.fallback(name, args, kwargs)

View File

@ -211,7 +211,9 @@ def validate_ir(node_or_nodes: Optional[_NodeOrNodes]) -> None:
int,
EffectfulKernel,
),
), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"
), (
f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"
)
# Be picky about the accepted data structure (don't use pytree here)
_check_tensorbox(node_or_nodes)
@ -298,13 +300,11 @@ def get_stride_order(
@overload
def ir_node_to_tensor(x: Literal[None], guard_shape: bool = True) -> None:
...
def ir_node_to_tensor(x: Literal[None], guard_shape: bool = True) -> None: ...
@overload
def ir_node_to_tensor(x: IRNode, guard_shape: bool = True) -> torch.Tensor:
...
def ir_node_to_tensor(x: IRNode, guard_shape: bool = True) -> torch.Tensor: ...
def ir_node_to_tensor(
@ -346,7 +346,7 @@ def may_convert_to_optional(
def get_device_type(
x: Union[IRNode, OutputSpec, torch.device, None, str]
x: Union[IRNode, OutputSpec, torch.device, None, str],
) -> Optional[str]:
if isinstance(x, str) or x is None:
return x
@ -698,8 +698,7 @@ class IRNode:
if TYPE_CHECKING:
@property
def dtype(self) -> torch.dtype:
...
def dtype(self) -> torch.dtype: ...
@ir_dataclass(frozen=False)
@ -839,8 +838,9 @@ class Loops(IRNode):
@cache_on_self
def inner_fn_opcount(self) -> OpCountResult:
opcounter = OpCounterCSE(V.MockHandler())
with V.set_ops_handler(opcounter), patch.object(
FlexibleLayout, "allow_indexing", True
with (
V.set_ops_handler(opcounter),
patch.object(FlexibleLayout, "allow_indexing", True),
):
self.inner_fn(*self.inner_fn_args())
return opcounter.getvalue()
@ -1364,9 +1364,9 @@ class Reduction(Loops):
# "all" is desugared to `!any(!val)`
}
assert (
reduction_type in rtypes_to_inits.keys()
), f"{reduction_type} not supported for zero-dimension tensors!"
assert reduction_type in rtypes_to_inits.keys(), (
f"{reduction_type} not supported for zero-dimension tensors!"
)
def const_fn(index: int) -> OpsValue:
return ops.constant(rtypes_to_inits[reduction_type], dst_dtype)
@ -1575,9 +1575,9 @@ class Reduction(Loops):
new_ranges: Sequence[Integer],
new_reduction_ranges: Sequence[Integer],
) -> Callable[[Sequence[sympy.Expr], Sequence[sympy.Expr]], OpsValue]:
assert all(
r == 1 for r in original_ranges
), f"Only enabled for numel_hint == 1, found {original_ranges=}"
assert all(r == 1 for r in original_ranges), (
f"Only enabled for numel_hint == 1, found {original_ranges=}"
)
reindex = View.dynamic_reshape_indexer(
original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges)
)
@ -1828,7 +1828,7 @@ class WelfordReduction(Reduction):
if reduction_numel == 1:
def copy(
loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue]
loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue],
) -> TensorBox:
def inner_fn(idx: Sequence[Expr]) -> OpsValue:
reduction_index = [sympy.S.Zero for _ in reduction_ranges]
@ -2571,9 +2571,9 @@ class ExpandView(BaseView):
# NB: new_size[i] == old_size[i] is expected to already be
# guarded because the meta formula was expected to have taught
# us this equality.
assert (
sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0
), "Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}"
assert sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0, (
"Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}"
)
return new_size
@classmethod
@ -3382,9 +3382,9 @@ class Layout(OutputSpec):
)
def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
assert (
FlexibleLayout.allow_indexing
), f"convert {type(self).__name__} to FixedLayout first"
assert FlexibleLayout.allow_indexing, (
f"convert {type(self).__name__} to FixedLayout first"
)
return self.as_fixed().make_indexer()
def __eq__(self, other) -> bool: # type: ignore[no-untyped-def]
@ -3684,9 +3684,9 @@ class MutationLayoutSHOULDREMOVE(Layout):
return target
result = unwrap_views(self.target)
assert isinstance(
result, Buffer
), "MutationLayoutSHOULDREMOVE must refer to a buffer"
assert isinstance(result, Buffer), (
"MutationLayoutSHOULDREMOVE must refer to a buffer"
)
return result
def real_layout(self): # type: ignore[no-untyped-def]
@ -3803,7 +3803,9 @@ class Buffer(IRNode):
assert isinstance(self.layout, FlexibleLayout)
self.layout = self.layout.as_same_order(stride)
def freeze_layout_with_exact_strides(self, exact_strides, allow_padding=False) -> None: # type: ignore[no-untyped-def]
def freeze_layout_with_exact_strides( # type: ignore[no-untyped-def]
self, exact_strides, allow_padding=False
) -> None:
assert isinstance(self.layout, FlexibleLayout)
self.layout = self.layout.as_exact_strides(
exact_strides, allow_padding=allow_padding
@ -4365,9 +4367,9 @@ class TritonTemplateBuffer(TemplateBuffer):
torch.ops.higher_order.flex_attention_backward,
)
current_node = V.graph.current_node.target
assert (
current_node in allowed_set
), f"Mutated inputs are only allowed for {allowed_set} but got {current_node}"
assert current_node in allowed_set, (
f"Mutated inputs are only allowed for {allowed_set} but got {current_node}"
)
device = self.inputs[0].get_device()
self.outputs += [
MutationOutput(NoneLayout(device=device), buf, self)
@ -5106,7 +5108,8 @@ class ExternKernel(InputsKernel):
x_unwrap_view.freeze_layout()
index_args, var_ranges = dependencies.index_vars_squeeze(
x.get_size(), prefix="r" # type: ignore[arg-type]
x.get_size(),
prefix="r", # type: ignore[arg-type]
)
range_vars = index_args[0]
index = x.make_indexer()(range_vars)
@ -5404,9 +5407,9 @@ class ExternKernel(InputsKernel):
# pass in a list of const arg names for arg_properties lookup.
name_to_arg_properties = None
if names and self.arg_properties:
assert len(self.constant_args) == len(
names
), "names passed to codegen_const_args does not match self.constant_args"
assert len(self.constant_args) == len(names), (
"names passed to codegen_const_args does not match self.constant_args"
)
name_to_arg_properties = {
arg.get("name"): arg for arg in self.arg_properties
}
@ -5442,9 +5445,9 @@ class ExternKernel(InputsKernel):
args = []
for i, x in enumerate(inputs):
if V.graph.cpp_wrapper:
assert self.arg_properties and i < len(
self.arg_properties
), "Invalid access to ExternKernel.arg_properties"
assert self.arg_properties and i < len(self.arg_properties), (
"Invalid access to ExternKernel.arg_properties"
)
type_ = self.arg_properties[i].get("type")
args.append(V.graph.wrapper_code.val_to_arg_str(x, type_))
else:
@ -5914,7 +5917,9 @@ class UserDefinedTritonKernel(ExternKernel):
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet()
def __init__(self, *, kernel_idx, grid, tma_descriptor_metadata, kernel_args) -> None: # type: ignore[no-untyped-def]
def __init__( # type: ignore[no-untyped-def]
self, *, kernel_idx, grid, tma_descriptor_metadata, kernel_args
) -> None:
inputs = []
kwargs = {}
constant_args = []
@ -6835,9 +6840,9 @@ class FallbackKernel(ExternKernelAlloc):
elif isinstance(output, torch.SymInt):
return output.node.expr
else:
assert (
output is None
), f"FallbackKernel output type {type(output)} is not supported"
assert output is None, (
f"FallbackKernel output type {type(output)} is not supported"
)
return None
outputs = generate_output(example_output, [])
@ -6919,7 +6924,12 @@ class MultiOutput(ExternKernel):
)
self.codegen_size_asserts(wrapper)
def __init__(self, layout: OutputSpec, input, indices: list[tuple[Any, ...]]) -> None: # type: ignore[no-untyped-def]
def __init__( # type: ignore[no-untyped-def]
self,
layout: OutputSpec,
input,
indices: list[tuple[Any, ...]],
) -> None:
super().__init__(None, layout, [input], ())
self.name = V.graph.register_buffer(self)
V.graph.register_operation(self)
@ -7496,9 +7506,9 @@ class WhileLoop(ExternKernel):
assert p.get_dtype() == torch.bool, p
assert len(p.get_size()) == 0, p
assert (
len(all_inputs) > 0
), "torch.while_loop is assumed to have at least one operand."
assert len(all_inputs) > 0, (
"torch.while_loop is assumed to have at least one operand."
)
device = all_inputs[0].get_device()
@ -7669,9 +7679,9 @@ class _CollectiveKernel(FallbackKernel):
# This is identical to FallbackKernel.set_cpp_kernel(), minus the
# part that checks against input aliasing and mutation.
def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None:
assert (
type(self.op_overload) is torch._ops.OpOverload
), "Setting cpp kernel needs a valid op_overload"
assert type(self.op_overload) is torch._ops.OpOverload, (
"Setting cpp kernel needs a valid op_overload"
)
kernel = self.op_overload
self.cpp_kernel_name = kernel._schema.name

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs
""" Triton Implementation of the flex_attention Kernel"""
"""Triton Implementation of the flex_attention Kernel"""
import copy
import logging
@ -60,9 +60,9 @@ def construct_strides(
) -> Sequence[int]:
"""From a list of sizes and a fill order, construct the strides of the permuted tensor."""
# Initialize strides
assert len(sizes) == len(
fill_order
), "Length of sizes must match the length of the fill order"
assert len(sizes) == len(fill_order), (
"Length of sizes must match the length of the fill order"
)
strides = [0] * len(sizes)
# Start with stride 1 for the innermost dimension
@ -1151,10 +1151,14 @@ def lower_cpu(
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
assert V.graph.sizevars.evaluate_expr(
sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE))
), "Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask."
), (
"Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask."
)
assert V.graph.sizevars.evaluate_expr(
sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE))
), "KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask."
), (
"KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask."
)
CppFlexAttentionTemplate.add_choices(
choices=_choices,
input_nodes=input_nodes,
@ -1364,15 +1368,15 @@ def flex_attention(
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
assert V.graph.sizevars.evaluate_expr(
sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)
), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
assert V.graph.sizevars.evaluate_expr(
sympy.Gt(seq_len_q, 0)
), "Query length must be greater than 0"
assert V.graph.sizevars.evaluate_expr(
sympy.Gt(seq_len_kv, 0)
), "Key length must be greater than 0"
assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), (
f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
)
assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_q, 0)), (
"Query length must be greater than 0"
)
assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_kv, 0)), (
"Key length must be greater than 0"
)
B = Bq
@ -2291,9 +2295,9 @@ def process_joint_outputs(
JointOutputResult containing processed buffers and gradients
"""
assert isinstance(all_joint_outputs, list)
assert (
all_joint_outputs[0] is not None
), "joint_subgraph_buffer is None - this is a bug!"
assert all_joint_outputs[0] is not None, (
"joint_subgraph_buffer is None - this is a bug!"
)
joint_buffer = all_joint_outputs[0]
other_grads = all_joint_outputs[num_placeholders - 1 :]
@ -2392,9 +2396,9 @@ def flex_attention_backward(*args, **kwargs):
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
assert V.graph.sizevars.evaluate_expr(
sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)
), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), (
f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
)
kernel_options = dict(kernel_options)
# Mark symbols in custom kernel options as static shapes and add guards.
@ -2639,9 +2643,11 @@ def flex_attention_backward(*args, **kwargs):
grad_key = broadcasted_grad_key
grad_value = broadcasted_grad_value
else:
assert V.graph.sizevars.evaluate_expr(
sympy.Gt(Bq, 1) & sympy.Eq(Bkv, 1)
), f"Bq and Bkv must broadcastable. Got Bq={V.graph.sizevars.evaluate_expr(Bq)} and Bkv={V.graph.sizevars.evaluate_expr(Bkv)}" # noqa: B950
assert V.graph.sizevars.evaluate_expr(sympy.Gt(Bq, 1) & sympy.Eq(Bkv, 1)), (
f"Bq and Bkv must broadcastable. "
f"Got Bq={V.graph.sizevars.evaluate_expr(Bq)} "
f"and Bkv={V.graph.sizevars.evaluate_expr(Bkv)}"
)
grad_key = lowerings[aten.sum](broadcasted_grad_key, axis=0, keepdims=True)
grad_value = lowerings[aten.sum](broadcasted_grad_value, axis=0, keepdims=True)

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
""" Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)"""
"""Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)"""
from typing import Any
import sympy
@ -367,9 +368,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
assert V.graph.sizevars.evaluate_expr(
sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)
), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), (
f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
)
B = Bq
kernel_options = dict(kernel_options)
@ -481,7 +482,8 @@ def create_flex_decoding_kernel(*args, **kwargs):
max(
next_power_of_2(
V.graph.sizevars.size_hint(
seq_len_q, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
seq_len_q,
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
)
* gqa_shared_heads
),

View File

@ -65,7 +65,8 @@ def filtered_configs(
m = max(
next_power_of_2(
V.graph.sizevars.size_hint(
m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
m,
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size,
@ -73,7 +74,8 @@ def filtered_configs(
n = max(
next_power_of_2(
V.graph.sizevars.size_hint(
n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
n,
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size,
@ -81,7 +83,8 @@ def filtered_configs(
k = max(
next_power_of_2(
V.graph.sizevars.size_hint(
k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
k,
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
)
),
min_block_size_k,
@ -467,8 +470,7 @@ def mm_options(config, sym_m, sym_n, sym_k, layout):
"""
even_k_symbolic = (
# it isn't worth guarding on this
sympy.gcd(sym_k, config.kwargs["BLOCK_K"])
== config.kwargs["BLOCK_K"]
sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"]
)
allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
not inductor_config.force_same_precision

View File

@ -194,11 +194,12 @@ class LoopBody:
# There is indeed an issue due to symbol name conflicting.
# y0 maybe reused for the y dimension later.
(
iter_vars,
reduce_vars,
), var_ranges = dependencies.index_vars_no_squeeze(
iter_sizes, reduce_sizes, prefix="t"
)
(
iter_vars,
reduce_vars,
),
var_ranges,
) = dependencies.index_vars_no_squeeze(iter_sizes, reduce_sizes, prefix="t")
new_body = LoopBody(
old_body,
[iter_reindex(iter_vars), reduce_reindex(reduce_vars)],
@ -234,7 +235,8 @@ class LoopBody:
new_sizes = (new_iter_size, reduce_size)
(iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
*new_sizes, prefix="t" # type: ignore[arg-type]
*new_sizes,
prefix="t", # type: ignore[arg-type]
)
inverse_order = {b: a for a, b in enumerate(new_order)}
@ -254,7 +256,8 @@ class LoopBody:
# use the original symbol prefix so we can do multiple round of reordering
(iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
*new_sizes, prefix="p" # type: ignore[arg-type]
*new_sizes,
prefix="p", # type: ignore[arg-type]
)
new_body = LoopBody(
loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2
@ -385,9 +388,9 @@ class LoopBody:
def indexing_from_args(self, indices):
index = [*itertools.chain.from_iterable(indices)]
assert len(index) == len(self.var_ranges), (index, self.var_ranges)
assert all(
v not in self.var_ranges for v in index
), f"{self.var_ranges=}, {indices=}"
assert all(v not in self.var_ranges for v in index), (
f"{self.var_ranges=}, {indices=}"
)
replacements = dict(zip(self.var_ranges.keys(), index))
return {
name: sympy_subs(expr, replacements)

View File

@ -346,7 +346,8 @@ def transform_args(
# only consider tensor kwargs for promotion, for now
promoting_args.extend(a for a in kwargs.values() if hasattr(a, "dtype"))
dtype = get_promoted_dtype(
*promoting_args, type_promotion_kind=type_promotion_kind # type: ignore[arg-type]
*promoting_args,
type_promotion_kind=type_promotion_kind, # type: ignore[arg-type]
)
device = (
@ -448,9 +449,9 @@ def _register_lowering(
(fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn
):
# explicitly assert for "out=" ops for better error messages
assert not any(
x == "out" for x in kwargs.keys()
), "out= ops aren't yet supported"
assert not any(x == "out" for x in kwargs.keys()), (
"out= ops aren't yet supported"
)
args, kwargs = transform_args(
args, kwargs, broadcast, type_promotion_kind, convert_input_to_bool
@ -517,9 +518,9 @@ def broadcast_symbolic_shapes(a, b):
def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None):
assert (
override_return_dtype is None or type_promotion_kind is None
), "only one of override_return_dtype or type_promotion_kind may be given"
assert override_return_dtype is None or type_promotion_kind is None, (
"only one of override_return_dtype or type_promotion_kind may be given"
)
if override_return_dtype is None and type_promotion_kind is None:
type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
@ -674,9 +675,9 @@ def make_foreach_pointwise(pw_fn, allow_alpha=False):
if isinstance(input, (list, tuple)):
a_list_input = input
break
assert (
a_list_input is not None
), "at least one input must be a list to a foreach op"
assert a_list_input is not None, (
"at least one input must be a list to a foreach op"
)
# broadcast scalar inputs to match length of list inputs
broadcast_inputs = []
@ -1321,12 +1322,12 @@ def quantized_decomposed_quantize_per_channel(
if input.get_dtype() == torch.bfloat16:
input = to_dtype(input, torch.float32)
assert (
input.get_dtype() == torch.float32
), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
assert axis < len(
input.get_size()
), f"Expecting axis to be < {len(input.get_size())}"
assert input.get_dtype() == torch.float32, (
f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
)
assert axis < len(input.get_size()), (
f"Expecting axis to be < {len(input.get_size())}"
)
input_loader = input.make_loader()
scales_loader = scales.make_loader()
@ -1373,12 +1374,12 @@ def quantized_decomposed_dequantize_per_channel(
) -> TensorBox:
assert len(scales.get_size()) == 1, "expect scales 1 dim"
assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim"
assert (
input.get_dtype() == dtype
), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
assert axis < len(
input.get_size()
), f"Expecting axis to be < {len(input.get_size())}"
assert input.get_dtype() == dtype, (
f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
)
assert axis < len(input.get_size()), (
f"Expecting axis to be < {len(input.get_size())}"
)
if out_dtype is None:
out_dtype = torch.float32
@ -1423,9 +1424,9 @@ def quantized_decomposed_quantize_per_tensor_default(
) -> TensorBox:
if input.get_dtype() == torch.bfloat16:
input = to_dtype(input, torch.float32)
assert (
input.get_dtype() == torch.float32
), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
assert input.get_dtype() == torch.float32, (
f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
)
input_loader = input.make_loader()
@ -1462,9 +1463,9 @@ def quantized_decomposed_dequantize_per_tensor_default(
*,
out_dtype: Optional[torch.dtype] = None,
) -> TensorBox:
assert (
input.get_dtype() == dtype
), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
assert input.get_dtype() == dtype, (
f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
)
if out_dtype is None:
out_dtype = torch.float32
@ -1501,9 +1502,9 @@ def quantized_decomposed_quantize_per_tensor_tensor(
) -> TensorBox:
if input.get_dtype() == torch.bfloat16:
input = to_dtype(input, torch.float32)
assert (
input.get_dtype() == torch.float32
), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
assert input.get_dtype() == torch.float32, (
f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
)
assert len(scale.get_size()) == 0 or (
len(scale.get_size()) == 1 and scale.get_size()[0] == 1
), "expect scale as scalar tensor"
@ -1555,9 +1556,9 @@ def quantized_decomposed_dequantize_per_tensor_tensor(
assert len(zero_point.get_size()) == 0 or (
len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1
), "expect zero_point as scalar tensor"
assert (
input.get_dtype() == dtype
), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
assert input.get_dtype() == dtype, (
f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
)
if out_dtype is None:
out_dtype = torch.float32
@ -1973,9 +1974,9 @@ def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=
def make_fallback(op, layout_constraint=None, warn=True, override_decomp=False):
assert (
op not in decompositions or override_decomp
), f"both a fallback and a decomp for same op: {op}"
assert op not in decompositions or override_decomp, (
f"both a fallback and a decomp for same op: {op}"
)
if (
warn
and bool(os.getenv("CI"))
@ -2086,9 +2087,9 @@ def native_dropout(x, p, train):
@register_lowering(aten.bernoulli_, type_promotion_kind=None)
def bernoulli_(x, *args):
assert config.fallback_random or x.get_device() == torch.device(
"cpu"
), "this should be handled in decomps unless config.fallback_random or the device is CPU"
assert config.fallback_random or x.get_device() == torch.device("cpu"), (
"this should be handled in decomps unless config.fallback_random or the device is CPU"
)
x.realize()
op_overload = (
aten.bernoulli_.float
@ -2101,9 +2102,9 @@ def bernoulli_(x, *args):
@register_lowering(aten.bernoulli.p, type_promotion_kind=None)
def bernoulli_p(x, *args):
assert config.fallback_random or x.get_device() == torch.device(
"cpu"
), "this should be handled in decomps unless config.fallback_random or the device is CPU"
assert config.fallback_random or x.get_device() == torch.device("cpu"), (
"this should be handled in decomps unless config.fallback_random or the device is CPU"
)
return bernoulli_(clone(x), *args)
@ -3376,7 +3377,9 @@ def check_and_broadcast_indices(indices, device):
i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8)
for i in indices
if i is not None
), f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}"
), (
f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}"
)
if any(
i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None
):
@ -5668,7 +5671,8 @@ def make_reduction(reduction_type: ReductionType, override_return_dtype=None):
)
result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs)
if isinstance(
result.data.data, Reduction # type: ignore[attr-defined]
result.data.data, # type: ignore[attr-defined]
Reduction,
): # Only realize if reduction isn't unrolled
result.realize()
return result
@ -6008,8 +6012,9 @@ def get_constant_value(x: ir.IRNode) -> Optional[ir.Constant]:
return None
handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device())
with V.set_ops_handler(handler), patch.object(
ir.FlexibleLayout, "allow_indexing", True
with (
V.set_ops_handler(handler),
patch.object(ir.FlexibleLayout, "allow_indexing", True),
):
out = x.inner_fn(*x.inner_fn_args())
@ -6898,9 +6903,9 @@ def force_fallback(op: torch._ops.OpOverload):
A context manager to force fallback an op. Used in unit test
for FallbackKernel.
"""
assert isinstance(
op, torch._ops.OpOverload
), "Only OpOverload to make the clean up easier"
assert isinstance(op, torch._ops.OpOverload), (
"Only OpOverload to make the clean up easier"
)
old_handler = lowerings.get(op)
try:
register_lowering(op)(fallback_handler(op))

View File

@ -35,9 +35,9 @@ class MemoryPlanningInfoForBuffer:
class MemoryPlanningInfoForNode:
index: int = 0
size: int = 0
pred_buffers: OrderedSet[
Union[SchedulerBuffer, FreeableInputBuffer]
] = dataclasses.field(default_factory=OrderedSet)
pred_buffers: OrderedSet[Union[SchedulerBuffer, FreeableInputBuffer]] = (
dataclasses.field(default_factory=OrderedSet)
)
pred_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field(
default_factory=OrderedSet
)
@ -87,9 +87,9 @@ def get_freeable_input_buf(
# get freeable input buffers' successor nodes and their sizes
# note that different deps can have the same name, so we use name as keys
dep_name_to_succ_nodes: dict[
str, OrderedSet[BaseSchedulerNode]
] = collections.defaultdict(OrderedSet)
dep_name_to_succ_nodes: dict[str, OrderedSet[BaseSchedulerNode]] = (
collections.defaultdict(OrderedSet)
)
dep_name_to_size: dict[str, int] = dict()
for node in nodes:
for dep in node.read_writes.reads:
@ -112,7 +112,7 @@ def get_freeable_input_buf(
def compute_size_for_scheduler_buffer(
name_to_buf: dict[str, SchedulerBuffer]
name_to_buf: dict[str, SchedulerBuffer],
) -> dict[str, tuple[int, int]]:
"""
Compute the size of each scheduler buffer, including (1) memory allocated when
@ -187,9 +187,9 @@ def assign_memory_planning_info_for_scheduler_buffers(
# get buffer's successor nodes
# note that different deps can have the same name, so we use name as keys
dep_name_to_succ_nodes: dict[
str, OrderedSet[BaseSchedulerNode]
] = collections.defaultdict(OrderedSet)
dep_name_to_succ_nodes: dict[str, OrderedSet[BaseSchedulerNode]] = (
collections.defaultdict(OrderedSet)
)
for node in nodes:
for dep in node.unmet_dependencies:
dep_name_to_succ_nodes[dep.name].add(node)

View File

@ -138,12 +138,12 @@ class MetricTable:
return
row_dict = row_fn()
assert len(self.column_names) == len(
row_dict
), f"{len(self.column_names)} v.s. {len(row_dict)}"
assert OrderedSet(self.column_names) == OrderedSet(
row_dict.keys()
), f"{OrderedSet(self.column_names)} v.s. {OrderedSet(row_dict.keys())}"
assert len(self.column_names) == len(row_dict), (
f"{len(self.column_names)} v.s. {len(row_dict)}"
)
assert OrderedSet(self.column_names) == OrderedSet(row_dict.keys()), (
f"{OrderedSet(self.column_names)} v.s. {OrderedSet(row_dict.keys())}"
)
bn = get_benchmark_name()
# assert bn is not None
@ -433,9 +433,9 @@ def enabled_metric_tables_impl(config_str: str) -> OrderedSet[str]:
name = name.strip()
if not name:
continue
assert (
name in REGISTERED_METRIC_TABLES
), f"Metric table name {name} is not registered"
assert name in REGISTERED_METRIC_TABLES, (
f"Metric table name {name} is not registered"
)
enabled.add(name)
return enabled

View File

@ -751,9 +751,9 @@ class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
unary_algorithm,
]
assert (
binary_attr == "sum"
), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E."
assert binary_attr == "sum", (
"For now, only post op sum is supported in QConvPointWiseBinaryPT2E."
)
V.graph.mark_buffer_mutated(qaccum.get_name())
packed = QConvPointWiseBinaryPT2E(

View File

@ -575,9 +575,9 @@ def register_onednn_fusion_ops():
algorithm,
layout=None,
):
assert (
packed_weight.get_dtype() is torch.int8
), "Only int8 weights are supported by oneDNN qlinear."
assert packed_weight.get_dtype() is torch.int8, (
"Only int8 weights are supported by oneDNN qlinear."
)
x_size = x.get_size()
if len(x_size) > 2:
# GEMM template needs 2D input, normalize input shape here
@ -928,9 +928,9 @@ def register_onednn_fusion_ops():
# we will do accum dtype convertion here.
x2 = to_dtype(x2, output_dtype)
else:
assert (
x2.get_dtype() == output_dtype
), "dtype of accum for qlinear post op sum should be the same as output"
assert x2.get_dtype() == output_dtype, (
"dtype of accum for qlinear post op sum should be the same as output"
)
x2_dtype = x2.get_dtype()
bias_dtype = bias.get_dtype() if bias is not None else None
choices: list[ChoiceCaller] = []

View File

@ -806,8 +806,8 @@ class DefaultHandler(OpsHandler[Any]):
assert self_arg == "self"
code.write(
f"""
def {target}(self, {', '.join(args)}):
return self._default({target!r}, ({', '.join(args)}, ), {{}})
def {target}(self, {", ".join(args)}):
return self._default({target!r}, ({", ".join(args)}, ), {{}})
""".strip()
)
code.write("\n\n")
@ -994,8 +994,9 @@ class KernelFormatterHandler(DefaultHandler):
)
formatter._output.writeline(f"{lhs} = {name}")
with V.set_ops_handler(formatter), patch.object(
FlexibleLayout, "allow_indexing", True
with (
V.set_ops_handler(formatter),
patch.object(FlexibleLayout, "allow_indexing", True),
):
result = ir_fn(*args)
return formatter.getvalue(result)

View File

@ -188,7 +188,9 @@ def package_aoti(
) or (
isinstance(archive_file, (str, os.PathLike))
and os.fspath(archive_file).endswith(".pt2")
), f"Expect archive file to be a file ending in .pt2, or is a buffer. Instead got {archive_file}"
), (
f"Expect archive file to be a file ending in .pt2, or is a buffer. Instead got {archive_file}"
)
# Save using the PT2 packaging format
# (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a)
@ -285,9 +287,9 @@ class AOTICompiledModel:
def load_package(path: FileLike, model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg]
assert (
isinstance(path, (io.IOBase, IO)) and path.readable() and path.seekable()
) or (
isinstance(path, (str, os.PathLike)) and os.fspath(path).endswith(".pt2")
), f"Unable to load package. Path must be a buffer or a file ending in .pt2. Instead got {path}"
) or (isinstance(path, (str, os.PathLike)) and os.fspath(path).endswith(".pt2")), (
f"Unable to load package. Path must be a buffer or a file ending in .pt2. Instead got {path}"
)
if isinstance(path, (io.IOBase, IO)):
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:

View File

@ -90,20 +90,17 @@ NodeOrConstant = Union[Constant, torch.fx.Node]
class SearchFn(Protocol):
__name__: str
def __call__(self, *args: Any, **kwargs: Any) -> Any:
...
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
class ReplaceFn(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> Any:
...
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
class TraceFn(Protocol):
def __call__(
self, fn: Union[SearchFn, ReplaceFn], *args: Any, **kwargs: Any
) -> torch.fx.GraphModule:
...
) -> torch.fx.GraphModule: ...
T = TypeVar("T")
@ -365,8 +362,7 @@ class PatternExpr(ABC):
"""
@abstractmethod
def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
...
def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: ...
def match(self, node: torch.fx.Node) -> MatchResult:
try:
@ -489,8 +485,7 @@ class _TargetExpr(PatternExpr):
@property
@abstractmethod
def op(self) -> str:
...
def op(self) -> str: ...
def fns_repr(self) -> str:
first_repr = self.fns[0]
@ -997,8 +992,9 @@ class PatternPrettyPrinter:
class _PassDictsType(Protocol):
def __getitem__(self, k: tuple[str, torch.fx.node.Target]) -> list[PatternEntry]:
...
def __getitem__(
self, k: tuple[str, torch.fx.node.Target]
) -> list[PatternEntry]: ...
@dataclasses.dataclass
@ -1925,7 +1921,10 @@ def fx_to_pattern(
get_attr = _not_implemented
def placeholder(
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
self,
target: str, # type: ignore[override]
args: Sequence[Any],
kwargs: Mapping[str, Any],
) -> Union[ExclusiveKeywordArg, KeywordArg]:
n = next(argnum)
if n < len(argnames):
@ -1942,7 +1941,10 @@ def fx_to_pattern(
return KeywordArg(name)
def call_function(
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
self,
target: str, # type: ignore[override]
args: Sequence[Any],
kwargs: Mapping[str, Any],
) -> PatternExpr:
process_arg_fn = process_arg
# Indexing is critical for matching getitem nodes, so we can't ignore int args here

View File

@ -24,7 +24,7 @@ T = TypeVar("T")
def time_and_count(
fn: Callable[Concatenate[Any, P], T]
fn: Callable[Concatenate[Any, P], T],
) -> Callable[Concatenate[Any, P], T]:
"""Wraps `fn` with `dynamo_timed` context, and increments the appropriate dynamo
counters. It is expected that `fn` is a method of `Benchmarker` or one of its

View File

@ -77,9 +77,9 @@ def validate_triton_config(cfg: Config) -> None:
# right now, if a pre-hook is attached to the config, it will not be saved;
# and then it won't be used when the config is loaded from cache.
# So we assert - if we do get a pre_hook, it might get ignored after caching.
assert (
getattr(cfg, "pre_hook", None) is None
), "triton configs with pre_hooks not supported"
assert getattr(cfg, "pre_hook", None) is None, (
"triton configs with pre_hooks not supported"
)
def create_bandwidth_info_str(

View File

@ -450,9 +450,9 @@ class CachingAutotuner(KernelInterface):
self.launchers = []
def __getstate__(self) -> dict[str, Any]:
assert (
not self.launchers
), "pickle should not be called with after make_launchers()"
assert not self.launchers, (
"pickle should not be called with after make_launchers()"
)
return {
**self.__dict__,
"lock": None,
@ -678,7 +678,9 @@ class CachingAutotuner(KernelInterface):
assert isinstance(
arg,
torch.Tensor,
), "self.reset_to_zero_arg_names should only contain valid argument names"
), (
"self.reset_to_zero_arg_names should only contain valid argument names"
)
arg.zero_()
for name, arg in kwargs.items():
@ -686,7 +688,9 @@ class CachingAutotuner(KernelInterface):
assert isinstance(
arg,
torch.Tensor,
), "self.reset_to_zero_arg_names should only contain valid argument names"
), (
"self.reset_to_zero_arg_names should only contain valid argument names"
)
arg.zero_()
def maybe_clone_args(
@ -866,7 +870,9 @@ class CachingAutotuner(KernelInterface):
assert not (
self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION
and "R0_BLOCK" in launcher.config.kwargs
), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK"
), (
"Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK"
)
start_time = time.time_ns()
best_config = self.coordesc_tuner.autotune(
benchmark_one_config, launcher.config, None
@ -882,9 +888,7 @@ class CachingAutotuner(KernelInterface):
)
return config2launcher.get(best_config)
def run(
self, *args, grid, stream, benchmark_run=False, **kwargs
): # type:ignore[override]
def run(self, *args, grid, stream, benchmark_run=False, **kwargs): # type:ignore[override]
if self.triton_interpret:
return self.fn[grid](
*args,
@ -1192,12 +1196,12 @@ class TritonCompileResult:
exec(
f"""
def launcher({', '.join(def_args)}, grid, stream):
def launcher({", ".join(def_args)}, grid, stream):
if callable(grid):
grid_0, grid_1, grid_2 = grid(grid_meta)
else:
grid_0, grid_1, grid_2 = grid
runner({', '.join(runner_args)})
runner({", ".join(runner_args)})
return bin
""".lstrip(),
scope,
@ -1503,9 +1507,9 @@ def check_max_block(cfg: dict[str, int]):
if block_suffix in var:
prefix = var.removesuffix(block_suffix)
max_block = TRITON_MAX_BLOCK[prefix]
assert (
val <= max_block
), f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
assert val <= max_block, (
f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
)
def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False):
@ -1657,20 +1661,20 @@ def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, in
prefix = f"r{idx}_"
max_size = min(size_hints[prefix], TRITON_MAX_BLOCK[prefix.upper()])
dim = min(max_size, remaining)
assert (
remaining % dim == 0
), f"Expected dimension '{dim}' to divide remaining size '{remaining}'"
assert remaining % dim == 0, (
f"Expected dimension '{dim}' to divide remaining size '{remaining}'"
)
rnumels[prefix] = dim
remaining //= dim
# Sanity check the results.
final_numel = conditional_product(*rnumels.values())
assert (
r == final_numel
), f"Expected ND reduction size ({rnumels}) to have {r} elements."
assert all(
rnumels[prefix] <= size_hints[prefix] for prefix in rnumels
), f"rnumels exceed size_hints. {rnumels} > {size_hints}"
assert r == final_numel, (
f"Expected ND reduction size ({rnumels}) to have {r} elements."
)
assert all(rnumels[prefix] <= size_hints[prefix] for prefix in rnumels), (
f"rnumels exceed size_hints. {rnumels} > {size_hints}"
)
return rnumels
@ -1967,9 +1971,9 @@ def cooperative_reduction(
size_hints["x"] = 1
# Cooperative reductions currently only support a single reduction dimension.
assert (
len(size_hints) == 2
), "Cooperative reductions don't support tiling reduction dims"
assert len(size_hints) == 2, (
"Cooperative reductions don't support tiling reduction dims"
)
xnumel, rnumel = size_hints["x"], size_hints["r0_"]
# TODO(jansel): we should base target on the SM count of the local GPU
@ -2274,9 +2278,9 @@ def grid_combo_kernels(
assert min_blocks_d is not None
min_blocks = min_blocks_d
else:
assert (
min_blocks_d is None or min_blocks == min_blocks_d
), f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}"
assert min_blocks_d is None or min_blocks == min_blocks_d, (
f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}"
)
else:
# sequential dispatch
seq_numels = list(numels)

View File

@ -200,9 +200,9 @@ class BaseSchedulerNode:
def __init__(self, scheduler: Scheduler) -> None:
self.scheduler: Scheduler = scheduler
self.debug_device_str: Callable[
[BaseSchedulerNode], list[str]
] = lambda *args, **kwargs: []
self.debug_device_str: Callable[[BaseSchedulerNode], list[str]] = (
lambda *args, **kwargs: []
)
def _init_from_node(self, node: ir.Operation) -> None:
self.node: Optional[ir.Operation] = node
@ -232,7 +232,7 @@ class BaseSchedulerNode:
buf = IndentedBuffer()
buf.splice(
f"""\
{name}: {type(self).__name__}({type(getattr(self, 'node', None)).__name__})
{name}: {type(self).__name__}({type(getattr(self, "node", None)).__name__})
{name}.writes = {pformat(self.read_writes.writes)}
{name}.unmet_dependencies = {pformat(self.unmet_dependencies)}
{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}
@ -525,9 +525,9 @@ class BaseSchedulerNode:
V.kernel.mutations.add(input_buf.get_name())
V.kernel.mutations.add(buf.get_name())
V.kernel.inplace_update_buffers[
buf.get_name()
] = input_buf.get_name()
V.kernel.inplace_update_buffers[buf.get_name()] = (
input_buf.get_name()
)
break
def codegen_originating_info(
@ -693,7 +693,7 @@ class BaseSchedulerNode:
continue
def get_buf_bytes(
buf: Optional[Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject]]
buf: Optional[Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject]],
) -> int:
if not buf:
return 0
@ -794,12 +794,11 @@ class BaseSchedulerNode:
# runtime for that today
return 0
with FakeTensorMode() as fake_mode, FlopCounterMode(
display=False
) as flop_counter_mode, V.set_current_node(
self.node.fx_node
), V.set_fake_mode(
fake_mode
with (
FakeTensorMode() as fake_mode,
FlopCounterMode(display=False) as flop_counter_mode,
V.set_current_node(self.node.fx_node),
V.set_fake_mode(fake_mode),
):
from .ir import ir_node_to_tensor
@ -1123,15 +1122,15 @@ class SchedulerNode(BaseSchedulerNode):
return self._sizes
def is_reduction(self) -> bool:
assert isinstance(
self.node, (ir.ComputedBuffer, ir.TemplateBuffer)
), f"{type(self.node)=}"
assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)), (
f"{type(self.node)=}"
)
return bool(self.node.get_reduction_type())
def is_split_scan(self) -> bool:
assert isinstance(
self.node, (ir.ComputedBuffer, ir.TemplateBuffer)
), f"{type(self.node)=}"
assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)), (
f"{type(self.node)=}"
)
return isinstance(self.node, ir.ComputedBuffer) and isinstance(
self.node.data, ir.SplitScan
)
@ -1163,9 +1162,10 @@ class SchedulerNode(BaseSchedulerNode):
def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None:
var_ranges = self.ranges_from_index_vars(index_vars)
try:
with V.set_ops_handler(
SimplifyIndexing(V.get_ops_handler(), var_ranges)
), V.kernel.set_current_node(self):
with (
V.set_ops_handler(SimplifyIndexing(V.get_ops_handler(), var_ranges)),
V.kernel.set_current_node(self),
):
self._body(*index_vars)
except Exception:
log.fatal("Error in codegen for %s", self.node)
@ -1231,7 +1231,7 @@ class SchedulerNode(BaseSchedulerNode):
def refresh_group_node_dependencies(
group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode]
group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode],
) -> None:
snodes = group_snode.snodes
group_snode.set_read_writes(
@ -1754,7 +1754,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
@staticmethod
def set_group_algorithm_for_combo_kernels(
custom_group_algorithm: Callable[[Scheduler], list[list[BaseSchedulerNode]]]
custom_group_algorithm: Callable[[Scheduler], list[list[BaseSchedulerNode]]],
) -> None:
ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels = (
custom_group_algorithm
@ -1975,9 +1975,9 @@ class Scheduler:
for node in self.nodes:
node.prune_deps()
self.name_to_donated_buffer: dict[
str, SchedulerDonatedBuffer
] = self.get_donated_buffers()
self.name_to_donated_buffer: dict[str, SchedulerDonatedBuffer] = (
self.get_donated_buffers()
)
self.name_to_node: dict[str, BaseSchedulerNode] = {
n.get_name(): n for n in self.nodes
}
@ -2099,9 +2099,9 @@ class Scheduler:
node.log_details()
def create_scheduler_node(self, node: ir.Operation) -> BaseSchedulerNode:
assert (
node.get_origins() is not None
), "All nodes passed to scheduling must have an origin"
assert node.get_origins() is not None, (
"All nodes passed to scheduling must have an origin"
)
if node.is_no_op():
return NopKernelSchedulerNode(self, node)
elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)):
@ -2260,9 +2260,9 @@ class Scheduler:
)
# if a kernel takes unbacked symints, register dependencies
for s in unbacked_symbol_uses:
assert (
s in unbacked_symbol_to_origin_node
), f"{s} not in {unbacked_symbol_to_origin_node}"
assert s in unbacked_symbol_to_origin_node, (
f"{s} not in {unbacked_symbol_to_origin_node}"
)
if (r := unbacked_symbol_to_origin_node[s]) is not None:
for buf in self.name_to_node[r].get_outputs():
node.add_fake_dep(StarDep(buf.get_name()))
@ -2310,9 +2310,9 @@ class Scheduler:
for alt_name in buf.get_mutations():
self.mutation_renames[rename(alt_name)] = buf.get_name()
self.mutation_renames[alt_name] = buf.get_name()
self.mutation_real_name[
buf.get_name()
] = self.mutation_real_name.get(alt_name, alt_name)
self.mutation_real_name[buf.get_name()] = (
self.mutation_real_name.get(alt_name, alt_name)
)
# make sure outputs aren't dead-code-eliminated
for buf_name in V.graph.get_output_names():
@ -2322,9 +2322,9 @@ class Scheduler:
# make sure unbacked symints aren't dead-code-eliminated
for out in V.graph.graph_outputs:
for s in out.get_unbacked_symbol_uses():
assert (
s in unbacked_symbol_to_origin_node
), f"{s} not in {unbacked_symbol_to_origin_node.keys()}"
assert s in unbacked_symbol_to_origin_node, (
f"{s} not in {unbacked_symbol_to_origin_node.keys()}"
)
if r := unbacked_symbol_to_origin_node[s]:
for buf_name in self.name_to_node[r].get_buffer_names():
log.debug(
@ -3304,15 +3304,15 @@ class Scheduler:
rhs_dep = node2_name2dep[buf_name]
if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep):
reasons[
buf_name
] = f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}"
reasons[buf_name] = (
f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}"
)
continue
if lhs_dep.get_numel() != rhs_dep.get_numel():
reasons[
buf_name
] = f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}"
reasons[buf_name] = (
f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}"
)
continue
# same numel but different MemoryDep.size. Should be broadcasting
@ -3340,9 +3340,9 @@ class Scheduler:
layout_str = ""
if not isinstance(buf, ir.TorchBindObject):
layout_str = f"Layout: {buf.layout}"
reasons[
buf_name
] = f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. {layout_str}"
reasons[buf_name] = (
f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. {layout_str}"
)
return str(reasons)
@ -3903,9 +3903,9 @@ class Scheduler:
self.free_buffers()
def create_backend(self, device: torch.device) -> BaseScheduling:
assert (
not is_gpu(device.type) or device.index is not None
), f"{device} should have been normalized in lowering"
assert not is_gpu(device.type) or device.index is not None, (
f"{device} should have been normalized in lowering"
)
V.graph.add_device_info(device)
device_scheduling = get_scheduling_for_device(device.type)
@ -4135,9 +4135,9 @@ class Scheduler:
partitions, signatures = self.graph_partition()
for partition, signature in zip(partitions, signatures):
assert (
len(partition) >= 1
), f"Each partition must have at least one node but found {len(partition)}"
assert len(partition) >= 1, (
f"Each partition must have at least one node but found {len(partition)}"
)
if signature.skip_cudagraph:
self._codegen(partition)

View File

@ -168,9 +168,9 @@ class PartialRender:
)
else:
return
assert (
self.replacement_hooks[hook_key] is not None
), "hook_key can only be called once"
assert self.replacement_hooks[hook_key] is not None, (
"hook_key can only be called once"
)
self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]())
self.replacement_hooks[hook_key] = None
@ -257,9 +257,9 @@ class ModificationWrapper(V.WrapperHandler): # type: ignore[name-defined]
This is used by flex_attention's backwards grad for captured buffers, see
zeros_and_scatter lowering
"""
assert (
self.mask is not None
), "Mask is required for inner stores in modifications"
assert self.mask is not None, (
"Mask is required for inner stores in modifications"
)
assert mode == "atomic_add", "Only atomic_add is supported for inner stores"
buf_name = self._add_kernel_input(name)
@ -573,12 +573,12 @@ class TritonTemplateKernel(TritonKernel):
def _get_subgraph(self, subgraph_number: int):
assert isinstance(subgraph_number, int)
assert isinstance(self.subgraphs, list)
assert subgraph_number < len(
self.subgraphs
), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}"
assert (
self.body.getvalue() == ""
), "Body should be clear before adding a modification"
assert subgraph_number < len(self.subgraphs), (
f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}"
)
assert self.body.getvalue() == "", (
"Body should be clear before adding a modification"
)
return self.subgraphs[subgraph_number]
def _handle_scatter_graph(self, scatter_graph):
@ -587,9 +587,9 @@ class TritonTemplateKernel(TritonKernel):
Args:
scatter_graph: The scatter graph to process
"""
assert isinstance(
scatter_graph, ir.ComputedBuffer
), f"scatter_graph must be an instance of ComputeBuffer but got {type(scatter_graph)}"
assert isinstance(scatter_graph, ir.ComputedBuffer), (
f"scatter_graph must be an instance of ComputeBuffer but got {type(scatter_graph)}"
)
def contiguous_strides(x):
# We always create a fresh contiguous grad for scattering into
@ -597,7 +597,9 @@ class TritonTemplateKernel(TritonKernel):
x_i * stride for x_i, stride in zip(x, scatter_graph.get_stride())
)
return scatter_graph.data.store_output(scatter_graph.name, contiguous_strides, []) # type: ignore[attr-defined]
return scatter_graph.data.store_output( # type: ignore[attr-defined]
scatter_graph.name, contiguous_strides, []
)
def modification(
self,
@ -626,9 +628,9 @@ class TritonTemplateKernel(TritonKernel):
self, subgraph_number, fixed_inputs, mask
)
with V.set_ops_handler(modification_handler):
assert isinstance(
subgraph, (ir.ComputedBuffer, list)
), f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}"
assert isinstance(subgraph, (ir.ComputedBuffer, list)), (
f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}"
)
# Handle scatter stores
if isinstance(subgraph, list):
for scatter_graph in subgraph:
@ -1123,15 +1125,17 @@ class TritonTemplate(KernelTemplate):
"subgraphs": subgraphs,
}
with patch.object(
V.graph, "get_dtype", self._fake_get_dtype(fake_out)
), V.graph.set_current_device(layout.device), TritonTemplateKernel(
kernel_name=kernel_name,
output_node=fake_out,
workspace_arg=workspace_arg,
use_jit=False,
**kernel_options,
) as kernel:
with (
patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_out)),
V.graph.set_current_device(layout.device),
TritonTemplateKernel(
kernel_name=kernel_name,
output_node=fake_out,
workspace_arg=workspace_arg,
use_jit=False,
**kernel_options,
) as kernel,
):
try:
template = kernel.render(self.template, kwargs)
with kernel.set_subgraph_body("<STORE_OUTPUT>"):
@ -1442,9 +1446,9 @@ class ExternKernelCaller(ChoiceCaller):
def output_node(self):
if self.choice.use_fallback_kernel:
assert (
self.choice.op_overload is not None
), "Please provide an op_overload to use ir.FallbackKernel"
assert self.choice.op_overload is not None, (
"Please provide an op_overload to use ir.FallbackKernel"
)
inner = ir.FallbackKernel.create(
self.choice.op_overload, *self.input_nodes, **self.kwargs
)
@ -1979,7 +1983,7 @@ class AlgorithmSelectorCache(PersistentCache):
input_gen_fns = {}
def get_inputs(
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]]
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]],
) -> AutotuneArgs:
# de-duplicate args
unique_example_inputs = {
@ -2099,7 +2103,7 @@ class AlgorithmSelectorCache(PersistentCache):
return timings
def benchmark_in_sub_process(
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]]
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]],
):
from . import autotune_process
@ -2139,7 +2143,8 @@ class AlgorithmSelectorCache(PersistentCache):
map(
str,
V.graph.sizevars.size_hints(
n.get_size(), fallback=config.unbacked_symint_fallback # type: ignore[arg-type]
n.get_size(),
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
),
)
)
@ -2313,15 +2318,15 @@ def autotune_select_algorithm(*args, **kwargs):
_ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache()
if "return_multi_template" not in kwargs:
kwargs[
"return_multi_template"
] = torch._inductor.config.benchmark_epilogue_fusion
kwargs["return_multi_template"] = (
torch._inductor.config.benchmark_epilogue_fusion
)
return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
def add_feedback_saver(
fn: Callable[[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None]
fn: Callable[[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None],
):
global _ALGORITHM_SELECTOR_CACHE
if _ALGORITHM_SELECTOR_CACHE is None:

View File

@ -905,9 +905,9 @@ class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined]
def __init__(self, inner, var_ranges: VarRanges) -> None:
super().__init__(inner)
self.name = "SimplifyIndexing"
self._simplify: Callable[
[Expr], Expr
] = lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges)
self._simplify: Callable[[Expr], Expr] = (
lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges)
)
def load(self, name: str, index: sympy.Expr):
return self._inner.load(name, self._simplify(index))

View File

@ -283,9 +283,9 @@ def ceildiv(
# TODO: There is a bug in a call to this function, to repro:
# python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy
# --amp --only YituTechConvBert --dynamic-shapes
assert isinstance(numer, int) and isinstance(
denom, int
), f"{numer}: {type(numer)}, {denom}: {type(denom)}"
assert isinstance(numer, int) and isinstance(denom, int), (
f"{numer}: {type(numer)}, {denom}: {type(denom)}"
)
return runtime_ceildiv(numer, denom)
@ -325,7 +325,7 @@ def _type_of(key: Optional[torch.dtype]) -> str:
def convert_shape_to_inductor(
lst: Iterable[Union[int, torch.SymInt]]
lst: Iterable[Union[int, torch.SymInt]],
) -> list[sympy.Expr]:
"""
Gets the shape and stride of a tensor. For non-symbolic tensors, this is
@ -502,11 +502,9 @@ RV = TypeVar("RV", covariant=True)
class CachedMethod(Protocol, Generic[P, RV]):
@staticmethod
def clear_cache(cache: Any) -> None:
...
def clear_cache(cache: Any) -> None: ...
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV:
...
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: ...
# See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature
@ -1359,9 +1357,9 @@ def _rocm_native_device_arch_name(device: str) -> str:
@functools.lru_cache(None)
def try_import_ck_lib() -> (
tuple[Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any]]
):
def try_import_ck_lib() -> tuple[
Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any]
]:
try:
import ck4inductor # type: ignore[import]
from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import]
@ -1610,9 +1608,12 @@ def get_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> list[str]:
return DummyModule()
with mock.patch.object(
GraphLowering, "compile_to_module", patched_compile_to_module
), mock.patch.object(GraphLowering, "save_output_code", save_output_code):
with (
mock.patch.object(
GraphLowering, "compile_to_module", patched_compile_to_module
),
mock.patch.object(GraphLowering, "save_output_code", save_output_code),
):
torch._dynamo.reset()
# Note the return here is None
_ = fn(*args, **kwargs)
@ -1623,18 +1624,18 @@ def get_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> list[str]:
def get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> str:
source_codes = get_code(fn, *args, **kwargs)
# Can have two outputs if backwards was eagerly compiled
assert (
1 <= len(source_codes) <= 2
), f"expected one or two code outputs got {len(source_codes)}"
assert 1 <= len(source_codes) <= 2, (
f"expected one or two code outputs got {len(source_codes)}"
)
return source_codes[0]
def run_and_get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> str:
_, source_codes = run_and_get_code(fn, *args, **kwargs)
# Can have two outputs if backwards was eagerly compiled
assert (
1 <= len(source_codes) <= 2
), f"expected one or two code outputs got {len(source_codes)}"
assert 1 <= len(source_codes) <= 2, (
f"expected one or two code outputs got {len(source_codes)}"
)
return source_codes[0]
@ -1760,9 +1761,9 @@ def is_cpu_device(inputs: Sequence[torch.Tensor]) -> bool:
def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype:
assert isinstance(
val, sympy.Expr
), "only support sympy.Expr as input to get_sympy_Expr_dtype"
assert isinstance(val, sympy.Expr), (
"only support sympy.Expr as input to get_sympy_Expr_dtype"
)
if val.is_integer: # type: ignore[attr-defined]
return torch.int64
else:
@ -1932,7 +1933,7 @@ def is_multi_outputs_template(input_buf: Optional[Union[Buffer, Operation]]) ->
def is_output_of_multi_outputs_template(
input_buf: Optional[Union[Buffer, Operation]]
input_buf: Optional[Union[Buffer, Operation]],
) -> bool:
"""
Check if input buffer is a output of multi-outputs template buffer
@ -2633,7 +2634,8 @@ def set_kernel_post_grad_provenance_tracing(
if node not in (EnableReduction, DisableReduction):
if node.node is not None:
V.debug._inductor_triton_kernel_to_post_grad_node_info[kernel_name] = [
origin.name for origin in node.node.origins # type: ignore[attr-defined]
origin.name
for origin in node.node.origins # type: ignore[attr-defined]
]

View File

@ -314,9 +314,9 @@ class _V:
KernelFormatterHandler = KernelFormatterHandler
WrapperHandler = WrapperHandler
set_ops_handler: Callable[
[OpsHandler[Any]], AbstractContextManager[None]
] = _ops._set_handler
set_ops_handler: Callable[[OpsHandler[Any]], AbstractContextManager[None]] = (
_ops._set_handler
)
get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler
set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler
set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler

View File

@ -14,8 +14,7 @@ from .runtime.runtime_utils import create_bandwidth_info_str, get_num_bytes
class BenchmarkCallableType(Protocol):
def __call__(self, times: int, repeat: int) -> float:
...
def __call__(self, times: int, repeat: int) -> float: ...
_kernel_category_choices = [
@ -138,9 +137,9 @@ def benchmark_all_kernels(
)
else:
ms = benchmarker.benchmark_gpu(lambda: kernel_mod.call(args), rep=40)
assert (
len(triton_kernel.launchers) == 1
), "Autotuner should have selected the best config"
assert len(triton_kernel.launchers) == 1, (
"Autotuner should have selected the best config"
)
launcher = triton_kernel.launchers[0]
print(
get_info_str(
@ -256,9 +255,9 @@ def parse_profile_event_list(
"triton_unknown",
"unknown",
]
assert OrderedSet(all_events.keys()).issubset(
OrderedSet(category_list)
), f"{list(all_events.keys())}"
assert OrderedSet(all_events.keys()).issubset(OrderedSet(category_list)), (
f"{list(all_events.keys())}"
)
per_category_wall_time = {}
total_device_ms = 0.0