[BE][PYFMT] migrate PYFMT for torch._inductor to ruff format (#144550)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144550
Approved by: https://github.com/jansel
This commit is contained in:
Xuehai Pan 2025-02-28 15:35:13 +08:00 committed by PyTorch MergeBot
parent 34d726011f
commit 1cb4e2df65
88 changed files with 1157 additions and 954 deletions

View File

@ -53,7 +53,6 @@ USE_BLACK_FILELIST = re.compile(
# torch/_[e-h]*/** # torch/_[e-h]*/**
"torch/_[e-h]*/**", "torch/_[e-h]*/**",
# torch/_i*/** # torch/_i*/**
"torch/_i*/**",
# torch/_[j-z]*/** # torch/_[j-z]*/**
"torch/_[j-z]*/**", "torch/_[j-z]*/**",
# torch/[a-c]*/** # torch/[a-c]*/**

View File

@ -66,7 +66,9 @@ def aoti_compile_and_package(
.. code-block:: python .. code-block:: python
ep = torch.export.export(M(), ...) ep = torch.export.export(M(), ...)
aoti_file = torch._inductor.aoti_compile_and_package(ep, package_path="my_package.pt2") aoti_file = torch._inductor.aoti_compile_and_package(
ep, package_path="my_package.pt2"
)
compiled_model = torch._inductor.aoti_load_package("my_package.pt2") compiled_model = torch._inductor.aoti_load_package("my_package.pt2")
To compile and save multiple models into a single ``.pt2`` artifact, you can do To compile and save multiple models into a single ``.pt2`` artifact, you can do
@ -75,11 +77,16 @@ def aoti_compile_and_package(
.. code-block:: python .. code-block:: python
ep1 = torch.export.export(M1(), ...) ep1 = torch.export.export(M1(), ...)
aoti_file1 = torch._inductor.aot_compile(ep1, ..., options={"aot_inductor.package": True}) aoti_file1 = torch._inductor.aot_compile(
ep1, ..., options={"aot_inductor.package": True}
)
ep2 = torch.export.export(M2(), ...) ep2 = torch.export.export(M2(), ...)
aoti_file2 = torch._inductor.aot_compile(ep2, ..., options={"aot_inductor.package": True}) aoti_file2 = torch._inductor.aot_compile(
ep2, ..., options={"aot_inductor.package": True}
)
from torch._inductor.package import package_aoti, load_package from torch._inductor.package import package_aoti, load_package
package_aoti("my_package.pt2", {"model1": aoti_file1, "model2": aoti_file2}) package_aoti("my_package.pt2", {"model1": aoti_file1, "model2": aoti_file2})
compiled_model1 = load_package("my_package.pt2", "model1") compiled_model1 = load_package("my_package.pt2", "model1")
@ -123,7 +130,9 @@ def aoti_compile_and_package(
isinstance(package_path, (str, os.PathLike)) isinstance(package_path, (str, os.PathLike))
and os.fspath(package_path).endswith(".pt2") and os.fspath(package_path).endswith(".pt2")
) )
), f"Expect package path to be a file ending in .pt2, is None, or is a buffer. Instead got {package_path}" ), (
f"Expect package path to be a file ending in .pt2, is None, or is a buffer. Instead got {package_path}"
)
inductor_configs = inductor_configs or {} inductor_configs = inductor_configs or {}
inductor_configs["aot_inductor.package"] = True inductor_configs["aot_inductor.package"] = True
@ -168,9 +177,9 @@ def _aoti_compile_and_package_inner(
""" """
if check_accuracy: if check_accuracy:
assert ( assert kwargs is None or len(kwargs) == 0, (
kwargs is None or len(kwargs) == 0 "when checking for accuracy, the inputs must have been flattened and kwargs is None"
), "when checking for accuracy, the inputs must have been flattened and kwargs is None" )
from .package import package_aoti from .package import package_aoti

View File

@ -156,8 +156,9 @@ def can_codegen_without_upcasts(
low_prec_analysis = RecordLowPrecisionOps(disallow_fp32_ops) low_prec_analysis = RecordLowPrecisionOps(disallow_fp32_ops)
# Need to turn off upcasting to do analysis of whether we can turn it off # Need to turn off upcasting to do analysis of whether we can turn it off
with config.patch("triton.codegen_upcast_to_fp32", False), V.set_ops_handler( with (
low_prec_analysis config.patch("triton.codegen_upcast_to_fp32", False),
V.set_ops_handler(low_prec_analysis),
): ):
prologue._body(*prologue.get_ranges()) prologue._body(*prologue.get_ranges())

View File

@ -245,8 +245,7 @@ class AsyncCompile:
def use_process_pool(self): def use_process_pool(self):
return ( return (
get_compile_threads() > 1 get_compile_threads() > 1 and self.process_pool().ready_future.done() # type: ignore[union-attr]
and self.process_pool().ready_future.done() # type: ignore[union-attr]
) )
def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"): def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):

View File

@ -24,8 +24,7 @@ if TYPE_CHECKING:
class Sortable(typing.Protocol): class Sortable(typing.Protocol):
"""Anything that can be used as a list.sort() key (int/tuple/etc)""" """Anything that can be used as a list.sort() key (int/tuple/etc)"""
def __lt__(self, other: typing.Self) -> bool: def __lt__(self, other: typing.Self) -> bool: ...
...
class InductorChoices: class InductorChoices:
@ -100,7 +99,9 @@ class InductorChoices:
# to pick the faster one. # to pick the faster one.
if config.triton.multi_kernel: if config.triton.multi_kernel:
threshold *= 16 threshold *= 16
return V.graph.sizevars.statically_known_leq(features.reduction_numel, threshold) # type: ignore[arg-types] return V.graph.sizevars.statically_known_leq(
features.reduction_numel, threshold
) # type: ignore[arg-types]
@staticmethod @staticmethod
def want_no_x_dim(features: SIMDKernelFeatures) -> bool: def want_no_x_dim(features: SIMDKernelFeatures) -> bool:

View File

@ -417,9 +417,9 @@ def write_atomic(
) -> None: ) -> None:
# Write into temporary file first to avoid conflicts between threads # Write into temporary file first to avoid conflicts between threads
# Avoid using a named temporary file, as those have restricted permissions # Avoid using a named temporary file, as those have restricted permissions
assert isinstance( assert isinstance(content, (str, bytes)), (
content, (str, bytes) "Only strings and byte arrays can be saved in the cache"
), "Only strings and byte arrays can be saved in the cache" )
path = Path(path_) path = Path(path_)
if make_dirs: if make_dirs:
path.parent.mkdir(parents=True, exist_ok=True) path.parent.mkdir(parents=True, exist_ok=True)
@ -975,9 +975,9 @@ class FxGraphCache:
symints = FxGraphCache._filter_backed_symints(example_inputs) symints = FxGraphCache._filter_backed_symints(example_inputs)
hints = [hint_int(s) for s in symints] hints = [hint_int(s) for s in symints]
def iterate_over_candidates() -> ( def iterate_over_candidates() -> Generator[
Generator[tuple[CompiledFxGraph, bytes], None, None] tuple[CompiledFxGraph, bytes], None, None
): ]:
if local: if local:
subdir = FxGraphCache._get_tmp_dir_for_key(key) subdir = FxGraphCache._get_tmp_dir_for_key(key)
if os.path.exists(subdir): if os.path.exists(subdir):
@ -1123,9 +1123,9 @@ class FxGraphCache:
""" """
from .compile_fx import CompiledFxGraph from .compile_fx import CompiledFxGraph
assert isinstance( assert isinstance(compiled_graph, CompiledFxGraph), (
compiled_graph, CompiledFxGraph f"serialization for {type(compiled_graph)} NYI"
), f"serialization for {type(compiled_graph)} NYI" )
disk_compiled_graph = copy(compiled_graph) disk_compiled_graph = copy(compiled_graph)
disk_compiled_graph.prepare_for_serialization() disk_compiled_graph.prepare_for_serialization()
@ -1315,9 +1315,8 @@ class FxGraphCache:
"distributed_ephemeral_timeout_us", time_saved_ns // 1000 "distributed_ephemeral_timeout_us", time_saved_ns // 1000
) )
if ( if (
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed( ephemeral_increase
time_saved_ns := add_ephemeral_timeout_increase_for_distributed(time_saved_ns)
)
) != 0: ) != 0:
cache_info["ephemeral_timeout_increase"] = ephemeral_increase cache_info["ephemeral_timeout_increase"] = ephemeral_increase
else: else:
@ -1556,9 +1555,9 @@ class AotCodeCompiler:
cpp_path_operator.with_name(f"{cpp_path_operator.stem}_metadata.json") cpp_path_operator.with_name(f"{cpp_path_operator.stem}_metadata.json")
) )
for k, v in config.aot_inductor.metadata.items(): for k, v in config.aot_inductor.metadata.items():
assert isinstance(k, str) and isinstance( assert isinstance(k, str) and isinstance(v, (str)), (
v, (str) "Metadata must only contain strings"
), "Metadata must only contain strings" )
with open(meta_json, "w") as f: with open(meta_json, "w") as f:
f.write(json.dumps(config.aot_inductor.metadata)) f.write(json.dumps(config.aot_inductor.metadata))

View File

@ -341,7 +341,7 @@ class BackendFeature(Enum):
def get_backend_features( def get_backend_features(
device: Union[torch.device, str, None] device: Union[torch.device, str, None],
) -> OrderedSet[BackendFeature]: ) -> OrderedSet[BackendFeature]:
if device is None: if device is None:
return OrderedSet() return OrderedSet()
@ -986,9 +986,9 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
if cls._is_unimplemented(funcname): if cls._is_unimplemented(funcname):
setattr(cls, funcname, cls._unimplemented(funcname)) setattr(cls, funcname, cls._unimplemented(funcname))
else: else:
assert ( assert funcname not in cls.__dict__, (
funcname not in cls.__dict__ f"multiple definitions of {funcname} on {cls.__name__}"
), f"multiple definitions of {funcname} on {cls.__name__}" )
impl.__name__ = funcname impl.__name__ = funcname
setattr(cls, funcname, staticmethod(impl)) setattr(cls, funcname, staticmethod(impl))
@ -2229,7 +2229,7 @@ class KernelTemplate:
@staticmethod @staticmethod
def _fake_get_dtype( def _fake_get_dtype(
fake_outs: Union[list[Buffer], Buffer] fake_outs: Union[list[Buffer], Buffer],
) -> Callable[[str], torch.dtype]: ) -> Callable[[str], torch.dtype]:
_get_dtype_real = V.graph.get_dtype _get_dtype_real = V.graph.get_dtype
if isinstance(fake_outs, (list, tuple)): if isinstance(fake_outs, (list, tuple)):

View File

@ -483,9 +483,9 @@ class OuterLoopFusedSchedulerNode(FusedSchedulerNode):
outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]], outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]],
outer_loop_fusion_depth, outer_loop_fusion_depth,
): ):
self.outer_fused_nodes: list[ self.outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]] = (
Union[FusedSchedulerNode, SchedulerNode] outer_fused_nodes
] = outer_fused_nodes )
self.outer_loop_fusion_depth = outer_loop_fusion_depth self.outer_loop_fusion_depth = outer_loop_fusion_depth
flatten_snodes = [] flatten_snodes = []
for _node in self.outer_fused_nodes: for _node in self.outer_fused_nodes:
@ -1361,9 +1361,9 @@ class CppVecOverrides(CppOverrides):
@staticmethod @staticmethod
def remainder(a, b): def remainder(a, b):
assert ( assert a.dtype == b.dtype, (
a.dtype == b.dtype "remainder vec implementation expect the same inputs' dtype."
), "remainder vec implementation expect the same inputs' dtype." )
return f"{a} - ({CppVecOverrides.floordiv(a, b)}) * {b}" return f"{a} - ({CppVecOverrides.floordiv(a, b)}) * {b}"
@staticmethod @staticmethod
@ -1468,9 +1468,9 @@ class CppVecOverrides(CppOverrides):
@staticmethod @staticmethod
def floordiv(a, b): def floordiv(a, b):
if is_float_dtype(a.dtype): if is_float_dtype(a.dtype):
assert ( assert a.dtype == b.dtype, (
a.dtype == b.dtype "div_floor_floating_vec implementation expect the same inputs' dtype."
), "div_floor_floating_vec implementation expect the same inputs' dtype." )
return f"div_floor_floating_vec({a}, {b})" return f"div_floor_floating_vec({a}, {b})"
else: else:
assert all(is_integer_dtype(item.dtype) for item in [a, b]) assert all(is_integer_dtype(item.dtype) for item in [a, b])
@ -1629,9 +1629,9 @@ class CppVecOverrides(CppOverrides):
assert isinstance(other_vec_var, CppCSEVariable), other_vec_var assert isinstance(other_vec_var, CppCSEVariable), other_vec_var
body_vec_var.dtype = dtype body_vec_var.dtype = dtype
other_vec_var.dtype = dtype other_vec_var.dtype = dtype
overrides: type[ overrides: type[Union[CppOverrides, CppVecOverrides]] = (
Union[CppOverrides, CppVecOverrides] V.kernel.overrides
] = V.kernel.overrides # type: ignore[has-type] ) # type: ignore[has-type]
code.writeline( code.writeline(
f"return {overrides.where(new_mask, body_vec_var, other_vec_var)};" f"return {overrides.where(new_mask, body_vec_var, other_vec_var)};"
) )
@ -2108,9 +2108,9 @@ class CppKernel(Kernel):
def set_ranges(self, lengths, reduction_lengths): def set_ranges(self, lengths, reduction_lengths):
if self.call_ranges: if self.call_ranges:
assert self.call_ranges == tuple(lengths) + tuple( assert self.call_ranges == tuple(lengths) + tuple(reduction_lengths), (
reduction_lengths f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}"
), f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}" )
assert self.reduction_depth == len(lengths) assert self.reduction_depth == len(lengths)
else: else:
self.call_ranges = tuple(lengths) + tuple(reduction_lengths) self.call_ranges = tuple(lengths) + tuple(reduction_lengths)
@ -2787,9 +2787,9 @@ class CppVecKernel(CppKernel):
self.weight_recps_val = self.weight_recps_cse.generate( self.weight_recps_val = self.weight_recps_cse.generate(
self.compute, f"reduction {self.weight_recp_vec_range}", write=False self.compute, f"reduction {self.weight_recp_vec_range}", write=False
) )
self.weight_recps_cse.reduction_cache[ self.weight_recps_cse.reduction_cache[self.weight_recp_vec_range] = (
self.weight_recp_vec_range self.weight_recps_val
] = self.weight_recps_val )
self.non_parallel_reduction_prefix.writeline( self.non_parallel_reduction_prefix.writeline(
self.welford_weight_reciprocal_vec(dtype) self.welford_weight_reciprocal_vec(dtype)
) )
@ -4969,9 +4969,9 @@ class CppScheduling(BaseScheduling):
] ]
counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes) counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes)
assert self.is_cpp_template( assert self.is_cpp_template(template_node), (
template_node "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer"
), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer" )
template_node = cast(SchedulerNode, template_node) template_node = cast(SchedulerNode, template_node)
_, (_, rnumel) = template_node.group _, (_, rnumel) = template_node.group
assert rnumel == () assert rnumel == ()
@ -4979,9 +4979,9 @@ class CppScheduling(BaseScheduling):
epilogue_ir_nodes: list[Optional[ir.Operation]] = [ epilogue_ir_nodes: list[Optional[ir.Operation]] = [
n.node for n in epilogue_nodes n.node for n in epilogue_nodes
] ]
assert all( assert all(isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes), (
isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes "Epilogue nodes must all be instances of ir.ComputedBuffer"
), "Epilogue nodes must all be instances of ir.ComputedBuffer" )
def template_buffer_has_other_users( def template_buffer_has_other_users(
template_buffer, outputs_by_name, epilogue_nodes template_buffer, outputs_by_name, epilogue_nodes
@ -5019,16 +5019,16 @@ class CppScheduling(BaseScheduling):
if is_multi_outputs_template(template_node.node): if is_multi_outputs_template(template_node.node):
# For multi outputs template, allocate buffers for each output after the epilogue # For multi outputs template, allocate buffers for each output after the epilogue
# codegen to which determines if the buffer has been removed. # codegen to which determines if the buffer has been removed.
assert ( assert len(template_node.outputs) == 1, (
len(template_node.outputs) == 1 "Multi outputs template should be with 1 output template buffer of MultiOutputLayout"
), "Multi outputs template should be with 1 output template buffer of MultiOutputLayout" )
for user in template_node.outputs[0].users: for user in template_node.outputs[0].users:
assert isinstance( assert isinstance(user.node, ExternKernelSchedulerNode), (
user.node, ExternKernelSchedulerNode "Multi outputs template should be with ExternKernelSchedulerNode"
), "Multi outputs template should be with ExternKernelSchedulerNode" )
assert isinstance( assert isinstance(user.node.node, ir.MultiOutput), (
user.node.node, ir.MultiOutput "Multi outputs template has multi users with MultiOutput"
), "Multi outputs template has multi users with MultiOutput" )
user.node.mark_run() user.node.mark_run()
kernel.call_kernel(kernel_name, ctb) kernel.call_kernel(kernel_name, ctb)
@ -5347,9 +5347,9 @@ class LoopNest:
return self.loops is not None and self.loops[0].is_reduction return self.loops is not None and self.loops[0].is_reduction
def mark_parallel(self, par_depth): def mark_parallel(self, par_depth):
assert ( assert par_depth <= self.max_parallel_depth(), (
par_depth <= self.max_parallel_depth() "Parallel depth cannot exceed the maximal allowed parallel depth"
), "Parallel depth cannot exceed the maximal allowed parallel depth" )
assert self.loops is not None assert self.loops is not None
assert len(self.loops) >= par_depth assert len(self.loops) >= par_depth
loop = self.loops[0] loop = self.loops[0]

View File

@ -862,7 +862,9 @@ class CppFlexAttentionTemplate(CppTemplate):
assert all( assert all(
mem.buffer_name in kernel_group.args.input_buffers mem.buffer_name in kernel_group.args.input_buffers
for mem in body.memory_usage[MemoryUsageType.LOAD] for mem in body.memory_usage[MemoryUsageType.LOAD]
), "All the buffers in the score and mask subgraph should be in kernel_group.args.input_buffers" ), (
"All the buffers in the score and mask subgraph should be in kernel_group.args.input_buffers"
)
bodies.append(body) bodies.append(body)
var_sizes_list.append((var_sizes, ())) var_sizes_list.append((var_sizes, ()))

View File

@ -557,9 +557,9 @@ class CppGemmTemplate(CppTemplate):
thread_block_m = math.ceil(m_blocks / m_factor) thread_block_m = math.ceil(m_blocks / m_factor)
return GemmBlocking(thread_block_m, thread_block_n, thread_block_k) return GemmBlocking(thread_block_m, thread_block_n, thread_block_k)
assert ( assert not self.is_dynamic_M, (
not self.is_dynamic_M "Unable to determine thread blocking for dynamic M."
), "Unable to determine thread blocking for dynamic M." )
register_blocking = self.register_blocking register_blocking = self.register_blocking
m_blocks = math.ceil(self.m / register_blocking.block_m) m_blocks = math.ceil(self.m / register_blocking.block_m)
n_blocks = math.ceil(self.n / register_blocking.block_n) n_blocks = math.ceil(self.n / register_blocking.block_n)
@ -673,17 +673,17 @@ class CppGemmTemplate(CppTemplate):
L1_cache_size = ( L1_cache_size = (
torch._C._cpu._L1d_cache_size() torch._C._cpu._L1d_cache_size()
) # per core cache size in Bytes ) # per core cache size in Bytes
assert ( assert L1_cache_size > 0, (
L1_cache_size > 0 f"Expect L1_cache_size > 0 but got {L1_cache_size}"
), f"Expect L1_cache_size > 0 but got {L1_cache_size}" )
L1 = L1_cache_size * L1_limit_factor L1 = L1_cache_size * L1_limit_factor
L2_cache_size = ( L2_cache_size = (
torch._C._cpu._L2_cache_size() torch._C._cpu._L2_cache_size()
) # per core cache size in Bytes ) # per core cache size in Bytes
assert ( assert L2_cache_size > 0, (
L2_cache_size > 0 f"Expect L2_cache_size > 0 but got {L2_cache_size}"
), f"Expect L2_cache_size > 0 but got {L2_cache_size}" )
L2 = L2_cache_size * L2_limit_factor L2 = L2_cache_size * L2_limit_factor
def get_num_byte(dtype): def get_num_byte(dtype):
@ -744,9 +744,9 @@ class CppGemmTemplate(CppTemplate):
return Mc_blocks, Nc_blocks, Kc_blocks return Mc_blocks, Nc_blocks, Kc_blocks
assert ( assert not self.is_dynamic_M, (
not self.is_dynamic_M "Unable to determine cache blocking for dynamic M."
), "Unable to determine cache blocking for dynamic M." )
register_blocking = self.register_blocking register_blocking = self.register_blocking
thread_blocking = self.thread_blocking(num_threads) thread_blocking = self.thread_blocking(num_threads)
@ -1114,9 +1114,9 @@ class CppGemmTemplate(CppTemplate):
LayoutType.VNNI4, LayoutType.VNNI4,
], f"We only support {layout_str} for now" ], f"We only support {layout_str} for now"
vnni_size = 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2 vnni_size = 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2
assert ( assert k % vnni_size == 0, (
k % vnni_size == 0 f"k should be divisible by vnni_size for {layout_str} layout"
), f"k should be divisible by vnni_size for {layout_str} layout" )
vnni_view_size = list(new_size) vnni_view_size = list(new_size)
vnni_view_size[-2] = k // vnni_size vnni_view_size[-2] = k // vnni_size
vnni_view_size.insert(-1, vnni_size) vnni_view_size.insert(-1, vnni_size)

View File

@ -309,9 +309,9 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
for W_node in W_nodes: for W_node in W_nodes:
assert W_node.get_name() in V.graph.constants assert W_node.get_name() in V.graph.constants
W_tensor.append(V.graph.constants[W_node.get_name()]) W_tensor.append(V.graph.constants[W_node.get_name()])
new_input_nodes[ new_input_nodes[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = (
wgt_start_idx : wgt_start_idx + gemm_grouped_num W_tensor # type: ignore[assignment]
] = W_tensor # type: ignore[assignment] )
new_input_nodes, _ = pack_weight( new_input_nodes, _ = pack_weight(
*normalize_shapes(*maybe_to_dense(new_input_nodes, layout)) *normalize_shapes(*maybe_to_dense(new_input_nodes, layout))
) )
@ -321,9 +321,9 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
W_packed = new_input_nodes[idx] W_packed = new_input_nodes[idx]
assert isinstance(W_packed, torch.Tensor) assert isinstance(W_packed, torch.Tensor)
W_packed_constant = V.graph.add_tensor_constant(W_packed) W_packed_constant = V.graph.add_tensor_constant(W_packed)
template_buffer.inputs[ template_buffer.inputs[idx] = (
idx ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
] = ir.InputsKernel.unwrap_storage_for_input(W_packed_constant) )
return output return output
template = DataProcessorTemplateWrapper( template = DataProcessorTemplateWrapper(
@ -419,9 +419,9 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
ir.Buffer(name=gemm_output_name, layout=template_buffer.layout) ir.Buffer(name=gemm_output_name, layout=template_buffer.layout)
) )
assert ( assert not self.epilogue_creator, (
not self.epilogue_creator "epilogue_creator is not supported yet in Grouped GEMM Template"
), "epilogue_creator is not supported yet in Grouped GEMM Template" )
kernel_args: dict[str, Optional[ir.IRNode]] = {} kernel_args: dict[str, Optional[ir.IRNode]] = {}
for x_idx in range(wgt_start_idx): for x_idx in range(wgt_start_idx):

View File

@ -231,9 +231,9 @@ micro_gemm_configs: dict[type[CppMicroGemm], list[CppMicroGemmConfig]] = {}
def register_micro_gemm(*configs): def register_micro_gemm(*configs):
def inner(cls): def inner(cls):
assert ( assert cls not in micro_gemm_configs, (
cls not in micro_gemm_configs f"Duplicate micro_gemm registration for {cls}"
), f"Duplicate micro_gemm registration for {cls}" )
assert len(configs) > 0, f"No micro_gemm configs provided for {cls}" assert len(configs) > 0, f"No micro_gemm configs provided for {cls}"
micro_gemm_configs[cls] = list(configs) micro_gemm_configs[cls] = list(configs)
return cls return cls

View File

@ -44,11 +44,13 @@ class CppTemplate(KernelTemplate):
def generate(self, **kwargs): def generate(self, **kwargs):
kernel_name = f"cpp_{self.name}" kernel_name = f"cpp_{self.name}"
with patch.object( with (
V.graph, "get_dtype", self._fake_get_dtype(self.output_node) patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)),
), patch.object(ir.FlexibleLayout, "allow_indexing", True), CppTemplateKernel( patch.object(ir.FlexibleLayout, "allow_indexing", True),
kernel_name=kernel_name, num_threads=self.num_threads CppTemplateKernel(
) as kernel: kernel_name=kernel_name, num_threads=self.num_threads
) as kernel,
):
code = kernel.render(self, **kwargs) code = kernel.render(self, **kwargs)
_, call_args, _, _ = kernel.args.python_argdefs() _, call_args, _, _ = kernel.args.python_argdefs()
log.debug("Generated Code:\n%s", code) log.debug("Generated Code:\n%s", code)

View File

@ -377,7 +377,10 @@ class CppTemplateKernel(CppKernel):
) )
epilogue_nodes = scope.localize_nodes(epilogue_nodes) epilogue_nodes = scope.localize_nodes(epilogue_nodes)
return self.store_pointwise_nodes( return self.store_pointwise_nodes(
dst, epilogue_nodes, offsets, reindexers # type: ignore[arg-type] dst,
epilogue_nodes, # type: ignore[arg-type]
offsets,
reindexers,
) )
else: else:
if dst.get_name() != src.get_name(): if dst.get_name() != src.get_name():

View File

@ -110,9 +110,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
Only valid when cuda == True. Only valid when cuda == True.
""" """
assert not gpu, "CppWrapperCpu.generate_kernel_call does not support GPU" assert not gpu, "CppWrapperCpu.generate_kernel_call does not support GPU"
assert arg_types is not None and len(call_args) == len( assert arg_types is not None and len(call_args) == len(arg_types), (
arg_types "Mismatch call_args and arg_types in generate_kernel_call"
), "Mismatch call_args and arg_types in generate_kernel_call" )
new_args = [] new_args = []
for idx, arg in enumerate(call_args): for idx, arg in enumerate(call_args):
if "*" in arg_types[idx]: if "*" in arg_types[idx]:
@ -506,9 +506,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
dtype = may_get_constant_buffer_dtype( dtype = may_get_constant_buffer_dtype(
V.graph.graph_inputs[input_key] # type: ignore[arg-type] V.graph.graph_inputs[input_key] # type: ignore[arg-type]
) )
assert ( assert dtype is not None, (
dtype is not None "Fails to get the dtype of the sympy.Expr"
), "Fails to get the dtype of the sympy.Expr" )
self.codegen_tensor_item( self.codegen_tensor_item(
dtype, f"inputs[{idx}]", input_key, self.prefix dtype, f"inputs[{idx}]", input_key, self.prefix
) )
@ -555,8 +555,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
def codegen_tensor_dtype_var_decl(self, code: IndentedBuffer, name): def codegen_tensor_dtype_var_decl(self, code: IndentedBuffer, name):
code.writeline(f"int32_t {name}_dtype;") code.writeline(f"int32_t {name}_dtype;")
code.writeline( code.writeline(
"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype" f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype({name}, &{name}_dtype));"
f"({name}, &{name}_dtype));"
) )
def codegen_input_size_var_decl(self, code: IndentedBuffer, name): def codegen_input_size_var_decl(self, code: IndentedBuffer, name):
@ -570,9 +569,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
# Tell compiler we need to link with the non-mangled symbols # Tell compiler we need to link with the non-mangled symbols
for kernel in self.initialized_kernels.values(): for kernel in self.initialized_kernels.values():
assert hasattr( assert hasattr(kernel, "get_signature"), (
kernel, "get_signature" f"{kernel} must have get_signature implemented"
), f"{kernel} must have get_signature implemented" )
signature = kernel.get_signature() signature = kernel.get_signature()
self.prefix.writeline(f'extern "C" {signature};') self.prefix.writeline(f'extern "C" {signature};')
@ -597,9 +596,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
) )
) )
for name, kernel in self.initialized_kernels.items(): for name, kernel in self.initialized_kernels.items():
assert hasattr( assert hasattr(kernel, "get_signature"), (
kernel, "get_signature" f"{kernel} must have get_signature implemented"
), f"{kernel} must have get_signature implemented" )
kernel_ptr = f"(*{name})" kernel_ptr = f"(*{name})"
signature = kernel.get_signature().replace(name, kernel_ptr) signature = kernel.get_signature().replace(name, kernel_ptr)
self.prefix.writeline(f" {signature} = torch::aot_inductor::{name};") self.prefix.writeline(f" {signature} = torch::aot_inductor::{name};")
@ -645,9 +644,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
with self.prefix.indent(): with self.prefix.indent():
for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()): for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()):
assert not isinstance( assert not isinstance(inp, sympy.Expr), (
inp, sympy.Expr f"input {name=} cannot be symbolic"
), f"input {name=} cannot be symbolic" )
self.write_input_output_info("inputs_info_", idx, name) self.write_input_output_info("inputs_info_", idx, name)
all_cuda = all( all_cuda = all(
@ -718,9 +717,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
opaque_metadata_tensor = torch.ops.mkldnn._get_mkldnn_serialized_md( opaque_metadata_tensor = torch.ops.mkldnn._get_mkldnn_serialized_md(
tensor tensor
) )
assert ( assert opaque_metadata_tensor.dim() == 1, (
opaque_metadata_tensor.dim() == 1 "Expect opaque_metadata_tensor to be 1-D"
), "Expect opaque_metadata_tensor to be 1-D" )
opaque_metadata_list = opaque_metadata_tensor.tolist() opaque_metadata_list = opaque_metadata_tensor.tolist()
opaque_metadata_str = self.codegen_shape_tuple(opaque_metadata_list) opaque_metadata_str = self.codegen_shape_tuple(opaque_metadata_list)
@ -757,9 +756,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
) )
for idx, output in enumerate(V.graph.graph_outputs): for idx, output in enumerate(V.graph.graph_outputs):
assert not isinstance( assert not isinstance(output, sympy.Expr), (
output, sympy.Expr f"output {name=} cannot be symbolic"
), f"output {name=} cannot be symbolic" )
name = f"output{idx}" name = f"output{idx}"
self.write_input_output_info("outputs_info_", idx, name) self.write_input_output_info("outputs_info_", idx, name)
@ -816,9 +815,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
for idx, (name, _) in enumerate(V.graph.constants.items()): for idx, (name, _) in enumerate(V.graph.constants.items()):
if name in V.graph.const_output_index: if name in V.graph.const_output_index:
const_index_mapping[V.graph.const_output_index[name]] = (idx, name) # type: ignore[call-overload] const_index_mapping[V.graph.const_output_index[name]] = (idx, name) # type: ignore[call-overload]
assert ( assert None not in const_index_mapping, (
None not in const_index_mapping "Not all constant gets mapped for constant folding graph."
), "Not all constant gets mapped for constant folding graph." )
self.prefix.writeline( self.prefix.writeline(
f""" f"""
@ -1117,9 +1116,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
name = f"{output.get_name()}" name = f"{output.get_name()}"
output_handle_name = f"{name}_handle" output_handle_name = f"{name}_handle"
if output.indices: if output.indices:
assert ( assert output.indices[0][1] == idx, (
output.indices[0][1] == idx f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}"
), f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}" )
self.writeline(f"AtenTensorHandle {output_handle_name};") self.writeline(f"AtenTensorHandle {output_handle_name};")
output_args.append(f"&{output_handle_name}") output_args.append(f"&{output_handle_name}")
output_raii_handles.append( output_raii_handles.append(
@ -1140,7 +1139,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
args = args + output_args args = args + output_args
device = d.type if (d := fallback_kernel.get_device()) else self.device device = d.type if (d := fallback_kernel.get_device()) else self.device
self.generate_c_shim_extern_kernel_call( self.generate_c_shim_extern_kernel_call(
fallback_kernel.cpp_kernel_name, args, device # type: ignore[arg-type] fallback_kernel.cpp_kernel_name, # type: ignore[arg-type]
args,
device,
) )
for raii_handle in output_raii_handles: for raii_handle in output_raii_handles:
self.writeline(raii_handle) self.writeline(raii_handle)
@ -1189,9 +1190,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
if reduce: if reduce:
line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}" line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
else: else:
assert ( assert reduce is None, (
reduce is None "Expect reduce to be None for aten.scatter_ with scalar src"
), "Expect reduce to be None for aten.scatter_ with scalar src" )
line += ");" line += ");"
self.writeline(line) self.writeline(line)
@ -1841,18 +1842,24 @@ class CppWrapperCpu(PythonWrapperCodegen):
# Only treat int Scalar as dynamic # Only treat int Scalar as dynamic
is_int_type = [isinstance(a, int) for a in arg] is_int_type = [isinstance(a, int) for a in arg]
if any(is_int_type): if any(is_int_type):
assert all( assert all(is_int_type), (
is_int_type "AOTInductor only supports int scalars of the same type"
), "AOTInductor only supports int scalars of the same type" )
new_int_args.extend([str(a) for a in arg]) new_int_args.extend([str(a) for a in arg])
else: else:
assert isinstance( assert isinstance(
arg_type.getElementType(), static_arg_types # type: ignore[arg-type] arg_type.getElementType(),
), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}" static_arg_types, # type: ignore[arg-type]
), (
f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
)
else: else:
assert isinstance( assert isinstance(
arg_type, static_arg_types # type: ignore[arg-type] arg_type,
), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}" static_arg_types, # type: ignore[arg-type]
), (
f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
)
for arg, arg_type in zip(raw_args, arg_types): for arg, arg_type in zip(raw_args, arg_types):
if arg is not None: if arg is not None:
@ -2378,9 +2385,9 @@ if (custom_op_wrapper.get() == NULL) {
return f"&{var_name}" return f"&{var_name}"
if isinstance(type_, torch.ListType): if isinstance(type_, torch.ListType):
assert isinstance( assert isinstance(val, (list, tuple)), (
val, (list, tuple) f"{val} does not match with arg type {type_}"
), f"{val} does not match with arg type {type_}" )
element_type = type_.getElementType() element_type = type_.getElementType()
var_name = f"var_array_{next(self.var_array_id)}" var_name = f"var_array_{next(self.var_array_id)}"
if len(val) == 0: if len(val) == 0:

View File

@ -56,9 +56,9 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
self.cached_output_id = count() self.cached_output_id = count()
self.scalar_to_tensor_id = count() self.scalar_to_tensor_id = count()
self.custom_op_wrapper_loaded = False self.custom_op_wrapper_loaded = False
self.allow_stack_allocation: Optional[ self.allow_stack_allocation: Optional[bool] = (
bool config.aot_inductor.allow_stack_allocation
] = config.aot_inductor.allow_stack_allocation )
self.stack_allocated_buffers: dict[BufferName, BufferLike] = {} self.stack_allocated_buffers: dict[BufferName, BufferLike] = {}
@staticmethod @staticmethod
@ -126,12 +126,12 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
Otherwise it uses the CUDA language for codegen. Otherwise it uses the CUDA language for codegen.
Only valid when cuda == True. Only valid when cuda == True.
""" """
assert ( assert not gpu, (
not gpu "CppWrapperCpuArrayRef.generate_kernel_call does not support GPU"
), "CppWrapperCpuArrayRef.generate_kernel_call does not support GPU" )
assert arg_types is not None and len(call_args) == len( assert arg_types is not None and len(call_args) == len(arg_types), (
arg_types "Mismatch call_args and arg_types in generate_kernel_call"
), "Mismatch call_args and arg_types in generate_kernel_call" )
new_args = [] new_args = []
for idx, arg in enumerate(call_args): for idx, arg in enumerate(call_args):
if "*" in arg_types[idx]: if "*" in arg_types[idx]:
@ -328,9 +328,9 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
dtype = may_get_constant_buffer_dtype( dtype = may_get_constant_buffer_dtype(
V.graph.graph_inputs[input_key] # type: ignore[arg-type] V.graph.graph_inputs[input_key] # type: ignore[arg-type]
) )
assert ( assert dtype is not None, (
dtype is not None "Fails to get the dtype of the sympy.Expr"
), "Fails to get the dtype of the sympy.Expr" )
self.codegen_tensor_item( self.codegen_tensor_item(
dtype, f"inputs[{idx}]", input_key, self.prefix dtype, f"inputs[{idx}]", input_key, self.prefix
) )
@ -724,9 +724,9 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
if reduce: if reduce:
line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}" line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
else: else:
assert ( assert reduce is None, (
reduce is None "Expect reduce to be None for aten.scatter_ with scalar src"
), "Expect reduce to be None for aten.scatter_ with scalar src" )
line += ");" line += ");"
self.writeline(line) self.writeline(line)

View File

@ -60,13 +60,13 @@ class DeferredGpuKernelLine(DeferredLineBase):
# MultiKernel will select one kernel after running the autotune block # MultiKernel will select one kernel after running the autotune block
self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
params = CudaKernelParamCache.get(self.kernel_name) params = CudaKernelParamCache.get(self.kernel_name)
assert ( assert params is not None, (
params is not None f"{self.kernel_name} not found in CudaKernelParamCache"
), f"{self.kernel_name} not found in CudaKernelParamCache" )
for key in self.keys: for key in self.keys:
assert ( assert key in params, (
key in params f"{key} not found in CudaKernelParamCache[{self.kernel_name}]"
), f"{key} not found in CudaKernelParamCache[{self.kernel_name}]" )
if key == get_cpp_wrapper_cubin_path_name(): if key == get_cpp_wrapper_cubin_path_name():
assert os.path.exists(params[key]), f"{params[key]} does not exist" assert os.path.exists(params[key]), f"{params[key]} does not exist"
self.additional_files.append(params[key]) self.additional_files.append(params[key])
@ -122,9 +122,9 @@ class DeferredGpuDefaultGrid:
grid_fn = self.grid_callable(*grid, **self.grid_extra_kwargs) grid_fn = self.grid_callable(*grid, **self.grid_extra_kwargs)
params = CudaKernelParamCache.get(self.kernel_name) params = CudaKernelParamCache.get(self.kernel_name)
assert ( assert params is not None, (
params is not None f"{self.kernel_name} not found in CudaKernelParamCache"
), f"{self.kernel_name} not found in CudaKernelParamCache" )
return grid_fn(params["meta"]) return grid_fn(params["meta"])
@ -153,9 +153,9 @@ class DeferredGpuGridLine(DeferredLineBase):
self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name) self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
params = CudaKernelParamCache.get(self.kernel_name) params = CudaKernelParamCache.get(self.kernel_name)
assert ( assert params is not None, (
params is not None f"{self.kernel_name} not found in CudaKernelParamCache"
), f"{self.kernel_name} not found in CudaKernelParamCache" )
if self.autotune_configs is not None: if self.autotune_configs is not None:
# This indicates the Triton kernel is a user-defined one. # This indicates the Triton kernel is a user-defined one.
@ -248,13 +248,13 @@ class CppWrapperGpu(CppWrapperCpu):
if V.graph.aot_mode and V.graph.inputs_to_check: if V.graph.aot_mode and V.graph.inputs_to_check:
for idx in V.graph.inputs_to_check: for idx in V.graph.inputs_to_check:
input_name = V.graph.graph_input_names[idx] input_name = V.graph.graph_input_names[idx]
assert ( assert input_name in V.graph.graph_inputs, (
input_name in V.graph.graph_inputs f"{input_name} not found in graph inputs"
), f"{input_name} not found in graph inputs" )
value = V.graph.graph_inputs[input_name] value = V.graph.graph_inputs[input_name]
assert isinstance( assert isinstance(value, TensorBox), (
value, TensorBox f"{input_name} is expected to be tensor but found as {type(value)}"
), f"{input_name} is expected to be tensor but found as {type(value)}" )
warn_msg = ( warn_msg = (
f"Input {idx} was compiled as {GPU_ALIGN_BYTES}-bytes aligned, " f"Input {idx} was compiled as {GPU_ALIGN_BYTES}-bytes aligned, "
"but it is not aligned at run time. Copying to an aligned tensor " "but it is not aligned at run time. Copying to an aligned tensor "

View File

@ -87,9 +87,9 @@ class CUDACPPScheduling(BaseScheduling):
Codegen a CUDA template, possibly with fused epilogues Codegen a CUDA template, possibly with fused epilogues
""" """
counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes) counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes)
assert self.is_cuda_cpp_template( assert self.is_cuda_cpp_template(template_node), (
template_node "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer"
), "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer" )
template_node = cast(SchedulerNode, template_node) template_node = cast(SchedulerNode, template_node)
_, (_numel, rnumel) = template_node.group _, (_numel, rnumel) = template_node.group
assert rnumel == 1 assert rnumel == 1

View File

@ -496,7 +496,9 @@ class CUDATemplateCaller(ChoiceCaller):
make_kernel_render: Callable[[CUDATemplateBuffer, Optional[list[IRNode]]], str], make_kernel_render: Callable[[CUDATemplateBuffer, Optional[list[IRNode]]], str],
bmreq: CUDABenchmarkRequest, bmreq: CUDABenchmarkRequest,
template: "CUDATemplate", # type: ignore[name-defined] template: "CUDATemplate", # type: ignore[name-defined]
info_kwargs: Optional[dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]], # type: ignore[type-arg] info_kwargs: Optional[
dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]
], # type: ignore[type-arg]
description: str, description: str,
) -> None: ) -> None:
super().__init__(name, input_nodes, layout, description) super().__init__(name, input_nodes, layout, description)

View File

@ -71,13 +71,14 @@ class CUDATemplate(KernelTemplate):
A CUDATemplateCaller object representing the generated CUDA template caller. A CUDATemplateCaller object representing the generated CUDA template caller.
""" """
kernel_name = f"cuda_{self.name}" kernel_name = f"cuda_{self.name}"
with patch.object( with (
V.graph, "get_dtype", self._fake_get_dtype(self.output_node) patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)),
), CUDATemplateKernel( CUDATemplateKernel(
kernel_name=kernel_name, kernel_name=kernel_name,
runtime_arg_info=self.get_runtime_arg_info(), runtime_arg_info=self.get_runtime_arg_info(),
runtime_arg_values=self.get_runtime_arg_values(**kwargs), runtime_arg_values=self.get_runtime_arg_values(**kwargs),
) as kernel: ) as kernel,
):
code = self.render(kernel=kernel, **kwargs) code = self.render(kernel=kernel, **kwargs)
_, call_args, _, _ = kernel.args.python_argdefs() _, call_args, _, _ = kernel.args.python_argdefs()
autotuning_log.debug("Generated Code:\n%s", code) autotuning_log.debug("Generated Code:\n%s", code)

View File

@ -147,7 +147,9 @@ if try_import_cutlass():
"element_d": DataTypeTag[operation.D.element], # type: ignore[name-defined] "element_d": DataTypeTag[operation.D.element], # type: ignore[name-defined]
"layout_d": LayoutTag[instance_layout_D], # type: ignore[name-defined] "layout_d": LayoutTag[instance_layout_D], # type: ignore[name-defined]
"element_accumulator": DataTypeTag[operation.accumulator_type()], # type: ignore[name-defined] "element_accumulator": DataTypeTag[operation.accumulator_type()], # type: ignore[name-defined]
"opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], # type: ignore[name-defined] # noqa: B950 "opcode_class": OpcodeClassTag[ # type: ignore[name-defined]
operation.tile_description.math_instruction.opcode_class
],
"arch": f"cutlass::arch::Sm{operation.arch:d}", "arch": f"cutlass::arch::Sm{operation.arch:d}",
"tile_shape_m": str(operation.tile_description.tile_shape[0]), "tile_shape_m": str(operation.tile_description.tile_shape[0]),
"tile_shape_n": str(operation.tile_description.tile_shape[1]), "tile_shape_n": str(operation.tile_description.tile_shape[1]),
@ -168,7 +170,9 @@ if try_import_cutlass():
operation.tile_description.math_instruction.instruction_shape[2] operation.tile_description.math_instruction.instruction_shape[2]
), ),
"kernel_schedule": str(KernelScheduleTag[operation.kernel_schedule]), # type: ignore[name-defined] "kernel_schedule": str(KernelScheduleTag[operation.kernel_schedule]), # type: ignore[name-defined]
"epilogue_schedule": str(EpilogueScheduleTag[operation.epilogue_schedule]), # type: ignore[name-defined] "epilogue_schedule": str(
EpilogueScheduleTag[operation.epilogue_schedule] # type: ignore[name-defined]
),
"epilogue_functor": epilogue_functor, "epilogue_functor": epilogue_functor,
"stages": stage_count_string, "stages": stage_count_string,
"align_a": str(operation.A.alignment), "align_a": str(operation.A.alignment),

View File

@ -56,9 +56,9 @@ def try_import_cutlass() -> bool:
"Found cutlass_library in python search path, overriding config.cuda.cutlass_dir" "Found cutlass_library in python search path, overriding config.cuda.cutlass_dir"
) )
cutlass_library_dir = os.path.dirname(cutlass_library.__file__) cutlass_library_dir = os.path.dirname(cutlass_library.__file__)
assert os.path.isdir( assert os.path.isdir(cutlass_library_dir), (
cutlass_library_dir f"{cutlass_library_dir} is not a directory"
), f"{cutlass_library_dir} is not a directory" )
config.cuda.cutlass_dir = os.path.abspath( config.cuda.cutlass_dir = os.path.abspath(
os.path.join( os.path.join(
cutlass_library_dir, cutlass_library_dir,
@ -86,9 +86,9 @@ def try_import_cutlass() -> bool:
if os.path.isdir(cutlass_py_full_path): if os.path.isdir(cutlass_py_full_path):
if tmp_cutlass_py_full_path not in sys.path: if tmp_cutlass_py_full_path not in sys.path:
if os.path.exists(dst_link): if os.path.exists(dst_link):
assert os.path.islink( assert os.path.islink(dst_link), (
dst_link f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again." )
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath( assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
cutlass_py_full_path cutlass_py_full_path
), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}" ), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"

View File

@ -949,9 +949,9 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
import cutlass_library.gemm_operation as cutlass_gemm_op import cutlass_library.gemm_operation as cutlass_gemm_op
import cutlass_library.library as cutlass_lib import cutlass_library.library as cutlass_lib
assert isinstance( assert isinstance(op, cutlass_gemm_op.GemmOperation), (
op, cutlass_gemm_op.GemmOperation "op argument is required and has to be an instance of GemmOperation"
), "op argument is required and has to be an instance of GemmOperation" )
assert len(self.input_nodes) >= 2 and self.output_node is not None assert len(self.input_nodes) >= 2 and self.output_node is not None
X, W = self.input_nodes[0], self.input_nodes[1] X, W = self.input_nodes[0], self.input_nodes[1]
@ -977,7 +977,10 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
else: else:
input_reorder = None input_reorder = None
kernel_call_signature = kernel.def_kernel( kernel_call_signature = kernel.def_kernel(
inputs=inputs, outputs=[Y], names_str=names_str, input_reorder=input_reorder # type: ignore[arg-type] inputs=inputs, # type: ignore[arg-type]
outputs=[Y],
names_str=names_str,
input_reorder=input_reorder,
) )
test_call_statement = self.test_call_statement(kernel, inputs, names_str) test_call_statement = self.test_call_statement(kernel, inputs, names_str)
# The layouts might have changed between autotuning and this call if they were FlexibleLayout # The layouts might have changed between autotuning and this call if they were FlexibleLayout

View File

@ -198,7 +198,7 @@ class HalidePrinter(PythonPrinter):
val, n = expr.args val, n = expr.args
val = self._print(val) val = self._print(val)
n = int(n) n = int(n)
return f"hl.f32({10.**(-n)!r})*hl.round(({val})*hl.f32({10.**n!r}))" return f"hl.f32({10.0 ** (-n)!r})*hl.round(({val})*hl.f32({10.0**n!r}))"
texpr = HalidePrinter().doprint texpr = HalidePrinter().doprint
@ -856,11 +856,11 @@ class HalideKernel(SIMDKernel):
for sym, size in added_sym_size: for sym, size in added_sym_size:
full_index += stride * sym full_index += stride * sym
stride *= size stride *= size
self.index_replacements[ self.index_replacements[node.symbol()] = (
node.symbol() V.graph.sizevars.simplify_with_ranges(
] = V.graph.sizevars.simplify_with_ranges( ModularIndexing(full_index, node.divisor, node.length),
ModularIndexing(full_index, node.divisor, node.length), self.halide_vars, # type: ignore[arg-type]
self.halide_vars, # type: ignore[arg-type] )
) )
# codegen the variable definitions # codegen the variable definitions
@ -1183,9 +1183,9 @@ class HalideKernel(SIMDKernel):
if isinstance(value, tuple): if isinstance(value, tuple):
assert reduction_type == "welford_combine" assert reduction_type == "welford_combine"
self.cse.reduction_cache[ self.cse.reduction_cache[cache_key] = result_tuple = (
cache_key self.welford_combine_impl(*value)
] = result_tuple = self.welford_combine_impl(*value) )
return result_tuple return result_tuple
assert isinstance(value, HalideCSEVariable) and value.used_dims is not None assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
@ -1304,9 +1304,9 @@ class HalideKernel(SIMDKernel):
scan = f"{scan_dom}.x" scan = f"{scan_dom}.x"
self.body.writeline(f"{scan_dom} = hl.RDom([hl.Range(1, {length})])") self.body.writeline(f"{scan_dom} = hl.RDom([hl.Range(1, {length})])")
assert ( assert len(self.reduction_renames) == 1, (
len(self.reduction_renames) == 1 "multi-dimensional scan not implemented"
), "multi-dimensional scan not implemented" )
(scan_var,) = [*self.reduction_renames] # type: ignore[misc] (scan_var,) = [*self.reduction_renames] # type: ignore[misc]
scan_renames_cur = {scan_var: sympy_index_symbol(scan)} scan_renames_cur = {scan_var: sympy_index_symbol(scan)}
scan_renames_pri = {scan_var: sympy_index_symbol(scan) - 1} scan_renames_pri = {scan_var: sympy_index_symbol(scan) - 1}

View File

@ -214,8 +214,7 @@ class MemorySplitProtocol(Protocol):
get_size_hint: CachedMethod[[], int] get_size_hint: CachedMethod[[], int]
get_symbolic_size: CachedMethod[[], sympy.Expr] get_symbolic_size: CachedMethod[[], sympy.Expr]
def _allocate(self, block: Allocation, is_last: bool) -> bool: def _allocate(self, block: Allocation, is_last: bool) -> bool: ...
...
class ClearCacheOnAllocateMixin(MemorySplitProtocol): class ClearCacheOnAllocateMixin(MemorySplitProtocol):

View File

@ -560,7 +560,10 @@ class MetalKernel(SIMDKernel):
threads = [self.pexpr(v.numel) for v in self.active_range_trees()] # type: ignore[misc] threads = [self.pexpr(v.numel) for v in self.active_range_trees()] # type: ignore[misc]
args += [f"threads=[{', '.join(threads)}]"] args += [f"threads=[{', '.join(threads)}]"]
if self.inside_reduction: if self.inside_reduction:
threads = [self.pexpr(v.numel) if v.is_reduction else "1" for v in self.active_range_trees()] # type: ignore[misc] threads = [
self.pexpr(v.numel) if v.is_reduction else "1" # type: ignore[misc]
for v in self.active_range_trees()
]
args += [f"group_size=[{', '.join(threads)}]"] args += [f"group_size=[{', '.join(threads)}]"]
wrapper.generate_kernel_call( wrapper.generate_kernel_call(

View File

@ -33,9 +33,9 @@ def _get_all_args(args_list, arg_types_list=None):
all_args = max(args_list, key=len)[:] all_args = max(args_list, key=len)[:]
arg_types = max(arg_types_list, key=len)[:] if arg_types_list is not None else None arg_types = max(arg_types_list, key=len)[:] if arg_types_list is not None else None
for args in args_list: for args in args_list:
assert OrderedSet(args).issubset( assert OrderedSet(args).issubset(OrderedSet(all_args)), (
OrderedSet(all_args) f"{args} v.s. {all_args}"
), f"{args} v.s. {all_args}" )
return all_args, arg_types return all_args, arg_types
@ -149,7 +149,9 @@ class MultiKernel:
Assume we do codegen for a MultiKernel encapsulating kernel1 and kernel2. Assume we do codegen for a MultiKernel encapsulating kernel1 and kernel2.
The generated definition for the multi-kernel will looks like: The generated definition for the multi-kernel will looks like:
``` ```
multi_kernel_kernel1 = MultiKernelCall([kernel1, kernel2], multi_kernel_definition_code) multi_kernel_kernel1 = MultiKernelCall(
[kernel1, kernel2], multi_kernel_definition_code
)
``` ```
Here is an concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39 Here is an concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39

View File

@ -516,7 +516,12 @@ class CKGroupedConvFwdTemplate(CKTemplate):
template_params=(",\n" + 12 * " ").join(template_params), template_params=(",\n" + 12 * " ").join(template_params),
), self._template_from_string(template_type).render(operation_name=op.name()) ), self._template_from_string(template_type).render(operation_name=op.name())
def render(self, kernel: ROCmTemplateKernel, op: "CKGroupedConvFwdOp", **kwargs) -> str: # type: ignore[override, name-defined] def render( # type: ignore[override]
self,
kernel: ROCmTemplateKernel,
op: "CKGroupedConvFwdOp", # type: ignore[name-defined]
**kwargs,
) -> str:
template_buffer_node = kwargs.get("template_buffer_node", None) template_buffer_node = kwargs.get("template_buffer_node", None)
if template_buffer_node is not None: if template_buffer_node is not None:
self.output_node = template_buffer_node self.output_node = template_buffer_node

View File

@ -602,7 +602,12 @@ class CKGemmTemplate(CKTemplate):
operation_name=operation_name operation_name=operation_name
) )
def render(self, kernel: ROCmTemplateKernel, op: "CKGemmOperation", **kwargs) -> str: # type: ignore[override] def render( # type: ignore[override]
self,
kernel: ROCmTemplateKernel,
op: "CKGemmOperation",
**kwargs,
) -> str:
""" """
The primary entry point for the code rendering process used in this template. The primary entry point for the code rendering process used in this template.
""" """
@ -706,7 +711,7 @@ class CKGemmTemplate(CKTemplate):
* Template instance {op} * Template instance {op}
* *
* {torch.__version__=} * {torch.__version__=}
* torch.version.git_version={getattr(torch.version, 'git_version', 'None')} * torch.version.git_version={getattr(torch.version, "git_version", "None")}
*/ */
""" """
epilogue = None epilogue = None

View File

@ -79,9 +79,9 @@ class ROCmCPPScheduling(BaseScheduling):
""" """
Codegen a ROCm template, possibly with fused epilogues Codegen a ROCm template, possibly with fused epilogues
""" """
assert self.is_rocm_cpp_template( assert self.is_rocm_cpp_template(template_node), (
template_node "Template node passed to ROCmScheduler.codegen_template must be a SchedulerNode that wraps a ROCmTemplateBuffer"
), "Template node passed to ROCmScheduler.codegen_template must be a SchedulerNode that wraps a ROCmTemplateBuffer" )
template_node = cast(SchedulerNode, template_node) template_node = cast(SchedulerNode, template_node)
_, (_numel, rnumel) = template_node.group _, (_numel, rnumel) = template_node.group
assert rnumel == 1 assert rnumel == 1

View File

@ -232,7 +232,9 @@ class ROCmTemplateCaller(ChoiceCaller):
], ],
bmreq: ROCmBenchmarkRequest, bmreq: ROCmBenchmarkRequest,
template: "ROCmTemplate", # type: ignore[name-defined] template: "ROCmTemplate", # type: ignore[name-defined]
info_kwargs: Optional[dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]], # type: ignore[type-arg] info_kwargs: Optional[
dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]
], # type: ignore[type-arg]
) -> None: ) -> None:
super().__init__(name, input_nodes, layout, description="") super().__init__(name, input_nodes, layout, description="")
self.category = category self.category = category

View File

@ -70,13 +70,14 @@ class ROCmTemplate(KernelTemplate):
""" """
kernel_name = f"rocm_{self.name}" kernel_name = f"rocm_{self.name}"
kernel_hash_name = f"rocm_{self.name}_{next(self.index_counter)}" kernel_hash_name = f"rocm_{self.name}_{next(self.index_counter)}"
with patch.object( with (
V.graph, "get_dtype", self._fake_get_dtype(self.output_node) patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)),
), ROCmTemplateKernel( ROCmTemplateKernel(
kernel_name=kernel_name, kernel_name=kernel_name,
runtime_arg_info=self.get_runtime_arg_info(), runtime_arg_info=self.get_runtime_arg_info(),
runtime_arg_values=self.get_runtime_arg_values(**kwargs), runtime_arg_values=self.get_runtime_arg_values(**kwargs),
) as kernel: ) as kernel,
):
code = self.render(kernel=kernel, **kwargs) code = self.render(kernel=kernel, **kwargs)
_, call_args, _, _ = kernel.args.python_argdefs() _, call_args, _, _ = kernel.args.python_argdefs()
log.debug("Autotune key: %s, Generated Code:\n%s", kernel_hash_name, code) log.debug("Autotune key: %s, Generated Code:\n%s", kernel_hash_name, code)

View File

@ -638,7 +638,8 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
continue continue
while current_group < len(remaining) and sv.statically_known_equals( while current_group < len(remaining) and sv.statically_known_equals(
remaining[current_group], 1 # type: ignore[arg-type] remaining[current_group],
1, # type: ignore[arg-type]
): ):
# scroll to next group with remaining elements # scroll to next group with remaining elements
current_group += 1 current_group += 1
@ -666,9 +667,9 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
) )
return_getters_groups.append(return_getters) return_getters_groups.append(return_getters)
assert all( assert all(V.graph.sizevars.size_hint(s) == 1 for s in remaining), (
V.graph.sizevars.size_hint(s) == 1 for s in remaining f"failed to set ranges {remaining} {lengths}"
), f"failed to set ranges {remaining} {lengths}" )
return new_ranges, return_getters_groups return new_ranges, return_getters_groups
@ -836,7 +837,8 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps) replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps)
if len(replacements) > 0: if len(replacements) > 0:
self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index] self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index]
self.range_tree_nodes[sym].expr, replacements # type: ignore[index] self.range_tree_nodes[sym].expr,
replacements, # type: ignore[index]
) )
self.range_tree_nodes[sym].codegen() # type: ignore[index] self.range_tree_nodes[sym].codegen() # type: ignore[index]
return expr return expr
@ -2071,9 +2073,10 @@ class SIMDScheduling(BaseScheduling):
features=SIMDKernelFeatures(node_schedule, numel, rnumel), features=SIMDKernelFeatures(node_schedule, numel, rnumel),
) )
self.codegen_node_schedule_with_kernel(node_schedule, kernel) self.codegen_node_schedule_with_kernel(node_schedule, kernel)
with config.patch( with (
"benchmark_kernel", benchmark_kernel config.patch("benchmark_kernel", benchmark_kernel),
), V.set_kernel_handler(kernel): V.set_kernel_handler(kernel),
):
src_code = kernel.codegen_kernel() src_code = kernel.codegen_kernel()
else: else:
prologue, template, epilogue = nodes[0].get_prologue_template_epilogue( prologue, template, epilogue = nodes[0].get_prologue_template_epilogue(

View File

@ -1579,9 +1579,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
self.block_ptr_id = itertools.count() self.block_ptr_id = itertools.count()
self.block_ptr_to_buffer = dict[str, str]() self.block_ptr_to_buffer = dict[str, str]()
self.helper_functions = HelperFunctions() self.helper_functions = HelperFunctions()
self.pointer_advancements: dict[ self.pointer_advancements: dict[SymT, dict[str, list[sympy.Expr]]] = (
SymT, dict[str, list[sympy.Expr]] collections.defaultdict(dict)
] = collections.defaultdict(dict) )
self._load_counts: collections.Counter[str] = collections.Counter() self._load_counts: collections.Counter[str] = collections.Counter()
# A set of autotuning hints to pass as part of triton_meta # A set of autotuning hints to pass as part of triton_meta
@ -2053,9 +2053,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
continue continue
advancements = self.pointer_advancements[symt] advancements = self.pointer_advancements[symt]
assert ( assert block_ptr not in advancements, (
block_ptr not in advancements "duplicate advancement for pointer '{block_ptr}' at type '{symt}'"
), "duplicate advancement for pointer '{block_ptr}' at type '{symt}'" )
advancements[block_ptr] = advance_offsets advancements[block_ptr] = advance_offsets
else: else:
block_ptr = indexing.format(var) block_ptr = indexing.format(var)
@ -2476,7 +2476,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
buffer.splice( buffer.splice(
f"""\ f"""\
{result_var}_val, {result_var}_idx = triton_helpers.{root_op}_with_index({value}, {index}, {dim}) {result_var}_val, {result_var}_idx = triton_helpers.{root_op}_with_index({value}, {index}, {dim})
{result_var} = {self.reduction_resize(f'{result_var}_idx')} {result_var} = {self.reduction_resize(f"{result_var}_idx")}
""" """
) )
@ -2576,8 +2576,8 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
{accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index( {accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index(
{accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index {accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index
) )
{accumulator} = {where_cond(f'{accumulator}_next', accumulator)} {accumulator} = {where_cond(f"{accumulator}_next", accumulator)}
{accumulator_index} = {where_cond(f'{accumulator_index}_next', accumulator_index)} {accumulator_index} = {where_cond(f"{accumulator_index}_next", accumulator_index)}
""" """
) )
final_argreduce( final_argreduce(
@ -2751,9 +2751,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
) )
self.compute.splice( self.compute.splice(
f"""\ f"""\
{accumulator} = {where_cond(f'{accumulator}_next', accumulator)} {accumulator} = {where_cond(f"{accumulator}_next", accumulator)}
{accumulator_m2} = {where_cond(f'{accumulator_m2}_next', accumulator_m2)} {accumulator_m2} = {where_cond(f"{accumulator_m2}_next", accumulator_m2)}
{accumulator_weight} = {where_cond(f'{accumulator_weight}_next', accumulator_weight)} {accumulator_weight} = {where_cond(f"{accumulator_weight}_next", accumulator_weight)}
""" """
) )
result_mean = result_var result_mean = result_var
@ -3040,9 +3040,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
self.filter_masks(masks) self.filter_masks(masks)
masks = sorted(masks) masks = sorted(masks)
assert not self._load_mask, "ops.sort not supported inside ops.masked" assert not self._load_mask, "ops.sort not supported inside ops.masked"
assert ( assert self.persistent_reduction, (
self.persistent_reduction "ops.sort is only supported in persistent reductions"
), "ops.sort is only supported in persistent reductions" )
cse_compute = functools.partial(self.cse.generate, self.compute) cse_compute = functools.partial(self.cse.generate, self.compute)
dim = self.triton_tensor_ndim() - self.num_reduction_dims dim = self.triton_tensor_ndim() - self.num_reduction_dims
@ -3302,9 +3302,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
{} {}
import torch import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
""".format( """.format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream"))
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
)
) )
def _get_heuristic(self): def _get_heuristic(self):
@ -3344,19 +3342,19 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
inductor_meta["profile_bandwidth"] = config.profile_bandwidth inductor_meta["profile_bandwidth"] = config.profile_bandwidth
inductor_meta["profile_bandwidth_regex"] = config.profile_bandwidth_regex inductor_meta["profile_bandwidth_regex"] = config.profile_bandwidth_regex
inductor_meta["profile_bandwidth_output"] = config.profile_bandwidth_output inductor_meta["profile_bandwidth_output"] = config.profile_bandwidth_output
inductor_meta[ inductor_meta["profile_bandwidth_with_do_bench_using_profiling"] = (
"profile_bandwidth_with_do_bench_using_profiling" config.profile_bandwidth_with_do_bench_using_profiling
] = config.profile_bandwidth_with_do_bench_using_profiling )
if config.coordinate_descent_tuning: if config.coordinate_descent_tuning:
inductor_meta[ inductor_meta["coordinate_descent_tuning"] = (
"coordinate_descent_tuning" config.coordinate_descent_tuning
] = config.coordinate_descent_tuning )
inductor_meta[ inductor_meta["coordinate_descent_search_radius"] = (
"coordinate_descent_search_radius" config.coordinate_descent_search_radius
] = config.coordinate_descent_search_radius )
inductor_meta[ inductor_meta["coordinate_descent_check_all_directions"] = (
"coordinate_descent_check_all_directions" config.coordinate_descent_check_all_directions
] = config.coordinate_descent_check_all_directions )
return inductor_meta return inductor_meta
def codegen_kernel(self, name=None): def codegen_kernel(self, name=None):
@ -4046,9 +4044,10 @@ class TritonScheduling(SIMDScheduling):
) -> tuple[float, str]: ) -> tuple[float, str]:
"""Benchmark an already compiled module""" """Benchmark an already compiled module"""
device_interface = get_interface_for_device(V.graph.device_type) device_interface = get_interface_for_device(V.graph.device_type)
with preserve_rng_state(), device_interface.device( with (
V.graph.get_current_device_or_throw() preserve_rng_state(),
): # type: ignore[attr-defined] device_interface.device(V.graph.get_current_device_or_throw()), # type: ignore[attr-defined]
):
ms = None ms = None
def cache_file_path(): def cache_file_path():
@ -4322,9 +4321,9 @@ def debug_triton_code(node: BaseSchedulerNode) -> list[str]:
device = node.get_device() device = node.get_device()
assert device is not None assert device is not None
backend = node.scheduler.get_backend(device) backend = node.scheduler.get_backend(device)
assert isinstance( assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling)), (
backend, (SIMDScheduling, CUDACombinedScheduling) f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}"
), f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}" )
with V.graph.set_current_device(device): with V.graph.set_current_device(device):
# Don't increment kernel count when generating debug string. # Don't increment kernel count when generating debug string.

View File

@ -86,7 +86,9 @@ def _default_custom_combo_kernel_horizontal_partition(
# rnumel > 2048 usually has long execution time # rnumel > 2048 usually has long execution time
# BaseSchedulerNode.group[-1][-1] is rnumel for reduction nodes # BaseSchedulerNode.group[-1][-1] is rnumel for reduction nodes
long_reduction = [ long_reduction = [
n for n in reduction if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 # type: ignore[arg-type] n
for n in reduction
if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 # type: ignore[arg-type]
] ]
short_reduction = [n for n in reduction if n not in long_reduction] short_reduction = [n for n in reduction if n not in long_reduction]
if long_reduction: if long_reduction:
@ -138,7 +140,7 @@ def set_custom_combo_kernel_horizontal_partition(
dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]], dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]],
], ],
list[list[BaseSchedulerNode]], list[list[BaseSchedulerNode]],
] ],
) -> None: ) -> None:
"""Sets the algorithm used to partition nodes into horizontal partitions. Nodes in different partitions """Sets the algorithm used to partition nodes into horizontal partitions. Nodes in different partitions
are implemented in different combo kernels. Nodes in the same partition are likely to be implemented are implemented in different combo kernels. Nodes in the same partition are likely to be implemented
@ -593,9 +595,9 @@ class ComboKernel(Kernel):
num_persistent_reduction = len( num_persistent_reduction = len(
[e for e in heuristics_list if e == "persistent_reduction"] [e for e in heuristics_list if e == "persistent_reduction"]
) )
assert ( assert num_reduction == 0, (
num_reduction == 0 "combining pointwise and reduction are not supported yet."
), "combining pointwise and reduction are not supported yet." )
heuristics = ( heuristics = (
"pointwise_with_reduction" "pointwise_with_reduction"
if num_persistent_reduction > 0 if num_persistent_reduction > 0
@ -784,13 +786,13 @@ class ComboKernel(Kernel):
name, tree, suffix=str(num) name, tree, suffix=str(num)
) )
if not tree.is_reduction: if not tree.is_reduction:
assert isinstance( assert isinstance(grid[i][num], str), (
grid[i][num], str f"Grid {grid[i][num]} should be a dynamic shape."
), f"Grid {grid[i][num]} should be a dynamic shape." )
numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else "" numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else ""
assert ( assert grid[i][num] == numel_sign + numel_name, (
grid[i][num] == numel_sign + numel_name f"numel args mismatch: {grid[i][num]} vs {numel_name}"
), f"numel args mismatch: {grid[i][num]} vs {numel_name}" )
grid[i][num] = -expr if numel_sign == "-" else expr grid[i][num] = -expr if numel_sign == "-" else expr
if not tree.is_reduction or sub_kernel.inside_reduction: if not tree.is_reduction or sub_kernel.inside_reduction:
@ -807,13 +809,13 @@ class ComboKernel(Kernel):
continue continue
expr = V.graph.sizevars.size_hint(tree.numel) expr = V.graph.sizevars.size_hint(tree.numel)
if not tree.is_reduction: if not tree.is_reduction:
assert isinstance( assert isinstance(grid[i][num], str), (
grid[i][num], str f"Grid {grid[i][num]} should be a dynamic shape."
), f"Grid {grid[i][num]} should be a dynamic shape." )
numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else "" numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else ""
assert ( assert grid[i][num] == numel_sign + numel_name, (
grid[i][num] == numel_sign + numel_name f"grid mismatch: {grid[i][num]} vs {numel_name}"
), f"grid mismatch: {grid[i][num]} vs {numel_name}" )
grid[i][num] = -expr if numel_sign == "-" else expr grid[i][num] = -expr if numel_sign == "-" else expr
if not tree.is_reduction or sub_kernel.inside_reduction: if not tree.is_reduction or sub_kernel.inside_reduction:
extra_args.append(expr) extra_args.append(expr)
@ -1015,9 +1017,7 @@ class ComboKernel(Kernel):
{} {}
import torch import torch
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels
""".format( """.format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream"))
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
)
) )
def uniquify_block_sizes( def uniquify_block_sizes(

View File

@ -57,9 +57,9 @@ class TritonSplitScanKernel(TritonKernel):
def initialize_range_tree(self, pid_cache): def initialize_range_tree(self, pid_cache):
prefixes = ["y", "x", "r0_"] prefixes = ["y", "x", "r0_"]
assert len(self.numels) <= len( assert len(self.numels) <= len(prefixes), (
prefixes "z dimension not supported for split scan"
), "z dimension not supported for split scan" )
active_prefixes = prefixes[len(prefixes) - len(self.numels) :] active_prefixes = prefixes[len(prefixes) - len(self.numels) :]
grid_dims = {"r0_": 0, "x": 1, "y": 2} grid_dims = {"r0_": 0, "x": 1, "y": 2}

View File

@ -184,7 +184,8 @@ def config_of(
if isinstance(x, TensorArg): if isinstance(x, TensorArg):
if include_tensor: if include_tensor:
offset_aligned = V.graph.sizevars.statically_known_multiple_of( offset_aligned = V.graph.sizevars.statically_known_multiple_of(
x.offset * x.dtype.itemsize, alignment # type: ignore[arg-type] x.offset * x.dtype.itemsize,
alignment, # type: ignore[arg-type]
) )
return offset_aligned and not is_unaligned_buffer(x) return offset_aligned and not is_unaligned_buffer(x)
else: else:

View File

@ -104,8 +104,7 @@ def can_match_buffer_size(input_buf: BufferLike, output_buf: BufferLike):
# NB: this is symbolic so that we don't try to reuse a buffer # NB: this is symbolic so that we don't try to reuse a buffer
# for s0 for s1, just because they happen to share the same # for s0 for s1, just because they happen to share the same
# size hint # size hint
sympy_str(input_size) sympy_str(input_size) == sympy_str(output_size)
== sympy_str(output_size)
) or ( ) or (
# statically known that 0.95 * input_size <= output_size <= input_size # statically known that 0.95 * input_size <= output_size <= input_size
V.graph.sizevars.statically_known_geq(output_size, 0.95 * input_size) V.graph.sizevars.statically_known_geq(output_size, 0.95 * input_size)
@ -138,9 +137,9 @@ def convert_arg_type(arg: torch.Argument) -> str:
container_match = re.findall(py_container + r"\[([a-zA-Z_]+)]", python_type) container_match = re.findall(py_container + r"\[([a-zA-Z_]+)]", python_type)
if len(container_match) == 1: if len(container_match) == 1:
contained_type = container_match[0] contained_type = container_match[0]
assert ( assert contained_type in PYTHON_TO_CPP, (
contained_type in PYTHON_TO_CPP f"unsupported {py_container} type in convert_arg_type: {contained_type}"
), f"unsupported {py_container} type in convert_arg_type: {contained_type}" )
cpp_contained_type = PYTHON_TO_CPP[contained_type] cpp_contained_type = PYTHON_TO_CPP[contained_type]
return f"{cpp_container}<{cpp_contained_type}>" return f"{cpp_container}<{cpp_contained_type}>"
@ -367,9 +366,9 @@ class SymbolicCallArg:
class MemoryPlanningState: class MemoryPlanningState:
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.reuse_pool: dict[ self.reuse_pool: dict[ReuseKey, list[FreeIfNotReusedLine]] = (
ReuseKey, list[FreeIfNotReusedLine] collections.defaultdict(list)
] = collections.defaultdict(list) )
self.total_allocated_buffer_size: int = 0 self.total_allocated_buffer_size: int = 0
def __contains__(self, key: ReuseKey) -> bool: def __contains__(self, key: ReuseKey) -> bool:
@ -431,9 +430,9 @@ class EnterDeviceContextManagerLine(WrapperLine):
f"{V.graph.device_ops.cpp_aoti_stream_guard()} stream_guard(stream, this->device_idx_);" f"{V.graph.device_ops.cpp_aoti_stream_guard()} stream_guard(stream, this->device_idx_);"
) )
else: else:
assert ( assert self.last_seen_device_guard_index == self.device_idx, (
self.last_seen_device_guard_index == self.device_idx "AOTInductor only supports running on one CUDA device"
), "AOTInductor only supports running on one CUDA device" )
else: else:
if self.last_seen_device_guard_index is None: if self.last_seen_device_guard_index is None:
code.writeline( code.writeline(
@ -1794,7 +1793,8 @@ class PythonWrapperCodegen(CodeGen):
equals_1 = isinstance( equals_1 = isinstance(
arg, (int, sympy.Integer) arg, (int, sympy.Integer)
) and V.graph.sizevars.statically_known_equals( ) and V.graph.sizevars.statically_known_equals(
arg, 1 # type: ignore[arg-type] arg,
1, # type: ignore[arg-type]
) )
add_arg(idx, SizeArg(key, arg), equals_1=equals_1) add_arg(idx, SizeArg(key, arg), equals_1=equals_1)
@ -2052,9 +2052,9 @@ class PythonWrapperCodegen(CodeGen):
buf_name = arg buf_name = arg
buf = V.graph.get_buffer(arg) buf = V.graph.get_buffer(arg)
else: else:
assert ( assert raw_arg is not None, (
raw_arg is not None "V.graph.get_buffer(arg) and raw_arg can't be None at the same time"
), "V.graph.get_buffer(arg) and raw_arg can't be None at the same time" )
buf_name = f"tmp_arg_{index}" buf_name = f"tmp_arg_{index}"
buf = raw_arg buf = raw_arg
@ -2181,9 +2181,9 @@ class PythonWrapperCodegen(CodeGen):
and kernel_name not in self.kernel_autotune_names and kernel_name not in self.kernel_autotune_names
): ):
# Create example args for autotune in a separate epilogue # Create example args for autotune in a separate epilogue
assert arg_types is not None and len(call_args) == len( assert arg_types is not None and len(call_args) == len(arg_types), (
arg_types "call_args and arg_types do not match"
), "call_args and arg_types do not match" )
tensor_args = {} tensor_args = {}
all_args = [] all_args = []
@ -2191,9 +2191,9 @@ class PythonWrapperCodegen(CodeGen):
# create a dummy raw_args for uniform behavior in the following loop # create a dummy raw_args for uniform behavior in the following loop
raw_args = [None] * len(call_args) raw_args = [None] * len(call_args)
else: else:
assert len(raw_args) == len( assert len(raw_args) == len(call_args), (
call_args "call_args and raw_args do not match"
), "call_args and raw_args do not match" )
for i, (arg, arg_type, raw_arg) in enumerate( for i, (arg, arg_type, raw_arg) in enumerate(
zip(call_args, arg_types, raw_args) zip(call_args, arg_types, raw_args)
@ -2411,9 +2411,9 @@ class PythonWrapperCodegen(CodeGen):
if isinstance(layout, ir.NoneLayout): if isinstance(layout, ir.NoneLayout):
return return
if isinstance(layout, ir.NonOwningLayout): if isinstance(layout, ir.NonOwningLayout):
assert isinstance( assert isinstance(layout.view, ir.ReinterpretView), (
layout.view, ir.ReinterpretView f"unexpected {type(layout.view)}: {layout.view}"
), f"unexpected {type(layout.view)}: {layout.view}" )
assert isinstance(layout.view.data, ir.StorageBox), type(layout.view.data) assert isinstance(layout.view.data, ir.StorageBox), type(layout.view.data)
assert isinstance(layout.view.data.data, ir.Buffer), type(layout.view.data) assert isinstance(layout.view.data.data, ir.Buffer), type(layout.view.data)
self.codegen_allocation(layout.view.data.data) self.codegen_allocation(layout.view.data.data)
@ -2535,9 +2535,9 @@ class PythonWrapperCodegen(CodeGen):
def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs): def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs):
# All inputs of hops must be explicitly passed in. # All inputs of hops must be explicitly passed in.
# Free tensors and basic symbols should have been explicitly lifted as inputs in dynamo. # Free tensors and basic symbols should have been explicitly lifted as inputs in dynamo.
assert len(outer_inputs) == len( assert len(outer_inputs) == len(subgraph.graph.graph_input_names), (
subgraph.graph.graph_input_names f"graph_input_names:{subgraph.graph.graph_input_names}, outer_inputs: {outer_inputs}"
), f"graph_input_names:{subgraph.graph.graph_input_names}, outer_inputs: {outer_inputs}" )
for inner_input, outer_input in zip( for inner_input, outer_input in zip(
subgraph.graph.graph_input_names, outer_inputs subgraph.graph.graph_input_names, outer_inputs
): ):

View File

@ -219,8 +219,7 @@ def _schedule_for_comm(
for snode, deps in unmet_deps.items(): for snode, deps in unmet_deps.items():
assert len(deps) == 0, ( assert len(deps) == 0, (
"Detected unscheduled nodes. " f"Detected unscheduled nodes. Nodes with unmet dependencies: {unmet_deps}"
f"Nodes with unmet dependencies: {unmet_deps}"
) )
return scheduled return scheduled
@ -354,9 +353,7 @@ def remove_fsdp2_unsharded_param_graph_input_usage(graph: torch.fx.Graph):
node.op == "call_function" node.op == "call_function"
and node.target == torch.ops.inductor.resize_storage_bytes_.default and node.target == torch.ops.inductor.resize_storage_bytes_.default
): ):
assert ( assert node.args[0].op == "placeholder", f"""\
node.args[0].op == "placeholder"
), f"""\
Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]} Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]}
""" """
graph_input = node.args[0] graph_input = node.args[0]
@ -408,9 +405,7 @@ Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass for that
if node.op == "call_function" and node.target == torch.ops.fsdp.copy_.default: if node.op == "call_function" and node.target == torch.ops.fsdp.copy_.default:
fsdp_copy_node = node fsdp_copy_node = node
unsharded_param = node.args[0] unsharded_param = node.args[0]
assert ( assert unsharded_param.op == "placeholder", f"""
unsharded_param.op == "placeholder"
), f"""
Assumed all FSDP2 `unsharded_param`s to be graph input, but it's not true! Assumed all FSDP2 `unsharded_param`s to be graph input, but it's not true!
Offending node: {unsharded_param}. Graph: {graph} Offending node: {unsharded_param}. Graph: {graph}
""" """

View File

@ -281,9 +281,9 @@ def _unlift_graph(
elif node_name in graph_signature.inputs_to_buffers: elif node_name in graph_signature.inputs_to_buffers:
buffer_name = graph_signature.inputs_to_buffers[node_name] buffer_name = graph_signature.inputs_to_buffers[node_name]
lifted_inputs.append(buffer_name) lifted_inputs.append(buffer_name)
gm.meta[ gm.meta[get_cloned_parameter_buffer_name(buffer_name)] = (
get_cloned_parameter_buffer_name(buffer_name) clone_preserve_strides(state_dict[buffer_name])
] = clone_preserve_strides(state_dict[buffer_name]) )
else: else:
assert node_name in graph_signature.user_inputs assert node_name in graph_signature.user_inputs
lifted_inputs.append(None) lifted_inputs.append(None)
@ -542,7 +542,7 @@ def fake_tensor_prop(
# pass config dict back to user # pass config dict back to user
def get_patched_config_dict( def get_patched_config_dict(
config_patches: Optional[Union[str, dict[str, Any]]] = None config_patches: Optional[Union[str, dict[str, Any]]] = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
with config.patch(config_patches): with config.patch(config_patches):
return config.get_config_copy() return config.get_config_copy()
@ -579,8 +579,7 @@ class _CompileFxCallable(Protocol):
gm: GraphModule, gm: GraphModule,
example_inputs: Sequence[InputType], example_inputs: Sequence[InputType],
**kwargs: Unpack[_CompileFxKwargs], **kwargs: Unpack[_CompileFxKwargs],
) -> OutputCode: ) -> OutputCode: ...
...
def compile_fx_inner( def compile_fx_inner(
@ -662,9 +661,9 @@ def _compile_fx_inner(
static_inputs_log.debug("static input idxs compile_fx_inner: %s", static_input_idxs) static_inputs_log.debug("static input idxs compile_fx_inner: %s", static_input_idxs)
inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs) inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs)
assert isinstance( assert isinstance(next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)), (
next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list) f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}"
), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}" )
if (cudagraphs := graph_kwargs.get("cudagraphs")) is None: if (cudagraphs := graph_kwargs.get("cudagraphs")) is None:
graph_kwargs["cudagraphs"] = cudagraphs = BoxedBool(config.triton.cudagraphs) graph_kwargs["cudagraphs"] = cudagraphs = BoxedBool(config.triton.cudagraphs)
@ -679,9 +678,10 @@ def _compile_fx_inner(
fx_graph_remote_cache = should_use_remote_fx_graph_cache() fx_graph_remote_cache = should_use_remote_fx_graph_cache()
with _WaitCounter( with (
"pytorch.wait_counter.fx_codegen_and_compile" _WaitCounter("pytorch.wait_counter.fx_codegen_and_compile").guard() as _,
).guard() as _, _WaitCounter("pytorch.wait_counter.all_compilation_types").guard(): _WaitCounter("pytorch.wait_counter.all_compilation_types").guard(),
):
use_cache = ( use_cache = (
not config.force_disable_caches not config.force_disable_caches
and (config.fx_graph_cache or fx_graph_remote_cache) and (config.fx_graph_cache or fx_graph_remote_cache)
@ -865,8 +865,7 @@ class FxCompile(ABC):
example_inputs: Sequence[InputType], example_inputs: Sequence[InputType],
inputs_to_check: Sequence[int], inputs_to_check: Sequence[int],
graph_kwargs: _CompileFxKwargs, graph_kwargs: _CompileFxKwargs,
) -> OutputCode: ) -> OutputCode: ...
...
class _InProcessFxCompile(FxCompile): class _InProcessFxCompile(FxCompile):
@ -890,16 +889,17 @@ class _InProcessFxCompile(FxCompile):
cpp_wrapper: bool = graph_kwargs.get("cpp_wrapper", False) cpp_wrapper: bool = graph_kwargs.get("cpp_wrapper", False)
aot_mode: bool = V.aot_compilation aot_mode: bool = V.aot_compilation
is_inference: bool = graph_kwargs.get("is_inference", False) is_inference: bool = graph_kwargs.get("is_inference", False)
extern_node_serializer: Optional[ extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]] = (
Callable[[list[ExternKernelNode]], Any] graph_kwargs.get("extern_node_serializer", None)
] = graph_kwargs.get("extern_node_serializer", None) )
boxed_forward_device_index: Optional[BoxedDeviceIndex] = graph_kwargs.get( boxed_forward_device_index: Optional[BoxedDeviceIndex] = graph_kwargs.get(
"boxed_forward_device_index", None "boxed_forward_device_index", None
) )
with _WaitCounter( with (
"pytorch.wait_counter.actual_codegen_and_compile" _WaitCounter("pytorch.wait_counter.actual_codegen_and_compile").guard(),
).guard(), dynamo_utils.preserve_rng_state(): dynamo_utils.preserve_rng_state(),
):
if (sleep_sec := config.sleep_sec_TESTING_ONLY) is not None: if (sleep_sec := config.sleep_sec_TESTING_ONLY) is not None:
import time import time
@ -1038,9 +1038,11 @@ class _InProcessFxCompile(FxCompile):
# See details in vllm/compilation/pass_manager.py. # See details in vllm/compilation/pass_manager.py.
log.warning("failed to log pt2_configs") log.warning("failed to log pt2_configs")
with V.set_fake_mode(fake_mode), maybe_disable_comprehensive_padding( with (
example_inputs V.set_fake_mode(fake_mode),
), maybe_disable_graph_partition(cpp_wrapper, aot_mode): maybe_disable_comprehensive_padding(example_inputs),
maybe_disable_graph_partition(cpp_wrapper, aot_mode),
):
const_output_index = None const_output_index = None
const_graph = None const_graph = None
const_code = None const_code = None
@ -1123,9 +1125,9 @@ class _InProcessFxCompile(FxCompile):
if graph.aot_mode: if graph.aot_mode:
from .codecache import AotCodeCompiler from .codecache import AotCodeCompiler
assert ( assert graph.cpp_wrapper, (
graph.cpp_wrapper "AOT mode only supports C++ wrapper"
), "AOT mode only supports C++ wrapper" )
code, linemap = graph.codegen_with_cpp_wrapper() code, linemap = graph.codegen_with_cpp_wrapper()
output_code_log.debug("Output code: \n%s", code) output_code_log.debug("Output code: \n%s", code)
@ -1509,10 +1511,13 @@ def cudagraphify(
def run(new_inputs: Sequence[InputType]) -> Any: def run(new_inputs: Sequence[InputType]) -> Any:
nonlocal compiled_fn nonlocal compiled_fn
if compiled_fn is None: if compiled_fn is None:
with dynamo_utils.dynamo_timed( with (
"cudagraphify", dynamo_utils.dynamo_timed(
log_pt2_compile_event=True, "cudagraphify",
), dynamo_utils.preserve_rng_state(): log_pt2_compile_event=True,
),
dynamo_utils.preserve_rng_state(),
):
compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs) compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs)
return compiled_fn(new_inputs) return compiled_fn(new_inputs)
@ -1669,13 +1674,16 @@ def compile_fx_aot(
extern_node_serializer = config_patches.pop("extern_node_serializer", None) extern_node_serializer = config_patches.pop("extern_node_serializer", None)
saved_compile_id = model_.meta.get("dynamo_compile_id", None) saved_compile_id = model_.meta.get("dynamo_compile_id", None)
saved_compile_context = torch._guards.CompileContext(saved_compile_id) saved_compile_context = torch._guards.CompileContext(saved_compile_id)
with V.set_aot_compilation(True), torch._guards.compile_context( with (
saved_compile_context V.set_aot_compilation(True),
), chromium_event_timed( torch._guards.compile_context(saved_compile_context),
"compile_fx_aot", chromium_event_timed(
log_pt2_compile_event=True, "compile_fx_aot",
reset_event_log_on_exit=True, log_pt2_compile_event=True,
), get_metrics_context(): reset_event_log_on_exit=True,
),
get_metrics_context(),
):
compiled_artifacts = compile_fx( compiled_artifacts = compile_fx(
model_, model_,
example_inputs_, example_inputs_,
@ -1875,12 +1883,15 @@ def compile_fx(
# TODO: This probably shouldn't be a recursive call # TODO: This probably shouldn't be a recursive call
if config.cpp_wrapper: if config.cpp_wrapper:
with config.patch( with (
{ config.patch(
"cpp_wrapper": False, # reset to break recursive call to compile_fx {
**get_cpp_wrapper_config(), "cpp_wrapper": False, # reset to break recursive call to compile_fx
} **get_cpp_wrapper_config(),
), V.set_real_inputs(example_inputs_): }
),
V.set_real_inputs(example_inputs_),
):
inputs_: Sequence[InputType] = example_inputs_ inputs_: Sequence[InputType] = example_inputs_
if isinstance(model_, GraphModule): if isinstance(model_, GraphModule):
@ -1940,10 +1951,10 @@ def compile_fx(
# Do the actual work # Do the actual work
with _use_lazy_graph_module( with (
dynamo_config.use_lazy_graph_module _use_lazy_graph_module(dynamo_config.use_lazy_graph_module),
), enable_python_dispatcher(), torch.fx.traceback.preserve_node_meta( enable_python_dispatcher(),
config.trace.enabled torch.fx.traceback.preserve_node_meta(config.trace.enabled),
): ):
# Pre-grad passes cannot be run if we weren't given a GraphModule. # Pre-grad passes cannot be run if we weren't given a GraphModule.
# Dynamo will always produce a GraphModule, but this handles cases # Dynamo will always produce a GraphModule, but this handles cases
@ -2085,9 +2096,9 @@ def compile_fx(
boxed_forward_device_index=forward_device, boxed_forward_device_index=forward_device,
) )
fw_compiler: Callable[ fw_compiler: Callable[[GraphModule, Sequence[InputType]], OutputCode] = (
[GraphModule, Sequence[InputType]], OutputCode functools.partial(fw_compiler_base, is_inference=False)
] = functools.partial(fw_compiler_base, is_inference=False) )
fw_compiler = SerializableAOTDispatchCompiler(OutputCode, fw_compiler) fw_compiler = SerializableAOTDispatchCompiler(OutputCode, fw_compiler)
if config.freezing and not torch.is_grad_enabled(): if config.freezing and not torch.is_grad_enabled():
@ -2124,9 +2135,10 @@ def compile_fx(
) -> OutputCode: ) -> OutputCode:
from torch._dynamo.convert_frame import compile_lock from torch._dynamo.convert_frame import compile_lock
with dynamo_utils.dynamo_timed( with (
"compile_fx.<locals>.bw_compiler" dynamo_utils.dynamo_timed("compile_fx.<locals>.bw_compiler"),
), compile_lock: compile_lock,
):
model_outputs_node = output_node(gm) model_outputs_node = output_node(gm)
if config.bw_outputs_user_visible: if config.bw_outputs_user_visible:
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args) model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
@ -2194,10 +2206,11 @@ def compile_fx(
with V.set_fake_mode(fake_mode), compiled_autograd._disable(), context(): with V.set_fake_mode(fake_mode), compiled_autograd._disable(), context():
return inference_compiler(unlifted_gm, example_inputs_) return inference_compiler(unlifted_gm, example_inputs_)
with V.set_fake_mode(fake_mode), torch._guards.tracing( with (
tracing_context V.set_fake_mode(fake_mode),
), compiled_autograd._disable(), functorch_config.patch( torch._guards.tracing(tracing_context),
unlift_effect_tokens=True compiled_autograd._disable(),
functorch_config.patch(unlift_effect_tokens=True),
): ):
try: try:
return aot_autograd( return aot_autograd(

View File

@ -530,7 +530,8 @@ class CompilerBisector:
) )
if result: if result:
curr_subsystem = cls.get_subsystem_object( curr_subsystem = cls.get_subsystem_object(
curr_backend, cls.get_subsystem() # type: ignore[arg-type] curr_backend,
cls.get_subsystem(), # type: ignore[arg-type]
) )
if isinstance(curr_subsystem, BinarySubsystem): if isinstance(curr_subsystem, BinarySubsystem):

View File

@ -80,9 +80,9 @@ fx_graph_cache: bool = Config(
fx_graph_remote_cache: Optional[bool] = fx_graph_remote_cache_default() fx_graph_remote_cache: Optional[bool] = fx_graph_remote_cache_default()
# should we bundle triton caching into fx graph cache # should we bundle triton caching into fx graph cache
bundle_triton_into_fx_graph_cache: Optional[ bundle_triton_into_fx_graph_cache: Optional[bool] = (
bool bundle_triton_into_fx_graph_cache_default()
] = bundle_triton_into_fx_graph_cache_default() )
# Enable autotune local cache. # Enable autotune local cache.
# #
@ -1390,12 +1390,12 @@ class halide:
# Halide autoscheduler to use, choices are: # Halide autoscheduler to use, choices are:
# "Anderson2021" (gpu-only), "Li2018", "Adams2019" (cpu-only), or "Mullapudi2016" (cpu-only) # "Anderson2021" (gpu-only), "Li2018", "Adams2019" (cpu-only), or "Mullapudi2016" (cpu-only)
scheduler_cuda: Literal[ scheduler_cuda: Literal["Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"] = (
"Anderson2021", "Li2018", "Adams2019", "Mullapudi2016" "Anderson2021"
] = "Anderson2021" )
scheduler_cpu: Literal[ scheduler_cpu: Literal["Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"] = (
"Anderson2021", "Li2018", "Adams2019", "Mullapudi2016" "Adams2019"
] = "Adams2019" )
# Controls `no_asserts` flag passed to Halide target (warning: can false positive) # Controls `no_asserts` flag passed to Halide target (warning: can false positive)
asserts = False asserts = False

View File

@ -125,7 +125,8 @@ class ConstantFolder(torch.fx.Interpreter):
and is_woq_int8_pattern(next(iter(node.users))) and is_woq_int8_pattern(next(iter(node.users)))
) )
) and is_const_source( ) and is_const_source(
node.args[0], self.lifted_constant_names # type: ignore[arg-type] node.args[0], # type: ignore[arg-type]
self.lifted_constant_names,
): ):
# Case 1: int8_weight -> dq -> bf16_weight # Case 1: int8_weight -> dq -> bf16_weight
# Case 2: int8_weight -> permute -> dq -> bf16_weight # Case 2: int8_weight -> permute -> dq -> bf16_weight

View File

@ -1633,8 +1633,8 @@ class CppBuilder:
""" """
) )
assert os.path.exists( assert os.path.exists(cmake_path), (
cmake_path f"save_link_cmd_to_cmakefile expects {cmake_path} to already exist"
), f"save_link_cmd_to_cmakefile expects {cmake_path} to already exist" )
with open(cmake_path, "a") as f: with open(cmake_path, "a") as f:
f.write(contents) f.write(contents)

View File

@ -119,6 +119,7 @@ from . import config
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class GraphID: class GraphID:
"Unique counter of a cuda graph recording" "Unique counter of a cuda graph recording"
id: int id: int
@ -622,11 +623,15 @@ class CUDAWarmupNode:
refs = list(self.path_live_weakrefs()) refs = list(self.path_live_weakrefs())
check_memory_pool(self.device_index, self.cuda_graphs_pool, refs) check_memory_pool(self.device_index, self.cuda_graphs_pool, refs)
with torch.cuda.device( with (
self.device_index torch.cuda.device(self.device_index),
), disable_conv_cache_emptying(), clear_cublas_manager(), _use_cuda_memory_pool_manager( disable_conv_cache_emptying(),
self.device_index, self.cuda_graphs_pool, self.stream clear_cublas_manager(),
), get_history_recording(): _use_cuda_memory_pool_manager(
self.device_index, self.cuda_graphs_pool, self.stream
),
get_history_recording(),
):
out = self.wrapped_function.model(new_inputs) out = self.wrapped_function.model(new_inputs)
# We need to know which outputs are allocated within the cudagraph pool # We need to know which outputs are allocated within the cudagraph pool
@ -713,6 +718,7 @@ UnaliasedStorage = _UnaliasedStorage()
class AliasesPriorGraphOutput(OutputAliasInfo): class AliasesPriorGraphOutput(OutputAliasInfo):
"Marks that the graph output aliases an output of a prior graph" "Marks that the graph output aliases an output of a prior graph"
__slots__ = ["index"] __slots__ = ["index"]
index: PathOutputIndex index: PathOutputIndex
@ -1200,14 +1206,18 @@ class CUDAGraphNode:
] ]
check_memory_pool(self.device, self.cuda_graphs_pool, memory) check_memory_pool(self.device, self.cuda_graphs_pool, memory)
with preserve_rng_state(), torch.cuda.device( with (
self.device preserve_rng_state(),
), clear_cublas_manager(), torch.cuda.graph( torch.cuda.device(self.device),
self.graph, clear_cublas_manager(),
stream=self.stream, torch.cuda.graph(
pool=self.cuda_graphs_pool, self.graph,
capture_error_mode="thread_local", stream=self.stream,
), get_history_recording(): pool=self.cuda_graphs_pool,
capture_error_mode="thread_local",
),
get_history_recording(),
):
static_outputs = model(inputs) static_outputs = model(inputs)
# running model should reclaim memory # running model should reclaim memory
@ -1247,11 +1257,13 @@ class CUDAGraphNode:
self.output_storage_alias.append(UnaliasedStorage) self.output_storage_alias.append(UnaliasedStorage)
continue continue
torch._check( (
o.is_cuda or o.untyped_storage().data_ptr() == 0, torch._check(
lambda: ( o.is_cuda or o.untyped_storage().data_ptr() == 0,
"Expected all cuda outputs in cuda graph recording. Non cuda output " lambda: (
f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}" "Expected all cuda outputs in cuda graph recording. Non cuda output "
f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}"
),
), ),
) )
@ -1291,9 +1303,9 @@ class CUDAGraphNode:
if self.stack_traces is None: if self.stack_traces is None:
self.stack_traces = [None for _ in range(len(outputs))] self.stack_traces = [None for _ in range(len(outputs))]
else: else:
assert len(self.stack_traces) == len( assert len(self.stack_traces) == len(outputs), (
outputs "Wrong number of stack traces passed in"
), "Wrong number of stack traces passed in" )
assert not self.outputs_weakrefs assert not self.outputs_weakrefs
for out, static_output_tensor in zip(outputs, self.static_output_tensors): for out, static_output_tensor in zip(outputs, self.static_output_tensors):
@ -1599,12 +1611,14 @@ class CUDAGraphNode:
self.stream.wait_stream(torch.cuda.current_stream()) self.stream.wait_stream(torch.cuda.current_stream())
recording_inputs: list[InputType] = [] recording_inputs: list[InputType] = []
with warnings.catch_warnings(record=True), torch.cuda.device( with (
self.device warnings.catch_warnings(record=True),
), _use_cuda_memory_pool_manager( torch.cuda.device(self.device),
self.device, _use_cuda_memory_pool_manager(
mem_pool=self.cuda_graphs_pool, self.device,
stream=self.stream, mem_pool=self.cuda_graphs_pool,
stream=self.stream,
),
): ):
for i, inp in enumerate(inputs): for i, inp in enumerate(inputs):
if not isinstance(inp, torch.Tensor): if not isinstance(inp, torch.Tensor):
@ -1736,12 +1750,8 @@ def check_memory_pool(
pool_id: tuple[int, int], pool_id: tuple[int, int],
live_storages_ptrs: list[StorageWeakRefWrapper], live_storages_ptrs: list[StorageWeakRefWrapper],
) -> None: ) -> None:
assert all( assert all(isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs) # noqa: C419
isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()} # noqa: set_linter
) # noqa: C419
unique_storages = {
stor.data_ptr() for stor in live_storages_ptrs if stor()
} # noqa: set_linter
# check if there is a divergence first, then do the expensive snapshot call after # check if there is a divergence first, then do the expensive snapshot call after
# we know it will error # we know it will error
@ -1864,11 +1874,14 @@ class CUDAGraphTreeManager:
self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph() self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph()
self.cuda_graphs_thread_pool = torch.cuda.graph_pool_handle() self.cuda_graphs_thread_pool = torch.cuda.graph_pool_handle()
with warnings.catch_warnings(record=True), torch.cuda.graph( with (
self.graph, warnings.catch_warnings(record=True),
pool=self.cuda_graphs_thread_pool, torch.cuda.graph(
stream=self.stream, self.graph,
capture_error_mode="thread_local", pool=self.cuda_graphs_thread_pool,
stream=self.stream,
capture_error_mode="thread_local",
),
): ):
pass pass
@ -2230,7 +2243,10 @@ class CUDAGraphTreeManager:
constants: tuple[torch.Tensor, ...], constants: tuple[torch.Tensor, ...],
placeholders: tuple[PlaceholderInfo, ...], placeholders: tuple[PlaceholderInfo, ...],
mutated_input_idxs: tuple[int, ...], mutated_input_idxs: tuple[int, ...],
) -> tuple[ModelType, OutputType,]: ) -> tuple[
ModelType,
OutputType,
]:
id = self.new_func_id() id = self.new_func_id()
self.ids_to_stack_traces[id] = stack_traces self.ids_to_stack_traces[id] = stack_traces
self.ids_to_funcs[id] = WrappedFunction( self.ids_to_funcs[id] = WrappedFunction(

View File

@ -28,6 +28,7 @@ ModelType = Callable[[list[InputType]], OutputType]
@dataclasses.dataclass(frozen=True) @dataclasses.dataclass(frozen=True)
class FunctionID: class FunctionID:
"Unique counter of a function wrapped in cudagraphify_impl" "Unique counter of a function wrapped in cudagraphify_impl"
id: int id: int
@ -164,7 +165,7 @@ def _get_use_stack_trace(node: torch.fx.Node) -> Optional[str]:
def check_multiple_devices_or_any_cpu_nodes( def check_multiple_devices_or_any_cpu_nodes(
device_node_mapping: dict[torch.device, torch.fx.Node] device_node_mapping: dict[torch.device, torch.fx.Node],
) -> Optional[str]: ) -> Optional[str]:
if cpu_node := device_node_mapping.get(torch.device("cpu")): if cpu_node := device_node_mapping.get(torch.device("cpu")):
msg = f"cpu device ({cpu_node.name})" msg = f"cpu device ({cpu_node.name})"
@ -184,7 +185,7 @@ def check_multiple_devices_or_any_cpu_nodes(
def check_lowering_disable_cudagraph( def check_lowering_disable_cudagraph(
device_node_mapping: dict[torch.device, torch.fx.Node] device_node_mapping: dict[torch.device, torch.fx.Node],
) -> Optional[str]: ) -> Optional[str]:
return check_multiple_devices_or_any_cpu_nodes(device_node_mapping) return check_multiple_devices_or_any_cpu_nodes(device_node_mapping)
@ -276,9 +277,9 @@ def log_data_ptr_mismatch(
Logs the mismatch between input data pointers and recorded data pointers. Logs the mismatch between input data pointers and recorded data pointers.
This checks only idxs in target_idxs. This checks only idxs in target_idxs.
""" """
assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len( assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len(placeholders), (
placeholders "length mismatch between inputs, recorded_data_ptr, and placeholders"
), "length mismatch between inputs, recorded_data_ptr, and placeholders" )
t_tensors = [inputs[i] for i in target_idxs] t_tensors = [inputs[i] for i in target_idxs]
t_data_ptrs = [recorded_data_ptr[i] for i in target_idxs] t_data_ptrs = [recorded_data_ptr[i] for i in target_idxs]

View File

@ -240,7 +240,7 @@ def update_orig_fx_node_name_to_buf_name(
def get_node_name_to_buf_meta( def get_node_name_to_buf_meta(
node_name_to_buf_name: dict[str, str] node_name_to_buf_name: dict[str, str],
) -> dict[str, BufMeta]: ) -> dict[str, BufMeta]:
buf_name_to_n_node = {} buf_name_to_n_node = {}
for node_name, buf_name in node_name_to_buf_name.items(): for node_name, buf_name in node_name_to_buf_name.items():

View File

@ -123,7 +123,7 @@ remove_decompositions(decompositions, decomps_to_exclude)
def register_decomposition( def register_decomposition(
ops: list[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]] ops: list[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]],
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined] for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined]
if op in decompositions: if op in decompositions:

View File

@ -194,7 +194,9 @@ class MemoryDep(Dep):
) )
new_index = sympy_subs(sympy.expand(self.index), replacement) # type: ignore[arg-type] # next PR new_index = sympy_subs(sympy.expand(self.index), replacement) # type: ignore[arg-type] # next PR
out = MemoryDep(self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values())) # type: ignore[arg-type] out = MemoryDep(
self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values())
) # type: ignore[arg-type]
return out return out
@property @property
@ -649,11 +651,16 @@ def extract_loop_body_with_args(
inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) # type: ignore[arg-type] inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) # type: ignore[arg-type]
for entry in fn.memory_usage[MemoryUsageType.STORE]: for entry in fn.memory_usage[MemoryUsageType.STORE]:
inner.store( inner.store(
entry.buffer_name, name_to_index[entry.index_name], None, entry.mode # type: ignore[arg-type] entry.buffer_name,
name_to_index[entry.index_name],
None, # type: ignore[arg-type]
entry.mode,
) )
for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]: for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]:
inner.store_reduction( inner.store_reduction(
entry.buffer_name, name_to_index[entry.index_name], None # type: ignore[arg-type] entry.buffer_name,
name_to_index[entry.index_name],
None, # type: ignore[arg-type]
) )
for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]: for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]:
inner.index_expr(name_to_index[entry.index_name], None) inner.index_expr(name_to_index[entry.index_name], None)
@ -661,7 +668,11 @@ def extract_loop_body_with_args(
# All that matters is that we record the buffer name, so place it in the # All that matters is that we record the buffer name, so place it in the
# "boundaries" name position to ensure that it's recorded. # "boundaries" name position to ensure that it's recorded.
inner.bucketize( inner.bucketize(
None, (entry.buffer_name, None, None, None), None, None, None # type: ignore[arg-type] None,
(entry.buffer_name, None, None, None),
None,
None, # type: ignore[arg-type]
None, # type: ignore[arg-type]
) )
# fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped # fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped
return inner return inner
@ -801,8 +812,9 @@ def extract_free_unbacked_symbols(
handler = FreeUnbackedSymbolsOpsHandler() handler = FreeUnbackedSymbolsOpsHandler()
# NB: I cargo culted the allow_indexing patch here, I don't understand why # NB: I cargo culted the allow_indexing patch here, I don't understand why
# people do this all over # people do this all over
with V.set_ops_handler(handler), patch.object( with (
FlexibleLayout, "allow_indexing", True V.set_ops_handler(handler),
patch.object(FlexibleLayout, "allow_indexing", True),
): ):
fn(*args) fn(*args)
return handler.symbols return handler.symbols

View File

@ -19,8 +19,7 @@ T = TypeVar("T")
class DTypeVar(Protocol): class DTypeVar(Protocol):
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype: ...
...
DTypeArg = Union[DTypeVar, torch.types.Number, str, OpsValue] DTypeArg = Union[DTypeVar, torch.types.Number, str, OpsValue]

View File

@ -526,6 +526,7 @@ class ConfigFuzzer:
```python ```python
import torch._inductor.config as cfg import torch._inductor.config as cfg
def create_simple_test_model_gpu() -> FactoryOutputType: def create_simple_test_model_gpu() -> FactoryOutputType:
batch_size = 32 batch_size = 32
seq_length = 50 seq_length = 50
@ -539,6 +540,8 @@ class ConfigFuzzer:
return True return True
return test_fn return test_fn
fuzzer = ConfigFuzzer(cfg, create_simple_test_model_gpu, seed=2) fuzzer = ConfigFuzzer(cfg, create_simple_test_model_gpu, seed=2)
# Test every pair of configs: # Test every pair of configs:
@ -550,7 +553,9 @@ class ConfigFuzzer:
ret = fuzzer.bisect(num_attempts=10) ret = fuzzer.bisect(num_attempts=10)
# reproduce a failing config # reproduce a failing config
fuzzer.reproduce([{"triton.autotune_pointwise": ..., "coordinate_descent_tuning": ...}]) fuzzer.reproduce(
[{"triton.autotune_pointwise": ..., "coordinate_descent_tuning": ...}]
)
``` ```
The list of known failures on inductor config are: The list of known failures on inductor config are:

View File

@ -531,7 +531,11 @@ def tuned_b2b_gemm(
A.realize() A.realize()
B.realize() B.realize()
C.realize() C.realize()
layout = FixedLayout(A.get_device_or_error(), A.get_dtype(), [A.shape[0], C.shape[1]]) # type: ignore[index] layout = FixedLayout(
A.get_device_or_error(),
A.get_dtype(),
[A.shape[0], C.shape[1]], # type: ignore[index]
)
subgraph_buffer = build_subgraph_buffer( subgraph_buffer = build_subgraph_buffer(
[create_placeholder("inner_mm", A.get_dtype(), A.get_device_or_error())], [create_placeholder("inner_mm", A.get_dtype(), A.get_device_or_error())],
subgraph, subgraph,

View File

@ -545,9 +545,9 @@ def schedule_comm_wait(graph: fx.Graph) -> None:
node_indices = {node: i for i, node in enumerate(graph.nodes)} node_indices = {node: i for i, node in enumerate(graph.nodes)}
for allreduce in comm_blocks: for allreduce in comm_blocks:
# Find the earliest/first user -- target_node. # Find the earliest/first user -- target_node.
assert ( assert len(allreduce.outputs) >= 1, (
len(allreduce.outputs) >= 1 f"Found a allreduce that has zero outputs/users -- {allreduce}."
), f"Found a allreduce that has zero outputs/users -- {allreduce}." )
# Initialize the target node to avoid typing issues. # Initialize the target node to avoid typing issues.
target_node = next(iter(next(iter(allreduce.outputs)).users)) target_node = next(iter(next(iter(allreduce.outputs)).users))
target_node_index = 2**31 target_node_index = 2**31

View File

@ -380,7 +380,9 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
# argument. `graph.get_attr` and # argument. `graph.get_attr` and
# `graph.call_function` does not allow the `name` argument. # `graph.call_function` does not allow the `name` argument.
conv_get_node = graph.create_node( conv_get_node = graph.create_node(
op="get_attr", target=conv_node.target, name="get_conv" # type: ignore[union-attr] op="get_attr",
target=conv_node.target, # type: ignore[union-attr]
name="get_conv",
) )
bn_get_node = graph.create_node( bn_get_node = graph.create_node(
op="get_attr", target=bn_node.target, name="get_bn" op="get_attr", target=bn_node.target, name="get_bn"

View File

@ -866,15 +866,18 @@ def _get_sfdp_patterns():
name += "_bs1" name += "_bs1"
training_name = name + "_training" training_name = name + "_training"
yield training_name, { yield (
"search_fn": pattern, training_name,
"replace_fn": replacement, {
"example_inputs": args, "search_fn": pattern,
"trace_fn": joint_fwd_bwd, "replace_fn": replacement,
"pass_dicts": patterns, "example_inputs": args,
"extra_check": extra_check, "trace_fn": joint_fwd_bwd,
"scalar_workaround": workaround, "pass_dicts": patterns,
} "extra_check": extra_check,
"scalar_workaround": workaround,
},
)
if workaround: if workaround:
assert len(workaround) == 1 and "dropout_p" in workaround assert len(workaround) == 1 and "dropout_p" in workaround
@ -886,18 +889,21 @@ def _get_sfdp_patterns():
workaround = {} workaround = {}
inference_name = name + "_inference" inference_name = name + "_inference"
yield inference_name, { yield (
"search_fn": pattern, inference_name,
"replace_fn": replacement, {
"example_inputs": args, "search_fn": pattern,
"trace_fn": fwd_only, "replace_fn": replacement,
"pass_dicts": patterns, "example_inputs": args,
"extra_check": extra_check, "trace_fn": fwd_only,
"scalar_workaround": workaround, "pass_dicts": patterns,
# with dropout turned into clone, we end up with a number of "extra_check": extra_check,
# semantically identical graphs "scalar_workaround": workaround,
"skip_duplicates": True, # with dropout turned into clone, we end up with a number of
} # semantically identical graphs
"skip_duplicates": True,
},
)
@functools.lru_cache(None) @functools.lru_cache(None)

View File

@ -271,7 +271,9 @@ class PostGradBatchLinearFusion(BatchFusion):
args=(batch_biases[i],), args=(batch_biases[i],),
kwargs={"size": broadcast_shape}, kwargs={"size": broadcast_shape},
) )
broadcast_bias.meta["val"] = aten.broadcast_to(batch_biases_meta[i]["val"], broadcast_shape) # type: ignore[assignment] broadcast_bias.meta["val"] = aten.broadcast_to(
batch_biases_meta[i]["val"], broadcast_shape
) # type: ignore[assignment]
new_bias_add = graph.call_function( # type: ignore[operator] new_bias_add = graph.call_function( # type: ignore[operator]
aten.add.Tensor, args=((broadcast_bias, new_mm)) aten.add.Tensor, args=((broadcast_bias, new_mm))
) )
@ -803,9 +805,9 @@ class BatchLayernormFusion(BatchFusion):
group_biases = None # type: ignore[assignment] group_biases = None # type: ignore[assignment]
if all(weight is None for weight in group_weights): if all(weight is None for weight in group_weights):
group_weights = None # type: ignore[assignment] group_weights = None # type: ignore[assignment]
assert all( assert all(eps == group_epss[0] for eps in group_epss), (
eps == group_epss[0] for eps in group_epss "all epsilon values must be equal"
), "all epsilon values must be equal" )
with graph.inserting_before(subset[0]): # type: ignore[operator] with graph.inserting_before(subset[0]): # type: ignore[operator]
stack_input = graph.call_function( # type: ignore[operator] stack_input = graph.call_function( # type: ignore[operator]
@ -996,7 +998,11 @@ class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
# for relu op, we also use the inplace to construct the key # for relu op, we also use the inplace to construct the key
# we batch the ops with same parent to enable followup split cat # we batch the ops with same parent to enable followup split cat
parent = node.args[0] parent = node.args[0]
parent = parent.target if self.graph_search_options.get("fuse_nodes_with_same_parent", False) else "" # type: ignore[union-attr] parent = (
parent.target # type: ignore[union-attr]
if self.graph_search_options.get("fuse_nodes_with_same_parent", False)
else ""
)
group_key = ( group_key = (
"batch_aten_" + self.op.__name__.lower().split(".")[0], "batch_aten_" + self.op.__name__.lower().split(".")[0],
str(input.meta["val"].shape), str(input.meta["val"].shape),
@ -1293,9 +1299,9 @@ def get_fusion_candidates(
""" """
q: collections.deque[tuple[int, torch.fx.Node]] = collections.deque() q: collections.deque[tuple[int, torch.fx.Node]] = collections.deque()
candidate_dict: collections.defaultdict[ candidate_dict: collections.defaultdict[Any, list[torch.fx.Node]] = (
Any, list[torch.fx.Node] collections.defaultdict(list)
] = collections.defaultdict(list) )
if root_node.target in SEARCH_EXCLUSIONS: if root_node.target in SEARCH_EXCLUSIONS:
return candidate_dict return candidate_dict

View File

@ -763,9 +763,7 @@ def _get_node_to_ancestors(
""" """
Compute the ancestors for all nodes in a graph. Compute the ancestors for all nodes in a graph.
""" """
node_to_ancestors = defaultdict( node_to_ancestors = defaultdict(OrderedSet[torch.fx.Node]) # type: ignore[var-annotated]
OrderedSet[torch.fx.Node]
) # type: ignore[var-annotated]
for node in graph.nodes: for node in graph.nodes:
node_to_ancestors[node] = OrderedSet(node.all_input_nodes) node_to_ancestors[node] = OrderedSet(node.all_input_nodes)
for dep in node.all_input_nodes: for dep in node.all_input_nodes:

View File

@ -558,9 +558,9 @@ if torch._C._has_mkldnn:
binary_nodes = filter_nodes(match.nodes, binary_op) binary_nodes = filter_nodes(match.nodes, binary_op)
def _get_compute_node(_binary_node, _other_index): def _get_compute_node(_binary_node, _other_index):
assert ( assert len(_binary_node.all_input_nodes) == 2, (
len(_binary_node.all_input_nodes) == 2 "Binary node should have 2 input nodes."
), "Binary node should have 2 input nodes." )
_compute_index = 1 if (_other_index == 0) else 0 _compute_index = 1 if (_other_index == 0) else 0
return _binary_node.args[_compute_index] return _binary_node.args[_compute_index]
@ -614,9 +614,9 @@ if torch._C._has_mkldnn:
else: else:
computation_args += [1.0, None, [], None] computation_args += [1.0, None, [], None]
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1 counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1
counters["inductor"][ counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"] += (
"mkldnn_conv_binary_unary_fusion_matcher_nodes" len(match.nodes)
] += len(match.nodes) )
return L[fusion_op](*computation_args) return L[fusion_op](*computation_args)
return fn return fn
@ -659,9 +659,9 @@ if torch._C._has_mkldnn:
else: else:
computation_args += [1.0, None, [], None] computation_args += [1.0, None, [], None]
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1 counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1
counters["inductor"][ counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"] += (
"mkldnn_conv_binary_unary_fusion_matcher_nodes" len(match.nodes)
] += len(match.nodes) )
# Make sure the other is not an alias or mutation(fx side doesn't has such info). # Make sure the other is not an alias or mutation(fx side doesn't has such info).
other.realize() other.realize()
if not _can_be_inplace(other) or other.data.shape != list( if not _can_be_inplace(other) or other.data.shape != list(
@ -1310,9 +1310,9 @@ if torch._C._has_mkldnn:
) )
batch_size = input.meta.get("val").shape[0] batch_size = input.meta.get("val").shape[0]
if has_free_symbols(batch_size): if has_free_symbols(batch_size):
assert ( assert is_lp_weight or mkldnn._is_mkldnn_acl_supported(), (
is_lp_weight or mkldnn._is_mkldnn_acl_supported() f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
), f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}" )
# For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance. # For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance.
packed_weight_inputs = ( packed_weight_inputs = (
transpose_weight_node, transpose_weight_node,

View File

@ -437,7 +437,7 @@ def _should_pad_bench(
return False return False
def realize_symbols( def realize_symbols(
ds: Union[torch.Size, tuple[torch.SymInt, ...]] ds: Union[torch.Size, tuple[torch.SymInt, ...]],
) -> list[int]: ) -> list[int]:
return [d if isinstance(d, int) else d.node.hint for d in ds] return [d if isinstance(d, int) else d.node.hint for d in ds]

View File

@ -137,9 +137,9 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
pattern_matcher_pass.apply pattern_matcher_pass.apply
) )
if not is_same_dict(counters["inductor"], inductor_before_change): if not is_same_dict(counters["inductor"], inductor_before_change):
optimus_scuba_log[ optimus_scuba_log[f"{pattern_matcher_pass.pass_name}_post_grad"] = (
f"{pattern_matcher_pass.pass_name}_post_grad" upload_graph(gm.graph)
] = upload_graph(gm.graph) )
if config.b2b_gemm_pass: if config.b2b_gemm_pass:
B2B_GEMM_PASS.apply(gm.graph) # type: ignore[arg-type] B2B_GEMM_PASS.apply(gm.graph) # type: ignore[arg-type]

View File

@ -277,9 +277,9 @@ def pre_grad_passes(
for _ in range(counter): for _ in range(counter):
pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type] pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
if not is_same_dict(counters["inductor"], inductor_before_change): if not is_same_dict(counters["inductor"], inductor_before_change):
optimus_scuba_log[ optimus_scuba_log[f"{pattern_matcher_pass.pass_name}_pre_grad"] = (
f"{pattern_matcher_pass.pass_name}_pre_grad" upload_graph(gm.graph)
] = upload_graph(gm.graph) )
# TODO: move efficient_conv_bn_eval_pass to the fusions dict too. # TODO: move efficient_conv_bn_eval_pass to the fusions dict too.
efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type] efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type]

View File

@ -763,9 +763,9 @@ def _register_quantized_conv_binary_lowering(
accum.realize() accum.realize()
from .mkldnn_fusion import _can_be_inplace from .mkldnn_fusion import _can_be_inplace
assert _can_be_inplace( assert _can_be_inplace(accum), (
accum "QConv Binary Inplace Fusion requires accum is not an alias or mutation."
), "QConv Binary Inplace Fusion requires accum is not an alias or mutation." )
computation_args = ( computation_args = (
x, x,
@ -1307,9 +1307,9 @@ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
def clone_to_new_node(graph, source_node, user_node): def clone_to_new_node(graph, source_node, user_node):
# Clone the source_node to a new node # Clone the source_node to a new node
# Replace user_node's input from source_node to new_node # Replace user_node's input from source_node to new_node
assert ( assert source_node.op == "call_function", (
source_node.op == "call_function" "clone_to_new_node only support node.op call_function"
), "clone_to_new_node only support node.op call_function" )
with graph.inserting_before(user_node): with graph.inserting_before(user_node):
new_node = graph.call_function( new_node = graph.call_function(
source_node.target, source_node.target,
@ -1343,9 +1343,9 @@ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
# For a dequant pattern, we expect the start node is a dequantize_per_tensor node # For a dequant pattern, we expect the start node is a dequantize_per_tensor node
return _node return _node
else: else:
assert ( assert len(_node.args) >= 1, (
len(_node.args) >= 1 "In in dequant pattern, each node should have more than 1 arg."
), "In in dequant pattern, each node should have more than 1 arg." )
return _find_first_node_in_dequant_pattern(_node.args[0]) return _find_first_node_in_dequant_pattern(_node.args[0])
dequant_pattern_start_node = _find_first_node_in_dequant_pattern( dequant_pattern_start_node = _find_first_node_in_dequant_pattern(

View File

@ -616,7 +616,8 @@ def merge_splits(
dim=first_split_dim, dim=first_split_dim,
) )
first_split_num_to_user = { first_split_num_to_user = {
user.args[1]: user for user in first_split.users.keys() # type: ignore[union-attr] user.args[1]: user
for user in first_split.users.keys() # type: ignore[union-attr]
} }
new_split_num = 0 new_split_num = 0
@ -706,7 +707,11 @@ class SplitCatSimplifier:
graph, split_node, split_sections, user_inputs_list, simplified_split_ranges graph, split_node, split_sections, user_inputs_list, simplified_split_ranges
) )
self.replace_cat( self.replace_cat(
graph, split_node, next_users, user_inputs_list_new, transform_params_list # type: ignore[arg-type] graph,
split_node,
next_users,
user_inputs_list_new,
transform_params_list, # type: ignore[arg-type]
) )
self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type] self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type]
counters["inductor"]["unbind_stack_pass"] += 1 counters["inductor"]["unbind_stack_pass"] += 1
@ -913,7 +918,9 @@ class SplitCatSimplifier:
) )
if is_node_meta_valid(split_input): # type: ignore[arg-type, union-attr] if is_node_meta_valid(split_input): # type: ignore[arg-type, union-attr]
new_split.meta["example_value"] = torch.split( new_split.meta["example_value"] = torch.split(
split_input.meta["example_value"], [r[1] - r[0] for r in split_ranges], dim=split_dim # type: ignore[union-attr] split_input.meta["example_value"], # type: ignore[union-attr]
[r[1] - r[0] for r in split_ranges],
dim=split_dim,
) )
counters["inductor"]["scmerge_split_added"] += 1 counters["inductor"]["scmerge_split_added"] += 1
split_items = [] split_items = []
@ -1005,7 +1012,10 @@ class SplitCatSimplifier:
stacked_input = graph.call_function( stacked_input = graph.call_function(
torch.stack, args=(to_stack,), kwargs={"dim": stack_dim} torch.stack, args=(to_stack,), kwargs={"dim": stack_dim}
) )
stacked_input.meta["example_value"] = torch.stack(to_stack_meta, dim=stack_dim) # type: ignore[arg-type, union-attr] stacked_input.meta["example_value"] = torch.stack( # type: ignore[arg-type]
to_stack_meta,
dim=stack_dim, # type: ignore[arg-type]
)
to_stack, to_stack_meta = [], [] to_stack, to_stack_meta = [], []
stack_dim = None stack_dim = None
user_inputs_new_transformed.append(stacked_input) user_inputs_new_transformed.append(stacked_input)
@ -1023,19 +1033,28 @@ class SplitCatSimplifier:
user_input_new = graph.call_function( user_input_new = graph.call_function(
torch.unflatten, args=(user_input_new, *unflatten_params) torch.unflatten, args=(user_input_new, *unflatten_params)
) )
user_input_new.meta["example_value"] = torch.unflatten(user_input_new_meta, *unflatten_params) # type: ignore[arg-type, possibly-undefined, union-attr] user_input_new.meta["example_value"] = torch.unflatten( # type: ignore[arg-type]
user_input_new_meta, # type: ignore[arg-type]
*unflatten_params, # type: ignore[arg-type]
)
if movedim_params: if movedim_params:
user_input_new_meta = user_input_new.meta["example_value"] user_input_new_meta = user_input_new.meta["example_value"]
user_input_new = graph.call_function( user_input_new = graph.call_function(
torch.movedim, args=(user_input_new, *movedim_params) torch.movedim, args=(user_input_new, *movedim_params)
) )
user_input_new.meta["example_value"] = torch.movedim(user_input_new_meta, *movedim_params) # type: ignore[arg-type, possibly-undefined, union-attr] user_input_new.meta["example_value"] = torch.movedim( # type: ignore[arg-type]
user_input_new_meta, # type: ignore[arg-type]
*movedim_params, # type: ignore[arg-type]
)
if flatten_params: if flatten_params:
user_input_new_meta = user_input_new.meta["example_value"] user_input_new_meta = user_input_new.meta["example_value"]
user_input_new = graph.call_function( user_input_new = graph.call_function(
torch.flatten, args=(user_input_new, *flatten_params) torch.flatten, args=(user_input_new, *flatten_params)
) )
user_input_new.meta["example_value"] = torch.flatten(user_input_new_meta, *flatten_params) # type: ignore[arg-type, possibly-undefined, union-attr] user_input_new.meta["example_value"] = torch.flatten( # type: ignore[arg-type]
user_input_new_meta,
*flatten_params, # type: ignore[arg-type]
)
user_inputs_new_transformed.append(user_input_new) user_inputs_new_transformed.append(user_input_new)
user_inputs_new_transformed_meta.append( user_inputs_new_transformed_meta.append(
user_input_new.meta["example_value"] user_input_new.meta["example_value"]
@ -1044,7 +1063,10 @@ class SplitCatSimplifier:
stacked_input = graph.call_function( stacked_input = graph.call_function(
torch.stack, args=(to_stack,), kwargs={"dim": stack_dim} torch.stack, args=(to_stack,), kwargs={"dim": stack_dim}
) )
stacked_input.meta["example_value"] = torch.stack(to_stack_meta, dim=stack_dim) # type: ignore[arg-type, union-attr] stacked_input.meta["example_value"] = torch.stack( # type: ignore[arg-type]
to_stack_meta,
dim=stack_dim, # type: ignore[arg-type]
)
user_inputs_new_transformed.append(stacked_input) user_inputs_new_transformed.append(stacked_input)
user_inputs_new_transformed_meta.append( user_inputs_new_transformed_meta.append(
stacked_input.meta["example_value"] stacked_input.meta["example_value"]
@ -1058,14 +1080,15 @@ class SplitCatSimplifier:
kwargs={"dim": cat_dim}, kwargs={"dim": cat_dim},
) )
new_cat_node.meta["example_value"] = torch.cat( new_cat_node.meta["example_value"] = torch.cat(
user_inputs_new_transformed_meta, dim=cat_dim user_inputs_new_transformed_meta,
dim=cat_dim,
) )
counters["inductor"]["scmerge_cat_added"] += 1 counters["inductor"]["scmerge_cat_added"] += 1
else: else:
new_cat_node = user_inputs_new_transformed[-1] new_cat_node = user_inputs_new_transformed[-1]
new_cat_node.meta[ new_cat_node.meta["example_value"] = (
"example_value" user_inputs_new_transformed_meta[-1]
] = user_inputs_new_transformed_meta[-1] )
if ( if (
user_node.target == torch.cat user_node.target == torch.cat
@ -1077,7 +1100,11 @@ class SplitCatSimplifier:
new_cat_node = graph.call_function( new_cat_node = graph.call_function(
torch.flatten, args=(new_cat_node, cat_dim, cat_dim + 1) torch.flatten, args=(new_cat_node, cat_dim, cat_dim + 1)
) )
new_cat_node.meta["example_value"] = torch.flatten(new_cat_node_meta, cat_dim, cat_dim + 1) # type: ignore[possibly-undefined, union-attr] new_cat_node.meta["example_value"] = torch.flatten(
new_cat_node_meta,
cat_dim,
cat_dim + 1,
)
user_node.replace_all_uses_with(new_cat_node) user_node.replace_all_uses_with(new_cat_node)
new_cats.append(new_cat_node) new_cats.append(new_cat_node)
@ -1123,9 +1150,7 @@ class UnbindCatRemover(SplitCatSimplifier):
] ]
if not is_sorted_and_consecutive(getitem_indices) or len( # type: ignore[arg-type] if not is_sorted_and_consecutive(getitem_indices) or len( # type: ignore[arg-type]
getitem_indices getitem_indices
) != len( ) != len(unbind_node.meta["example_value"]):
unbind_node.meta["example_value"]
):
return return
num_unbind = len(getitem_indices) num_unbind = len(getitem_indices)
split_sections = [1 for _ in range(num_unbind)] # type: ignore[operator, arg-type] split_sections = [1 for _ in range(num_unbind)] # type: ignore[operator, arg-type]
@ -1510,7 +1535,8 @@ def merge_getitem_cat(match: Match, split_sections: list[int], dim: int):
fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index] fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index]
# update the split sections # update the split sections
split_sections[indices[0]] = calculate_fused_tensor_size( # type: ignore[index] split_sections[indices[0]] = calculate_fused_tensor_size( # type: ignore[index]
split_node, indices # type: ignore[arg-type] split_node,
indices, # type: ignore[arg-type]
) )
# padding others with zeros to keep the same dict size # padding others with zeros to keep the same dict size
for i in indices[1:]: for i in indices[1:]:
@ -1613,10 +1639,12 @@ def mutate_cat_node(match: Match, split_sections: list[int], dim: int):
elif is_node_meta_valid(split_node.args[0]): # type: ignore[arg-type] elif is_node_meta_valid(split_node.args[0]): # type: ignore[arg-type]
# check the split dim, and construct the slice tuple # check the split dim, and construct the slice tuple
start_fused_size = calculate_fused_tensor_size( start_fused_size = calculate_fused_tensor_size(
split_node, list(range(indices[0])) # type: ignore[arg-type] split_node,
list(range(indices[0])), # type: ignore[arg-type]
) )
end_fused_size = start_fused_size + calculate_fused_tensor_size( end_fused_size = start_fused_size + calculate_fused_tensor_size(
split_node, indices # type: ignore[arg-type] split_node,
indices, # type: ignore[arg-type]
) )
slice_list = [] slice_list = []
for i in range(len(split_node.args[0].meta["example_value"].shape)): # type: ignore[union-attr] for i in range(len(split_node.args[0].meta["example_value"].shape)): # type: ignore[union-attr]
@ -1714,7 +1742,10 @@ def merge_split_cat_aten(match: Match, *args, **kwargs):
continue continue
# check the cat node has consecutive indices # check the cat node has consecutive indices
indices = [arg.args[1] for arg in cat_node.args[0]] # type: ignore[union-attr] indices = [arg.args[1] for arg in cat_node.args[0]] # type: ignore[union-attr]
if not is_sorted_and_consecutive(indices) and len(getitem_nodes) != len(cat_inputs): # type: ignore[arg-type] if (
not is_sorted_and_consecutive(indices) # type: ignore[arg-type]
and len(getitem_nodes) != len(cat_inputs)
):
continue continue
# replace the users of the cat node to be the input of the split node # replace the users of the cat node to be the input of the split node
cat_node.replace_all_uses_with(split_input) cat_node.replace_all_uses_with(split_input)
@ -1764,7 +1795,10 @@ def merge_select_cat_aten(match: Match, *args, **kwargs):
continue continue
# check the cat node has consecutive indices # check the cat node has consecutive indices
indices = [select.args[2] for select in cat_node.args[0]] # type: ignore[union-attr] indices = [select.args[2] for select in cat_node.args[0]] # type: ignore[union-attr]
if not is_sorted_and_consecutive(indices) or len(select_nodes) != len(cat_inputs): # type: ignore[arg-type] if (
not is_sorted_and_consecutive(indices) # type: ignore[arg-type]
or len(select_nodes) != len(cat_inputs)
):
continue continue
# check all the select nodes can be merged to the cat node input # check all the select nodes can be merged to the cat node input
if len(indices) != select_nodes[0].args[0].meta["val"].shape[cat_dim]: # type: ignore[union-attr] if len(indices) != select_nodes[0].args[0].meta["val"].shape[cat_dim]: # type: ignore[union-attr]
@ -2318,7 +2352,9 @@ def unbind_cat_to_view(match: Match, unbind_input: torch.fx.Node, dim: int):
args=(new_cat_args,), args=(new_cat_args,),
kwargs={"dim": cat_dim}, kwargs={"dim": cat_dim},
) )
new_cat_node.meta["example_value"] = torch.cat(new_cat_args_meta, dim=cat_dim) # type: ignore[arg-type] new_cat_node.meta["example_value"] = torch.cat(
new_cat_args_meta, dim=cat_dim
) # type: ignore[arg-type]
cat_node.replace_all_uses_with(new_cat_node) cat_node.replace_all_uses_with(new_cat_node)
new_cat_node.meta.update(cat_node.meta) new_cat_node.meta.update(cat_node.meta)
# remove inputs of cat_node if they have no users # remove inputs of cat_node if they have no users
@ -2411,7 +2447,8 @@ def convert_reshape_cat_arg_to_stack(
args=(permute_node, tuple(stack_node_shape)), # type: ignore[arg-type] args=(permute_node, tuple(stack_node_shape)), # type: ignore[arg-type]
) )
reshape_node.meta["example_value"] = torch.Tensor.view( reshape_node.meta["example_value"] = torch.Tensor.view(
permute_node.meta["example_value"], tuple(stack_node_shape) # type: ignore[arg-type] permute_node.meta["example_value"],
tuple(stack_node_shape), # type: ignore[arg-type]
) )
return reshape_node return reshape_node
@ -2687,7 +2724,9 @@ def move_reshape_out_of_split_stack(match: Match, *args, **kwargs):
cat_inputs.append(decomposed_stack_node) cat_inputs.append(decomposed_stack_node)
# cat_arg must be the split input # cat_arg must be the split input
view_shape_list = get_view_shape_list(cat_arg, stack_dim) view_shape_list = get_view_shape_list(cat_arg, stack_dim)
stack_node_shape = torch.reshape(cat_arg.meta["example_value"], tuple(view_shape_list)).shape # type: ignore[union-attr] stack_node_shape = torch.reshape(
cat_arg.meta["example_value"], tuple(view_shape_list)
).shape # type: ignore[union-attr]
cat_inputs.append( cat_inputs.append(
convert_reshape_cat_arg_to_stack( convert_reshape_cat_arg_to_stack(
graph, graph,

View File

@ -105,9 +105,9 @@ class FakeTensorUpdater:
if new is None: if new is None:
return old is None return old is None
if not isinstance(new, torch.Tensor): if not isinstance(new, torch.Tensor):
assert isinstance( assert isinstance(new, (torch.SymInt, torch.SymBool, torch.SymFloat)), (
new, (torch.SymInt, torch.SymBool, torch.SymFloat) f"Unknown type {type(new)} in {self.graph}"
), f"Unknown type {type(new)} in {self.graph}" )
return ( return (
new.node.shape_env._maybe_evaluate_static( new.node.shape_env._maybe_evaluate_static(
sympy.Eq(new.node.expr, old.node.expr) sympy.Eq(new.node.expr, old.node.expr)

View File

@ -136,7 +136,9 @@ else:
def may_get_constant_buffer_dtype(constant_buffer: sympy.Expr) -> Optional[torch.dtype]: def may_get_constant_buffer_dtype(constant_buffer: sympy.Expr) -> Optional[torch.dtype]:
assert isinstance( assert isinstance(
constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer) constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer" ), (
"get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer"
)
if isinstance(constant_buffer, sympy.core.numbers.Integer): if isinstance(constant_buffer, sympy.core.numbers.Integer):
return torch.int64 return torch.int64
@ -308,9 +310,9 @@ class GraphLowering(torch.fx.Interpreter):
self.reuse_shape_env = True self.reuse_shape_env = True
self._shape_env = shape_env self._shape_env = shape_env
# We're going to mutate ras_by_symbol as we finish generating them # We're going to mutate ras_by_symbol as we finish generating them
self.ras_by_symbol: dict[ self.ras_by_symbol: dict[Optional[sympy.Symbol], list[RuntimeAssert]] = (
Optional[sympy.Symbol], list[RuntimeAssert] shape_env.deferred_runtime_asserts.copy()
] = shape_env.deferred_runtime_asserts.copy() )
self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]() self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]()
self.sizevars = SizeVarAllocator(shape_env) self.sizevars = SizeVarAllocator(shape_env)
self.graph_input_names: list[str] = [] self.graph_input_names: list[str] = []
@ -400,9 +402,7 @@ class GraphLowering(torch.fx.Interpreter):
self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
self.cache_linemap: list[ self.cache_linemap: list[
tuple[int, str] tuple[int, str]
] = ( ] = [] # This is the linemap used by the profiler to mark custom compiled kernels getting run
[]
) # This is the linemap used by the profiler to mark custom compiled kernels getting run
# Used if lowering encounters cases where cudagraphs are not supported # Used if lowering encounters cases where cudagraphs are not supported
self.disable_cudagraphs_reason: Optional[str] = None self.disable_cudagraphs_reason: Optional[str] = None
@ -1012,7 +1012,10 @@ class GraphLowering(torch.fx.Interpreter):
) )
def placeholder( def placeholder(
self, target: str, args: tuple[object], kwargs: dict[str, object] # type: ignore[override] self,
target: str, # type: ignore[override]
args: tuple[object], # type: ignore[override]
kwargs: dict[str, object],
) -> Union[Expr, TensorBox, None]: ) -> Union[Expr, TensorBox, None]:
self.placeholder_idx += 1 self.placeholder_idx += 1
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type] example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
@ -1118,9 +1121,9 @@ class GraphLowering(torch.fx.Interpreter):
return target(*args, **kwargs) return target(*args, **kwargs)
if target not in lowerings: if target not in lowerings:
assert isinstance( assert isinstance(target, torch._ops.OpOverload), (
target, torch._ops.OpOverload f"{target} is not an OpOverload"
), f"{target} is not an OpOverload" )
base_name = target.name().split(".")[0] base_name = target.name().split(".")[0]
if base_name in FALLBACK_ALLOW_LIST: if base_name in FALLBACK_ALLOW_LIST:
make_fallback(target, warn=False, override_decomp=True) make_fallback(target, warn=False, override_decomp=True)
@ -1189,7 +1192,10 @@ class GraphLowering(torch.fx.Interpreter):
return len(t.shape) == 1 and t.shape[0] <= 8 return len(t.shape) == 1 and t.shape[0] <= 8
def get_attr( def get_attr(
self, target: str, args: tuple[()], kwargs: dict[str, object] # type: ignore[override] self,
target: str, # type: ignore[override]
args: tuple[()], # type: ignore[override]
kwargs: dict[str, object],
) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]: ) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
# this is a constant # this is a constant
value = getattr_recursive(self.module, target) # type: ignore[arg-type] value = getattr_recursive(self.module, target) # type: ignore[arg-type]
@ -1241,7 +1247,10 @@ class GraphLowering(torch.fx.Interpreter):
raise AssertionError raise AssertionError
def output( def output(
self, target: str, args: tuple[object], kwargs: dict[str, object] # type: ignore[override] self,
target: str, # type: ignore[override]
args: tuple[object], # type: ignore[override]
kwargs: dict[str, object],
) -> None: ) -> None:
result = super().output(target, args, kwargs) # type: ignore[arg-type] result = super().output(target, args, kwargs) # type: ignore[arg-type]
if not isinstance(result, (tuple, list)): if not isinstance(result, (tuple, list)):
@ -1439,9 +1448,11 @@ class GraphLowering(torch.fx.Interpreter):
if is_call_function: if is_call_function:
args, kwargs = self.fetch_args_kwargs_from_env(n) args, kwargs = self.fetch_args_kwargs_from_env(n)
origins |= gather_origins(args, kwargs) origins |= gather_origins(args, kwargs)
with ir.IRNode.current_origins(origins), self.set_current_node( with (
n ir.IRNode.current_origins(origins),
), V.set_current_node(n): self.set_current_node(n),
V.set_current_node(n),
):
if ( if (
n.op == "call_function" n.op == "call_function"
and n.target is not operator.getitem and n.target is not operator.getitem
@ -1454,7 +1465,8 @@ class GraphLowering(torch.fx.Interpreter):
): ):
debug("fallback_handler") debug("fallback_handler")
result = fallback_handler(n.target, add_to_fallback_set=False)( result = fallback_handler(n.target, add_to_fallback_set=False)(
*args, **kwargs # type: ignore[possibly-undefined] *args, # type: ignore[possibly-undefined]
**kwargs, # type: ignore[possibly-undefined]
) )
elif ( elif (
n.op == "call_function" n.op == "call_function"
@ -1833,9 +1845,9 @@ class GraphLowering(torch.fx.Interpreter):
wrapper_code_gen_cls = get_wrapper_codegen_for_device( wrapper_code_gen_cls = get_wrapper_codegen_for_device(
self.device_type, self.cpp_wrapper self.device_type, self.cpp_wrapper
) )
assert ( assert wrapper_code_gen_cls is not None, (
wrapper_code_gen_cls is not None f"Device {self.device_type} not supported"
), f"Device {self.device_type} not supported" )
self.wrapper_code = wrapper_code_gen_cls.create( self.wrapper_code = wrapper_code_gen_cls.create(
is_subgraph, is_subgraph,
subgraph_name, subgraph_name,
@ -1866,7 +1878,7 @@ class GraphLowering(torch.fx.Interpreter):
compiled = self.compile_to_module().call compiled = self.compile_to_module().call
def materialize( def materialize(
x: Union[torch.SymInt, torch.SymFloat, torch.Tensor] x: Union[torch.SymInt, torch.SymFloat, torch.Tensor],
) -> Union[int, float, torch.Tensor]: ) -> Union[int, float, torch.Tensor]:
if x is None: if x is None:
return None return None
@ -1876,9 +1888,9 @@ class GraphLowering(torch.fx.Interpreter):
elif isinstance(x, FakeTensor): elif isinstance(x, FakeTensor):
return defake(x) return defake(x)
else: else:
assert isinstance( assert isinstance(x, torch.Tensor), (
x, torch.Tensor "Unknown type when creating real inputs" + str(type(x))
), "Unknown type when creating real inputs" + str(type(x)) )
return x return x
tracing_context = torch._guards.TracingContext.try_get() tracing_context = torch._guards.TracingContext.try_get()

View File

@ -5,21 +5,22 @@ propagation of sympy expressions downstream of ops.index_expr calls.
For example, say we have the IR: For example, say we have the IR:
tmp0 = ops.index_expr(x, torch.int32) tmp0 = ops.index_expr(x, torch.int32)
tmp1 = ops.constant(2, torch.int32) tmp1 = ops.constant(2, torch.int32)
tmp2 = ops.mul(tmp0, tmp1) tmp2 = ops.mul(tmp0, tmp1)
tmp3 = ops.indirect_indexing(tmp2, x_size) tmp3 = ops.indirect_indexing(tmp2, x_size)
tmp4 = ops.load("buf0", tmp3) tmp4 = ops.load("buf0", tmp3)
The underlying handler would just see: The underlying handler would just see:
ops.load("buf0", x * 2) ops.load("buf0", x * 2)
This is limited by the set of operators handled in the sympy expression This is limited by the set of operators handled in the sympy expression
printers. So simple operations like minimum and maximum cannot be translated to printers. So simple operations like minimum and maximum cannot be translated to
SymPy expressions yet, despite sympy.Min and sympy.Max existing. SymPy expressions yet, despite sympy.Min and sympy.Max existing.
""" """
import itertools import itertools
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass
@ -179,9 +180,9 @@ class IndexPropVar:
return IndexPropVar(expr, is_symbolic=True) return IndexPropVar(expr, is_symbolic=True)
def __post_init__(self): def __post_init__(self):
assert not self.is_symbolic or isinstance( assert not self.is_symbolic or isinstance(self.value, TypedExpr), (
self.value, TypedExpr "Symbolic IndexPropVar must contain a TypedExpr"
), "Symbolic IndexPropVar must contain a TypedExpr" )
IndexPropResult: TypeAlias = Union[IndexPropVar, tuple["IndexPropResult", ...]] IndexPropResult: TypeAlias = Union[IndexPropVar, tuple["IndexPropResult", ...]]
@ -251,14 +252,12 @@ class IndexPropagation(DefaultHandler):
name: Literal["indirect_indexing"], name: Literal["indirect_indexing"],
args: Sequence[Any], args: Sequence[Any],
kwargs: dict[str, Any], kwargs: dict[str, Any],
) -> IndexPropVar: ) -> IndexPropVar: ...
...
@overload @overload
def fallback( def fallback(
self, name: str, args: Sequence[Any], kwargs: dict[str, Any] self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
) -> IndexPropResult: ) -> IndexPropResult: ...
...
def fallback( def fallback(
self, name: str, args: Sequence[Any], kwargs: dict[str, Any] self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
@ -283,8 +282,7 @@ class IndexPropagation(DefaultHandler):
is_valid_expr = new_expr is not NotImplemented and ( is_valid_expr = new_expr is not NotImplemented and (
# Inductor doesn't expect floating point in sympy expressions, but # Inductor doesn't expect floating point in sympy expressions, but
# allow floating point constants to be propagated # allow floating point constants to be propagated
new_expr.is_constant() new_expr.is_constant() or new_expr.expr.is_integer
or new_expr.expr.is_integer
) )
if not is_valid_expr: if not is_valid_expr:
return self.fallback(name, args, kwargs) return self.fallback(name, args, kwargs)

View File

@ -211,7 +211,9 @@ def validate_ir(node_or_nodes: Optional[_NodeOrNodes]) -> None:
int, int,
EffectfulKernel, EffectfulKernel,
), ),
), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]" ), (
f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"
)
# Be picky about the accepted data structure (don't use pytree here) # Be picky about the accepted data structure (don't use pytree here)
_check_tensorbox(node_or_nodes) _check_tensorbox(node_or_nodes)
@ -298,13 +300,11 @@ def get_stride_order(
@overload @overload
def ir_node_to_tensor(x: Literal[None], guard_shape: bool = True) -> None: def ir_node_to_tensor(x: Literal[None], guard_shape: bool = True) -> None: ...
...
@overload @overload
def ir_node_to_tensor(x: IRNode, guard_shape: bool = True) -> torch.Tensor: def ir_node_to_tensor(x: IRNode, guard_shape: bool = True) -> torch.Tensor: ...
...
def ir_node_to_tensor( def ir_node_to_tensor(
@ -346,7 +346,7 @@ def may_convert_to_optional(
def get_device_type( def get_device_type(
x: Union[IRNode, OutputSpec, torch.device, None, str] x: Union[IRNode, OutputSpec, torch.device, None, str],
) -> Optional[str]: ) -> Optional[str]:
if isinstance(x, str) or x is None: if isinstance(x, str) or x is None:
return x return x
@ -698,8 +698,7 @@ class IRNode:
if TYPE_CHECKING: if TYPE_CHECKING:
@property @property
def dtype(self) -> torch.dtype: def dtype(self) -> torch.dtype: ...
...
@ir_dataclass(frozen=False) @ir_dataclass(frozen=False)
@ -839,8 +838,9 @@ class Loops(IRNode):
@cache_on_self @cache_on_self
def inner_fn_opcount(self) -> OpCountResult: def inner_fn_opcount(self) -> OpCountResult:
opcounter = OpCounterCSE(V.MockHandler()) opcounter = OpCounterCSE(V.MockHandler())
with V.set_ops_handler(opcounter), patch.object( with (
FlexibleLayout, "allow_indexing", True V.set_ops_handler(opcounter),
patch.object(FlexibleLayout, "allow_indexing", True),
): ):
self.inner_fn(*self.inner_fn_args()) self.inner_fn(*self.inner_fn_args())
return opcounter.getvalue() return opcounter.getvalue()
@ -1364,9 +1364,9 @@ class Reduction(Loops):
# "all" is desugared to `!any(!val)` # "all" is desugared to `!any(!val)`
} }
assert ( assert reduction_type in rtypes_to_inits.keys(), (
reduction_type in rtypes_to_inits.keys() f"{reduction_type} not supported for zero-dimension tensors!"
), f"{reduction_type} not supported for zero-dimension tensors!" )
def const_fn(index: int) -> OpsValue: def const_fn(index: int) -> OpsValue:
return ops.constant(rtypes_to_inits[reduction_type], dst_dtype) return ops.constant(rtypes_to_inits[reduction_type], dst_dtype)
@ -1575,9 +1575,9 @@ class Reduction(Loops):
new_ranges: Sequence[Integer], new_ranges: Sequence[Integer],
new_reduction_ranges: Sequence[Integer], new_reduction_ranges: Sequence[Integer],
) -> Callable[[Sequence[sympy.Expr], Sequence[sympy.Expr]], OpsValue]: ) -> Callable[[Sequence[sympy.Expr], Sequence[sympy.Expr]], OpsValue]:
assert all( assert all(r == 1 for r in original_ranges), (
r == 1 for r in original_ranges f"Only enabled for numel_hint == 1, found {original_ranges=}"
), f"Only enabled for numel_hint == 1, found {original_ranges=}" )
reindex = View.dynamic_reshape_indexer( reindex = View.dynamic_reshape_indexer(
original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges) original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges)
) )
@ -1828,7 +1828,7 @@ class WelfordReduction(Reduction):
if reduction_numel == 1: if reduction_numel == 1:
def copy( def copy(
loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue] loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue],
) -> TensorBox: ) -> TensorBox:
def inner_fn(idx: Sequence[Expr]) -> OpsValue: def inner_fn(idx: Sequence[Expr]) -> OpsValue:
reduction_index = [sympy.S.Zero for _ in reduction_ranges] reduction_index = [sympy.S.Zero for _ in reduction_ranges]
@ -2571,9 +2571,9 @@ class ExpandView(BaseView):
# NB: new_size[i] == old_size[i] is expected to already be # NB: new_size[i] == old_size[i] is expected to already be
# guarded because the meta formula was expected to have taught # guarded because the meta formula was expected to have taught
# us this equality. # us this equality.
assert ( assert sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0, (
sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0 "Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}"
), "Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}" )
return new_size return new_size
@classmethod @classmethod
@ -3382,9 +3382,9 @@ class Layout(OutputSpec):
) )
def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]: def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
assert ( assert FlexibleLayout.allow_indexing, (
FlexibleLayout.allow_indexing f"convert {type(self).__name__} to FixedLayout first"
), f"convert {type(self).__name__} to FixedLayout first" )
return self.as_fixed().make_indexer() return self.as_fixed().make_indexer()
def __eq__(self, other) -> bool: # type: ignore[no-untyped-def] def __eq__(self, other) -> bool: # type: ignore[no-untyped-def]
@ -3684,9 +3684,9 @@ class MutationLayoutSHOULDREMOVE(Layout):
return target return target
result = unwrap_views(self.target) result = unwrap_views(self.target)
assert isinstance( assert isinstance(result, Buffer), (
result, Buffer "MutationLayoutSHOULDREMOVE must refer to a buffer"
), "MutationLayoutSHOULDREMOVE must refer to a buffer" )
return result return result
def real_layout(self): # type: ignore[no-untyped-def] def real_layout(self): # type: ignore[no-untyped-def]
@ -3803,7 +3803,9 @@ class Buffer(IRNode):
assert isinstance(self.layout, FlexibleLayout) assert isinstance(self.layout, FlexibleLayout)
self.layout = self.layout.as_same_order(stride) self.layout = self.layout.as_same_order(stride)
def freeze_layout_with_exact_strides(self, exact_strides, allow_padding=False) -> None: # type: ignore[no-untyped-def] def freeze_layout_with_exact_strides( # type: ignore[no-untyped-def]
self, exact_strides, allow_padding=False
) -> None:
assert isinstance(self.layout, FlexibleLayout) assert isinstance(self.layout, FlexibleLayout)
self.layout = self.layout.as_exact_strides( self.layout = self.layout.as_exact_strides(
exact_strides, allow_padding=allow_padding exact_strides, allow_padding=allow_padding
@ -4365,9 +4367,9 @@ class TritonTemplateBuffer(TemplateBuffer):
torch.ops.higher_order.flex_attention_backward, torch.ops.higher_order.flex_attention_backward,
) )
current_node = V.graph.current_node.target current_node = V.graph.current_node.target
assert ( assert current_node in allowed_set, (
current_node in allowed_set f"Mutated inputs are only allowed for {allowed_set} but got {current_node}"
), f"Mutated inputs are only allowed for {allowed_set} but got {current_node}" )
device = self.inputs[0].get_device() device = self.inputs[0].get_device()
self.outputs += [ self.outputs += [
MutationOutput(NoneLayout(device=device), buf, self) MutationOutput(NoneLayout(device=device), buf, self)
@ -5106,7 +5108,8 @@ class ExternKernel(InputsKernel):
x_unwrap_view.freeze_layout() x_unwrap_view.freeze_layout()
index_args, var_ranges = dependencies.index_vars_squeeze( index_args, var_ranges = dependencies.index_vars_squeeze(
x.get_size(), prefix="r" # type: ignore[arg-type] x.get_size(),
prefix="r", # type: ignore[arg-type]
) )
range_vars = index_args[0] range_vars = index_args[0]
index = x.make_indexer()(range_vars) index = x.make_indexer()(range_vars)
@ -5404,9 +5407,9 @@ class ExternKernel(InputsKernel):
# pass in a list of const arg names for arg_properties lookup. # pass in a list of const arg names for arg_properties lookup.
name_to_arg_properties = None name_to_arg_properties = None
if names and self.arg_properties: if names and self.arg_properties:
assert len(self.constant_args) == len( assert len(self.constant_args) == len(names), (
names "names passed to codegen_const_args does not match self.constant_args"
), "names passed to codegen_const_args does not match self.constant_args" )
name_to_arg_properties = { name_to_arg_properties = {
arg.get("name"): arg for arg in self.arg_properties arg.get("name"): arg for arg in self.arg_properties
} }
@ -5442,9 +5445,9 @@ class ExternKernel(InputsKernel):
args = [] args = []
for i, x in enumerate(inputs): for i, x in enumerate(inputs):
if V.graph.cpp_wrapper: if V.graph.cpp_wrapper:
assert self.arg_properties and i < len( assert self.arg_properties and i < len(self.arg_properties), (
self.arg_properties "Invalid access to ExternKernel.arg_properties"
), "Invalid access to ExternKernel.arg_properties" )
type_ = self.arg_properties[i].get("type") type_ = self.arg_properties[i].get("type")
args.append(V.graph.wrapper_code.val_to_arg_str(x, type_)) args.append(V.graph.wrapper_code.val_to_arg_str(x, type_))
else: else:
@ -5914,7 +5917,9 @@ class UserDefinedTritonKernel(ExternKernel):
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]: def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
return OrderedSet() return OrderedSet()
def __init__(self, *, kernel_idx, grid, tma_descriptor_metadata, kernel_args) -> None: # type: ignore[no-untyped-def] def __init__( # type: ignore[no-untyped-def]
self, *, kernel_idx, grid, tma_descriptor_metadata, kernel_args
) -> None:
inputs = [] inputs = []
kwargs = {} kwargs = {}
constant_args = [] constant_args = []
@ -6835,9 +6840,9 @@ class FallbackKernel(ExternKernelAlloc):
elif isinstance(output, torch.SymInt): elif isinstance(output, torch.SymInt):
return output.node.expr return output.node.expr
else: else:
assert ( assert output is None, (
output is None f"FallbackKernel output type {type(output)} is not supported"
), f"FallbackKernel output type {type(output)} is not supported" )
return None return None
outputs = generate_output(example_output, []) outputs = generate_output(example_output, [])
@ -6919,7 +6924,12 @@ class MultiOutput(ExternKernel):
) )
self.codegen_size_asserts(wrapper) self.codegen_size_asserts(wrapper)
def __init__(self, layout: OutputSpec, input, indices: list[tuple[Any, ...]]) -> None: # type: ignore[no-untyped-def] def __init__( # type: ignore[no-untyped-def]
self,
layout: OutputSpec,
input,
indices: list[tuple[Any, ...]],
) -> None:
super().__init__(None, layout, [input], ()) super().__init__(None, layout, [input], ())
self.name = V.graph.register_buffer(self) self.name = V.graph.register_buffer(self)
V.graph.register_operation(self) V.graph.register_operation(self)
@ -7496,9 +7506,9 @@ class WhileLoop(ExternKernel):
assert p.get_dtype() == torch.bool, p assert p.get_dtype() == torch.bool, p
assert len(p.get_size()) == 0, p assert len(p.get_size()) == 0, p
assert ( assert len(all_inputs) > 0, (
len(all_inputs) > 0 "torch.while_loop is assumed to have at least one operand."
), "torch.while_loop is assumed to have at least one operand." )
device = all_inputs[0].get_device() device = all_inputs[0].get_device()
@ -7669,9 +7679,9 @@ class _CollectiveKernel(FallbackKernel):
# This is identical to FallbackKernel.set_cpp_kernel(), minus the # This is identical to FallbackKernel.set_cpp_kernel(), minus the
# part that checks against input aliasing and mutation. # part that checks against input aliasing and mutation.
def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None: def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None:
assert ( assert type(self.op_overload) is torch._ops.OpOverload, (
type(self.op_overload) is torch._ops.OpOverload "Setting cpp kernel needs a valid op_overload"
), "Setting cpp kernel needs a valid op_overload" )
kernel = self.op_overload kernel = self.op_overload
self.cpp_kernel_name = kernel._schema.name self.cpp_kernel_name = kernel._schema.name

View File

@ -1,5 +1,5 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
""" Triton Implementation of the flex_attention Kernel""" """Triton Implementation of the flex_attention Kernel"""
import copy import copy
import logging import logging
@ -60,9 +60,9 @@ def construct_strides(
) -> Sequence[int]: ) -> Sequence[int]:
"""From a list of sizes and a fill order, construct the strides of the permuted tensor.""" """From a list of sizes and a fill order, construct the strides of the permuted tensor."""
# Initialize strides # Initialize strides
assert len(sizes) == len( assert len(sizes) == len(fill_order), (
fill_order "Length of sizes must match the length of the fill order"
), "Length of sizes must match the length of the fill order" )
strides = [0] * len(sizes) strides = [0] * len(sizes)
# Start with stride 1 for the innermost dimension # Start with stride 1 for the innermost dimension
@ -1151,10 +1151,14 @@ def lower_cpu(
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE) SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
assert V.graph.sizevars.evaluate_expr( assert V.graph.sizevars.evaluate_expr(
sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE)) sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE))
), "Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask." ), (
"Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask."
)
assert V.graph.sizevars.evaluate_expr( assert V.graph.sizevars.evaluate_expr(
sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE)) sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE))
), "KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask." ), (
"KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask."
)
CppFlexAttentionTemplate.add_choices( CppFlexAttentionTemplate.add_choices(
choices=_choices, choices=_choices,
input_nodes=input_nodes, input_nodes=input_nodes,
@ -1364,15 +1368,15 @@ def flex_attention(
Bq, Hq, seq_len_q, qk_head_dim = query.get_size() Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
assert V.graph.sizevars.evaluate_expr( assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), (
sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1) f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" )
assert V.graph.sizevars.evaluate_expr( assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_q, 0)), (
sympy.Gt(seq_len_q, 0) "Query length must be greater than 0"
), "Query length must be greater than 0" )
assert V.graph.sizevars.evaluate_expr( assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_kv, 0)), (
sympy.Gt(seq_len_kv, 0) "Key length must be greater than 0"
), "Key length must be greater than 0" )
B = Bq B = Bq
@ -2291,9 +2295,9 @@ def process_joint_outputs(
JointOutputResult containing processed buffers and gradients JointOutputResult containing processed buffers and gradients
""" """
assert isinstance(all_joint_outputs, list) assert isinstance(all_joint_outputs, list)
assert ( assert all_joint_outputs[0] is not None, (
all_joint_outputs[0] is not None "joint_subgraph_buffer is None - this is a bug!"
), "joint_subgraph_buffer is None - this is a bug!" )
joint_buffer = all_joint_outputs[0] joint_buffer = all_joint_outputs[0]
other_grads = all_joint_outputs[num_placeholders - 1 :] other_grads = all_joint_outputs[num_placeholders - 1 :]
@ -2392,9 +2396,9 @@ def flex_attention_backward(*args, **kwargs):
Bq, Hq, seq_len_q, qk_head_dim = query.get_size() Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
assert V.graph.sizevars.evaluate_expr( assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), (
sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1) f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" )
kernel_options = dict(kernel_options) kernel_options = dict(kernel_options)
# Mark symbols in custom kernel options as static shapes and add guards. # Mark symbols in custom kernel options as static shapes and add guards.
@ -2639,9 +2643,11 @@ def flex_attention_backward(*args, **kwargs):
grad_key = broadcasted_grad_key grad_key = broadcasted_grad_key
grad_value = broadcasted_grad_value grad_value = broadcasted_grad_value
else: else:
assert V.graph.sizevars.evaluate_expr( assert V.graph.sizevars.evaluate_expr(sympy.Gt(Bq, 1) & sympy.Eq(Bkv, 1)), (
sympy.Gt(Bq, 1) & sympy.Eq(Bkv, 1) f"Bq and Bkv must broadcastable. "
), f"Bq and Bkv must broadcastable. Got Bq={V.graph.sizevars.evaluate_expr(Bq)} and Bkv={V.graph.sizevars.evaluate_expr(Bkv)}" # noqa: B950 f"Got Bq={V.graph.sizevars.evaluate_expr(Bq)} "
f"and Bkv={V.graph.sizevars.evaluate_expr(Bkv)}"
)
grad_key = lowerings[aten.sum](broadcasted_grad_key, axis=0, keepdims=True) grad_key = lowerings[aten.sum](broadcasted_grad_key, axis=0, keepdims=True)
grad_value = lowerings[aten.sum](broadcasted_grad_value, axis=0, keepdims=True) grad_value = lowerings[aten.sum](broadcasted_grad_value, axis=0, keepdims=True)

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs # mypy: allow-untyped-defs
""" Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)""" """Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)"""
from typing import Any from typing import Any
import sympy import sympy
@ -367,9 +368,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
Bq, Hq, seq_len_q, qk_head_dim = query.get_size() Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size() Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
assert V.graph.sizevars.evaluate_expr( assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), (
sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1) f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}" )
B = Bq B = Bq
kernel_options = dict(kernel_options) kernel_options = dict(kernel_options)
@ -481,7 +482,8 @@ def create_flex_decoding_kernel(*args, **kwargs):
max( max(
next_power_of_2( next_power_of_2(
V.graph.sizevars.size_hint( V.graph.sizevars.size_hint(
seq_len_q, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] seq_len_q,
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
) )
* gqa_shared_heads * gqa_shared_heads
), ),

View File

@ -65,7 +65,8 @@ def filtered_configs(
m = max( m = max(
next_power_of_2( next_power_of_2(
V.graph.sizevars.size_hint( V.graph.sizevars.size_hint(
m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] m,
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
) )
), ),
min_block_size, min_block_size,
@ -73,7 +74,8 @@ def filtered_configs(
n = max( n = max(
next_power_of_2( next_power_of_2(
V.graph.sizevars.size_hint( V.graph.sizevars.size_hint(
n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] n,
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
) )
), ),
min_block_size, min_block_size,
@ -81,7 +83,8 @@ def filtered_configs(
k = max( k = max(
next_power_of_2( next_power_of_2(
V.graph.sizevars.size_hint( V.graph.sizevars.size_hint(
k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type] k,
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
) )
), ),
min_block_size_k, min_block_size_k,
@ -467,8 +470,7 @@ def mm_options(config, sym_m, sym_n, sym_k, layout):
""" """
even_k_symbolic = ( even_k_symbolic = (
# it isn't worth guarding on this # it isn't worth guarding on this
sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"]
== config.kwargs["BLOCK_K"]
) )
allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and ( allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
not inductor_config.force_same_precision not inductor_config.force_same_precision

View File

@ -194,11 +194,12 @@ class LoopBody:
# There is indeed an issue due to symbol name conflicting. # There is indeed an issue due to symbol name conflicting.
# y0 maybe reused for the y dimension later. # y0 maybe reused for the y dimension later.
( (
iter_vars, (
reduce_vars, iter_vars,
), var_ranges = dependencies.index_vars_no_squeeze( reduce_vars,
iter_sizes, reduce_sizes, prefix="t" ),
) var_ranges,
) = dependencies.index_vars_no_squeeze(iter_sizes, reduce_sizes, prefix="t")
new_body = LoopBody( new_body = LoopBody(
old_body, old_body,
[iter_reindex(iter_vars), reduce_reindex(reduce_vars)], [iter_reindex(iter_vars), reduce_reindex(reduce_vars)],
@ -234,7 +235,8 @@ class LoopBody:
new_sizes = (new_iter_size, reduce_size) new_sizes = (new_iter_size, reduce_size)
(iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze( (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
*new_sizes, prefix="t" # type: ignore[arg-type] *new_sizes,
prefix="t", # type: ignore[arg-type]
) )
inverse_order = {b: a for a, b in enumerate(new_order)} inverse_order = {b: a for a, b in enumerate(new_order)}
@ -254,7 +256,8 @@ class LoopBody:
# use the original symbol prefix so we can do multiple round of reordering # use the original symbol prefix so we can do multiple round of reordering
(iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze( (iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
*new_sizes, prefix="p" # type: ignore[arg-type] *new_sizes,
prefix="p", # type: ignore[arg-type]
) )
new_body = LoopBody( new_body = LoopBody(
loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2 loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2
@ -385,9 +388,9 @@ class LoopBody:
def indexing_from_args(self, indices): def indexing_from_args(self, indices):
index = [*itertools.chain.from_iterable(indices)] index = [*itertools.chain.from_iterable(indices)]
assert len(index) == len(self.var_ranges), (index, self.var_ranges) assert len(index) == len(self.var_ranges), (index, self.var_ranges)
assert all( assert all(v not in self.var_ranges for v in index), (
v not in self.var_ranges for v in index f"{self.var_ranges=}, {indices=}"
), f"{self.var_ranges=}, {indices=}" )
replacements = dict(zip(self.var_ranges.keys(), index)) replacements = dict(zip(self.var_ranges.keys(), index))
return { return {
name: sympy_subs(expr, replacements) name: sympy_subs(expr, replacements)

View File

@ -346,7 +346,8 @@ def transform_args(
# only consider tensor kwargs for promotion, for now # only consider tensor kwargs for promotion, for now
promoting_args.extend(a for a in kwargs.values() if hasattr(a, "dtype")) promoting_args.extend(a for a in kwargs.values() if hasattr(a, "dtype"))
dtype = get_promoted_dtype( dtype = get_promoted_dtype(
*promoting_args, type_promotion_kind=type_promotion_kind # type: ignore[arg-type] *promoting_args,
type_promotion_kind=type_promotion_kind, # type: ignore[arg-type]
) )
device = ( device = (
@ -448,9 +449,9 @@ def _register_lowering(
(fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn (fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn
): ):
# explicitly assert for "out=" ops for better error messages # explicitly assert for "out=" ops for better error messages
assert not any( assert not any(x == "out" for x in kwargs.keys()), (
x == "out" for x in kwargs.keys() "out= ops aren't yet supported"
), "out= ops aren't yet supported" )
args, kwargs = transform_args( args, kwargs = transform_args(
args, kwargs, broadcast, type_promotion_kind, convert_input_to_bool args, kwargs, broadcast, type_promotion_kind, convert_input_to_bool
@ -517,9 +518,9 @@ def broadcast_symbolic_shapes(a, b):
def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None): def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None):
assert ( assert override_return_dtype is None or type_promotion_kind is None, (
override_return_dtype is None or type_promotion_kind is None "only one of override_return_dtype or type_promotion_kind may be given"
), "only one of override_return_dtype or type_promotion_kind may be given" )
if override_return_dtype is None and type_promotion_kind is None: if override_return_dtype is None and type_promotion_kind is None:
type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
@ -674,9 +675,9 @@ def make_foreach_pointwise(pw_fn, allow_alpha=False):
if isinstance(input, (list, tuple)): if isinstance(input, (list, tuple)):
a_list_input = input a_list_input = input
break break
assert ( assert a_list_input is not None, (
a_list_input is not None "at least one input must be a list to a foreach op"
), "at least one input must be a list to a foreach op" )
# broadcast scalar inputs to match length of list inputs # broadcast scalar inputs to match length of list inputs
broadcast_inputs = [] broadcast_inputs = []
@ -1321,12 +1322,12 @@ def quantized_decomposed_quantize_per_channel(
if input.get_dtype() == torch.bfloat16: if input.get_dtype() == torch.bfloat16:
input = to_dtype(input, torch.float32) input = to_dtype(input, torch.float32)
assert ( assert input.get_dtype() == torch.float32, (
input.get_dtype() == torch.float32 f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" )
assert axis < len( assert axis < len(input.get_size()), (
input.get_size() f"Expecting axis to be < {len(input.get_size())}"
), f"Expecting axis to be < {len(input.get_size())}" )
input_loader = input.make_loader() input_loader = input.make_loader()
scales_loader = scales.make_loader() scales_loader = scales.make_loader()
@ -1373,12 +1374,12 @@ def quantized_decomposed_dequantize_per_channel(
) -> TensorBox: ) -> TensorBox:
assert len(scales.get_size()) == 1, "expect scales 1 dim" assert len(scales.get_size()) == 1, "expect scales 1 dim"
assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim" assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim"
assert ( assert input.get_dtype() == dtype, (
input.get_dtype() == dtype f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" )
assert axis < len( assert axis < len(input.get_size()), (
input.get_size() f"Expecting axis to be < {len(input.get_size())}"
), f"Expecting axis to be < {len(input.get_size())}" )
if out_dtype is None: if out_dtype is None:
out_dtype = torch.float32 out_dtype = torch.float32
@ -1423,9 +1424,9 @@ def quantized_decomposed_quantize_per_tensor_default(
) -> TensorBox: ) -> TensorBox:
if input.get_dtype() == torch.bfloat16: if input.get_dtype() == torch.bfloat16:
input = to_dtype(input, torch.float32) input = to_dtype(input, torch.float32)
assert ( assert input.get_dtype() == torch.float32, (
input.get_dtype() == torch.float32 f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" )
input_loader = input.make_loader() input_loader = input.make_loader()
@ -1462,9 +1463,9 @@ def quantized_decomposed_dequantize_per_tensor_default(
*, *,
out_dtype: Optional[torch.dtype] = None, out_dtype: Optional[torch.dtype] = None,
) -> TensorBox: ) -> TensorBox:
assert ( assert input.get_dtype() == dtype, (
input.get_dtype() == dtype f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" )
if out_dtype is None: if out_dtype is None:
out_dtype = torch.float32 out_dtype = torch.float32
@ -1501,9 +1502,9 @@ def quantized_decomposed_quantize_per_tensor_tensor(
) -> TensorBox: ) -> TensorBox:
if input.get_dtype() == torch.bfloat16: if input.get_dtype() == torch.bfloat16:
input = to_dtype(input, torch.float32) input = to_dtype(input, torch.float32)
assert ( assert input.get_dtype() == torch.float32, (
input.get_dtype() == torch.float32 f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}" )
assert len(scale.get_size()) == 0 or ( assert len(scale.get_size()) == 0 or (
len(scale.get_size()) == 1 and scale.get_size()[0] == 1 len(scale.get_size()) == 1 and scale.get_size()[0] == 1
), "expect scale as scalar tensor" ), "expect scale as scalar tensor"
@ -1555,9 +1556,9 @@ def quantized_decomposed_dequantize_per_tensor_tensor(
assert len(zero_point.get_size()) == 0 or ( assert len(zero_point.get_size()) == 0 or (
len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1 len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1
), "expect zero_point as scalar tensor" ), "expect zero_point as scalar tensor"
assert ( assert input.get_dtype() == dtype, (
input.get_dtype() == dtype f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}" )
if out_dtype is None: if out_dtype is None:
out_dtype = torch.float32 out_dtype = torch.float32
@ -1973,9 +1974,9 @@ def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=
def make_fallback(op, layout_constraint=None, warn=True, override_decomp=False): def make_fallback(op, layout_constraint=None, warn=True, override_decomp=False):
assert ( assert op not in decompositions or override_decomp, (
op not in decompositions or override_decomp f"both a fallback and a decomp for same op: {op}"
), f"both a fallback and a decomp for same op: {op}" )
if ( if (
warn warn
and bool(os.getenv("CI")) and bool(os.getenv("CI"))
@ -2086,9 +2087,9 @@ def native_dropout(x, p, train):
@register_lowering(aten.bernoulli_, type_promotion_kind=None) @register_lowering(aten.bernoulli_, type_promotion_kind=None)
def bernoulli_(x, *args): def bernoulli_(x, *args):
assert config.fallback_random or x.get_device() == torch.device( assert config.fallback_random or x.get_device() == torch.device("cpu"), (
"cpu" "this should be handled in decomps unless config.fallback_random or the device is CPU"
), "this should be handled in decomps unless config.fallback_random or the device is CPU" )
x.realize() x.realize()
op_overload = ( op_overload = (
aten.bernoulli_.float aten.bernoulli_.float
@ -2101,9 +2102,9 @@ def bernoulli_(x, *args):
@register_lowering(aten.bernoulli.p, type_promotion_kind=None) @register_lowering(aten.bernoulli.p, type_promotion_kind=None)
def bernoulli_p(x, *args): def bernoulli_p(x, *args):
assert config.fallback_random or x.get_device() == torch.device( assert config.fallback_random or x.get_device() == torch.device("cpu"), (
"cpu" "this should be handled in decomps unless config.fallback_random or the device is CPU"
), "this should be handled in decomps unless config.fallback_random or the device is CPU" )
return bernoulli_(clone(x), *args) return bernoulli_(clone(x), *args)
@ -3376,7 +3377,9 @@ def check_and_broadcast_indices(indices, device):
i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8) i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8)
for i in indices for i in indices
if i is not None if i is not None
), f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}" ), (
f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}"
)
if any( if any(
i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None
): ):
@ -5668,7 +5671,8 @@ def make_reduction(reduction_type: ReductionType, override_return_dtype=None):
) )
result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs) result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs)
if isinstance( if isinstance(
result.data.data, Reduction # type: ignore[attr-defined] result.data.data, # type: ignore[attr-defined]
Reduction,
): # Only realize if reduction isn't unrolled ): # Only realize if reduction isn't unrolled
result.realize() result.realize()
return result return result
@ -6008,8 +6012,9 @@ def get_constant_value(x: ir.IRNode) -> Optional[ir.Constant]:
return None return None
handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device()) handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device())
with V.set_ops_handler(handler), patch.object( with (
ir.FlexibleLayout, "allow_indexing", True V.set_ops_handler(handler),
patch.object(ir.FlexibleLayout, "allow_indexing", True),
): ):
out = x.inner_fn(*x.inner_fn_args()) out = x.inner_fn(*x.inner_fn_args())
@ -6898,9 +6903,9 @@ def force_fallback(op: torch._ops.OpOverload):
A context manager to force fallback an op. Used in unit test A context manager to force fallback an op. Used in unit test
for FallbackKernel. for FallbackKernel.
""" """
assert isinstance( assert isinstance(op, torch._ops.OpOverload), (
op, torch._ops.OpOverload "Only OpOverload to make the clean up easier"
), "Only OpOverload to make the clean up easier" )
old_handler = lowerings.get(op) old_handler = lowerings.get(op)
try: try:
register_lowering(op)(fallback_handler(op)) register_lowering(op)(fallback_handler(op))

View File

@ -35,9 +35,9 @@ class MemoryPlanningInfoForBuffer:
class MemoryPlanningInfoForNode: class MemoryPlanningInfoForNode:
index: int = 0 index: int = 0
size: int = 0 size: int = 0
pred_buffers: OrderedSet[ pred_buffers: OrderedSet[Union[SchedulerBuffer, FreeableInputBuffer]] = (
Union[SchedulerBuffer, FreeableInputBuffer] dataclasses.field(default_factory=OrderedSet)
] = dataclasses.field(default_factory=OrderedSet) )
pred_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field( pred_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field(
default_factory=OrderedSet default_factory=OrderedSet
) )
@ -87,9 +87,9 @@ def get_freeable_input_buf(
# get freeable input buffers' successor nodes and their sizes # get freeable input buffers' successor nodes and their sizes
# note that different deps can have the same name, so we use name as keys # note that different deps can have the same name, so we use name as keys
dep_name_to_succ_nodes: dict[ dep_name_to_succ_nodes: dict[str, OrderedSet[BaseSchedulerNode]] = (
str, OrderedSet[BaseSchedulerNode] collections.defaultdict(OrderedSet)
] = collections.defaultdict(OrderedSet) )
dep_name_to_size: dict[str, int] = dict() dep_name_to_size: dict[str, int] = dict()
for node in nodes: for node in nodes:
for dep in node.read_writes.reads: for dep in node.read_writes.reads:
@ -112,7 +112,7 @@ def get_freeable_input_buf(
def compute_size_for_scheduler_buffer( def compute_size_for_scheduler_buffer(
name_to_buf: dict[str, SchedulerBuffer] name_to_buf: dict[str, SchedulerBuffer],
) -> dict[str, tuple[int, int]]: ) -> dict[str, tuple[int, int]]:
""" """
Compute the size of each scheduler buffer, including (1) memory allocated when Compute the size of each scheduler buffer, including (1) memory allocated when
@ -187,9 +187,9 @@ def assign_memory_planning_info_for_scheduler_buffers(
# get buffer's successor nodes # get buffer's successor nodes
# note that different deps can have the same name, so we use name as keys # note that different deps can have the same name, so we use name as keys
dep_name_to_succ_nodes: dict[ dep_name_to_succ_nodes: dict[str, OrderedSet[BaseSchedulerNode]] = (
str, OrderedSet[BaseSchedulerNode] collections.defaultdict(OrderedSet)
] = collections.defaultdict(OrderedSet) )
for node in nodes: for node in nodes:
for dep in node.unmet_dependencies: for dep in node.unmet_dependencies:
dep_name_to_succ_nodes[dep.name].add(node) dep_name_to_succ_nodes[dep.name].add(node)

View File

@ -138,12 +138,12 @@ class MetricTable:
return return
row_dict = row_fn() row_dict = row_fn()
assert len(self.column_names) == len( assert len(self.column_names) == len(row_dict), (
row_dict f"{len(self.column_names)} v.s. {len(row_dict)}"
), f"{len(self.column_names)} v.s. {len(row_dict)}" )
assert OrderedSet(self.column_names) == OrderedSet( assert OrderedSet(self.column_names) == OrderedSet(row_dict.keys()), (
row_dict.keys() f"{OrderedSet(self.column_names)} v.s. {OrderedSet(row_dict.keys())}"
), f"{OrderedSet(self.column_names)} v.s. {OrderedSet(row_dict.keys())}" )
bn = get_benchmark_name() bn = get_benchmark_name()
# assert bn is not None # assert bn is not None
@ -433,9 +433,9 @@ def enabled_metric_tables_impl(config_str: str) -> OrderedSet[str]:
name = name.strip() name = name.strip()
if not name: if not name:
continue continue
assert ( assert name in REGISTERED_METRIC_TABLES, (
name in REGISTERED_METRIC_TABLES f"Metric table name {name} is not registered"
), f"Metric table name {name} is not registered" )
enabled.add(name) enabled.add(name)
return enabled return enabled

View File

@ -751,9 +751,9 @@ class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
unary_algorithm, unary_algorithm,
] ]
assert ( assert binary_attr == "sum", (
binary_attr == "sum" "For now, only post op sum is supported in QConvPointWiseBinaryPT2E."
), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E." )
V.graph.mark_buffer_mutated(qaccum.get_name()) V.graph.mark_buffer_mutated(qaccum.get_name())
packed = QConvPointWiseBinaryPT2E( packed = QConvPointWiseBinaryPT2E(

View File

@ -575,9 +575,9 @@ def register_onednn_fusion_ops():
algorithm, algorithm,
layout=None, layout=None,
): ):
assert ( assert packed_weight.get_dtype() is torch.int8, (
packed_weight.get_dtype() is torch.int8 "Only int8 weights are supported by oneDNN qlinear."
), "Only int8 weights are supported by oneDNN qlinear." )
x_size = x.get_size() x_size = x.get_size()
if len(x_size) > 2: if len(x_size) > 2:
# GEMM template needs 2D input, normalize input shape here # GEMM template needs 2D input, normalize input shape here
@ -928,9 +928,9 @@ def register_onednn_fusion_ops():
# we will do accum dtype convertion here. # we will do accum dtype convertion here.
x2 = to_dtype(x2, output_dtype) x2 = to_dtype(x2, output_dtype)
else: else:
assert ( assert x2.get_dtype() == output_dtype, (
x2.get_dtype() == output_dtype "dtype of accum for qlinear post op sum should be the same as output"
), "dtype of accum for qlinear post op sum should be the same as output" )
x2_dtype = x2.get_dtype() x2_dtype = x2.get_dtype()
bias_dtype = bias.get_dtype() if bias is not None else None bias_dtype = bias.get_dtype() if bias is not None else None
choices: list[ChoiceCaller] = [] choices: list[ChoiceCaller] = []

View File

@ -806,8 +806,8 @@ class DefaultHandler(OpsHandler[Any]):
assert self_arg == "self" assert self_arg == "self"
code.write( code.write(
f""" f"""
def {target}(self, {', '.join(args)}): def {target}(self, {", ".join(args)}):
return self._default({target!r}, ({', '.join(args)}, ), {{}}) return self._default({target!r}, ({", ".join(args)}, ), {{}})
""".strip() """.strip()
) )
code.write("\n\n") code.write("\n\n")
@ -994,8 +994,9 @@ class KernelFormatterHandler(DefaultHandler):
) )
formatter._output.writeline(f"{lhs} = {name}") formatter._output.writeline(f"{lhs} = {name}")
with V.set_ops_handler(formatter), patch.object( with (
FlexibleLayout, "allow_indexing", True V.set_ops_handler(formatter),
patch.object(FlexibleLayout, "allow_indexing", True),
): ):
result = ir_fn(*args) result = ir_fn(*args)
return formatter.getvalue(result) return formatter.getvalue(result)

View File

@ -188,7 +188,9 @@ def package_aoti(
) or ( ) or (
isinstance(archive_file, (str, os.PathLike)) isinstance(archive_file, (str, os.PathLike))
and os.fspath(archive_file).endswith(".pt2") and os.fspath(archive_file).endswith(".pt2")
), f"Expect archive file to be a file ending in .pt2, or is a buffer. Instead got {archive_file}" ), (
f"Expect archive file to be a file ending in .pt2, or is a buffer. Instead got {archive_file}"
)
# Save using the PT2 packaging format # Save using the PT2 packaging format
# (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a) # (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a)
@ -285,9 +287,9 @@ class AOTICompiledModel:
def load_package(path: FileLike, model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg] def load_package(path: FileLike, model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg]
assert ( assert (
isinstance(path, (io.IOBase, IO)) and path.readable() and path.seekable() isinstance(path, (io.IOBase, IO)) and path.readable() and path.seekable()
) or ( ) or (isinstance(path, (str, os.PathLike)) and os.fspath(path).endswith(".pt2")), (
isinstance(path, (str, os.PathLike)) and os.fspath(path).endswith(".pt2") f"Unable to load package. Path must be a buffer or a file ending in .pt2. Instead got {path}"
), f"Unable to load package. Path must be a buffer or a file ending in .pt2. Instead got {path}" )
if isinstance(path, (io.IOBase, IO)): if isinstance(path, (io.IOBase, IO)):
with tempfile.NamedTemporaryFile(suffix=".pt2") as f: with tempfile.NamedTemporaryFile(suffix=".pt2") as f:

View File

@ -90,20 +90,17 @@ NodeOrConstant = Union[Constant, torch.fx.Node]
class SearchFn(Protocol): class SearchFn(Protocol):
__name__: str __name__: str
def __call__(self, *args: Any, **kwargs: Any) -> Any: def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
...
class ReplaceFn(Protocol): class ReplaceFn(Protocol):
def __call__(self, *args: Any, **kwargs: Any) -> Any: def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
...
class TraceFn(Protocol): class TraceFn(Protocol):
def __call__( def __call__(
self, fn: Union[SearchFn, ReplaceFn], *args: Any, **kwargs: Any self, fn: Union[SearchFn, ReplaceFn], *args: Any, **kwargs: Any
) -> torch.fx.GraphModule: ) -> torch.fx.GraphModule: ...
...
T = TypeVar("T") T = TypeVar("T")
@ -365,8 +362,7 @@ class PatternExpr(ABC):
""" """
@abstractmethod @abstractmethod
def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: ...
...
def match(self, node: torch.fx.Node) -> MatchResult: def match(self, node: torch.fx.Node) -> MatchResult:
try: try:
@ -489,8 +485,7 @@ class _TargetExpr(PatternExpr):
@property @property
@abstractmethod @abstractmethod
def op(self) -> str: def op(self) -> str: ...
...
def fns_repr(self) -> str: def fns_repr(self) -> str:
first_repr = self.fns[0] first_repr = self.fns[0]
@ -997,8 +992,9 @@ class PatternPrettyPrinter:
class _PassDictsType(Protocol): class _PassDictsType(Protocol):
def __getitem__(self, k: tuple[str, torch.fx.node.Target]) -> list[PatternEntry]: def __getitem__(
... self, k: tuple[str, torch.fx.node.Target]
) -> list[PatternEntry]: ...
@dataclasses.dataclass @dataclasses.dataclass
@ -1925,7 +1921,10 @@ def fx_to_pattern(
get_attr = _not_implemented get_attr = _not_implemented
def placeholder( def placeholder(
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override] self,
target: str, # type: ignore[override]
args: Sequence[Any],
kwargs: Mapping[str, Any],
) -> Union[ExclusiveKeywordArg, KeywordArg]: ) -> Union[ExclusiveKeywordArg, KeywordArg]:
n = next(argnum) n = next(argnum)
if n < len(argnames): if n < len(argnames):
@ -1942,7 +1941,10 @@ def fx_to_pattern(
return KeywordArg(name) return KeywordArg(name)
def call_function( def call_function(
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override] self,
target: str, # type: ignore[override]
args: Sequence[Any],
kwargs: Mapping[str, Any],
) -> PatternExpr: ) -> PatternExpr:
process_arg_fn = process_arg process_arg_fn = process_arg
# Indexing is critical for matching getitem nodes, so we can't ignore int args here # Indexing is critical for matching getitem nodes, so we can't ignore int args here

View File

@ -24,7 +24,7 @@ T = TypeVar("T")
def time_and_count( def time_and_count(
fn: Callable[Concatenate[Any, P], T] fn: Callable[Concatenate[Any, P], T],
) -> Callable[Concatenate[Any, P], T]: ) -> Callable[Concatenate[Any, P], T]:
"""Wraps `fn` with `dynamo_timed` context, and increments the appropriate dynamo """Wraps `fn` with `dynamo_timed` context, and increments the appropriate dynamo
counters. It is expected that `fn` is a method of `Benchmarker` or one of its counters. It is expected that `fn` is a method of `Benchmarker` or one of its

View File

@ -77,9 +77,9 @@ def validate_triton_config(cfg: Config) -> None:
# right now, if a pre-hook is attached to the config, it will not be saved; # right now, if a pre-hook is attached to the config, it will not be saved;
# and then it won't be used when the config is loaded from cache. # and then it won't be used when the config is loaded from cache.
# So we assert - if we do get a pre_hook, it might get ignored after caching. # So we assert - if we do get a pre_hook, it might get ignored after caching.
assert ( assert getattr(cfg, "pre_hook", None) is None, (
getattr(cfg, "pre_hook", None) is None "triton configs with pre_hooks not supported"
), "triton configs with pre_hooks not supported" )
def create_bandwidth_info_str( def create_bandwidth_info_str(

View File

@ -450,9 +450,9 @@ class CachingAutotuner(KernelInterface):
self.launchers = [] self.launchers = []
def __getstate__(self) -> dict[str, Any]: def __getstate__(self) -> dict[str, Any]:
assert ( assert not self.launchers, (
not self.launchers "pickle should not be called with after make_launchers()"
), "pickle should not be called with after make_launchers()" )
return { return {
**self.__dict__, **self.__dict__,
"lock": None, "lock": None,
@ -678,7 +678,9 @@ class CachingAutotuner(KernelInterface):
assert isinstance( assert isinstance(
arg, arg,
torch.Tensor, torch.Tensor,
), "self.reset_to_zero_arg_names should only contain valid argument names" ), (
"self.reset_to_zero_arg_names should only contain valid argument names"
)
arg.zero_() arg.zero_()
for name, arg in kwargs.items(): for name, arg in kwargs.items():
@ -686,7 +688,9 @@ class CachingAutotuner(KernelInterface):
assert isinstance( assert isinstance(
arg, arg,
torch.Tensor, torch.Tensor,
), "self.reset_to_zero_arg_names should only contain valid argument names" ), (
"self.reset_to_zero_arg_names should only contain valid argument names"
)
arg.zero_() arg.zero_()
def maybe_clone_args( def maybe_clone_args(
@ -866,7 +870,9 @@ class CachingAutotuner(KernelInterface):
assert not ( assert not (
self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION
and "R0_BLOCK" in launcher.config.kwargs and "R0_BLOCK" in launcher.config.kwargs
), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK" ), (
"Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK"
)
start_time = time.time_ns() start_time = time.time_ns()
best_config = self.coordesc_tuner.autotune( best_config = self.coordesc_tuner.autotune(
benchmark_one_config, launcher.config, None benchmark_one_config, launcher.config, None
@ -882,9 +888,7 @@ class CachingAutotuner(KernelInterface):
) )
return config2launcher.get(best_config) return config2launcher.get(best_config)
def run( def run(self, *args, grid, stream, benchmark_run=False, **kwargs): # type:ignore[override]
self, *args, grid, stream, benchmark_run=False, **kwargs
): # type:ignore[override]
if self.triton_interpret: if self.triton_interpret:
return self.fn[grid]( return self.fn[grid](
*args, *args,
@ -1192,12 +1196,12 @@ class TritonCompileResult:
exec( exec(
f""" f"""
def launcher({', '.join(def_args)}, grid, stream): def launcher({", ".join(def_args)}, grid, stream):
if callable(grid): if callable(grid):
grid_0, grid_1, grid_2 = grid(grid_meta) grid_0, grid_1, grid_2 = grid(grid_meta)
else: else:
grid_0, grid_1, grid_2 = grid grid_0, grid_1, grid_2 = grid
runner({', '.join(runner_args)}) runner({", ".join(runner_args)})
return bin return bin
""".lstrip(), """.lstrip(),
scope, scope,
@ -1503,9 +1507,9 @@ def check_max_block(cfg: dict[str, int]):
if block_suffix in var: if block_suffix in var:
prefix = var.removesuffix(block_suffix) prefix = var.removesuffix(block_suffix)
max_block = TRITON_MAX_BLOCK[prefix] max_block = TRITON_MAX_BLOCK[prefix]
assert ( assert val <= max_block, (
val <= max_block f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
), f"'{var}' too large. Maximum: {max_block}. Actual: {val}." )
def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False): def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False):
@ -1657,20 +1661,20 @@ def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, in
prefix = f"r{idx}_" prefix = f"r{idx}_"
max_size = min(size_hints[prefix], TRITON_MAX_BLOCK[prefix.upper()]) max_size = min(size_hints[prefix], TRITON_MAX_BLOCK[prefix.upper()])
dim = min(max_size, remaining) dim = min(max_size, remaining)
assert ( assert remaining % dim == 0, (
remaining % dim == 0 f"Expected dimension '{dim}' to divide remaining size '{remaining}'"
), f"Expected dimension '{dim}' to divide remaining size '{remaining}'" )
rnumels[prefix] = dim rnumels[prefix] = dim
remaining //= dim remaining //= dim
# Sanity check the results. # Sanity check the results.
final_numel = conditional_product(*rnumels.values()) final_numel = conditional_product(*rnumels.values())
assert ( assert r == final_numel, (
r == final_numel f"Expected ND reduction size ({rnumels}) to have {r} elements."
), f"Expected ND reduction size ({rnumels}) to have {r} elements." )
assert all( assert all(rnumels[prefix] <= size_hints[prefix] for prefix in rnumels), (
rnumels[prefix] <= size_hints[prefix] for prefix in rnumels f"rnumels exceed size_hints. {rnumels} > {size_hints}"
), f"rnumels exceed size_hints. {rnumels} > {size_hints}" )
return rnumels return rnumels
@ -1967,9 +1971,9 @@ def cooperative_reduction(
size_hints["x"] = 1 size_hints["x"] = 1
# Cooperative reductions currently only support a single reduction dimension. # Cooperative reductions currently only support a single reduction dimension.
assert ( assert len(size_hints) == 2, (
len(size_hints) == 2 "Cooperative reductions don't support tiling reduction dims"
), "Cooperative reductions don't support tiling reduction dims" )
xnumel, rnumel = size_hints["x"], size_hints["r0_"] xnumel, rnumel = size_hints["x"], size_hints["r0_"]
# TODO(jansel): we should base target on the SM count of the local GPU # TODO(jansel): we should base target on the SM count of the local GPU
@ -2274,9 +2278,9 @@ def grid_combo_kernels(
assert min_blocks_d is not None assert min_blocks_d is not None
min_blocks = min_blocks_d min_blocks = min_blocks_d
else: else:
assert ( assert min_blocks_d is None or min_blocks == min_blocks_d, (
min_blocks_d is None or min_blocks == min_blocks_d f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}"
), f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}" )
else: else:
# sequential dispatch # sequential dispatch
seq_numels = list(numels) seq_numels = list(numels)

View File

@ -200,9 +200,9 @@ class BaseSchedulerNode:
def __init__(self, scheduler: Scheduler) -> None: def __init__(self, scheduler: Scheduler) -> None:
self.scheduler: Scheduler = scheduler self.scheduler: Scheduler = scheduler
self.debug_device_str: Callable[ self.debug_device_str: Callable[[BaseSchedulerNode], list[str]] = (
[BaseSchedulerNode], list[str] lambda *args, **kwargs: []
] = lambda *args, **kwargs: [] )
def _init_from_node(self, node: ir.Operation) -> None: def _init_from_node(self, node: ir.Operation) -> None:
self.node: Optional[ir.Operation] = node self.node: Optional[ir.Operation] = node
@ -232,7 +232,7 @@ class BaseSchedulerNode:
buf = IndentedBuffer() buf = IndentedBuffer()
buf.splice( buf.splice(
f"""\ f"""\
{name}: {type(self).__name__}({type(getattr(self, 'node', None)).__name__}) {name}: {type(self).__name__}({type(getattr(self, "node", None)).__name__})
{name}.writes = {pformat(self.read_writes.writes)} {name}.writes = {pformat(self.read_writes.writes)}
{name}.unmet_dependencies = {pformat(self.unmet_dependencies)} {name}.unmet_dependencies = {pformat(self.unmet_dependencies)}
{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)} {name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}
@ -525,9 +525,9 @@ class BaseSchedulerNode:
V.kernel.mutations.add(input_buf.get_name()) V.kernel.mutations.add(input_buf.get_name())
V.kernel.mutations.add(buf.get_name()) V.kernel.mutations.add(buf.get_name())
V.kernel.inplace_update_buffers[ V.kernel.inplace_update_buffers[buf.get_name()] = (
buf.get_name() input_buf.get_name()
] = input_buf.get_name() )
break break
def codegen_originating_info( def codegen_originating_info(
@ -693,7 +693,7 @@ class BaseSchedulerNode:
continue continue
def get_buf_bytes( def get_buf_bytes(
buf: Optional[Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject]] buf: Optional[Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject]],
) -> int: ) -> int:
if not buf: if not buf:
return 0 return 0
@ -794,12 +794,11 @@ class BaseSchedulerNode:
# runtime for that today # runtime for that today
return 0 return 0
with FakeTensorMode() as fake_mode, FlopCounterMode( with (
display=False FakeTensorMode() as fake_mode,
) as flop_counter_mode, V.set_current_node( FlopCounterMode(display=False) as flop_counter_mode,
self.node.fx_node V.set_current_node(self.node.fx_node),
), V.set_fake_mode( V.set_fake_mode(fake_mode),
fake_mode
): ):
from .ir import ir_node_to_tensor from .ir import ir_node_to_tensor
@ -1123,15 +1122,15 @@ class SchedulerNode(BaseSchedulerNode):
return self._sizes return self._sizes
def is_reduction(self) -> bool: def is_reduction(self) -> bool:
assert isinstance( assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)), (
self.node, (ir.ComputedBuffer, ir.TemplateBuffer) f"{type(self.node)=}"
), f"{type(self.node)=}" )
return bool(self.node.get_reduction_type()) return bool(self.node.get_reduction_type())
def is_split_scan(self) -> bool: def is_split_scan(self) -> bool:
assert isinstance( assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)), (
self.node, (ir.ComputedBuffer, ir.TemplateBuffer) f"{type(self.node)=}"
), f"{type(self.node)=}" )
return isinstance(self.node, ir.ComputedBuffer) and isinstance( return isinstance(self.node, ir.ComputedBuffer) and isinstance(
self.node.data, ir.SplitScan self.node.data, ir.SplitScan
) )
@ -1163,9 +1162,10 @@ class SchedulerNode(BaseSchedulerNode):
def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None: def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None:
var_ranges = self.ranges_from_index_vars(index_vars) var_ranges = self.ranges_from_index_vars(index_vars)
try: try:
with V.set_ops_handler( with (
SimplifyIndexing(V.get_ops_handler(), var_ranges) V.set_ops_handler(SimplifyIndexing(V.get_ops_handler(), var_ranges)),
), V.kernel.set_current_node(self): V.kernel.set_current_node(self),
):
self._body(*index_vars) self._body(*index_vars)
except Exception: except Exception:
log.fatal("Error in codegen for %s", self.node) log.fatal("Error in codegen for %s", self.node)
@ -1231,7 +1231,7 @@ class SchedulerNode(BaseSchedulerNode):
def refresh_group_node_dependencies( def refresh_group_node_dependencies(
group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode] group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode],
) -> None: ) -> None:
snodes = group_snode.snodes snodes = group_snode.snodes
group_snode.set_read_writes( group_snode.set_read_writes(
@ -1754,7 +1754,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
@staticmethod @staticmethod
def set_group_algorithm_for_combo_kernels( def set_group_algorithm_for_combo_kernels(
custom_group_algorithm: Callable[[Scheduler], list[list[BaseSchedulerNode]]] custom_group_algorithm: Callable[[Scheduler], list[list[BaseSchedulerNode]]],
) -> None: ) -> None:
ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels = ( ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels = (
custom_group_algorithm custom_group_algorithm
@ -1975,9 +1975,9 @@ class Scheduler:
for node in self.nodes: for node in self.nodes:
node.prune_deps() node.prune_deps()
self.name_to_donated_buffer: dict[ self.name_to_donated_buffer: dict[str, SchedulerDonatedBuffer] = (
str, SchedulerDonatedBuffer self.get_donated_buffers()
] = self.get_donated_buffers() )
self.name_to_node: dict[str, BaseSchedulerNode] = { self.name_to_node: dict[str, BaseSchedulerNode] = {
n.get_name(): n for n in self.nodes n.get_name(): n for n in self.nodes
} }
@ -2099,9 +2099,9 @@ class Scheduler:
node.log_details() node.log_details()
def create_scheduler_node(self, node: ir.Operation) -> BaseSchedulerNode: def create_scheduler_node(self, node: ir.Operation) -> BaseSchedulerNode:
assert ( assert node.get_origins() is not None, (
node.get_origins() is not None "All nodes passed to scheduling must have an origin"
), "All nodes passed to scheduling must have an origin" )
if node.is_no_op(): if node.is_no_op():
return NopKernelSchedulerNode(self, node) return NopKernelSchedulerNode(self, node)
elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)): elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)):
@ -2260,9 +2260,9 @@ class Scheduler:
) )
# if a kernel takes unbacked symints, register dependencies # if a kernel takes unbacked symints, register dependencies
for s in unbacked_symbol_uses: for s in unbacked_symbol_uses:
assert ( assert s in unbacked_symbol_to_origin_node, (
s in unbacked_symbol_to_origin_node f"{s} not in {unbacked_symbol_to_origin_node}"
), f"{s} not in {unbacked_symbol_to_origin_node}" )
if (r := unbacked_symbol_to_origin_node[s]) is not None: if (r := unbacked_symbol_to_origin_node[s]) is not None:
for buf in self.name_to_node[r].get_outputs(): for buf in self.name_to_node[r].get_outputs():
node.add_fake_dep(StarDep(buf.get_name())) node.add_fake_dep(StarDep(buf.get_name()))
@ -2310,9 +2310,9 @@ class Scheduler:
for alt_name in buf.get_mutations(): for alt_name in buf.get_mutations():
self.mutation_renames[rename(alt_name)] = buf.get_name() self.mutation_renames[rename(alt_name)] = buf.get_name()
self.mutation_renames[alt_name] = buf.get_name() self.mutation_renames[alt_name] = buf.get_name()
self.mutation_real_name[ self.mutation_real_name[buf.get_name()] = (
buf.get_name() self.mutation_real_name.get(alt_name, alt_name)
] = self.mutation_real_name.get(alt_name, alt_name) )
# make sure outputs aren't dead-code-eliminated # make sure outputs aren't dead-code-eliminated
for buf_name in V.graph.get_output_names(): for buf_name in V.graph.get_output_names():
@ -2322,9 +2322,9 @@ class Scheduler:
# make sure unbacked symints aren't dead-code-eliminated # make sure unbacked symints aren't dead-code-eliminated
for out in V.graph.graph_outputs: for out in V.graph.graph_outputs:
for s in out.get_unbacked_symbol_uses(): for s in out.get_unbacked_symbol_uses():
assert ( assert s in unbacked_symbol_to_origin_node, (
s in unbacked_symbol_to_origin_node f"{s} not in {unbacked_symbol_to_origin_node.keys()}"
), f"{s} not in {unbacked_symbol_to_origin_node.keys()}" )
if r := unbacked_symbol_to_origin_node[s]: if r := unbacked_symbol_to_origin_node[s]:
for buf_name in self.name_to_node[r].get_buffer_names(): for buf_name in self.name_to_node[r].get_buffer_names():
log.debug( log.debug(
@ -3304,15 +3304,15 @@ class Scheduler:
rhs_dep = node2_name2dep[buf_name] rhs_dep = node2_name2dep[buf_name]
if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep): if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep):
reasons[ reasons[buf_name] = (
buf_name f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}"
] = f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}" )
continue continue
if lhs_dep.get_numel() != rhs_dep.get_numel(): if lhs_dep.get_numel() != rhs_dep.get_numel():
reasons[ reasons[buf_name] = (
buf_name f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}"
] = f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}" )
continue continue
# same numel but different MemoryDep.size. Should be broadcasting # same numel but different MemoryDep.size. Should be broadcasting
@ -3340,9 +3340,9 @@ class Scheduler:
layout_str = "" layout_str = ""
if not isinstance(buf, ir.TorchBindObject): if not isinstance(buf, ir.TorchBindObject):
layout_str = f"Layout: {buf.layout}" layout_str = f"Layout: {buf.layout}"
reasons[ reasons[buf_name] = (
buf_name f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. {layout_str}"
] = f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. {layout_str}" )
return str(reasons) return str(reasons)
@ -3903,9 +3903,9 @@ class Scheduler:
self.free_buffers() self.free_buffers()
def create_backend(self, device: torch.device) -> BaseScheduling: def create_backend(self, device: torch.device) -> BaseScheduling:
assert ( assert not is_gpu(device.type) or device.index is not None, (
not is_gpu(device.type) or device.index is not None f"{device} should have been normalized in lowering"
), f"{device} should have been normalized in lowering" )
V.graph.add_device_info(device) V.graph.add_device_info(device)
device_scheduling = get_scheduling_for_device(device.type) device_scheduling = get_scheduling_for_device(device.type)
@ -4135,9 +4135,9 @@ class Scheduler:
partitions, signatures = self.graph_partition() partitions, signatures = self.graph_partition()
for partition, signature in zip(partitions, signatures): for partition, signature in zip(partitions, signatures):
assert ( assert len(partition) >= 1, (
len(partition) >= 1 f"Each partition must have at least one node but found {len(partition)}"
), f"Each partition must have at least one node but found {len(partition)}" )
if signature.skip_cudagraph: if signature.skip_cudagraph:
self._codegen(partition) self._codegen(partition)

View File

@ -168,9 +168,9 @@ class PartialRender:
) )
else: else:
return return
assert ( assert self.replacement_hooks[hook_key] is not None, (
self.replacement_hooks[hook_key] is not None "hook_key can only be called once"
), "hook_key can only be called once" )
self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]()) self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]())
self.replacement_hooks[hook_key] = None self.replacement_hooks[hook_key] = None
@ -257,9 +257,9 @@ class ModificationWrapper(V.WrapperHandler): # type: ignore[name-defined]
This is used by flex_attention's backwards grad for captured buffers, see This is used by flex_attention's backwards grad for captured buffers, see
zeros_and_scatter lowering zeros_and_scatter lowering
""" """
assert ( assert self.mask is not None, (
self.mask is not None "Mask is required for inner stores in modifications"
), "Mask is required for inner stores in modifications" )
assert mode == "atomic_add", "Only atomic_add is supported for inner stores" assert mode == "atomic_add", "Only atomic_add is supported for inner stores"
buf_name = self._add_kernel_input(name) buf_name = self._add_kernel_input(name)
@ -573,12 +573,12 @@ class TritonTemplateKernel(TritonKernel):
def _get_subgraph(self, subgraph_number: int): def _get_subgraph(self, subgraph_number: int):
assert isinstance(subgraph_number, int) assert isinstance(subgraph_number, int)
assert isinstance(self.subgraphs, list) assert isinstance(self.subgraphs, list)
assert subgraph_number < len( assert subgraph_number < len(self.subgraphs), (
self.subgraphs f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}"
), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}" )
assert ( assert self.body.getvalue() == "", (
self.body.getvalue() == "" "Body should be clear before adding a modification"
), "Body should be clear before adding a modification" )
return self.subgraphs[subgraph_number] return self.subgraphs[subgraph_number]
def _handle_scatter_graph(self, scatter_graph): def _handle_scatter_graph(self, scatter_graph):
@ -587,9 +587,9 @@ class TritonTemplateKernel(TritonKernel):
Args: Args:
scatter_graph: The scatter graph to process scatter_graph: The scatter graph to process
""" """
assert isinstance( assert isinstance(scatter_graph, ir.ComputedBuffer), (
scatter_graph, ir.ComputedBuffer f"scatter_graph must be an instance of ComputeBuffer but got {type(scatter_graph)}"
), f"scatter_graph must be an instance of ComputeBuffer but got {type(scatter_graph)}" )
def contiguous_strides(x): def contiguous_strides(x):
# We always create a fresh contiguous grad for scattering into # We always create a fresh contiguous grad for scattering into
@ -597,7 +597,9 @@ class TritonTemplateKernel(TritonKernel):
x_i * stride for x_i, stride in zip(x, scatter_graph.get_stride()) x_i * stride for x_i, stride in zip(x, scatter_graph.get_stride())
) )
return scatter_graph.data.store_output(scatter_graph.name, contiguous_strides, []) # type: ignore[attr-defined] return scatter_graph.data.store_output( # type: ignore[attr-defined]
scatter_graph.name, contiguous_strides, []
)
def modification( def modification(
self, self,
@ -626,9 +628,9 @@ class TritonTemplateKernel(TritonKernel):
self, subgraph_number, fixed_inputs, mask self, subgraph_number, fixed_inputs, mask
) )
with V.set_ops_handler(modification_handler): with V.set_ops_handler(modification_handler):
assert isinstance( assert isinstance(subgraph, (ir.ComputedBuffer, list)), (
subgraph, (ir.ComputedBuffer, list) f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}"
), f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}" )
# Handle scatter stores # Handle scatter stores
if isinstance(subgraph, list): if isinstance(subgraph, list):
for scatter_graph in subgraph: for scatter_graph in subgraph:
@ -1123,15 +1125,17 @@ class TritonTemplate(KernelTemplate):
"subgraphs": subgraphs, "subgraphs": subgraphs,
} }
with patch.object( with (
V.graph, "get_dtype", self._fake_get_dtype(fake_out) patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_out)),
), V.graph.set_current_device(layout.device), TritonTemplateKernel( V.graph.set_current_device(layout.device),
kernel_name=kernel_name, TritonTemplateKernel(
output_node=fake_out, kernel_name=kernel_name,
workspace_arg=workspace_arg, output_node=fake_out,
use_jit=False, workspace_arg=workspace_arg,
**kernel_options, use_jit=False,
) as kernel: **kernel_options,
) as kernel,
):
try: try:
template = kernel.render(self.template, kwargs) template = kernel.render(self.template, kwargs)
with kernel.set_subgraph_body("<STORE_OUTPUT>"): with kernel.set_subgraph_body("<STORE_OUTPUT>"):
@ -1442,9 +1446,9 @@ class ExternKernelCaller(ChoiceCaller):
def output_node(self): def output_node(self):
if self.choice.use_fallback_kernel: if self.choice.use_fallback_kernel:
assert ( assert self.choice.op_overload is not None, (
self.choice.op_overload is not None "Please provide an op_overload to use ir.FallbackKernel"
), "Please provide an op_overload to use ir.FallbackKernel" )
inner = ir.FallbackKernel.create( inner = ir.FallbackKernel.create(
self.choice.op_overload, *self.input_nodes, **self.kwargs self.choice.op_overload, *self.input_nodes, **self.kwargs
) )
@ -1979,7 +1983,7 @@ class AlgorithmSelectorCache(PersistentCache):
input_gen_fns = {} input_gen_fns = {}
def get_inputs( def get_inputs(
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]] choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]],
) -> AutotuneArgs: ) -> AutotuneArgs:
# de-duplicate args # de-duplicate args
unique_example_inputs = { unique_example_inputs = {
@ -2099,7 +2103,7 @@ class AlgorithmSelectorCache(PersistentCache):
return timings return timings
def benchmark_in_sub_process( def benchmark_in_sub_process(
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]] choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]],
): ):
from . import autotune_process from . import autotune_process
@ -2139,7 +2143,8 @@ class AlgorithmSelectorCache(PersistentCache):
map( map(
str, str,
V.graph.sizevars.size_hints( V.graph.sizevars.size_hints(
n.get_size(), fallback=config.unbacked_symint_fallback # type: ignore[arg-type] n.get_size(),
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
), ),
) )
) )
@ -2313,15 +2318,15 @@ def autotune_select_algorithm(*args, **kwargs):
_ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache() _ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache()
if "return_multi_template" not in kwargs: if "return_multi_template" not in kwargs:
kwargs[ kwargs["return_multi_template"] = (
"return_multi_template" torch._inductor.config.benchmark_epilogue_fusion
] = torch._inductor.config.benchmark_epilogue_fusion )
return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs) return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
def add_feedback_saver( def add_feedback_saver(
fn: Callable[[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None] fn: Callable[[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None],
): ):
global _ALGORITHM_SELECTOR_CACHE global _ALGORITHM_SELECTOR_CACHE
if _ALGORITHM_SELECTOR_CACHE is None: if _ALGORITHM_SELECTOR_CACHE is None:

View File

@ -905,9 +905,9 @@ class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined]
def __init__(self, inner, var_ranges: VarRanges) -> None: def __init__(self, inner, var_ranges: VarRanges) -> None:
super().__init__(inner) super().__init__(inner)
self.name = "SimplifyIndexing" self.name = "SimplifyIndexing"
self._simplify: Callable[ self._simplify: Callable[[Expr], Expr] = (
[Expr], Expr lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges)
] = lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges) )
def load(self, name: str, index: sympy.Expr): def load(self, name: str, index: sympy.Expr):
return self._inner.load(name, self._simplify(index)) return self._inner.load(name, self._simplify(index))

View File

@ -283,9 +283,9 @@ def ceildiv(
# TODO: There is a bug in a call to this function, to repro: # TODO: There is a bug in a call to this function, to repro:
# python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy # python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy
# --amp --only YituTechConvBert --dynamic-shapes # --amp --only YituTechConvBert --dynamic-shapes
assert isinstance(numer, int) and isinstance( assert isinstance(numer, int) and isinstance(denom, int), (
denom, int f"{numer}: {type(numer)}, {denom}: {type(denom)}"
), f"{numer}: {type(numer)}, {denom}: {type(denom)}" )
return runtime_ceildiv(numer, denom) return runtime_ceildiv(numer, denom)
@ -325,7 +325,7 @@ def _type_of(key: Optional[torch.dtype]) -> str:
def convert_shape_to_inductor( def convert_shape_to_inductor(
lst: Iterable[Union[int, torch.SymInt]] lst: Iterable[Union[int, torch.SymInt]],
) -> list[sympy.Expr]: ) -> list[sympy.Expr]:
""" """
Gets the shape and stride of a tensor. For non-symbolic tensors, this is Gets the shape and stride of a tensor. For non-symbolic tensors, this is
@ -502,11 +502,9 @@ RV = TypeVar("RV", covariant=True)
class CachedMethod(Protocol, Generic[P, RV]): class CachedMethod(Protocol, Generic[P, RV]):
@staticmethod @staticmethod
def clear_cache(cache: Any) -> None: def clear_cache(cache: Any) -> None: ...
...
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: ...
...
# See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature # See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature
@ -1359,9 +1357,9 @@ def _rocm_native_device_arch_name(device: str) -> str:
@functools.lru_cache(None) @functools.lru_cache(None)
def try_import_ck_lib() -> ( def try_import_ck_lib() -> tuple[
tuple[Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any]] Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any]
): ]:
try: try:
import ck4inductor # type: ignore[import] import ck4inductor # type: ignore[import]
from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import] from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import]
@ -1610,9 +1608,12 @@ def get_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> list[str]:
return DummyModule() return DummyModule()
with mock.patch.object( with (
GraphLowering, "compile_to_module", patched_compile_to_module mock.patch.object(
), mock.patch.object(GraphLowering, "save_output_code", save_output_code): GraphLowering, "compile_to_module", patched_compile_to_module
),
mock.patch.object(GraphLowering, "save_output_code", save_output_code),
):
torch._dynamo.reset() torch._dynamo.reset()
# Note the return here is None # Note the return here is None
_ = fn(*args, **kwargs) _ = fn(*args, **kwargs)
@ -1623,18 +1624,18 @@ def get_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> list[str]:
def get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> str: def get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> str:
source_codes = get_code(fn, *args, **kwargs) source_codes = get_code(fn, *args, **kwargs)
# Can have two outputs if backwards was eagerly compiled # Can have two outputs if backwards was eagerly compiled
assert ( assert 1 <= len(source_codes) <= 2, (
1 <= len(source_codes) <= 2 f"expected one or two code outputs got {len(source_codes)}"
), f"expected one or two code outputs got {len(source_codes)}" )
return source_codes[0] return source_codes[0]
def run_and_get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> str: def run_and_get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> str:
_, source_codes = run_and_get_code(fn, *args, **kwargs) _, source_codes = run_and_get_code(fn, *args, **kwargs)
# Can have two outputs if backwards was eagerly compiled # Can have two outputs if backwards was eagerly compiled
assert ( assert 1 <= len(source_codes) <= 2, (
1 <= len(source_codes) <= 2 f"expected one or two code outputs got {len(source_codes)}"
), f"expected one or two code outputs got {len(source_codes)}" )
return source_codes[0] return source_codes[0]
@ -1760,9 +1761,9 @@ def is_cpu_device(inputs: Sequence[torch.Tensor]) -> bool:
def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype: def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype:
assert isinstance( assert isinstance(val, sympy.Expr), (
val, sympy.Expr "only support sympy.Expr as input to get_sympy_Expr_dtype"
), "only support sympy.Expr as input to get_sympy_Expr_dtype" )
if val.is_integer: # type: ignore[attr-defined] if val.is_integer: # type: ignore[attr-defined]
return torch.int64 return torch.int64
else: else:
@ -1932,7 +1933,7 @@ def is_multi_outputs_template(input_buf: Optional[Union[Buffer, Operation]]) ->
def is_output_of_multi_outputs_template( def is_output_of_multi_outputs_template(
input_buf: Optional[Union[Buffer, Operation]] input_buf: Optional[Union[Buffer, Operation]],
) -> bool: ) -> bool:
""" """
Check if input buffer is a output of multi-outputs template buffer Check if input buffer is a output of multi-outputs template buffer
@ -2633,7 +2634,8 @@ def set_kernel_post_grad_provenance_tracing(
if node not in (EnableReduction, DisableReduction): if node not in (EnableReduction, DisableReduction):
if node.node is not None: if node.node is not None:
V.debug._inductor_triton_kernel_to_post_grad_node_info[kernel_name] = [ V.debug._inductor_triton_kernel_to_post_grad_node_info[kernel_name] = [
origin.name for origin in node.node.origins # type: ignore[attr-defined] origin.name
for origin in node.node.origins # type: ignore[attr-defined]
] ]

View File

@ -314,9 +314,9 @@ class _V:
KernelFormatterHandler = KernelFormatterHandler KernelFormatterHandler = KernelFormatterHandler
WrapperHandler = WrapperHandler WrapperHandler = WrapperHandler
set_ops_handler: Callable[ set_ops_handler: Callable[[OpsHandler[Any]], AbstractContextManager[None]] = (
[OpsHandler[Any]], AbstractContextManager[None] _ops._set_handler
] = _ops._set_handler )
get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler
set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler
set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler

View File

@ -14,8 +14,7 @@ from .runtime.runtime_utils import create_bandwidth_info_str, get_num_bytes
class BenchmarkCallableType(Protocol): class BenchmarkCallableType(Protocol):
def __call__(self, times: int, repeat: int) -> float: def __call__(self, times: int, repeat: int) -> float: ...
...
_kernel_category_choices = [ _kernel_category_choices = [
@ -138,9 +137,9 @@ def benchmark_all_kernels(
) )
else: else:
ms = benchmarker.benchmark_gpu(lambda: kernel_mod.call(args), rep=40) ms = benchmarker.benchmark_gpu(lambda: kernel_mod.call(args), rep=40)
assert ( assert len(triton_kernel.launchers) == 1, (
len(triton_kernel.launchers) == 1 "Autotuner should have selected the best config"
), "Autotuner should have selected the best config" )
launcher = triton_kernel.launchers[0] launcher = triton_kernel.launchers[0]
print( print(
get_info_str( get_info_str(
@ -256,9 +255,9 @@ def parse_profile_event_list(
"triton_unknown", "triton_unknown",
"unknown", "unknown",
] ]
assert OrderedSet(all_events.keys()).issubset( assert OrderedSet(all_events.keys()).issubset(OrderedSet(category_list)), (
OrderedSet(category_list) f"{list(all_events.keys())}"
), f"{list(all_events.keys())}" )
per_category_wall_time = {} per_category_wall_time = {}
total_device_ms = 0.0 total_device_ms = 0.0