mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 00:20:18 +01:00
[BE][PYFMT] migrate PYFMT for torch._inductor to ruff format (#144550)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144550 Approved by: https://github.com/jansel
This commit is contained in:
parent
34d726011f
commit
1cb4e2df65
|
|
@ -53,7 +53,6 @@ USE_BLACK_FILELIST = re.compile(
|
|||
# torch/_[e-h]*/**
|
||||
"torch/_[e-h]*/**",
|
||||
# torch/_i*/**
|
||||
"torch/_i*/**",
|
||||
# torch/_[j-z]*/**
|
||||
"torch/_[j-z]*/**",
|
||||
# torch/[a-c]*/**
|
||||
|
|
|
|||
|
|
@ -66,7 +66,9 @@ def aoti_compile_and_package(
|
|||
.. code-block:: python
|
||||
|
||||
ep = torch.export.export(M(), ...)
|
||||
aoti_file = torch._inductor.aoti_compile_and_package(ep, package_path="my_package.pt2")
|
||||
aoti_file = torch._inductor.aoti_compile_and_package(
|
||||
ep, package_path="my_package.pt2"
|
||||
)
|
||||
compiled_model = torch._inductor.aoti_load_package("my_package.pt2")
|
||||
|
||||
To compile and save multiple models into a single ``.pt2`` artifact, you can do
|
||||
|
|
@ -75,11 +77,16 @@ def aoti_compile_and_package(
|
|||
.. code-block:: python
|
||||
|
||||
ep1 = torch.export.export(M1(), ...)
|
||||
aoti_file1 = torch._inductor.aot_compile(ep1, ..., options={"aot_inductor.package": True})
|
||||
aoti_file1 = torch._inductor.aot_compile(
|
||||
ep1, ..., options={"aot_inductor.package": True}
|
||||
)
|
||||
ep2 = torch.export.export(M2(), ...)
|
||||
aoti_file2 = torch._inductor.aot_compile(ep2, ..., options={"aot_inductor.package": True})
|
||||
aoti_file2 = torch._inductor.aot_compile(
|
||||
ep2, ..., options={"aot_inductor.package": True}
|
||||
)
|
||||
|
||||
from torch._inductor.package import package_aoti, load_package
|
||||
|
||||
package_aoti("my_package.pt2", {"model1": aoti_file1, "model2": aoti_file2})
|
||||
|
||||
compiled_model1 = load_package("my_package.pt2", "model1")
|
||||
|
|
@ -123,7 +130,9 @@ def aoti_compile_and_package(
|
|||
isinstance(package_path, (str, os.PathLike))
|
||||
and os.fspath(package_path).endswith(".pt2")
|
||||
)
|
||||
), f"Expect package path to be a file ending in .pt2, is None, or is a buffer. Instead got {package_path}"
|
||||
), (
|
||||
f"Expect package path to be a file ending in .pt2, is None, or is a buffer. Instead got {package_path}"
|
||||
)
|
||||
|
||||
inductor_configs = inductor_configs or {}
|
||||
inductor_configs["aot_inductor.package"] = True
|
||||
|
|
@ -168,9 +177,9 @@ def _aoti_compile_and_package_inner(
|
|||
"""
|
||||
|
||||
if check_accuracy:
|
||||
assert (
|
||||
kwargs is None or len(kwargs) == 0
|
||||
), "when checking for accuracy, the inputs must have been flattened and kwargs is None"
|
||||
assert kwargs is None or len(kwargs) == 0, (
|
||||
"when checking for accuracy, the inputs must have been flattened and kwargs is None"
|
||||
)
|
||||
|
||||
from .package import package_aoti
|
||||
|
||||
|
|
|
|||
|
|
@ -156,8 +156,9 @@ def can_codegen_without_upcasts(
|
|||
low_prec_analysis = RecordLowPrecisionOps(disallow_fp32_ops)
|
||||
|
||||
# Need to turn off upcasting to do analysis of whether we can turn it off
|
||||
with config.patch("triton.codegen_upcast_to_fp32", False), V.set_ops_handler(
|
||||
low_prec_analysis
|
||||
with (
|
||||
config.patch("triton.codegen_upcast_to_fp32", False),
|
||||
V.set_ops_handler(low_prec_analysis),
|
||||
):
|
||||
prologue._body(*prologue.get_ranges())
|
||||
|
||||
|
|
|
|||
|
|
@ -245,8 +245,7 @@ class AsyncCompile:
|
|||
|
||||
def use_process_pool(self):
|
||||
return (
|
||||
get_compile_threads() > 1
|
||||
and self.process_pool().ready_future.done() # type: ignore[union-attr]
|
||||
get_compile_threads() > 1 and self.process_pool().ready_future.done() # type: ignore[union-attr]
|
||||
)
|
||||
|
||||
def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):
|
||||
|
|
|
|||
|
|
@ -24,8 +24,7 @@ if TYPE_CHECKING:
|
|||
class Sortable(typing.Protocol):
|
||||
"""Anything that can be used as a list.sort() key (int/tuple/etc)"""
|
||||
|
||||
def __lt__(self, other: typing.Self) -> bool:
|
||||
...
|
||||
def __lt__(self, other: typing.Self) -> bool: ...
|
||||
|
||||
|
||||
class InductorChoices:
|
||||
|
|
@ -100,7 +99,9 @@ class InductorChoices:
|
|||
# to pick the faster one.
|
||||
if config.triton.multi_kernel:
|
||||
threshold *= 16
|
||||
return V.graph.sizevars.statically_known_leq(features.reduction_numel, threshold) # type: ignore[arg-types]
|
||||
return V.graph.sizevars.statically_known_leq(
|
||||
features.reduction_numel, threshold
|
||||
) # type: ignore[arg-types]
|
||||
|
||||
@staticmethod
|
||||
def want_no_x_dim(features: SIMDKernelFeatures) -> bool:
|
||||
|
|
|
|||
|
|
@ -417,9 +417,9 @@ def write_atomic(
|
|||
) -> None:
|
||||
# Write into temporary file first to avoid conflicts between threads
|
||||
# Avoid using a named temporary file, as those have restricted permissions
|
||||
assert isinstance(
|
||||
content, (str, bytes)
|
||||
), "Only strings and byte arrays can be saved in the cache"
|
||||
assert isinstance(content, (str, bytes)), (
|
||||
"Only strings and byte arrays can be saved in the cache"
|
||||
)
|
||||
path = Path(path_)
|
||||
if make_dirs:
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
|
@ -975,9 +975,9 @@ class FxGraphCache:
|
|||
symints = FxGraphCache._filter_backed_symints(example_inputs)
|
||||
hints = [hint_int(s) for s in symints]
|
||||
|
||||
def iterate_over_candidates() -> (
|
||||
Generator[tuple[CompiledFxGraph, bytes], None, None]
|
||||
):
|
||||
def iterate_over_candidates() -> Generator[
|
||||
tuple[CompiledFxGraph, bytes], None, None
|
||||
]:
|
||||
if local:
|
||||
subdir = FxGraphCache._get_tmp_dir_for_key(key)
|
||||
if os.path.exists(subdir):
|
||||
|
|
@ -1123,9 +1123,9 @@ class FxGraphCache:
|
|||
"""
|
||||
from .compile_fx import CompiledFxGraph
|
||||
|
||||
assert isinstance(
|
||||
compiled_graph, CompiledFxGraph
|
||||
), f"serialization for {type(compiled_graph)} NYI"
|
||||
assert isinstance(compiled_graph, CompiledFxGraph), (
|
||||
f"serialization for {type(compiled_graph)} NYI"
|
||||
)
|
||||
disk_compiled_graph = copy(compiled_graph)
|
||||
disk_compiled_graph.prepare_for_serialization()
|
||||
|
||||
|
|
@ -1315,9 +1315,8 @@ class FxGraphCache:
|
|||
"distributed_ephemeral_timeout_us", time_saved_ns // 1000
|
||||
)
|
||||
if (
|
||||
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(
|
||||
time_saved_ns
|
||||
)
|
||||
ephemeral_increase
|
||||
:= add_ephemeral_timeout_increase_for_distributed(time_saved_ns)
|
||||
) != 0:
|
||||
cache_info["ephemeral_timeout_increase"] = ephemeral_increase
|
||||
else:
|
||||
|
|
@ -1556,9 +1555,9 @@ class AotCodeCompiler:
|
|||
cpp_path_operator.with_name(f"{cpp_path_operator.stem}_metadata.json")
|
||||
)
|
||||
for k, v in config.aot_inductor.metadata.items():
|
||||
assert isinstance(k, str) and isinstance(
|
||||
v, (str)
|
||||
), "Metadata must only contain strings"
|
||||
assert isinstance(k, str) and isinstance(v, (str)), (
|
||||
"Metadata must only contain strings"
|
||||
)
|
||||
|
||||
with open(meta_json, "w") as f:
|
||||
f.write(json.dumps(config.aot_inductor.metadata))
|
||||
|
|
|
|||
|
|
@ -341,7 +341,7 @@ class BackendFeature(Enum):
|
|||
|
||||
|
||||
def get_backend_features(
|
||||
device: Union[torch.device, str, None]
|
||||
device: Union[torch.device, str, None],
|
||||
) -> OrderedSet[BackendFeature]:
|
||||
if device is None:
|
||||
return OrderedSet()
|
||||
|
|
@ -986,9 +986,9 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
|
|||
if cls._is_unimplemented(funcname):
|
||||
setattr(cls, funcname, cls._unimplemented(funcname))
|
||||
else:
|
||||
assert (
|
||||
funcname not in cls.__dict__
|
||||
), f"multiple definitions of {funcname} on {cls.__name__}"
|
||||
assert funcname not in cls.__dict__, (
|
||||
f"multiple definitions of {funcname} on {cls.__name__}"
|
||||
)
|
||||
impl.__name__ = funcname
|
||||
setattr(cls, funcname, staticmethod(impl))
|
||||
|
||||
|
|
@ -2229,7 +2229,7 @@ class KernelTemplate:
|
|||
|
||||
@staticmethod
|
||||
def _fake_get_dtype(
|
||||
fake_outs: Union[list[Buffer], Buffer]
|
||||
fake_outs: Union[list[Buffer], Buffer],
|
||||
) -> Callable[[str], torch.dtype]:
|
||||
_get_dtype_real = V.graph.get_dtype
|
||||
if isinstance(fake_outs, (list, tuple)):
|
||||
|
|
|
|||
|
|
@ -483,9 +483,9 @@ class OuterLoopFusedSchedulerNode(FusedSchedulerNode):
|
|||
outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]],
|
||||
outer_loop_fusion_depth,
|
||||
):
|
||||
self.outer_fused_nodes: list[
|
||||
Union[FusedSchedulerNode, SchedulerNode]
|
||||
] = outer_fused_nodes
|
||||
self.outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]] = (
|
||||
outer_fused_nodes
|
||||
)
|
||||
self.outer_loop_fusion_depth = outer_loop_fusion_depth
|
||||
flatten_snodes = []
|
||||
for _node in self.outer_fused_nodes:
|
||||
|
|
@ -1361,9 +1361,9 @@ class CppVecOverrides(CppOverrides):
|
|||
|
||||
@staticmethod
|
||||
def remainder(a, b):
|
||||
assert (
|
||||
a.dtype == b.dtype
|
||||
), "remainder vec implementation expect the same inputs' dtype."
|
||||
assert a.dtype == b.dtype, (
|
||||
"remainder vec implementation expect the same inputs' dtype."
|
||||
)
|
||||
return f"{a} - ({CppVecOverrides.floordiv(a, b)}) * {b}"
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -1468,9 +1468,9 @@ class CppVecOverrides(CppOverrides):
|
|||
@staticmethod
|
||||
def floordiv(a, b):
|
||||
if is_float_dtype(a.dtype):
|
||||
assert (
|
||||
a.dtype == b.dtype
|
||||
), "div_floor_floating_vec implementation expect the same inputs' dtype."
|
||||
assert a.dtype == b.dtype, (
|
||||
"div_floor_floating_vec implementation expect the same inputs' dtype."
|
||||
)
|
||||
return f"div_floor_floating_vec({a}, {b})"
|
||||
else:
|
||||
assert all(is_integer_dtype(item.dtype) for item in [a, b])
|
||||
|
|
@ -1629,9 +1629,9 @@ class CppVecOverrides(CppOverrides):
|
|||
assert isinstance(other_vec_var, CppCSEVariable), other_vec_var
|
||||
body_vec_var.dtype = dtype
|
||||
other_vec_var.dtype = dtype
|
||||
overrides: type[
|
||||
Union[CppOverrides, CppVecOverrides]
|
||||
] = V.kernel.overrides # type: ignore[has-type]
|
||||
overrides: type[Union[CppOverrides, CppVecOverrides]] = (
|
||||
V.kernel.overrides
|
||||
) # type: ignore[has-type]
|
||||
code.writeline(
|
||||
f"return {overrides.where(new_mask, body_vec_var, other_vec_var)};"
|
||||
)
|
||||
|
|
@ -2108,9 +2108,9 @@ class CppKernel(Kernel):
|
|||
|
||||
def set_ranges(self, lengths, reduction_lengths):
|
||||
if self.call_ranges:
|
||||
assert self.call_ranges == tuple(lengths) + tuple(
|
||||
reduction_lengths
|
||||
), f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}"
|
||||
assert self.call_ranges == tuple(lengths) + tuple(reduction_lengths), (
|
||||
f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}"
|
||||
)
|
||||
assert self.reduction_depth == len(lengths)
|
||||
else:
|
||||
self.call_ranges = tuple(lengths) + tuple(reduction_lengths)
|
||||
|
|
@ -2787,9 +2787,9 @@ class CppVecKernel(CppKernel):
|
|||
self.weight_recps_val = self.weight_recps_cse.generate(
|
||||
self.compute, f"reduction {self.weight_recp_vec_range}", write=False
|
||||
)
|
||||
self.weight_recps_cse.reduction_cache[
|
||||
self.weight_recp_vec_range
|
||||
] = self.weight_recps_val
|
||||
self.weight_recps_cse.reduction_cache[self.weight_recp_vec_range] = (
|
||||
self.weight_recps_val
|
||||
)
|
||||
self.non_parallel_reduction_prefix.writeline(
|
||||
self.welford_weight_reciprocal_vec(dtype)
|
||||
)
|
||||
|
|
@ -4969,9 +4969,9 @@ class CppScheduling(BaseScheduling):
|
|||
]
|
||||
|
||||
counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes)
|
||||
assert self.is_cpp_template(
|
||||
template_node
|
||||
), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer"
|
||||
assert self.is_cpp_template(template_node), (
|
||||
"Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer"
|
||||
)
|
||||
template_node = cast(SchedulerNode, template_node)
|
||||
_, (_, rnumel) = template_node.group
|
||||
assert rnumel == ()
|
||||
|
|
@ -4979,9 +4979,9 @@ class CppScheduling(BaseScheduling):
|
|||
epilogue_ir_nodes: list[Optional[ir.Operation]] = [
|
||||
n.node for n in epilogue_nodes
|
||||
]
|
||||
assert all(
|
||||
isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes
|
||||
), "Epilogue nodes must all be instances of ir.ComputedBuffer"
|
||||
assert all(isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes), (
|
||||
"Epilogue nodes must all be instances of ir.ComputedBuffer"
|
||||
)
|
||||
|
||||
def template_buffer_has_other_users(
|
||||
template_buffer, outputs_by_name, epilogue_nodes
|
||||
|
|
@ -5019,16 +5019,16 @@ class CppScheduling(BaseScheduling):
|
|||
if is_multi_outputs_template(template_node.node):
|
||||
# For multi outputs template, allocate buffers for each output after the epilogue
|
||||
# codegen to which determines if the buffer has been removed.
|
||||
assert (
|
||||
len(template_node.outputs) == 1
|
||||
), "Multi outputs template should be with 1 output template buffer of MultiOutputLayout"
|
||||
assert len(template_node.outputs) == 1, (
|
||||
"Multi outputs template should be with 1 output template buffer of MultiOutputLayout"
|
||||
)
|
||||
for user in template_node.outputs[0].users:
|
||||
assert isinstance(
|
||||
user.node, ExternKernelSchedulerNode
|
||||
), "Multi outputs template should be with ExternKernelSchedulerNode"
|
||||
assert isinstance(
|
||||
user.node.node, ir.MultiOutput
|
||||
), "Multi outputs template has multi users with MultiOutput"
|
||||
assert isinstance(user.node, ExternKernelSchedulerNode), (
|
||||
"Multi outputs template should be with ExternKernelSchedulerNode"
|
||||
)
|
||||
assert isinstance(user.node.node, ir.MultiOutput), (
|
||||
"Multi outputs template has multi users with MultiOutput"
|
||||
)
|
||||
user.node.mark_run()
|
||||
|
||||
kernel.call_kernel(kernel_name, ctb)
|
||||
|
|
@ -5347,9 +5347,9 @@ class LoopNest:
|
|||
return self.loops is not None and self.loops[0].is_reduction
|
||||
|
||||
def mark_parallel(self, par_depth):
|
||||
assert (
|
||||
par_depth <= self.max_parallel_depth()
|
||||
), "Parallel depth cannot exceed the maximal allowed parallel depth"
|
||||
assert par_depth <= self.max_parallel_depth(), (
|
||||
"Parallel depth cannot exceed the maximal allowed parallel depth"
|
||||
)
|
||||
assert self.loops is not None
|
||||
assert len(self.loops) >= par_depth
|
||||
loop = self.loops[0]
|
||||
|
|
|
|||
|
|
@ -862,7 +862,9 @@ class CppFlexAttentionTemplate(CppTemplate):
|
|||
assert all(
|
||||
mem.buffer_name in kernel_group.args.input_buffers
|
||||
for mem in body.memory_usage[MemoryUsageType.LOAD]
|
||||
), "All the buffers in the score and mask subgraph should be in kernel_group.args.input_buffers"
|
||||
), (
|
||||
"All the buffers in the score and mask subgraph should be in kernel_group.args.input_buffers"
|
||||
)
|
||||
|
||||
bodies.append(body)
|
||||
var_sizes_list.append((var_sizes, ()))
|
||||
|
|
|
|||
|
|
@ -557,9 +557,9 @@ class CppGemmTemplate(CppTemplate):
|
|||
thread_block_m = math.ceil(m_blocks / m_factor)
|
||||
return GemmBlocking(thread_block_m, thread_block_n, thread_block_k)
|
||||
|
||||
assert (
|
||||
not self.is_dynamic_M
|
||||
), "Unable to determine thread blocking for dynamic M."
|
||||
assert not self.is_dynamic_M, (
|
||||
"Unable to determine thread blocking for dynamic M."
|
||||
)
|
||||
register_blocking = self.register_blocking
|
||||
m_blocks = math.ceil(self.m / register_blocking.block_m)
|
||||
n_blocks = math.ceil(self.n / register_blocking.block_n)
|
||||
|
|
@ -673,17 +673,17 @@ class CppGemmTemplate(CppTemplate):
|
|||
L1_cache_size = (
|
||||
torch._C._cpu._L1d_cache_size()
|
||||
) # per core cache size in Bytes
|
||||
assert (
|
||||
L1_cache_size > 0
|
||||
), f"Expect L1_cache_size > 0 but got {L1_cache_size}"
|
||||
assert L1_cache_size > 0, (
|
||||
f"Expect L1_cache_size > 0 but got {L1_cache_size}"
|
||||
)
|
||||
L1 = L1_cache_size * L1_limit_factor
|
||||
|
||||
L2_cache_size = (
|
||||
torch._C._cpu._L2_cache_size()
|
||||
) # per core cache size in Bytes
|
||||
assert (
|
||||
L2_cache_size > 0
|
||||
), f"Expect L2_cache_size > 0 but got {L2_cache_size}"
|
||||
assert L2_cache_size > 0, (
|
||||
f"Expect L2_cache_size > 0 but got {L2_cache_size}"
|
||||
)
|
||||
L2 = L2_cache_size * L2_limit_factor
|
||||
|
||||
def get_num_byte(dtype):
|
||||
|
|
@ -744,9 +744,9 @@ class CppGemmTemplate(CppTemplate):
|
|||
|
||||
return Mc_blocks, Nc_blocks, Kc_blocks
|
||||
|
||||
assert (
|
||||
not self.is_dynamic_M
|
||||
), "Unable to determine cache blocking for dynamic M."
|
||||
assert not self.is_dynamic_M, (
|
||||
"Unable to determine cache blocking for dynamic M."
|
||||
)
|
||||
register_blocking = self.register_blocking
|
||||
thread_blocking = self.thread_blocking(num_threads)
|
||||
|
||||
|
|
@ -1114,9 +1114,9 @@ class CppGemmTemplate(CppTemplate):
|
|||
LayoutType.VNNI4,
|
||||
], f"We only support {layout_str} for now"
|
||||
vnni_size = 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2
|
||||
assert (
|
||||
k % vnni_size == 0
|
||||
), f"k should be divisible by vnni_size for {layout_str} layout"
|
||||
assert k % vnni_size == 0, (
|
||||
f"k should be divisible by vnni_size for {layout_str} layout"
|
||||
)
|
||||
vnni_view_size = list(new_size)
|
||||
vnni_view_size[-2] = k // vnni_size
|
||||
vnni_view_size.insert(-1, vnni_size)
|
||||
|
|
|
|||
|
|
@ -309,9 +309,9 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
|||
for W_node in W_nodes:
|
||||
assert W_node.get_name() in V.graph.constants
|
||||
W_tensor.append(V.graph.constants[W_node.get_name()])
|
||||
new_input_nodes[
|
||||
wgt_start_idx : wgt_start_idx + gemm_grouped_num
|
||||
] = W_tensor # type: ignore[assignment]
|
||||
new_input_nodes[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = (
|
||||
W_tensor # type: ignore[assignment]
|
||||
)
|
||||
new_input_nodes, _ = pack_weight(
|
||||
*normalize_shapes(*maybe_to_dense(new_input_nodes, layout))
|
||||
)
|
||||
|
|
@ -321,9 +321,9 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
|||
W_packed = new_input_nodes[idx]
|
||||
assert isinstance(W_packed, torch.Tensor)
|
||||
W_packed_constant = V.graph.add_tensor_constant(W_packed)
|
||||
template_buffer.inputs[
|
||||
idx
|
||||
] = ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
|
||||
template_buffer.inputs[idx] = (
|
||||
ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
|
||||
)
|
||||
return output
|
||||
|
||||
template = DataProcessorTemplateWrapper(
|
||||
|
|
@ -419,9 +419,9 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
|||
ir.Buffer(name=gemm_output_name, layout=template_buffer.layout)
|
||||
)
|
||||
|
||||
assert (
|
||||
not self.epilogue_creator
|
||||
), "epilogue_creator is not supported yet in Grouped GEMM Template"
|
||||
assert not self.epilogue_creator, (
|
||||
"epilogue_creator is not supported yet in Grouped GEMM Template"
|
||||
)
|
||||
|
||||
kernel_args: dict[str, Optional[ir.IRNode]] = {}
|
||||
for x_idx in range(wgt_start_idx):
|
||||
|
|
|
|||
|
|
@ -231,9 +231,9 @@ micro_gemm_configs: dict[type[CppMicroGemm], list[CppMicroGemmConfig]] = {}
|
|||
|
||||
def register_micro_gemm(*configs):
|
||||
def inner(cls):
|
||||
assert (
|
||||
cls not in micro_gemm_configs
|
||||
), f"Duplicate micro_gemm registration for {cls}"
|
||||
assert cls not in micro_gemm_configs, (
|
||||
f"Duplicate micro_gemm registration for {cls}"
|
||||
)
|
||||
assert len(configs) > 0, f"No micro_gemm configs provided for {cls}"
|
||||
micro_gemm_configs[cls] = list(configs)
|
||||
return cls
|
||||
|
|
|
|||
|
|
@ -44,11 +44,13 @@ class CppTemplate(KernelTemplate):
|
|||
|
||||
def generate(self, **kwargs):
|
||||
kernel_name = f"cpp_{self.name}"
|
||||
with patch.object(
|
||||
V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
|
||||
), patch.object(ir.FlexibleLayout, "allow_indexing", True), CppTemplateKernel(
|
||||
kernel_name=kernel_name, num_threads=self.num_threads
|
||||
) as kernel:
|
||||
with (
|
||||
patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)),
|
||||
patch.object(ir.FlexibleLayout, "allow_indexing", True),
|
||||
CppTemplateKernel(
|
||||
kernel_name=kernel_name, num_threads=self.num_threads
|
||||
) as kernel,
|
||||
):
|
||||
code = kernel.render(self, **kwargs)
|
||||
_, call_args, _, _ = kernel.args.python_argdefs()
|
||||
log.debug("Generated Code:\n%s", code)
|
||||
|
|
|
|||
|
|
@ -377,7 +377,10 @@ class CppTemplateKernel(CppKernel):
|
|||
)
|
||||
epilogue_nodes = scope.localize_nodes(epilogue_nodes)
|
||||
return self.store_pointwise_nodes(
|
||||
dst, epilogue_nodes, offsets, reindexers # type: ignore[arg-type]
|
||||
dst,
|
||||
epilogue_nodes, # type: ignore[arg-type]
|
||||
offsets,
|
||||
reindexers,
|
||||
)
|
||||
else:
|
||||
if dst.get_name() != src.get_name():
|
||||
|
|
|
|||
|
|
@ -110,9 +110,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
Only valid when cuda == True.
|
||||
"""
|
||||
assert not gpu, "CppWrapperCpu.generate_kernel_call does not support GPU"
|
||||
assert arg_types is not None and len(call_args) == len(
|
||||
arg_types
|
||||
), "Mismatch call_args and arg_types in generate_kernel_call"
|
||||
assert arg_types is not None and len(call_args) == len(arg_types), (
|
||||
"Mismatch call_args and arg_types in generate_kernel_call"
|
||||
)
|
||||
new_args = []
|
||||
for idx, arg in enumerate(call_args):
|
||||
if "*" in arg_types[idx]:
|
||||
|
|
@ -506,9 +506,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
dtype = may_get_constant_buffer_dtype(
|
||||
V.graph.graph_inputs[input_key] # type: ignore[arg-type]
|
||||
)
|
||||
assert (
|
||||
dtype is not None
|
||||
), "Fails to get the dtype of the sympy.Expr"
|
||||
assert dtype is not None, (
|
||||
"Fails to get the dtype of the sympy.Expr"
|
||||
)
|
||||
self.codegen_tensor_item(
|
||||
dtype, f"inputs[{idx}]", input_key, self.prefix
|
||||
)
|
||||
|
|
@ -555,8 +555,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
def codegen_tensor_dtype_var_decl(self, code: IndentedBuffer, name):
|
||||
code.writeline(f"int32_t {name}_dtype;")
|
||||
code.writeline(
|
||||
"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype"
|
||||
f"({name}, &{name}_dtype));"
|
||||
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype({name}, &{name}_dtype));"
|
||||
)
|
||||
|
||||
def codegen_input_size_var_decl(self, code: IndentedBuffer, name):
|
||||
|
|
@ -570,9 +569,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
|
||||
# Tell compiler we need to link with the non-mangled symbols
|
||||
for kernel in self.initialized_kernels.values():
|
||||
assert hasattr(
|
||||
kernel, "get_signature"
|
||||
), f"{kernel} must have get_signature implemented"
|
||||
assert hasattr(kernel, "get_signature"), (
|
||||
f"{kernel} must have get_signature implemented"
|
||||
)
|
||||
signature = kernel.get_signature()
|
||||
self.prefix.writeline(f'extern "C" {signature};')
|
||||
|
||||
|
|
@ -597,9 +596,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
)
|
||||
)
|
||||
for name, kernel in self.initialized_kernels.items():
|
||||
assert hasattr(
|
||||
kernel, "get_signature"
|
||||
), f"{kernel} must have get_signature implemented"
|
||||
assert hasattr(kernel, "get_signature"), (
|
||||
f"{kernel} must have get_signature implemented"
|
||||
)
|
||||
kernel_ptr = f"(*{name})"
|
||||
signature = kernel.get_signature().replace(name, kernel_ptr)
|
||||
self.prefix.writeline(f" {signature} = torch::aot_inductor::{name};")
|
||||
|
|
@ -645,9 +644,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
|
||||
with self.prefix.indent():
|
||||
for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()):
|
||||
assert not isinstance(
|
||||
inp, sympy.Expr
|
||||
), f"input {name=} cannot be symbolic"
|
||||
assert not isinstance(inp, sympy.Expr), (
|
||||
f"input {name=} cannot be symbolic"
|
||||
)
|
||||
self.write_input_output_info("inputs_info_", idx, name)
|
||||
|
||||
all_cuda = all(
|
||||
|
|
@ -718,9 +717,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
opaque_metadata_tensor = torch.ops.mkldnn._get_mkldnn_serialized_md(
|
||||
tensor
|
||||
)
|
||||
assert (
|
||||
opaque_metadata_tensor.dim() == 1
|
||||
), "Expect opaque_metadata_tensor to be 1-D"
|
||||
assert opaque_metadata_tensor.dim() == 1, (
|
||||
"Expect opaque_metadata_tensor to be 1-D"
|
||||
)
|
||||
|
||||
opaque_metadata_list = opaque_metadata_tensor.tolist()
|
||||
opaque_metadata_str = self.codegen_shape_tuple(opaque_metadata_list)
|
||||
|
|
@ -757,9 +756,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
)
|
||||
|
||||
for idx, output in enumerate(V.graph.graph_outputs):
|
||||
assert not isinstance(
|
||||
output, sympy.Expr
|
||||
), f"output {name=} cannot be symbolic"
|
||||
assert not isinstance(output, sympy.Expr), (
|
||||
f"output {name=} cannot be symbolic"
|
||||
)
|
||||
name = f"output{idx}"
|
||||
self.write_input_output_info("outputs_info_", idx, name)
|
||||
|
||||
|
|
@ -816,9 +815,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
for idx, (name, _) in enumerate(V.graph.constants.items()):
|
||||
if name in V.graph.const_output_index:
|
||||
const_index_mapping[V.graph.const_output_index[name]] = (idx, name) # type: ignore[call-overload]
|
||||
assert (
|
||||
None not in const_index_mapping
|
||||
), "Not all constant gets mapped for constant folding graph."
|
||||
assert None not in const_index_mapping, (
|
||||
"Not all constant gets mapped for constant folding graph."
|
||||
)
|
||||
|
||||
self.prefix.writeline(
|
||||
f"""
|
||||
|
|
@ -1117,9 +1116,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
name = f"{output.get_name()}"
|
||||
output_handle_name = f"{name}_handle"
|
||||
if output.indices:
|
||||
assert (
|
||||
output.indices[0][1] == idx
|
||||
), f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}"
|
||||
assert output.indices[0][1] == idx, (
|
||||
f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}"
|
||||
)
|
||||
self.writeline(f"AtenTensorHandle {output_handle_name};")
|
||||
output_args.append(f"&{output_handle_name}")
|
||||
output_raii_handles.append(
|
||||
|
|
@ -1140,7 +1139,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
args = args + output_args
|
||||
device = d.type if (d := fallback_kernel.get_device()) else self.device
|
||||
self.generate_c_shim_extern_kernel_call(
|
||||
fallback_kernel.cpp_kernel_name, args, device # type: ignore[arg-type]
|
||||
fallback_kernel.cpp_kernel_name, # type: ignore[arg-type]
|
||||
args,
|
||||
device,
|
||||
)
|
||||
for raii_handle in output_raii_handles:
|
||||
self.writeline(raii_handle)
|
||||
|
|
@ -1189,9 +1190,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
if reduce:
|
||||
line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
|
||||
else:
|
||||
assert (
|
||||
reduce is None
|
||||
), "Expect reduce to be None for aten.scatter_ with scalar src"
|
||||
assert reduce is None, (
|
||||
"Expect reduce to be None for aten.scatter_ with scalar src"
|
||||
)
|
||||
line += ");"
|
||||
self.writeline(line)
|
||||
|
||||
|
|
@ -1841,18 +1842,24 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
|||
# Only treat int Scalar as dynamic
|
||||
is_int_type = [isinstance(a, int) for a in arg]
|
||||
if any(is_int_type):
|
||||
assert all(
|
||||
is_int_type
|
||||
), "AOTInductor only supports int scalars of the same type"
|
||||
assert all(is_int_type), (
|
||||
"AOTInductor only supports int scalars of the same type"
|
||||
)
|
||||
new_int_args.extend([str(a) for a in arg])
|
||||
else:
|
||||
assert isinstance(
|
||||
arg_type.getElementType(), static_arg_types # type: ignore[arg-type]
|
||||
), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
|
||||
arg_type.getElementType(),
|
||||
static_arg_types, # type: ignore[arg-type]
|
||||
), (
|
||||
f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
|
||||
)
|
||||
else:
|
||||
assert isinstance(
|
||||
arg_type, static_arg_types # type: ignore[arg-type]
|
||||
), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
|
||||
arg_type,
|
||||
static_arg_types, # type: ignore[arg-type]
|
||||
), (
|
||||
f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
|
||||
)
|
||||
|
||||
for arg, arg_type in zip(raw_args, arg_types):
|
||||
if arg is not None:
|
||||
|
|
@ -2378,9 +2385,9 @@ if (custom_op_wrapper.get() == NULL) {
|
|||
return f"&{var_name}"
|
||||
|
||||
if isinstance(type_, torch.ListType):
|
||||
assert isinstance(
|
||||
val, (list, tuple)
|
||||
), f"{val} does not match with arg type {type_}"
|
||||
assert isinstance(val, (list, tuple)), (
|
||||
f"{val} does not match with arg type {type_}"
|
||||
)
|
||||
element_type = type_.getElementType()
|
||||
var_name = f"var_array_{next(self.var_array_id)}"
|
||||
if len(val) == 0:
|
||||
|
|
|
|||
|
|
@ -56,9 +56,9 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
|
|||
self.cached_output_id = count()
|
||||
self.scalar_to_tensor_id = count()
|
||||
self.custom_op_wrapper_loaded = False
|
||||
self.allow_stack_allocation: Optional[
|
||||
bool
|
||||
] = config.aot_inductor.allow_stack_allocation
|
||||
self.allow_stack_allocation: Optional[bool] = (
|
||||
config.aot_inductor.allow_stack_allocation
|
||||
)
|
||||
self.stack_allocated_buffers: dict[BufferName, BufferLike] = {}
|
||||
|
||||
@staticmethod
|
||||
|
|
@ -126,12 +126,12 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
|
|||
Otherwise it uses the CUDA language for codegen.
|
||||
Only valid when cuda == True.
|
||||
"""
|
||||
assert (
|
||||
not gpu
|
||||
), "CppWrapperCpuArrayRef.generate_kernel_call does not support GPU"
|
||||
assert arg_types is not None and len(call_args) == len(
|
||||
arg_types
|
||||
), "Mismatch call_args and arg_types in generate_kernel_call"
|
||||
assert not gpu, (
|
||||
"CppWrapperCpuArrayRef.generate_kernel_call does not support GPU"
|
||||
)
|
||||
assert arg_types is not None and len(call_args) == len(arg_types), (
|
||||
"Mismatch call_args and arg_types in generate_kernel_call"
|
||||
)
|
||||
new_args = []
|
||||
for idx, arg in enumerate(call_args):
|
||||
if "*" in arg_types[idx]:
|
||||
|
|
@ -328,9 +328,9 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
|
|||
dtype = may_get_constant_buffer_dtype(
|
||||
V.graph.graph_inputs[input_key] # type: ignore[arg-type]
|
||||
)
|
||||
assert (
|
||||
dtype is not None
|
||||
), "Fails to get the dtype of the sympy.Expr"
|
||||
assert dtype is not None, (
|
||||
"Fails to get the dtype of the sympy.Expr"
|
||||
)
|
||||
self.codegen_tensor_item(
|
||||
dtype, f"inputs[{idx}]", input_key, self.prefix
|
||||
)
|
||||
|
|
@ -724,9 +724,9 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
|
|||
if reduce:
|
||||
line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
|
||||
else:
|
||||
assert (
|
||||
reduce is None
|
||||
), "Expect reduce to be None for aten.scatter_ with scalar src"
|
||||
assert reduce is None, (
|
||||
"Expect reduce to be None for aten.scatter_ with scalar src"
|
||||
)
|
||||
line += ");"
|
||||
self.writeline(line)
|
||||
|
||||
|
|
|
|||
|
|
@ -60,13 +60,13 @@ class DeferredGpuKernelLine(DeferredLineBase):
|
|||
# MultiKernel will select one kernel after running the autotune block
|
||||
self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
|
||||
params = CudaKernelParamCache.get(self.kernel_name)
|
||||
assert (
|
||||
params is not None
|
||||
), f"{self.kernel_name} not found in CudaKernelParamCache"
|
||||
assert params is not None, (
|
||||
f"{self.kernel_name} not found in CudaKernelParamCache"
|
||||
)
|
||||
for key in self.keys:
|
||||
assert (
|
||||
key in params
|
||||
), f"{key} not found in CudaKernelParamCache[{self.kernel_name}]"
|
||||
assert key in params, (
|
||||
f"{key} not found in CudaKernelParamCache[{self.kernel_name}]"
|
||||
)
|
||||
if key == get_cpp_wrapper_cubin_path_name():
|
||||
assert os.path.exists(params[key]), f"{params[key]} does not exist"
|
||||
self.additional_files.append(params[key])
|
||||
|
|
@ -122,9 +122,9 @@ class DeferredGpuDefaultGrid:
|
|||
grid_fn = self.grid_callable(*grid, **self.grid_extra_kwargs)
|
||||
|
||||
params = CudaKernelParamCache.get(self.kernel_name)
|
||||
assert (
|
||||
params is not None
|
||||
), f"{self.kernel_name} not found in CudaKernelParamCache"
|
||||
assert params is not None, (
|
||||
f"{self.kernel_name} not found in CudaKernelParamCache"
|
||||
)
|
||||
return grid_fn(params["meta"])
|
||||
|
||||
|
||||
|
|
@ -153,9 +153,9 @@ class DeferredGpuGridLine(DeferredLineBase):
|
|||
self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
|
||||
|
||||
params = CudaKernelParamCache.get(self.kernel_name)
|
||||
assert (
|
||||
params is not None
|
||||
), f"{self.kernel_name} not found in CudaKernelParamCache"
|
||||
assert params is not None, (
|
||||
f"{self.kernel_name} not found in CudaKernelParamCache"
|
||||
)
|
||||
|
||||
if self.autotune_configs is not None:
|
||||
# This indicates the Triton kernel is a user-defined one.
|
||||
|
|
@ -248,13 +248,13 @@ class CppWrapperGpu(CppWrapperCpu):
|
|||
if V.graph.aot_mode and V.graph.inputs_to_check:
|
||||
for idx in V.graph.inputs_to_check:
|
||||
input_name = V.graph.graph_input_names[idx]
|
||||
assert (
|
||||
input_name in V.graph.graph_inputs
|
||||
), f"{input_name} not found in graph inputs"
|
||||
assert input_name in V.graph.graph_inputs, (
|
||||
f"{input_name} not found in graph inputs"
|
||||
)
|
||||
value = V.graph.graph_inputs[input_name]
|
||||
assert isinstance(
|
||||
value, TensorBox
|
||||
), f"{input_name} is expected to be tensor but found as {type(value)}"
|
||||
assert isinstance(value, TensorBox), (
|
||||
f"{input_name} is expected to be tensor but found as {type(value)}"
|
||||
)
|
||||
warn_msg = (
|
||||
f"Input {idx} was compiled as {GPU_ALIGN_BYTES}-bytes aligned, "
|
||||
"but it is not aligned at run time. Copying to an aligned tensor "
|
||||
|
|
|
|||
|
|
@ -87,9 +87,9 @@ class CUDACPPScheduling(BaseScheduling):
|
|||
Codegen a CUDA template, possibly with fused epilogues
|
||||
"""
|
||||
counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes)
|
||||
assert self.is_cuda_cpp_template(
|
||||
template_node
|
||||
), "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer"
|
||||
assert self.is_cuda_cpp_template(template_node), (
|
||||
"Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer"
|
||||
)
|
||||
template_node = cast(SchedulerNode, template_node)
|
||||
_, (_numel, rnumel) = template_node.group
|
||||
assert rnumel == 1
|
||||
|
|
|
|||
|
|
@ -496,7 +496,9 @@ class CUDATemplateCaller(ChoiceCaller):
|
|||
make_kernel_render: Callable[[CUDATemplateBuffer, Optional[list[IRNode]]], str],
|
||||
bmreq: CUDABenchmarkRequest,
|
||||
template: "CUDATemplate", # type: ignore[name-defined]
|
||||
info_kwargs: Optional[dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]], # type: ignore[type-arg]
|
||||
info_kwargs: Optional[
|
||||
dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]
|
||||
], # type: ignore[type-arg]
|
||||
description: str,
|
||||
) -> None:
|
||||
super().__init__(name, input_nodes, layout, description)
|
||||
|
|
|
|||
|
|
@ -71,13 +71,14 @@ class CUDATemplate(KernelTemplate):
|
|||
A CUDATemplateCaller object representing the generated CUDA template caller.
|
||||
"""
|
||||
kernel_name = f"cuda_{self.name}"
|
||||
with patch.object(
|
||||
V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
|
||||
), CUDATemplateKernel(
|
||||
kernel_name=kernel_name,
|
||||
runtime_arg_info=self.get_runtime_arg_info(),
|
||||
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
|
||||
) as kernel:
|
||||
with (
|
||||
patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)),
|
||||
CUDATemplateKernel(
|
||||
kernel_name=kernel_name,
|
||||
runtime_arg_info=self.get_runtime_arg_info(),
|
||||
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
|
||||
) as kernel,
|
||||
):
|
||||
code = self.render(kernel=kernel, **kwargs)
|
||||
_, call_args, _, _ = kernel.args.python_argdefs()
|
||||
autotuning_log.debug("Generated Code:\n%s", code)
|
||||
|
|
|
|||
|
|
@ -147,7 +147,9 @@ if try_import_cutlass():
|
|||
"element_d": DataTypeTag[operation.D.element], # type: ignore[name-defined]
|
||||
"layout_d": LayoutTag[instance_layout_D], # type: ignore[name-defined]
|
||||
"element_accumulator": DataTypeTag[operation.accumulator_type()], # type: ignore[name-defined]
|
||||
"opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], # type: ignore[name-defined] # noqa: B950
|
||||
"opcode_class": OpcodeClassTag[ # type: ignore[name-defined]
|
||||
operation.tile_description.math_instruction.opcode_class
|
||||
],
|
||||
"arch": f"cutlass::arch::Sm{operation.arch:d}",
|
||||
"tile_shape_m": str(operation.tile_description.tile_shape[0]),
|
||||
"tile_shape_n": str(operation.tile_description.tile_shape[1]),
|
||||
|
|
@ -168,7 +170,9 @@ if try_import_cutlass():
|
|||
operation.tile_description.math_instruction.instruction_shape[2]
|
||||
),
|
||||
"kernel_schedule": str(KernelScheduleTag[operation.kernel_schedule]), # type: ignore[name-defined]
|
||||
"epilogue_schedule": str(EpilogueScheduleTag[operation.epilogue_schedule]), # type: ignore[name-defined]
|
||||
"epilogue_schedule": str(
|
||||
EpilogueScheduleTag[operation.epilogue_schedule] # type: ignore[name-defined]
|
||||
),
|
||||
"epilogue_functor": epilogue_functor,
|
||||
"stages": stage_count_string,
|
||||
"align_a": str(operation.A.alignment),
|
||||
|
|
|
|||
|
|
@ -56,9 +56,9 @@ def try_import_cutlass() -> bool:
|
|||
"Found cutlass_library in python search path, overriding config.cuda.cutlass_dir"
|
||||
)
|
||||
cutlass_library_dir = os.path.dirname(cutlass_library.__file__)
|
||||
assert os.path.isdir(
|
||||
cutlass_library_dir
|
||||
), f"{cutlass_library_dir} is not a directory"
|
||||
assert os.path.isdir(cutlass_library_dir), (
|
||||
f"{cutlass_library_dir} is not a directory"
|
||||
)
|
||||
config.cuda.cutlass_dir = os.path.abspath(
|
||||
os.path.join(
|
||||
cutlass_library_dir,
|
||||
|
|
@ -86,9 +86,9 @@ def try_import_cutlass() -> bool:
|
|||
if os.path.isdir(cutlass_py_full_path):
|
||||
if tmp_cutlass_py_full_path not in sys.path:
|
||||
if os.path.exists(dst_link):
|
||||
assert os.path.islink(
|
||||
dst_link
|
||||
), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
|
||||
assert os.path.islink(dst_link), (
|
||||
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
|
||||
)
|
||||
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
|
||||
cutlass_py_full_path
|
||||
), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
|
||||
|
|
|
|||
|
|
@ -949,9 +949,9 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
|||
import cutlass_library.gemm_operation as cutlass_gemm_op
|
||||
import cutlass_library.library as cutlass_lib
|
||||
|
||||
assert isinstance(
|
||||
op, cutlass_gemm_op.GemmOperation
|
||||
), "op argument is required and has to be an instance of GemmOperation"
|
||||
assert isinstance(op, cutlass_gemm_op.GemmOperation), (
|
||||
"op argument is required and has to be an instance of GemmOperation"
|
||||
)
|
||||
|
||||
assert len(self.input_nodes) >= 2 and self.output_node is not None
|
||||
X, W = self.input_nodes[0], self.input_nodes[1]
|
||||
|
|
@ -977,7 +977,10 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
|||
else:
|
||||
input_reorder = None
|
||||
kernel_call_signature = kernel.def_kernel(
|
||||
inputs=inputs, outputs=[Y], names_str=names_str, input_reorder=input_reorder # type: ignore[arg-type]
|
||||
inputs=inputs, # type: ignore[arg-type]
|
||||
outputs=[Y],
|
||||
names_str=names_str,
|
||||
input_reorder=input_reorder,
|
||||
)
|
||||
test_call_statement = self.test_call_statement(kernel, inputs, names_str)
|
||||
# The layouts might have changed between autotuning and this call if they were FlexibleLayout
|
||||
|
|
|
|||
|
|
@ -198,7 +198,7 @@ class HalidePrinter(PythonPrinter):
|
|||
val, n = expr.args
|
||||
val = self._print(val)
|
||||
n = int(n)
|
||||
return f"hl.f32({10.**(-n)!r})*hl.round(({val})*hl.f32({10.**n!r}))"
|
||||
return f"hl.f32({10.0 ** (-n)!r})*hl.round(({val})*hl.f32({10.0**n!r}))"
|
||||
|
||||
|
||||
texpr = HalidePrinter().doprint
|
||||
|
|
@ -856,11 +856,11 @@ class HalideKernel(SIMDKernel):
|
|||
for sym, size in added_sym_size:
|
||||
full_index += stride * sym
|
||||
stride *= size
|
||||
self.index_replacements[
|
||||
node.symbol()
|
||||
] = V.graph.sizevars.simplify_with_ranges(
|
||||
ModularIndexing(full_index, node.divisor, node.length),
|
||||
self.halide_vars, # type: ignore[arg-type]
|
||||
self.index_replacements[node.symbol()] = (
|
||||
V.graph.sizevars.simplify_with_ranges(
|
||||
ModularIndexing(full_index, node.divisor, node.length),
|
||||
self.halide_vars, # type: ignore[arg-type]
|
||||
)
|
||||
)
|
||||
|
||||
# codegen the variable definitions
|
||||
|
|
@ -1183,9 +1183,9 @@ class HalideKernel(SIMDKernel):
|
|||
|
||||
if isinstance(value, tuple):
|
||||
assert reduction_type == "welford_combine"
|
||||
self.cse.reduction_cache[
|
||||
cache_key
|
||||
] = result_tuple = self.welford_combine_impl(*value)
|
||||
self.cse.reduction_cache[cache_key] = result_tuple = (
|
||||
self.welford_combine_impl(*value)
|
||||
)
|
||||
return result_tuple
|
||||
|
||||
assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
|
||||
|
|
@ -1304,9 +1304,9 @@ class HalideKernel(SIMDKernel):
|
|||
scan = f"{scan_dom}.x"
|
||||
self.body.writeline(f"{scan_dom} = hl.RDom([hl.Range(1, {length})])")
|
||||
|
||||
assert (
|
||||
len(self.reduction_renames) == 1
|
||||
), "multi-dimensional scan not implemented"
|
||||
assert len(self.reduction_renames) == 1, (
|
||||
"multi-dimensional scan not implemented"
|
||||
)
|
||||
(scan_var,) = [*self.reduction_renames] # type: ignore[misc]
|
||||
scan_renames_cur = {scan_var: sympy_index_symbol(scan)}
|
||||
scan_renames_pri = {scan_var: sympy_index_symbol(scan) - 1}
|
||||
|
|
|
|||
|
|
@ -214,8 +214,7 @@ class MemorySplitProtocol(Protocol):
|
|||
get_size_hint: CachedMethod[[], int]
|
||||
get_symbolic_size: CachedMethod[[], sympy.Expr]
|
||||
|
||||
def _allocate(self, block: Allocation, is_last: bool) -> bool:
|
||||
...
|
||||
def _allocate(self, block: Allocation, is_last: bool) -> bool: ...
|
||||
|
||||
|
||||
class ClearCacheOnAllocateMixin(MemorySplitProtocol):
|
||||
|
|
|
|||
|
|
@ -560,7 +560,10 @@ class MetalKernel(SIMDKernel):
|
|||
threads = [self.pexpr(v.numel) for v in self.active_range_trees()] # type: ignore[misc]
|
||||
args += [f"threads=[{', '.join(threads)}]"]
|
||||
if self.inside_reduction:
|
||||
threads = [self.pexpr(v.numel) if v.is_reduction else "1" for v in self.active_range_trees()] # type: ignore[misc]
|
||||
threads = [
|
||||
self.pexpr(v.numel) if v.is_reduction else "1" # type: ignore[misc]
|
||||
for v in self.active_range_trees()
|
||||
]
|
||||
args += [f"group_size=[{', '.join(threads)}]"]
|
||||
|
||||
wrapper.generate_kernel_call(
|
||||
|
|
|
|||
|
|
@ -33,9 +33,9 @@ def _get_all_args(args_list, arg_types_list=None):
|
|||
all_args = max(args_list, key=len)[:]
|
||||
arg_types = max(arg_types_list, key=len)[:] if arg_types_list is not None else None
|
||||
for args in args_list:
|
||||
assert OrderedSet(args).issubset(
|
||||
OrderedSet(all_args)
|
||||
), f"{args} v.s. {all_args}"
|
||||
assert OrderedSet(args).issubset(OrderedSet(all_args)), (
|
||||
f"{args} v.s. {all_args}"
|
||||
)
|
||||
|
||||
return all_args, arg_types
|
||||
|
||||
|
|
@ -149,7 +149,9 @@ class MultiKernel:
|
|||
Assume we do codegen for a MultiKernel encapsulating kernel1 and kernel2.
|
||||
The generated definition for the multi-kernel will looks like:
|
||||
```
|
||||
multi_kernel_kernel1 = MultiKernelCall([kernel1, kernel2], multi_kernel_definition_code)
|
||||
multi_kernel_kernel1 = MultiKernelCall(
|
||||
[kernel1, kernel2], multi_kernel_definition_code
|
||||
)
|
||||
```
|
||||
|
||||
Here is an concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39
|
||||
|
|
|
|||
|
|
@ -516,7 +516,12 @@ class CKGroupedConvFwdTemplate(CKTemplate):
|
|||
template_params=(",\n" + 12 * " ").join(template_params),
|
||||
), self._template_from_string(template_type).render(operation_name=op.name())
|
||||
|
||||
def render(self, kernel: ROCmTemplateKernel, op: "CKGroupedConvFwdOp", **kwargs) -> str: # type: ignore[override, name-defined]
|
||||
def render( # type: ignore[override]
|
||||
self,
|
||||
kernel: ROCmTemplateKernel,
|
||||
op: "CKGroupedConvFwdOp", # type: ignore[name-defined]
|
||||
**kwargs,
|
||||
) -> str:
|
||||
template_buffer_node = kwargs.get("template_buffer_node", None)
|
||||
if template_buffer_node is not None:
|
||||
self.output_node = template_buffer_node
|
||||
|
|
|
|||
|
|
@ -602,7 +602,12 @@ class CKGemmTemplate(CKTemplate):
|
|||
operation_name=operation_name
|
||||
)
|
||||
|
||||
def render(self, kernel: ROCmTemplateKernel, op: "CKGemmOperation", **kwargs) -> str: # type: ignore[override]
|
||||
def render( # type: ignore[override]
|
||||
self,
|
||||
kernel: ROCmTemplateKernel,
|
||||
op: "CKGemmOperation",
|
||||
**kwargs,
|
||||
) -> str:
|
||||
"""
|
||||
The primary entry point for the code rendering process used in this template.
|
||||
"""
|
||||
|
|
@ -706,7 +711,7 @@ class CKGemmTemplate(CKTemplate):
|
|||
* Template instance {op}
|
||||
*
|
||||
* {torch.__version__=}
|
||||
* torch.version.git_version={getattr(torch.version, 'git_version', 'None')}
|
||||
* torch.version.git_version={getattr(torch.version, "git_version", "None")}
|
||||
*/
|
||||
"""
|
||||
epilogue = None
|
||||
|
|
|
|||
|
|
@ -79,9 +79,9 @@ class ROCmCPPScheduling(BaseScheduling):
|
|||
"""
|
||||
Codegen a ROCm template, possibly with fused epilogues
|
||||
"""
|
||||
assert self.is_rocm_cpp_template(
|
||||
template_node
|
||||
), "Template node passed to ROCmScheduler.codegen_template must be a SchedulerNode that wraps a ROCmTemplateBuffer"
|
||||
assert self.is_rocm_cpp_template(template_node), (
|
||||
"Template node passed to ROCmScheduler.codegen_template must be a SchedulerNode that wraps a ROCmTemplateBuffer"
|
||||
)
|
||||
template_node = cast(SchedulerNode, template_node)
|
||||
_, (_numel, rnumel) = template_node.group
|
||||
assert rnumel == 1
|
||||
|
|
|
|||
|
|
@ -232,7 +232,9 @@ class ROCmTemplateCaller(ChoiceCaller):
|
|||
],
|
||||
bmreq: ROCmBenchmarkRequest,
|
||||
template: "ROCmTemplate", # type: ignore[name-defined]
|
||||
info_kwargs: Optional[dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]], # type: ignore[type-arg]
|
||||
info_kwargs: Optional[
|
||||
dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]
|
||||
], # type: ignore[type-arg]
|
||||
) -> None:
|
||||
super().__init__(name, input_nodes, layout, description="")
|
||||
self.category = category
|
||||
|
|
|
|||
|
|
@ -70,13 +70,14 @@ class ROCmTemplate(KernelTemplate):
|
|||
"""
|
||||
kernel_name = f"rocm_{self.name}"
|
||||
kernel_hash_name = f"rocm_{self.name}_{next(self.index_counter)}"
|
||||
with patch.object(
|
||||
V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
|
||||
), ROCmTemplateKernel(
|
||||
kernel_name=kernel_name,
|
||||
runtime_arg_info=self.get_runtime_arg_info(),
|
||||
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
|
||||
) as kernel:
|
||||
with (
|
||||
patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)),
|
||||
ROCmTemplateKernel(
|
||||
kernel_name=kernel_name,
|
||||
runtime_arg_info=self.get_runtime_arg_info(),
|
||||
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
|
||||
) as kernel,
|
||||
):
|
||||
code = self.render(kernel=kernel, **kwargs)
|
||||
_, call_args, _, _ = kernel.args.python_argdefs()
|
||||
log.debug("Autotune key: %s, Generated Code:\n%s", kernel_hash_name, code)
|
||||
|
|
|
|||
|
|
@ -638,7 +638,8 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
|||
continue
|
||||
|
||||
while current_group < len(remaining) and sv.statically_known_equals(
|
||||
remaining[current_group], 1 # type: ignore[arg-type]
|
||||
remaining[current_group],
|
||||
1, # type: ignore[arg-type]
|
||||
):
|
||||
# scroll to next group with remaining elements
|
||||
current_group += 1
|
||||
|
|
@ -666,9 +667,9 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
|||
)
|
||||
return_getters_groups.append(return_getters)
|
||||
|
||||
assert all(
|
||||
V.graph.sizevars.size_hint(s) == 1 for s in remaining
|
||||
), f"failed to set ranges {remaining} {lengths}"
|
||||
assert all(V.graph.sizevars.size_hint(s) == 1 for s in remaining), (
|
||||
f"failed to set ranges {remaining} {lengths}"
|
||||
)
|
||||
|
||||
return new_ranges, return_getters_groups
|
||||
|
||||
|
|
@ -836,7 +837,8 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
|||
replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps)
|
||||
if len(replacements) > 0:
|
||||
self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index]
|
||||
self.range_tree_nodes[sym].expr, replacements # type: ignore[index]
|
||||
self.range_tree_nodes[sym].expr,
|
||||
replacements, # type: ignore[index]
|
||||
)
|
||||
self.range_tree_nodes[sym].codegen() # type: ignore[index]
|
||||
return expr
|
||||
|
|
@ -2071,9 +2073,10 @@ class SIMDScheduling(BaseScheduling):
|
|||
features=SIMDKernelFeatures(node_schedule, numel, rnumel),
|
||||
)
|
||||
self.codegen_node_schedule_with_kernel(node_schedule, kernel)
|
||||
with config.patch(
|
||||
"benchmark_kernel", benchmark_kernel
|
||||
), V.set_kernel_handler(kernel):
|
||||
with (
|
||||
config.patch("benchmark_kernel", benchmark_kernel),
|
||||
V.set_kernel_handler(kernel),
|
||||
):
|
||||
src_code = kernel.codegen_kernel()
|
||||
else:
|
||||
prologue, template, epilogue = nodes[0].get_prologue_template_epilogue(
|
||||
|
|
|
|||
|
|
@ -1579,9 +1579,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
self.block_ptr_id = itertools.count()
|
||||
self.block_ptr_to_buffer = dict[str, str]()
|
||||
self.helper_functions = HelperFunctions()
|
||||
self.pointer_advancements: dict[
|
||||
SymT, dict[str, list[sympy.Expr]]
|
||||
] = collections.defaultdict(dict)
|
||||
self.pointer_advancements: dict[SymT, dict[str, list[sympy.Expr]]] = (
|
||||
collections.defaultdict(dict)
|
||||
)
|
||||
self._load_counts: collections.Counter[str] = collections.Counter()
|
||||
|
||||
# A set of autotuning hints to pass as part of triton_meta
|
||||
|
|
@ -2053,9 +2053,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
continue
|
||||
|
||||
advancements = self.pointer_advancements[symt]
|
||||
assert (
|
||||
block_ptr not in advancements
|
||||
), "duplicate advancement for pointer '{block_ptr}' at type '{symt}'"
|
||||
assert block_ptr not in advancements, (
|
||||
"duplicate advancement for pointer '{block_ptr}' at type '{symt}'"
|
||||
)
|
||||
advancements[block_ptr] = advance_offsets
|
||||
else:
|
||||
block_ptr = indexing.format(var)
|
||||
|
|
@ -2476,7 +2476,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
buffer.splice(
|
||||
f"""\
|
||||
{result_var}_val, {result_var}_idx = triton_helpers.{root_op}_with_index({value}, {index}, {dim})
|
||||
{result_var} = {self.reduction_resize(f'{result_var}_idx')}
|
||||
{result_var} = {self.reduction_resize(f"{result_var}_idx")}
|
||||
"""
|
||||
)
|
||||
|
||||
|
|
@ -2576,8 +2576,8 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
{accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index(
|
||||
{accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index
|
||||
)
|
||||
{accumulator} = {where_cond(f'{accumulator}_next', accumulator)}
|
||||
{accumulator_index} = {where_cond(f'{accumulator_index}_next', accumulator_index)}
|
||||
{accumulator} = {where_cond(f"{accumulator}_next", accumulator)}
|
||||
{accumulator_index} = {where_cond(f"{accumulator_index}_next", accumulator_index)}
|
||||
"""
|
||||
)
|
||||
final_argreduce(
|
||||
|
|
@ -2751,9 +2751,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
)
|
||||
self.compute.splice(
|
||||
f"""\
|
||||
{accumulator} = {where_cond(f'{accumulator}_next', accumulator)}
|
||||
{accumulator_m2} = {where_cond(f'{accumulator_m2}_next', accumulator_m2)}
|
||||
{accumulator_weight} = {where_cond(f'{accumulator_weight}_next', accumulator_weight)}
|
||||
{accumulator} = {where_cond(f"{accumulator}_next", accumulator)}
|
||||
{accumulator_m2} = {where_cond(f"{accumulator_m2}_next", accumulator_m2)}
|
||||
{accumulator_weight} = {where_cond(f"{accumulator_weight}_next", accumulator_weight)}
|
||||
"""
|
||||
)
|
||||
result_mean = result_var
|
||||
|
|
@ -3040,9 +3040,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
self.filter_masks(masks)
|
||||
masks = sorted(masks)
|
||||
assert not self._load_mask, "ops.sort not supported inside ops.masked"
|
||||
assert (
|
||||
self.persistent_reduction
|
||||
), "ops.sort is only supported in persistent reductions"
|
||||
assert self.persistent_reduction, (
|
||||
"ops.sort is only supported in persistent reductions"
|
||||
)
|
||||
|
||||
cse_compute = functools.partial(self.cse.generate, self.compute)
|
||||
dim = self.triton_tensor_ndim() - self.num_reduction_dims
|
||||
|
|
@ -3302,9 +3302,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
{}
|
||||
import torch
|
||||
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
|
||||
""".format(
|
||||
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
|
||||
)
|
||||
""".format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream"))
|
||||
)
|
||||
|
||||
def _get_heuristic(self):
|
||||
|
|
@ -3344,19 +3342,19 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
|||
inductor_meta["profile_bandwidth"] = config.profile_bandwidth
|
||||
inductor_meta["profile_bandwidth_regex"] = config.profile_bandwidth_regex
|
||||
inductor_meta["profile_bandwidth_output"] = config.profile_bandwidth_output
|
||||
inductor_meta[
|
||||
"profile_bandwidth_with_do_bench_using_profiling"
|
||||
] = config.profile_bandwidth_with_do_bench_using_profiling
|
||||
inductor_meta["profile_bandwidth_with_do_bench_using_profiling"] = (
|
||||
config.profile_bandwidth_with_do_bench_using_profiling
|
||||
)
|
||||
if config.coordinate_descent_tuning:
|
||||
inductor_meta[
|
||||
"coordinate_descent_tuning"
|
||||
] = config.coordinate_descent_tuning
|
||||
inductor_meta[
|
||||
"coordinate_descent_search_radius"
|
||||
] = config.coordinate_descent_search_radius
|
||||
inductor_meta[
|
||||
"coordinate_descent_check_all_directions"
|
||||
] = config.coordinate_descent_check_all_directions
|
||||
inductor_meta["coordinate_descent_tuning"] = (
|
||||
config.coordinate_descent_tuning
|
||||
)
|
||||
inductor_meta["coordinate_descent_search_radius"] = (
|
||||
config.coordinate_descent_search_radius
|
||||
)
|
||||
inductor_meta["coordinate_descent_check_all_directions"] = (
|
||||
config.coordinate_descent_check_all_directions
|
||||
)
|
||||
return inductor_meta
|
||||
|
||||
def codegen_kernel(self, name=None):
|
||||
|
|
@ -4046,9 +4044,10 @@ class TritonScheduling(SIMDScheduling):
|
|||
) -> tuple[float, str]:
|
||||
"""Benchmark an already compiled module"""
|
||||
device_interface = get_interface_for_device(V.graph.device_type)
|
||||
with preserve_rng_state(), device_interface.device(
|
||||
V.graph.get_current_device_or_throw()
|
||||
): # type: ignore[attr-defined]
|
||||
with (
|
||||
preserve_rng_state(),
|
||||
device_interface.device(V.graph.get_current_device_or_throw()), # type: ignore[attr-defined]
|
||||
):
|
||||
ms = None
|
||||
|
||||
def cache_file_path():
|
||||
|
|
@ -4322,9 +4321,9 @@ def debug_triton_code(node: BaseSchedulerNode) -> list[str]:
|
|||
device = node.get_device()
|
||||
assert device is not None
|
||||
backend = node.scheduler.get_backend(device)
|
||||
assert isinstance(
|
||||
backend, (SIMDScheduling, CUDACombinedScheduling)
|
||||
), f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}"
|
||||
assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling)), (
|
||||
f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}"
|
||||
)
|
||||
|
||||
with V.graph.set_current_device(device):
|
||||
# Don't increment kernel count when generating debug string.
|
||||
|
|
|
|||
|
|
@ -86,7 +86,9 @@ def _default_custom_combo_kernel_horizontal_partition(
|
|||
# rnumel > 2048 usually has long execution time
|
||||
# BaseSchedulerNode.group[-1][-1] is rnumel for reduction nodes
|
||||
long_reduction = [
|
||||
n for n in reduction if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 # type: ignore[arg-type]
|
||||
n
|
||||
for n in reduction
|
||||
if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 # type: ignore[arg-type]
|
||||
]
|
||||
short_reduction = [n for n in reduction if n not in long_reduction]
|
||||
if long_reduction:
|
||||
|
|
@ -138,7 +140,7 @@ def set_custom_combo_kernel_horizontal_partition(
|
|||
dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]],
|
||||
],
|
||||
list[list[BaseSchedulerNode]],
|
||||
]
|
||||
],
|
||||
) -> None:
|
||||
"""Sets the algorithm used to partition nodes into horizontal partitions. Nodes in different partitions
|
||||
are implemented in different combo kernels. Nodes in the same partition are likely to be implemented
|
||||
|
|
@ -593,9 +595,9 @@ class ComboKernel(Kernel):
|
|||
num_persistent_reduction = len(
|
||||
[e for e in heuristics_list if e == "persistent_reduction"]
|
||||
)
|
||||
assert (
|
||||
num_reduction == 0
|
||||
), "combining pointwise and reduction are not supported yet."
|
||||
assert num_reduction == 0, (
|
||||
"combining pointwise and reduction are not supported yet."
|
||||
)
|
||||
heuristics = (
|
||||
"pointwise_with_reduction"
|
||||
if num_persistent_reduction > 0
|
||||
|
|
@ -784,13 +786,13 @@ class ComboKernel(Kernel):
|
|||
name, tree, suffix=str(num)
|
||||
)
|
||||
if not tree.is_reduction:
|
||||
assert isinstance(
|
||||
grid[i][num], str
|
||||
), f"Grid {grid[i][num]} should be a dynamic shape."
|
||||
assert isinstance(grid[i][num], str), (
|
||||
f"Grid {grid[i][num]} should be a dynamic shape."
|
||||
)
|
||||
numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else ""
|
||||
assert (
|
||||
grid[i][num] == numel_sign + numel_name
|
||||
), f"numel args mismatch: {grid[i][num]} vs {numel_name}"
|
||||
assert grid[i][num] == numel_sign + numel_name, (
|
||||
f"numel args mismatch: {grid[i][num]} vs {numel_name}"
|
||||
)
|
||||
grid[i][num] = -expr if numel_sign == "-" else expr
|
||||
|
||||
if not tree.is_reduction or sub_kernel.inside_reduction:
|
||||
|
|
@ -807,13 +809,13 @@ class ComboKernel(Kernel):
|
|||
continue
|
||||
expr = V.graph.sizevars.size_hint(tree.numel)
|
||||
if not tree.is_reduction:
|
||||
assert isinstance(
|
||||
grid[i][num], str
|
||||
), f"Grid {grid[i][num]} should be a dynamic shape."
|
||||
assert isinstance(grid[i][num], str), (
|
||||
f"Grid {grid[i][num]} should be a dynamic shape."
|
||||
)
|
||||
numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else ""
|
||||
assert (
|
||||
grid[i][num] == numel_sign + numel_name
|
||||
), f"grid mismatch: {grid[i][num]} vs {numel_name}"
|
||||
assert grid[i][num] == numel_sign + numel_name, (
|
||||
f"grid mismatch: {grid[i][num]} vs {numel_name}"
|
||||
)
|
||||
grid[i][num] = -expr if numel_sign == "-" else expr
|
||||
if not tree.is_reduction or sub_kernel.inside_reduction:
|
||||
extra_args.append(expr)
|
||||
|
|
@ -1015,9 +1017,7 @@ class ComboKernel(Kernel):
|
|||
{}
|
||||
import torch
|
||||
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels
|
||||
""".format(
|
||||
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
|
||||
)
|
||||
""".format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream"))
|
||||
)
|
||||
|
||||
def uniquify_block_sizes(
|
||||
|
|
|
|||
|
|
@ -57,9 +57,9 @@ class TritonSplitScanKernel(TritonKernel):
|
|||
|
||||
def initialize_range_tree(self, pid_cache):
|
||||
prefixes = ["y", "x", "r0_"]
|
||||
assert len(self.numels) <= len(
|
||||
prefixes
|
||||
), "z dimension not supported for split scan"
|
||||
assert len(self.numels) <= len(prefixes), (
|
||||
"z dimension not supported for split scan"
|
||||
)
|
||||
active_prefixes = prefixes[len(prefixes) - len(self.numels) :]
|
||||
|
||||
grid_dims = {"r0_": 0, "x": 1, "y": 2}
|
||||
|
|
|
|||
|
|
@ -184,7 +184,8 @@ def config_of(
|
|||
if isinstance(x, TensorArg):
|
||||
if include_tensor:
|
||||
offset_aligned = V.graph.sizevars.statically_known_multiple_of(
|
||||
x.offset * x.dtype.itemsize, alignment # type: ignore[arg-type]
|
||||
x.offset * x.dtype.itemsize,
|
||||
alignment, # type: ignore[arg-type]
|
||||
)
|
||||
return offset_aligned and not is_unaligned_buffer(x)
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -104,8 +104,7 @@ def can_match_buffer_size(input_buf: BufferLike, output_buf: BufferLike):
|
|||
# NB: this is symbolic so that we don't try to reuse a buffer
|
||||
# for s0 for s1, just because they happen to share the same
|
||||
# size hint
|
||||
sympy_str(input_size)
|
||||
== sympy_str(output_size)
|
||||
sympy_str(input_size) == sympy_str(output_size)
|
||||
) or (
|
||||
# statically known that 0.95 * input_size <= output_size <= input_size
|
||||
V.graph.sizevars.statically_known_geq(output_size, 0.95 * input_size)
|
||||
|
|
@ -138,9 +137,9 @@ def convert_arg_type(arg: torch.Argument) -> str:
|
|||
container_match = re.findall(py_container + r"\[([a-zA-Z_]+)]", python_type)
|
||||
if len(container_match) == 1:
|
||||
contained_type = container_match[0]
|
||||
assert (
|
||||
contained_type in PYTHON_TO_CPP
|
||||
), f"unsupported {py_container} type in convert_arg_type: {contained_type}"
|
||||
assert contained_type in PYTHON_TO_CPP, (
|
||||
f"unsupported {py_container} type in convert_arg_type: {contained_type}"
|
||||
)
|
||||
cpp_contained_type = PYTHON_TO_CPP[contained_type]
|
||||
return f"{cpp_container}<{cpp_contained_type}>"
|
||||
|
||||
|
|
@ -367,9 +366,9 @@ class SymbolicCallArg:
|
|||
class MemoryPlanningState:
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.reuse_pool: dict[
|
||||
ReuseKey, list[FreeIfNotReusedLine]
|
||||
] = collections.defaultdict(list)
|
||||
self.reuse_pool: dict[ReuseKey, list[FreeIfNotReusedLine]] = (
|
||||
collections.defaultdict(list)
|
||||
)
|
||||
self.total_allocated_buffer_size: int = 0
|
||||
|
||||
def __contains__(self, key: ReuseKey) -> bool:
|
||||
|
|
@ -431,9 +430,9 @@ class EnterDeviceContextManagerLine(WrapperLine):
|
|||
f"{V.graph.device_ops.cpp_aoti_stream_guard()} stream_guard(stream, this->device_idx_);"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
self.last_seen_device_guard_index == self.device_idx
|
||||
), "AOTInductor only supports running on one CUDA device"
|
||||
assert self.last_seen_device_guard_index == self.device_idx, (
|
||||
"AOTInductor only supports running on one CUDA device"
|
||||
)
|
||||
else:
|
||||
if self.last_seen_device_guard_index is None:
|
||||
code.writeline(
|
||||
|
|
@ -1794,7 +1793,8 @@ class PythonWrapperCodegen(CodeGen):
|
|||
equals_1 = isinstance(
|
||||
arg, (int, sympy.Integer)
|
||||
) and V.graph.sizevars.statically_known_equals(
|
||||
arg, 1 # type: ignore[arg-type]
|
||||
arg,
|
||||
1, # type: ignore[arg-type]
|
||||
)
|
||||
add_arg(idx, SizeArg(key, arg), equals_1=equals_1)
|
||||
|
||||
|
|
@ -2052,9 +2052,9 @@ class PythonWrapperCodegen(CodeGen):
|
|||
buf_name = arg
|
||||
buf = V.graph.get_buffer(arg)
|
||||
else:
|
||||
assert (
|
||||
raw_arg is not None
|
||||
), "V.graph.get_buffer(arg) and raw_arg can't be None at the same time"
|
||||
assert raw_arg is not None, (
|
||||
"V.graph.get_buffer(arg) and raw_arg can't be None at the same time"
|
||||
)
|
||||
buf_name = f"tmp_arg_{index}"
|
||||
buf = raw_arg
|
||||
|
||||
|
|
@ -2181,9 +2181,9 @@ class PythonWrapperCodegen(CodeGen):
|
|||
and kernel_name not in self.kernel_autotune_names
|
||||
):
|
||||
# Create example args for autotune in a separate epilogue
|
||||
assert arg_types is not None and len(call_args) == len(
|
||||
arg_types
|
||||
), "call_args and arg_types do not match"
|
||||
assert arg_types is not None and len(call_args) == len(arg_types), (
|
||||
"call_args and arg_types do not match"
|
||||
)
|
||||
|
||||
tensor_args = {}
|
||||
all_args = []
|
||||
|
|
@ -2191,9 +2191,9 @@ class PythonWrapperCodegen(CodeGen):
|
|||
# create a dummy raw_args for uniform behavior in the following loop
|
||||
raw_args = [None] * len(call_args)
|
||||
else:
|
||||
assert len(raw_args) == len(
|
||||
call_args
|
||||
), "call_args and raw_args do not match"
|
||||
assert len(raw_args) == len(call_args), (
|
||||
"call_args and raw_args do not match"
|
||||
)
|
||||
|
||||
for i, (arg, arg_type, raw_arg) in enumerate(
|
||||
zip(call_args, arg_types, raw_args)
|
||||
|
|
@ -2411,9 +2411,9 @@ class PythonWrapperCodegen(CodeGen):
|
|||
if isinstance(layout, ir.NoneLayout):
|
||||
return
|
||||
if isinstance(layout, ir.NonOwningLayout):
|
||||
assert isinstance(
|
||||
layout.view, ir.ReinterpretView
|
||||
), f"unexpected {type(layout.view)}: {layout.view}"
|
||||
assert isinstance(layout.view, ir.ReinterpretView), (
|
||||
f"unexpected {type(layout.view)}: {layout.view}"
|
||||
)
|
||||
assert isinstance(layout.view.data, ir.StorageBox), type(layout.view.data)
|
||||
assert isinstance(layout.view.data.data, ir.Buffer), type(layout.view.data)
|
||||
self.codegen_allocation(layout.view.data.data)
|
||||
|
|
@ -2535,9 +2535,9 @@ class PythonWrapperCodegen(CodeGen):
|
|||
def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs):
|
||||
# All inputs of hops must be explicitly passed in.
|
||||
# Free tensors and basic symbols should have been explicitly lifted as inputs in dynamo.
|
||||
assert len(outer_inputs) == len(
|
||||
subgraph.graph.graph_input_names
|
||||
), f"graph_input_names:{subgraph.graph.graph_input_names}, outer_inputs: {outer_inputs}"
|
||||
assert len(outer_inputs) == len(subgraph.graph.graph_input_names), (
|
||||
f"graph_input_names:{subgraph.graph.graph_input_names}, outer_inputs: {outer_inputs}"
|
||||
)
|
||||
for inner_input, outer_input in zip(
|
||||
subgraph.graph.graph_input_names, outer_inputs
|
||||
):
|
||||
|
|
|
|||
|
|
@ -219,8 +219,7 @@ def _schedule_for_comm(
|
|||
|
||||
for snode, deps in unmet_deps.items():
|
||||
assert len(deps) == 0, (
|
||||
"Detected unscheduled nodes. "
|
||||
f"Nodes with unmet dependencies: {unmet_deps}"
|
||||
f"Detected unscheduled nodes. Nodes with unmet dependencies: {unmet_deps}"
|
||||
)
|
||||
return scheduled
|
||||
|
||||
|
|
@ -354,9 +353,7 @@ def remove_fsdp2_unsharded_param_graph_input_usage(graph: torch.fx.Graph):
|
|||
node.op == "call_function"
|
||||
and node.target == torch.ops.inductor.resize_storage_bytes_.default
|
||||
):
|
||||
assert (
|
||||
node.args[0].op == "placeholder"
|
||||
), f"""\
|
||||
assert node.args[0].op == "placeholder", f"""\
|
||||
Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]}
|
||||
"""
|
||||
graph_input = node.args[0]
|
||||
|
|
@ -408,9 +405,7 @@ Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass for that
|
|||
if node.op == "call_function" and node.target == torch.ops.fsdp.copy_.default:
|
||||
fsdp_copy_node = node
|
||||
unsharded_param = node.args[0]
|
||||
assert (
|
||||
unsharded_param.op == "placeholder"
|
||||
), f"""
|
||||
assert unsharded_param.op == "placeholder", f"""
|
||||
Assumed all FSDP2 `unsharded_param`s to be graph input, but it's not true!
|
||||
Offending node: {unsharded_param}. Graph: {graph}
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -281,9 +281,9 @@ def _unlift_graph(
|
|||
elif node_name in graph_signature.inputs_to_buffers:
|
||||
buffer_name = graph_signature.inputs_to_buffers[node_name]
|
||||
lifted_inputs.append(buffer_name)
|
||||
gm.meta[
|
||||
get_cloned_parameter_buffer_name(buffer_name)
|
||||
] = clone_preserve_strides(state_dict[buffer_name])
|
||||
gm.meta[get_cloned_parameter_buffer_name(buffer_name)] = (
|
||||
clone_preserve_strides(state_dict[buffer_name])
|
||||
)
|
||||
else:
|
||||
assert node_name in graph_signature.user_inputs
|
||||
lifted_inputs.append(None)
|
||||
|
|
@ -542,7 +542,7 @@ def fake_tensor_prop(
|
|||
|
||||
# pass config dict back to user
|
||||
def get_patched_config_dict(
|
||||
config_patches: Optional[Union[str, dict[str, Any]]] = None
|
||||
config_patches: Optional[Union[str, dict[str, Any]]] = None,
|
||||
) -> dict[str, Any]:
|
||||
with config.patch(config_patches):
|
||||
return config.get_config_copy()
|
||||
|
|
@ -579,8 +579,7 @@ class _CompileFxCallable(Protocol):
|
|||
gm: GraphModule,
|
||||
example_inputs: Sequence[InputType],
|
||||
**kwargs: Unpack[_CompileFxKwargs],
|
||||
) -> OutputCode:
|
||||
...
|
||||
) -> OutputCode: ...
|
||||
|
||||
|
||||
def compile_fx_inner(
|
||||
|
|
@ -662,9 +661,9 @@ def _compile_fx_inner(
|
|||
static_inputs_log.debug("static input idxs compile_fx_inner: %s", static_input_idxs)
|
||||
inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs)
|
||||
|
||||
assert isinstance(
|
||||
next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)
|
||||
), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}"
|
||||
assert isinstance(next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)), (
|
||||
f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}"
|
||||
)
|
||||
|
||||
if (cudagraphs := graph_kwargs.get("cudagraphs")) is None:
|
||||
graph_kwargs["cudagraphs"] = cudagraphs = BoxedBool(config.triton.cudagraphs)
|
||||
|
|
@ -679,9 +678,10 @@ def _compile_fx_inner(
|
|||
|
||||
fx_graph_remote_cache = should_use_remote_fx_graph_cache()
|
||||
|
||||
with _WaitCounter(
|
||||
"pytorch.wait_counter.fx_codegen_and_compile"
|
||||
).guard() as _, _WaitCounter("pytorch.wait_counter.all_compilation_types").guard():
|
||||
with (
|
||||
_WaitCounter("pytorch.wait_counter.fx_codegen_and_compile").guard() as _,
|
||||
_WaitCounter("pytorch.wait_counter.all_compilation_types").guard(),
|
||||
):
|
||||
use_cache = (
|
||||
not config.force_disable_caches
|
||||
and (config.fx_graph_cache or fx_graph_remote_cache)
|
||||
|
|
@ -865,8 +865,7 @@ class FxCompile(ABC):
|
|||
example_inputs: Sequence[InputType],
|
||||
inputs_to_check: Sequence[int],
|
||||
graph_kwargs: _CompileFxKwargs,
|
||||
) -> OutputCode:
|
||||
...
|
||||
) -> OutputCode: ...
|
||||
|
||||
|
||||
class _InProcessFxCompile(FxCompile):
|
||||
|
|
@ -890,16 +889,17 @@ class _InProcessFxCompile(FxCompile):
|
|||
cpp_wrapper: bool = graph_kwargs.get("cpp_wrapper", False)
|
||||
aot_mode: bool = V.aot_compilation
|
||||
is_inference: bool = graph_kwargs.get("is_inference", False)
|
||||
extern_node_serializer: Optional[
|
||||
Callable[[list[ExternKernelNode]], Any]
|
||||
] = graph_kwargs.get("extern_node_serializer", None)
|
||||
extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]] = (
|
||||
graph_kwargs.get("extern_node_serializer", None)
|
||||
)
|
||||
boxed_forward_device_index: Optional[BoxedDeviceIndex] = graph_kwargs.get(
|
||||
"boxed_forward_device_index", None
|
||||
)
|
||||
|
||||
with _WaitCounter(
|
||||
"pytorch.wait_counter.actual_codegen_and_compile"
|
||||
).guard(), dynamo_utils.preserve_rng_state():
|
||||
with (
|
||||
_WaitCounter("pytorch.wait_counter.actual_codegen_and_compile").guard(),
|
||||
dynamo_utils.preserve_rng_state(),
|
||||
):
|
||||
if (sleep_sec := config.sleep_sec_TESTING_ONLY) is not None:
|
||||
import time
|
||||
|
||||
|
|
@ -1038,9 +1038,11 @@ class _InProcessFxCompile(FxCompile):
|
|||
# See details in vllm/compilation/pass_manager.py.
|
||||
log.warning("failed to log pt2_configs")
|
||||
|
||||
with V.set_fake_mode(fake_mode), maybe_disable_comprehensive_padding(
|
||||
example_inputs
|
||||
), maybe_disable_graph_partition(cpp_wrapper, aot_mode):
|
||||
with (
|
||||
V.set_fake_mode(fake_mode),
|
||||
maybe_disable_comprehensive_padding(example_inputs),
|
||||
maybe_disable_graph_partition(cpp_wrapper, aot_mode),
|
||||
):
|
||||
const_output_index = None
|
||||
const_graph = None
|
||||
const_code = None
|
||||
|
|
@ -1123,9 +1125,9 @@ class _InProcessFxCompile(FxCompile):
|
|||
if graph.aot_mode:
|
||||
from .codecache import AotCodeCompiler
|
||||
|
||||
assert (
|
||||
graph.cpp_wrapper
|
||||
), "AOT mode only supports C++ wrapper"
|
||||
assert graph.cpp_wrapper, (
|
||||
"AOT mode only supports C++ wrapper"
|
||||
)
|
||||
code, linemap = graph.codegen_with_cpp_wrapper()
|
||||
output_code_log.debug("Output code: \n%s", code)
|
||||
|
||||
|
|
@ -1509,10 +1511,13 @@ def cudagraphify(
|
|||
def run(new_inputs: Sequence[InputType]) -> Any:
|
||||
nonlocal compiled_fn
|
||||
if compiled_fn is None:
|
||||
with dynamo_utils.dynamo_timed(
|
||||
"cudagraphify",
|
||||
log_pt2_compile_event=True,
|
||||
), dynamo_utils.preserve_rng_state():
|
||||
with (
|
||||
dynamo_utils.dynamo_timed(
|
||||
"cudagraphify",
|
||||
log_pt2_compile_event=True,
|
||||
),
|
||||
dynamo_utils.preserve_rng_state(),
|
||||
):
|
||||
compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs)
|
||||
return compiled_fn(new_inputs)
|
||||
|
||||
|
|
@ -1669,13 +1674,16 @@ def compile_fx_aot(
|
|||
extern_node_serializer = config_patches.pop("extern_node_serializer", None)
|
||||
saved_compile_id = model_.meta.get("dynamo_compile_id", None)
|
||||
saved_compile_context = torch._guards.CompileContext(saved_compile_id)
|
||||
with V.set_aot_compilation(True), torch._guards.compile_context(
|
||||
saved_compile_context
|
||||
), chromium_event_timed(
|
||||
"compile_fx_aot",
|
||||
log_pt2_compile_event=True,
|
||||
reset_event_log_on_exit=True,
|
||||
), get_metrics_context():
|
||||
with (
|
||||
V.set_aot_compilation(True),
|
||||
torch._guards.compile_context(saved_compile_context),
|
||||
chromium_event_timed(
|
||||
"compile_fx_aot",
|
||||
log_pt2_compile_event=True,
|
||||
reset_event_log_on_exit=True,
|
||||
),
|
||||
get_metrics_context(),
|
||||
):
|
||||
compiled_artifacts = compile_fx(
|
||||
model_,
|
||||
example_inputs_,
|
||||
|
|
@ -1875,12 +1883,15 @@ def compile_fx(
|
|||
|
||||
# TODO: This probably shouldn't be a recursive call
|
||||
if config.cpp_wrapper:
|
||||
with config.patch(
|
||||
{
|
||||
"cpp_wrapper": False, # reset to break recursive call to compile_fx
|
||||
**get_cpp_wrapper_config(),
|
||||
}
|
||||
), V.set_real_inputs(example_inputs_):
|
||||
with (
|
||||
config.patch(
|
||||
{
|
||||
"cpp_wrapper": False, # reset to break recursive call to compile_fx
|
||||
**get_cpp_wrapper_config(),
|
||||
}
|
||||
),
|
||||
V.set_real_inputs(example_inputs_),
|
||||
):
|
||||
inputs_: Sequence[InputType] = example_inputs_
|
||||
|
||||
if isinstance(model_, GraphModule):
|
||||
|
|
@ -1940,10 +1951,10 @@ def compile_fx(
|
|||
|
||||
# Do the actual work
|
||||
|
||||
with _use_lazy_graph_module(
|
||||
dynamo_config.use_lazy_graph_module
|
||||
), enable_python_dispatcher(), torch.fx.traceback.preserve_node_meta(
|
||||
config.trace.enabled
|
||||
with (
|
||||
_use_lazy_graph_module(dynamo_config.use_lazy_graph_module),
|
||||
enable_python_dispatcher(),
|
||||
torch.fx.traceback.preserve_node_meta(config.trace.enabled),
|
||||
):
|
||||
# Pre-grad passes cannot be run if we weren't given a GraphModule.
|
||||
# Dynamo will always produce a GraphModule, but this handles cases
|
||||
|
|
@ -2085,9 +2096,9 @@ def compile_fx(
|
|||
boxed_forward_device_index=forward_device,
|
||||
)
|
||||
|
||||
fw_compiler: Callable[
|
||||
[GraphModule, Sequence[InputType]], OutputCode
|
||||
] = functools.partial(fw_compiler_base, is_inference=False)
|
||||
fw_compiler: Callable[[GraphModule, Sequence[InputType]], OutputCode] = (
|
||||
functools.partial(fw_compiler_base, is_inference=False)
|
||||
)
|
||||
fw_compiler = SerializableAOTDispatchCompiler(OutputCode, fw_compiler)
|
||||
|
||||
if config.freezing and not torch.is_grad_enabled():
|
||||
|
|
@ -2124,9 +2135,10 @@ def compile_fx(
|
|||
) -> OutputCode:
|
||||
from torch._dynamo.convert_frame import compile_lock
|
||||
|
||||
with dynamo_utils.dynamo_timed(
|
||||
"compile_fx.<locals>.bw_compiler"
|
||||
), compile_lock:
|
||||
with (
|
||||
dynamo_utils.dynamo_timed("compile_fx.<locals>.bw_compiler"),
|
||||
compile_lock,
|
||||
):
|
||||
model_outputs_node = output_node(gm)
|
||||
if config.bw_outputs_user_visible:
|
||||
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
||||
|
|
@ -2194,10 +2206,11 @@ def compile_fx(
|
|||
with V.set_fake_mode(fake_mode), compiled_autograd._disable(), context():
|
||||
return inference_compiler(unlifted_gm, example_inputs_)
|
||||
|
||||
with V.set_fake_mode(fake_mode), torch._guards.tracing(
|
||||
tracing_context
|
||||
), compiled_autograd._disable(), functorch_config.patch(
|
||||
unlift_effect_tokens=True
|
||||
with (
|
||||
V.set_fake_mode(fake_mode),
|
||||
torch._guards.tracing(tracing_context),
|
||||
compiled_autograd._disable(),
|
||||
functorch_config.patch(unlift_effect_tokens=True),
|
||||
):
|
||||
try:
|
||||
return aot_autograd(
|
||||
|
|
|
|||
|
|
@ -530,7 +530,8 @@ class CompilerBisector:
|
|||
)
|
||||
if result:
|
||||
curr_subsystem = cls.get_subsystem_object(
|
||||
curr_backend, cls.get_subsystem() # type: ignore[arg-type]
|
||||
curr_backend,
|
||||
cls.get_subsystem(), # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
if isinstance(curr_subsystem, BinarySubsystem):
|
||||
|
|
|
|||
|
|
@ -80,9 +80,9 @@ fx_graph_cache: bool = Config(
|
|||
fx_graph_remote_cache: Optional[bool] = fx_graph_remote_cache_default()
|
||||
|
||||
# should we bundle triton caching into fx graph cache
|
||||
bundle_triton_into_fx_graph_cache: Optional[
|
||||
bool
|
||||
] = bundle_triton_into_fx_graph_cache_default()
|
||||
bundle_triton_into_fx_graph_cache: Optional[bool] = (
|
||||
bundle_triton_into_fx_graph_cache_default()
|
||||
)
|
||||
|
||||
# Enable autotune local cache.
|
||||
#
|
||||
|
|
@ -1390,12 +1390,12 @@ class halide:
|
|||
|
||||
# Halide autoscheduler to use, choices are:
|
||||
# "Anderson2021" (gpu-only), "Li2018", "Adams2019" (cpu-only), or "Mullapudi2016" (cpu-only)
|
||||
scheduler_cuda: Literal[
|
||||
"Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"
|
||||
] = "Anderson2021"
|
||||
scheduler_cpu: Literal[
|
||||
"Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"
|
||||
] = "Adams2019"
|
||||
scheduler_cuda: Literal["Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"] = (
|
||||
"Anderson2021"
|
||||
)
|
||||
scheduler_cpu: Literal["Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"] = (
|
||||
"Adams2019"
|
||||
)
|
||||
|
||||
# Controls `no_asserts` flag passed to Halide target (warning: can false positive)
|
||||
asserts = False
|
||||
|
|
|
|||
|
|
@ -125,7 +125,8 @@ class ConstantFolder(torch.fx.Interpreter):
|
|||
and is_woq_int8_pattern(next(iter(node.users)))
|
||||
)
|
||||
) and is_const_source(
|
||||
node.args[0], self.lifted_constant_names # type: ignore[arg-type]
|
||||
node.args[0], # type: ignore[arg-type]
|
||||
self.lifted_constant_names,
|
||||
):
|
||||
# Case 1: int8_weight -> dq -> bf16_weight
|
||||
# Case 2: int8_weight -> permute -> dq -> bf16_weight
|
||||
|
|
|
|||
|
|
@ -1633,8 +1633,8 @@ class CppBuilder:
|
|||
"""
|
||||
)
|
||||
|
||||
assert os.path.exists(
|
||||
cmake_path
|
||||
), f"save_link_cmd_to_cmakefile expects {cmake_path} to already exist"
|
||||
assert os.path.exists(cmake_path), (
|
||||
f"save_link_cmd_to_cmakefile expects {cmake_path} to already exist"
|
||||
)
|
||||
with open(cmake_path, "a") as f:
|
||||
f.write(contents)
|
||||
|
|
|
|||
|
|
@ -119,6 +119,7 @@ from . import config
|
|||
@dataclasses.dataclass(frozen=True)
|
||||
class GraphID:
|
||||
"Unique counter of a cuda graph recording"
|
||||
|
||||
id: int
|
||||
|
||||
|
||||
|
|
@ -622,11 +623,15 @@ class CUDAWarmupNode:
|
|||
refs = list(self.path_live_weakrefs())
|
||||
check_memory_pool(self.device_index, self.cuda_graphs_pool, refs)
|
||||
|
||||
with torch.cuda.device(
|
||||
self.device_index
|
||||
), disable_conv_cache_emptying(), clear_cublas_manager(), _use_cuda_memory_pool_manager(
|
||||
self.device_index, self.cuda_graphs_pool, self.stream
|
||||
), get_history_recording():
|
||||
with (
|
||||
torch.cuda.device(self.device_index),
|
||||
disable_conv_cache_emptying(),
|
||||
clear_cublas_manager(),
|
||||
_use_cuda_memory_pool_manager(
|
||||
self.device_index, self.cuda_graphs_pool, self.stream
|
||||
),
|
||||
get_history_recording(),
|
||||
):
|
||||
out = self.wrapped_function.model(new_inputs)
|
||||
|
||||
# We need to know which outputs are allocated within the cudagraph pool
|
||||
|
|
@ -713,6 +718,7 @@ UnaliasedStorage = _UnaliasedStorage()
|
|||
|
||||
class AliasesPriorGraphOutput(OutputAliasInfo):
|
||||
"Marks that the graph output aliases an output of a prior graph"
|
||||
|
||||
__slots__ = ["index"]
|
||||
|
||||
index: PathOutputIndex
|
||||
|
|
@ -1200,14 +1206,18 @@ class CUDAGraphNode:
|
|||
]
|
||||
check_memory_pool(self.device, self.cuda_graphs_pool, memory)
|
||||
|
||||
with preserve_rng_state(), torch.cuda.device(
|
||||
self.device
|
||||
), clear_cublas_manager(), torch.cuda.graph(
|
||||
self.graph,
|
||||
stream=self.stream,
|
||||
pool=self.cuda_graphs_pool,
|
||||
capture_error_mode="thread_local",
|
||||
), get_history_recording():
|
||||
with (
|
||||
preserve_rng_state(),
|
||||
torch.cuda.device(self.device),
|
||||
clear_cublas_manager(),
|
||||
torch.cuda.graph(
|
||||
self.graph,
|
||||
stream=self.stream,
|
||||
pool=self.cuda_graphs_pool,
|
||||
capture_error_mode="thread_local",
|
||||
),
|
||||
get_history_recording(),
|
||||
):
|
||||
static_outputs = model(inputs)
|
||||
|
||||
# running model should reclaim memory
|
||||
|
|
@ -1247,11 +1257,13 @@ class CUDAGraphNode:
|
|||
self.output_storage_alias.append(UnaliasedStorage)
|
||||
continue
|
||||
|
||||
torch._check(
|
||||
o.is_cuda or o.untyped_storage().data_ptr() == 0,
|
||||
lambda: (
|
||||
"Expected all cuda outputs in cuda graph recording. Non cuda output "
|
||||
f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}"
|
||||
(
|
||||
torch._check(
|
||||
o.is_cuda or o.untyped_storage().data_ptr() == 0,
|
||||
lambda: (
|
||||
"Expected all cuda outputs in cuda graph recording. Non cuda output "
|
||||
f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}"
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
|
|
@ -1291,9 +1303,9 @@ class CUDAGraphNode:
|
|||
if self.stack_traces is None:
|
||||
self.stack_traces = [None for _ in range(len(outputs))]
|
||||
else:
|
||||
assert len(self.stack_traces) == len(
|
||||
outputs
|
||||
), "Wrong number of stack traces passed in"
|
||||
assert len(self.stack_traces) == len(outputs), (
|
||||
"Wrong number of stack traces passed in"
|
||||
)
|
||||
|
||||
assert not self.outputs_weakrefs
|
||||
for out, static_output_tensor in zip(outputs, self.static_output_tensors):
|
||||
|
|
@ -1599,12 +1611,14 @@ class CUDAGraphNode:
|
|||
self.stream.wait_stream(torch.cuda.current_stream())
|
||||
recording_inputs: list[InputType] = []
|
||||
|
||||
with warnings.catch_warnings(record=True), torch.cuda.device(
|
||||
self.device
|
||||
), _use_cuda_memory_pool_manager(
|
||||
self.device,
|
||||
mem_pool=self.cuda_graphs_pool,
|
||||
stream=self.stream,
|
||||
with (
|
||||
warnings.catch_warnings(record=True),
|
||||
torch.cuda.device(self.device),
|
||||
_use_cuda_memory_pool_manager(
|
||||
self.device,
|
||||
mem_pool=self.cuda_graphs_pool,
|
||||
stream=self.stream,
|
||||
),
|
||||
):
|
||||
for i, inp in enumerate(inputs):
|
||||
if not isinstance(inp, torch.Tensor):
|
||||
|
|
@ -1736,12 +1750,8 @@ def check_memory_pool(
|
|||
pool_id: tuple[int, int],
|
||||
live_storages_ptrs: list[StorageWeakRefWrapper],
|
||||
) -> None:
|
||||
assert all(
|
||||
isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs
|
||||
) # noqa: C419
|
||||
unique_storages = {
|
||||
stor.data_ptr() for stor in live_storages_ptrs if stor()
|
||||
} # noqa: set_linter
|
||||
assert all(isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs) # noqa: C419
|
||||
unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()} # noqa: set_linter
|
||||
|
||||
# check if there is a divergence first, then do the expensive snapshot call after
|
||||
# we know it will error
|
||||
|
|
@ -1864,11 +1874,14 @@ class CUDAGraphTreeManager:
|
|||
self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs_thread_pool = torch.cuda.graph_pool_handle()
|
||||
|
||||
with warnings.catch_warnings(record=True), torch.cuda.graph(
|
||||
self.graph,
|
||||
pool=self.cuda_graphs_thread_pool,
|
||||
stream=self.stream,
|
||||
capture_error_mode="thread_local",
|
||||
with (
|
||||
warnings.catch_warnings(record=True),
|
||||
torch.cuda.graph(
|
||||
self.graph,
|
||||
pool=self.cuda_graphs_thread_pool,
|
||||
stream=self.stream,
|
||||
capture_error_mode="thread_local",
|
||||
),
|
||||
):
|
||||
pass
|
||||
|
||||
|
|
@ -2230,7 +2243,10 @@ class CUDAGraphTreeManager:
|
|||
constants: tuple[torch.Tensor, ...],
|
||||
placeholders: tuple[PlaceholderInfo, ...],
|
||||
mutated_input_idxs: tuple[int, ...],
|
||||
) -> tuple[ModelType, OutputType,]:
|
||||
) -> tuple[
|
||||
ModelType,
|
||||
OutputType,
|
||||
]:
|
||||
id = self.new_func_id()
|
||||
self.ids_to_stack_traces[id] = stack_traces
|
||||
self.ids_to_funcs[id] = WrappedFunction(
|
||||
|
|
|
|||
|
|
@ -28,6 +28,7 @@ ModelType = Callable[[list[InputType]], OutputType]
|
|||
@dataclasses.dataclass(frozen=True)
|
||||
class FunctionID:
|
||||
"Unique counter of a function wrapped in cudagraphify_impl"
|
||||
|
||||
id: int
|
||||
|
||||
|
||||
|
|
@ -164,7 +165,7 @@ def _get_use_stack_trace(node: torch.fx.Node) -> Optional[str]:
|
|||
|
||||
|
||||
def check_multiple_devices_or_any_cpu_nodes(
|
||||
device_node_mapping: dict[torch.device, torch.fx.Node]
|
||||
device_node_mapping: dict[torch.device, torch.fx.Node],
|
||||
) -> Optional[str]:
|
||||
if cpu_node := device_node_mapping.get(torch.device("cpu")):
|
||||
msg = f"cpu device ({cpu_node.name})"
|
||||
|
|
@ -184,7 +185,7 @@ def check_multiple_devices_or_any_cpu_nodes(
|
|||
|
||||
|
||||
def check_lowering_disable_cudagraph(
|
||||
device_node_mapping: dict[torch.device, torch.fx.Node]
|
||||
device_node_mapping: dict[torch.device, torch.fx.Node],
|
||||
) -> Optional[str]:
|
||||
return check_multiple_devices_or_any_cpu_nodes(device_node_mapping)
|
||||
|
||||
|
|
@ -276,9 +277,9 @@ def log_data_ptr_mismatch(
|
|||
Logs the mismatch between input data pointers and recorded data pointers.
|
||||
This checks only idxs in target_idxs.
|
||||
"""
|
||||
assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len(
|
||||
placeholders
|
||||
), "length mismatch between inputs, recorded_data_ptr, and placeholders"
|
||||
assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len(placeholders), (
|
||||
"length mismatch between inputs, recorded_data_ptr, and placeholders"
|
||||
)
|
||||
|
||||
t_tensors = [inputs[i] for i in target_idxs]
|
||||
t_data_ptrs = [recorded_data_ptr[i] for i in target_idxs]
|
||||
|
|
|
|||
|
|
@ -240,7 +240,7 @@ def update_orig_fx_node_name_to_buf_name(
|
|||
|
||||
|
||||
def get_node_name_to_buf_meta(
|
||||
node_name_to_buf_name: dict[str, str]
|
||||
node_name_to_buf_name: dict[str, str],
|
||||
) -> dict[str, BufMeta]:
|
||||
buf_name_to_n_node = {}
|
||||
for node_name, buf_name in node_name_to_buf_name.items():
|
||||
|
|
|
|||
|
|
@ -123,7 +123,7 @@ remove_decompositions(decompositions, decomps_to_exclude)
|
|||
|
||||
|
||||
def register_decomposition(
|
||||
ops: list[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]]
|
||||
ops: list[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]],
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined]
|
||||
if op in decompositions:
|
||||
|
|
|
|||
|
|
@ -194,7 +194,9 @@ class MemoryDep(Dep):
|
|||
)
|
||||
new_index = sympy_subs(sympy.expand(self.index), replacement) # type: ignore[arg-type] # next PR
|
||||
|
||||
out = MemoryDep(self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values())) # type: ignore[arg-type]
|
||||
out = MemoryDep(
|
||||
self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values())
|
||||
) # type: ignore[arg-type]
|
||||
return out
|
||||
|
||||
@property
|
||||
|
|
@ -649,11 +651,16 @@ def extract_loop_body_with_args(
|
|||
inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) # type: ignore[arg-type]
|
||||
for entry in fn.memory_usage[MemoryUsageType.STORE]:
|
||||
inner.store(
|
||||
entry.buffer_name, name_to_index[entry.index_name], None, entry.mode # type: ignore[arg-type]
|
||||
entry.buffer_name,
|
||||
name_to_index[entry.index_name],
|
||||
None, # type: ignore[arg-type]
|
||||
entry.mode,
|
||||
)
|
||||
for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]:
|
||||
inner.store_reduction(
|
||||
entry.buffer_name, name_to_index[entry.index_name], None # type: ignore[arg-type]
|
||||
entry.buffer_name,
|
||||
name_to_index[entry.index_name],
|
||||
None, # type: ignore[arg-type]
|
||||
)
|
||||
for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]:
|
||||
inner.index_expr(name_to_index[entry.index_name], None)
|
||||
|
|
@ -661,7 +668,11 @@ def extract_loop_body_with_args(
|
|||
# All that matters is that we record the buffer name, so place it in the
|
||||
# "boundaries" name position to ensure that it's recorded.
|
||||
inner.bucketize(
|
||||
None, (entry.buffer_name, None, None, None), None, None, None # type: ignore[arg-type]
|
||||
None,
|
||||
(entry.buffer_name, None, None, None),
|
||||
None,
|
||||
None, # type: ignore[arg-type]
|
||||
None, # type: ignore[arg-type]
|
||||
)
|
||||
# fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped
|
||||
return inner
|
||||
|
|
@ -801,8 +812,9 @@ def extract_free_unbacked_symbols(
|
|||
handler = FreeUnbackedSymbolsOpsHandler()
|
||||
# NB: I cargo culted the allow_indexing patch here, I don't understand why
|
||||
# people do this all over
|
||||
with V.set_ops_handler(handler), patch.object(
|
||||
FlexibleLayout, "allow_indexing", True
|
||||
with (
|
||||
V.set_ops_handler(handler),
|
||||
patch.object(FlexibleLayout, "allow_indexing", True),
|
||||
):
|
||||
fn(*args)
|
||||
return handler.symbols
|
||||
|
|
|
|||
|
|
@ -19,8 +19,7 @@ T = TypeVar("T")
|
|||
|
||||
class DTypeVar(Protocol):
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
...
|
||||
def dtype(self) -> torch.dtype: ...
|
||||
|
||||
|
||||
DTypeArg = Union[DTypeVar, torch.types.Number, str, OpsValue]
|
||||
|
|
|
|||
|
|
@ -526,6 +526,7 @@ class ConfigFuzzer:
|
|||
```python
|
||||
import torch._inductor.config as cfg
|
||||
|
||||
|
||||
def create_simple_test_model_gpu() -> FactoryOutputType:
|
||||
batch_size = 32
|
||||
seq_length = 50
|
||||
|
|
@ -539,6 +540,8 @@ class ConfigFuzzer:
|
|||
return True
|
||||
|
||||
return test_fn
|
||||
|
||||
|
||||
fuzzer = ConfigFuzzer(cfg, create_simple_test_model_gpu, seed=2)
|
||||
|
||||
# Test every pair of configs:
|
||||
|
|
@ -550,7 +553,9 @@ class ConfigFuzzer:
|
|||
ret = fuzzer.bisect(num_attempts=10)
|
||||
|
||||
# reproduce a failing config
|
||||
fuzzer.reproduce([{"triton.autotune_pointwise": ..., "coordinate_descent_tuning": ...}])
|
||||
fuzzer.reproduce(
|
||||
[{"triton.autotune_pointwise": ..., "coordinate_descent_tuning": ...}]
|
||||
)
|
||||
```
|
||||
|
||||
The list of known failures on inductor config are:
|
||||
|
|
|
|||
|
|
@ -531,7 +531,11 @@ def tuned_b2b_gemm(
|
|||
A.realize()
|
||||
B.realize()
|
||||
C.realize()
|
||||
layout = FixedLayout(A.get_device_or_error(), A.get_dtype(), [A.shape[0], C.shape[1]]) # type: ignore[index]
|
||||
layout = FixedLayout(
|
||||
A.get_device_or_error(),
|
||||
A.get_dtype(),
|
||||
[A.shape[0], C.shape[1]], # type: ignore[index]
|
||||
)
|
||||
subgraph_buffer = build_subgraph_buffer(
|
||||
[create_placeholder("inner_mm", A.get_dtype(), A.get_device_or_error())],
|
||||
subgraph,
|
||||
|
|
|
|||
|
|
@ -545,9 +545,9 @@ def schedule_comm_wait(graph: fx.Graph) -> None:
|
|||
node_indices = {node: i for i, node in enumerate(graph.nodes)}
|
||||
for allreduce in comm_blocks:
|
||||
# Find the earliest/first user -- target_node.
|
||||
assert (
|
||||
len(allreduce.outputs) >= 1
|
||||
), f"Found a allreduce that has zero outputs/users -- {allreduce}."
|
||||
assert len(allreduce.outputs) >= 1, (
|
||||
f"Found a allreduce that has zero outputs/users -- {allreduce}."
|
||||
)
|
||||
# Initialize the target node to avoid typing issues.
|
||||
target_node = next(iter(next(iter(allreduce.outputs)).users))
|
||||
target_node_index = 2**31
|
||||
|
|
|
|||
|
|
@ -380,7 +380,9 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
|
|||
# argument. `graph.get_attr` and
|
||||
# `graph.call_function` does not allow the `name` argument.
|
||||
conv_get_node = graph.create_node(
|
||||
op="get_attr", target=conv_node.target, name="get_conv" # type: ignore[union-attr]
|
||||
op="get_attr",
|
||||
target=conv_node.target, # type: ignore[union-attr]
|
||||
name="get_conv",
|
||||
)
|
||||
bn_get_node = graph.create_node(
|
||||
op="get_attr", target=bn_node.target, name="get_bn"
|
||||
|
|
|
|||
|
|
@ -866,15 +866,18 @@ def _get_sfdp_patterns():
|
|||
name += "_bs1"
|
||||
|
||||
training_name = name + "_training"
|
||||
yield training_name, {
|
||||
"search_fn": pattern,
|
||||
"replace_fn": replacement,
|
||||
"example_inputs": args,
|
||||
"trace_fn": joint_fwd_bwd,
|
||||
"pass_dicts": patterns,
|
||||
"extra_check": extra_check,
|
||||
"scalar_workaround": workaround,
|
||||
}
|
||||
yield (
|
||||
training_name,
|
||||
{
|
||||
"search_fn": pattern,
|
||||
"replace_fn": replacement,
|
||||
"example_inputs": args,
|
||||
"trace_fn": joint_fwd_bwd,
|
||||
"pass_dicts": patterns,
|
||||
"extra_check": extra_check,
|
||||
"scalar_workaround": workaround,
|
||||
},
|
||||
)
|
||||
|
||||
if workaround:
|
||||
assert len(workaround) == 1 and "dropout_p" in workaround
|
||||
|
|
@ -886,18 +889,21 @@ def _get_sfdp_patterns():
|
|||
workaround = {}
|
||||
|
||||
inference_name = name + "_inference"
|
||||
yield inference_name, {
|
||||
"search_fn": pattern,
|
||||
"replace_fn": replacement,
|
||||
"example_inputs": args,
|
||||
"trace_fn": fwd_only,
|
||||
"pass_dicts": patterns,
|
||||
"extra_check": extra_check,
|
||||
"scalar_workaround": workaround,
|
||||
# with dropout turned into clone, we end up with a number of
|
||||
# semantically identical graphs
|
||||
"skip_duplicates": True,
|
||||
}
|
||||
yield (
|
||||
inference_name,
|
||||
{
|
||||
"search_fn": pattern,
|
||||
"replace_fn": replacement,
|
||||
"example_inputs": args,
|
||||
"trace_fn": fwd_only,
|
||||
"pass_dicts": patterns,
|
||||
"extra_check": extra_check,
|
||||
"scalar_workaround": workaround,
|
||||
# with dropout turned into clone, we end up with a number of
|
||||
# semantically identical graphs
|
||||
"skip_duplicates": True,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
|
|
|
|||
|
|
@ -271,7 +271,9 @@ class PostGradBatchLinearFusion(BatchFusion):
|
|||
args=(batch_biases[i],),
|
||||
kwargs={"size": broadcast_shape},
|
||||
)
|
||||
broadcast_bias.meta["val"] = aten.broadcast_to(batch_biases_meta[i]["val"], broadcast_shape) # type: ignore[assignment]
|
||||
broadcast_bias.meta["val"] = aten.broadcast_to(
|
||||
batch_biases_meta[i]["val"], broadcast_shape
|
||||
) # type: ignore[assignment]
|
||||
new_bias_add = graph.call_function( # type: ignore[operator]
|
||||
aten.add.Tensor, args=((broadcast_bias, new_mm))
|
||||
)
|
||||
|
|
@ -803,9 +805,9 @@ class BatchLayernormFusion(BatchFusion):
|
|||
group_biases = None # type: ignore[assignment]
|
||||
if all(weight is None for weight in group_weights):
|
||||
group_weights = None # type: ignore[assignment]
|
||||
assert all(
|
||||
eps == group_epss[0] for eps in group_epss
|
||||
), "all epsilon values must be equal"
|
||||
assert all(eps == group_epss[0] for eps in group_epss), (
|
||||
"all epsilon values must be equal"
|
||||
)
|
||||
|
||||
with graph.inserting_before(subset[0]): # type: ignore[operator]
|
||||
stack_input = graph.call_function( # type: ignore[operator]
|
||||
|
|
@ -996,7 +998,11 @@ class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
|
|||
# for relu op, we also use the inplace to construct the key
|
||||
# we batch the ops with same parent to enable followup split cat
|
||||
parent = node.args[0]
|
||||
parent = parent.target if self.graph_search_options.get("fuse_nodes_with_same_parent", False) else "" # type: ignore[union-attr]
|
||||
parent = (
|
||||
parent.target # type: ignore[union-attr]
|
||||
if self.graph_search_options.get("fuse_nodes_with_same_parent", False)
|
||||
else ""
|
||||
)
|
||||
group_key = (
|
||||
"batch_aten_" + self.op.__name__.lower().split(".")[0],
|
||||
str(input.meta["val"].shape),
|
||||
|
|
@ -1293,9 +1299,9 @@ def get_fusion_candidates(
|
|||
"""
|
||||
q: collections.deque[tuple[int, torch.fx.Node]] = collections.deque()
|
||||
|
||||
candidate_dict: collections.defaultdict[
|
||||
Any, list[torch.fx.Node]
|
||||
] = collections.defaultdict(list)
|
||||
candidate_dict: collections.defaultdict[Any, list[torch.fx.Node]] = (
|
||||
collections.defaultdict(list)
|
||||
)
|
||||
|
||||
if root_node.target in SEARCH_EXCLUSIONS:
|
||||
return candidate_dict
|
||||
|
|
|
|||
|
|
@ -763,9 +763,7 @@ def _get_node_to_ancestors(
|
|||
"""
|
||||
Compute the ancestors for all nodes in a graph.
|
||||
"""
|
||||
node_to_ancestors = defaultdict(
|
||||
OrderedSet[torch.fx.Node]
|
||||
) # type: ignore[var-annotated]
|
||||
node_to_ancestors = defaultdict(OrderedSet[torch.fx.Node]) # type: ignore[var-annotated]
|
||||
for node in graph.nodes:
|
||||
node_to_ancestors[node] = OrderedSet(node.all_input_nodes)
|
||||
for dep in node.all_input_nodes:
|
||||
|
|
|
|||
|
|
@ -558,9 +558,9 @@ if torch._C._has_mkldnn:
|
|||
binary_nodes = filter_nodes(match.nodes, binary_op)
|
||||
|
||||
def _get_compute_node(_binary_node, _other_index):
|
||||
assert (
|
||||
len(_binary_node.all_input_nodes) == 2
|
||||
), "Binary node should have 2 input nodes."
|
||||
assert len(_binary_node.all_input_nodes) == 2, (
|
||||
"Binary node should have 2 input nodes."
|
||||
)
|
||||
_compute_index = 1 if (_other_index == 0) else 0
|
||||
return _binary_node.args[_compute_index]
|
||||
|
||||
|
|
@ -614,9 +614,9 @@ if torch._C._has_mkldnn:
|
|||
else:
|
||||
computation_args += [1.0, None, [], None]
|
||||
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1
|
||||
counters["inductor"][
|
||||
"mkldnn_conv_binary_unary_fusion_matcher_nodes"
|
||||
] += len(match.nodes)
|
||||
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"] += (
|
||||
len(match.nodes)
|
||||
)
|
||||
return L[fusion_op](*computation_args)
|
||||
|
||||
return fn
|
||||
|
|
@ -659,9 +659,9 @@ if torch._C._has_mkldnn:
|
|||
else:
|
||||
computation_args += [1.0, None, [], None]
|
||||
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1
|
||||
counters["inductor"][
|
||||
"mkldnn_conv_binary_unary_fusion_matcher_nodes"
|
||||
] += len(match.nodes)
|
||||
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"] += (
|
||||
len(match.nodes)
|
||||
)
|
||||
# Make sure the other is not an alias or mutation(fx side doesn't has such info).
|
||||
other.realize()
|
||||
if not _can_be_inplace(other) or other.data.shape != list(
|
||||
|
|
@ -1310,9 +1310,9 @@ if torch._C._has_mkldnn:
|
|||
)
|
||||
batch_size = input.meta.get("val").shape[0]
|
||||
if has_free_symbols(batch_size):
|
||||
assert (
|
||||
is_lp_weight or mkldnn._is_mkldnn_acl_supported()
|
||||
), f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
|
||||
assert is_lp_weight or mkldnn._is_mkldnn_acl_supported(), (
|
||||
f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
|
||||
)
|
||||
# For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance.
|
||||
packed_weight_inputs = (
|
||||
transpose_weight_node,
|
||||
|
|
|
|||
|
|
@ -437,7 +437,7 @@ def _should_pad_bench(
|
|||
return False
|
||||
|
||||
def realize_symbols(
|
||||
ds: Union[torch.Size, tuple[torch.SymInt, ...]]
|
||||
ds: Union[torch.Size, tuple[torch.SymInt, ...]],
|
||||
) -> list[int]:
|
||||
return [d if isinstance(d, int) else d.node.hint for d in ds]
|
||||
|
||||
|
|
|
|||
|
|
@ -137,9 +137,9 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
|||
pattern_matcher_pass.apply
|
||||
)
|
||||
if not is_same_dict(counters["inductor"], inductor_before_change):
|
||||
optimus_scuba_log[
|
||||
f"{pattern_matcher_pass.pass_name}_post_grad"
|
||||
] = upload_graph(gm.graph)
|
||||
optimus_scuba_log[f"{pattern_matcher_pass.pass_name}_post_grad"] = (
|
||||
upload_graph(gm.graph)
|
||||
)
|
||||
if config.b2b_gemm_pass:
|
||||
B2B_GEMM_PASS.apply(gm.graph) # type: ignore[arg-type]
|
||||
|
||||
|
|
|
|||
|
|
@ -277,9 +277,9 @@ def pre_grad_passes(
|
|||
for _ in range(counter):
|
||||
pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
|
||||
if not is_same_dict(counters["inductor"], inductor_before_change):
|
||||
optimus_scuba_log[
|
||||
f"{pattern_matcher_pass.pass_name}_pre_grad"
|
||||
] = upload_graph(gm.graph)
|
||||
optimus_scuba_log[f"{pattern_matcher_pass.pass_name}_pre_grad"] = (
|
||||
upload_graph(gm.graph)
|
||||
)
|
||||
# TODO: move efficient_conv_bn_eval_pass to the fusions dict too.
|
||||
efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type]
|
||||
|
||||
|
|
|
|||
|
|
@ -763,9 +763,9 @@ def _register_quantized_conv_binary_lowering(
|
|||
accum.realize()
|
||||
from .mkldnn_fusion import _can_be_inplace
|
||||
|
||||
assert _can_be_inplace(
|
||||
accum
|
||||
), "QConv Binary Inplace Fusion requires accum is not an alias or mutation."
|
||||
assert _can_be_inplace(accum), (
|
||||
"QConv Binary Inplace Fusion requires accum is not an alias or mutation."
|
||||
)
|
||||
|
||||
computation_args = (
|
||||
x,
|
||||
|
|
@ -1307,9 +1307,9 @@ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
|
|||
def clone_to_new_node(graph, source_node, user_node):
|
||||
# Clone the source_node to a new node
|
||||
# Replace user_node's input from source_node to new_node
|
||||
assert (
|
||||
source_node.op == "call_function"
|
||||
), "clone_to_new_node only support node.op call_function"
|
||||
assert source_node.op == "call_function", (
|
||||
"clone_to_new_node only support node.op call_function"
|
||||
)
|
||||
with graph.inserting_before(user_node):
|
||||
new_node = graph.call_function(
|
||||
source_node.target,
|
||||
|
|
@ -1343,9 +1343,9 @@ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
|
|||
# For a dequant pattern, we expect the start node is a dequantize_per_tensor node
|
||||
return _node
|
||||
else:
|
||||
assert (
|
||||
len(_node.args) >= 1
|
||||
), "In in dequant pattern, each node should have more than 1 arg."
|
||||
assert len(_node.args) >= 1, (
|
||||
"In in dequant pattern, each node should have more than 1 arg."
|
||||
)
|
||||
return _find_first_node_in_dequant_pattern(_node.args[0])
|
||||
|
||||
dequant_pattern_start_node = _find_first_node_in_dequant_pattern(
|
||||
|
|
|
|||
|
|
@ -616,7 +616,8 @@ def merge_splits(
|
|||
dim=first_split_dim,
|
||||
)
|
||||
first_split_num_to_user = {
|
||||
user.args[1]: user for user in first_split.users.keys() # type: ignore[union-attr]
|
||||
user.args[1]: user
|
||||
for user in first_split.users.keys() # type: ignore[union-attr]
|
||||
}
|
||||
|
||||
new_split_num = 0
|
||||
|
|
@ -706,7 +707,11 @@ class SplitCatSimplifier:
|
|||
graph, split_node, split_sections, user_inputs_list, simplified_split_ranges
|
||||
)
|
||||
self.replace_cat(
|
||||
graph, split_node, next_users, user_inputs_list_new, transform_params_list # type: ignore[arg-type]
|
||||
graph,
|
||||
split_node,
|
||||
next_users,
|
||||
user_inputs_list_new,
|
||||
transform_params_list, # type: ignore[arg-type]
|
||||
)
|
||||
self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type]
|
||||
counters["inductor"]["unbind_stack_pass"] += 1
|
||||
|
|
@ -913,7 +918,9 @@ class SplitCatSimplifier:
|
|||
)
|
||||
if is_node_meta_valid(split_input): # type: ignore[arg-type, union-attr]
|
||||
new_split.meta["example_value"] = torch.split(
|
||||
split_input.meta["example_value"], [r[1] - r[0] for r in split_ranges], dim=split_dim # type: ignore[union-attr]
|
||||
split_input.meta["example_value"], # type: ignore[union-attr]
|
||||
[r[1] - r[0] for r in split_ranges],
|
||||
dim=split_dim,
|
||||
)
|
||||
counters["inductor"]["scmerge_split_added"] += 1
|
||||
split_items = []
|
||||
|
|
@ -1005,7 +1012,10 @@ class SplitCatSimplifier:
|
|||
stacked_input = graph.call_function(
|
||||
torch.stack, args=(to_stack,), kwargs={"dim": stack_dim}
|
||||
)
|
||||
stacked_input.meta["example_value"] = torch.stack(to_stack_meta, dim=stack_dim) # type: ignore[arg-type, union-attr]
|
||||
stacked_input.meta["example_value"] = torch.stack( # type: ignore[arg-type]
|
||||
to_stack_meta,
|
||||
dim=stack_dim, # type: ignore[arg-type]
|
||||
)
|
||||
to_stack, to_stack_meta = [], []
|
||||
stack_dim = None
|
||||
user_inputs_new_transformed.append(stacked_input)
|
||||
|
|
@ -1023,19 +1033,28 @@ class SplitCatSimplifier:
|
|||
user_input_new = graph.call_function(
|
||||
torch.unflatten, args=(user_input_new, *unflatten_params)
|
||||
)
|
||||
user_input_new.meta["example_value"] = torch.unflatten(user_input_new_meta, *unflatten_params) # type: ignore[arg-type, possibly-undefined, union-attr]
|
||||
user_input_new.meta["example_value"] = torch.unflatten( # type: ignore[arg-type]
|
||||
user_input_new_meta, # type: ignore[arg-type]
|
||||
*unflatten_params, # type: ignore[arg-type]
|
||||
)
|
||||
if movedim_params:
|
||||
user_input_new_meta = user_input_new.meta["example_value"]
|
||||
user_input_new = graph.call_function(
|
||||
torch.movedim, args=(user_input_new, *movedim_params)
|
||||
)
|
||||
user_input_new.meta["example_value"] = torch.movedim(user_input_new_meta, *movedim_params) # type: ignore[arg-type, possibly-undefined, union-attr]
|
||||
user_input_new.meta["example_value"] = torch.movedim( # type: ignore[arg-type]
|
||||
user_input_new_meta, # type: ignore[arg-type]
|
||||
*movedim_params, # type: ignore[arg-type]
|
||||
)
|
||||
if flatten_params:
|
||||
user_input_new_meta = user_input_new.meta["example_value"]
|
||||
user_input_new = graph.call_function(
|
||||
torch.flatten, args=(user_input_new, *flatten_params)
|
||||
)
|
||||
user_input_new.meta["example_value"] = torch.flatten(user_input_new_meta, *flatten_params) # type: ignore[arg-type, possibly-undefined, union-attr]
|
||||
user_input_new.meta["example_value"] = torch.flatten( # type: ignore[arg-type]
|
||||
user_input_new_meta,
|
||||
*flatten_params, # type: ignore[arg-type]
|
||||
)
|
||||
user_inputs_new_transformed.append(user_input_new)
|
||||
user_inputs_new_transformed_meta.append(
|
||||
user_input_new.meta["example_value"]
|
||||
|
|
@ -1044,7 +1063,10 @@ class SplitCatSimplifier:
|
|||
stacked_input = graph.call_function(
|
||||
torch.stack, args=(to_stack,), kwargs={"dim": stack_dim}
|
||||
)
|
||||
stacked_input.meta["example_value"] = torch.stack(to_stack_meta, dim=stack_dim) # type: ignore[arg-type, union-attr]
|
||||
stacked_input.meta["example_value"] = torch.stack( # type: ignore[arg-type]
|
||||
to_stack_meta,
|
||||
dim=stack_dim, # type: ignore[arg-type]
|
||||
)
|
||||
user_inputs_new_transformed.append(stacked_input)
|
||||
user_inputs_new_transformed_meta.append(
|
||||
stacked_input.meta["example_value"]
|
||||
|
|
@ -1058,14 +1080,15 @@ class SplitCatSimplifier:
|
|||
kwargs={"dim": cat_dim},
|
||||
)
|
||||
new_cat_node.meta["example_value"] = torch.cat(
|
||||
user_inputs_new_transformed_meta, dim=cat_dim
|
||||
user_inputs_new_transformed_meta,
|
||||
dim=cat_dim,
|
||||
)
|
||||
counters["inductor"]["scmerge_cat_added"] += 1
|
||||
else:
|
||||
new_cat_node = user_inputs_new_transformed[-1]
|
||||
new_cat_node.meta[
|
||||
"example_value"
|
||||
] = user_inputs_new_transformed_meta[-1]
|
||||
new_cat_node.meta["example_value"] = (
|
||||
user_inputs_new_transformed_meta[-1]
|
||||
)
|
||||
|
||||
if (
|
||||
user_node.target == torch.cat
|
||||
|
|
@ -1077,7 +1100,11 @@ class SplitCatSimplifier:
|
|||
new_cat_node = graph.call_function(
|
||||
torch.flatten, args=(new_cat_node, cat_dim, cat_dim + 1)
|
||||
)
|
||||
new_cat_node.meta["example_value"] = torch.flatten(new_cat_node_meta, cat_dim, cat_dim + 1) # type: ignore[possibly-undefined, union-attr]
|
||||
new_cat_node.meta["example_value"] = torch.flatten(
|
||||
new_cat_node_meta,
|
||||
cat_dim,
|
||||
cat_dim + 1,
|
||||
)
|
||||
user_node.replace_all_uses_with(new_cat_node)
|
||||
new_cats.append(new_cat_node)
|
||||
|
||||
|
|
@ -1123,9 +1150,7 @@ class UnbindCatRemover(SplitCatSimplifier):
|
|||
]
|
||||
if not is_sorted_and_consecutive(getitem_indices) or len( # type: ignore[arg-type]
|
||||
getitem_indices
|
||||
) != len(
|
||||
unbind_node.meta["example_value"]
|
||||
):
|
||||
) != len(unbind_node.meta["example_value"]):
|
||||
return
|
||||
num_unbind = len(getitem_indices)
|
||||
split_sections = [1 for _ in range(num_unbind)] # type: ignore[operator, arg-type]
|
||||
|
|
@ -1510,7 +1535,8 @@ def merge_getitem_cat(match: Match, split_sections: list[int], dim: int):
|
|||
fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index]
|
||||
# update the split sections
|
||||
split_sections[indices[0]] = calculate_fused_tensor_size( # type: ignore[index]
|
||||
split_node, indices # type: ignore[arg-type]
|
||||
split_node,
|
||||
indices, # type: ignore[arg-type]
|
||||
)
|
||||
# padding others with zeros to keep the same dict size
|
||||
for i in indices[1:]:
|
||||
|
|
@ -1613,10 +1639,12 @@ def mutate_cat_node(match: Match, split_sections: list[int], dim: int):
|
|||
elif is_node_meta_valid(split_node.args[0]): # type: ignore[arg-type]
|
||||
# check the split dim, and construct the slice tuple
|
||||
start_fused_size = calculate_fused_tensor_size(
|
||||
split_node, list(range(indices[0])) # type: ignore[arg-type]
|
||||
split_node,
|
||||
list(range(indices[0])), # type: ignore[arg-type]
|
||||
)
|
||||
end_fused_size = start_fused_size + calculate_fused_tensor_size(
|
||||
split_node, indices # type: ignore[arg-type]
|
||||
split_node,
|
||||
indices, # type: ignore[arg-type]
|
||||
)
|
||||
slice_list = []
|
||||
for i in range(len(split_node.args[0].meta["example_value"].shape)): # type: ignore[union-attr]
|
||||
|
|
@ -1714,7 +1742,10 @@ def merge_split_cat_aten(match: Match, *args, **kwargs):
|
|||
continue
|
||||
# check the cat node has consecutive indices
|
||||
indices = [arg.args[1] for arg in cat_node.args[0]] # type: ignore[union-attr]
|
||||
if not is_sorted_and_consecutive(indices) and len(getitem_nodes) != len(cat_inputs): # type: ignore[arg-type]
|
||||
if (
|
||||
not is_sorted_and_consecutive(indices) # type: ignore[arg-type]
|
||||
and len(getitem_nodes) != len(cat_inputs)
|
||||
):
|
||||
continue
|
||||
# replace the users of the cat node to be the input of the split node
|
||||
cat_node.replace_all_uses_with(split_input)
|
||||
|
|
@ -1764,7 +1795,10 @@ def merge_select_cat_aten(match: Match, *args, **kwargs):
|
|||
continue
|
||||
# check the cat node has consecutive indices
|
||||
indices = [select.args[2] for select in cat_node.args[0]] # type: ignore[union-attr]
|
||||
if not is_sorted_and_consecutive(indices) or len(select_nodes) != len(cat_inputs): # type: ignore[arg-type]
|
||||
if (
|
||||
not is_sorted_and_consecutive(indices) # type: ignore[arg-type]
|
||||
or len(select_nodes) != len(cat_inputs)
|
||||
):
|
||||
continue
|
||||
# check all the select nodes can be merged to the cat node input
|
||||
if len(indices) != select_nodes[0].args[0].meta["val"].shape[cat_dim]: # type: ignore[union-attr]
|
||||
|
|
@ -2318,7 +2352,9 @@ def unbind_cat_to_view(match: Match, unbind_input: torch.fx.Node, dim: int):
|
|||
args=(new_cat_args,),
|
||||
kwargs={"dim": cat_dim},
|
||||
)
|
||||
new_cat_node.meta["example_value"] = torch.cat(new_cat_args_meta, dim=cat_dim) # type: ignore[arg-type]
|
||||
new_cat_node.meta["example_value"] = torch.cat(
|
||||
new_cat_args_meta, dim=cat_dim
|
||||
) # type: ignore[arg-type]
|
||||
cat_node.replace_all_uses_with(new_cat_node)
|
||||
new_cat_node.meta.update(cat_node.meta)
|
||||
# remove inputs of cat_node if they have no users
|
||||
|
|
@ -2411,7 +2447,8 @@ def convert_reshape_cat_arg_to_stack(
|
|||
args=(permute_node, tuple(stack_node_shape)), # type: ignore[arg-type]
|
||||
)
|
||||
reshape_node.meta["example_value"] = torch.Tensor.view(
|
||||
permute_node.meta["example_value"], tuple(stack_node_shape) # type: ignore[arg-type]
|
||||
permute_node.meta["example_value"],
|
||||
tuple(stack_node_shape), # type: ignore[arg-type]
|
||||
)
|
||||
return reshape_node
|
||||
|
||||
|
|
@ -2687,7 +2724,9 @@ def move_reshape_out_of_split_stack(match: Match, *args, **kwargs):
|
|||
cat_inputs.append(decomposed_stack_node)
|
||||
# cat_arg must be the split input
|
||||
view_shape_list = get_view_shape_list(cat_arg, stack_dim)
|
||||
stack_node_shape = torch.reshape(cat_arg.meta["example_value"], tuple(view_shape_list)).shape # type: ignore[union-attr]
|
||||
stack_node_shape = torch.reshape(
|
||||
cat_arg.meta["example_value"], tuple(view_shape_list)
|
||||
).shape # type: ignore[union-attr]
|
||||
cat_inputs.append(
|
||||
convert_reshape_cat_arg_to_stack(
|
||||
graph,
|
||||
|
|
|
|||
|
|
@ -105,9 +105,9 @@ class FakeTensorUpdater:
|
|||
if new is None:
|
||||
return old is None
|
||||
if not isinstance(new, torch.Tensor):
|
||||
assert isinstance(
|
||||
new, (torch.SymInt, torch.SymBool, torch.SymFloat)
|
||||
), f"Unknown type {type(new)} in {self.graph}"
|
||||
assert isinstance(new, (torch.SymInt, torch.SymBool, torch.SymFloat)), (
|
||||
f"Unknown type {type(new)} in {self.graph}"
|
||||
)
|
||||
return (
|
||||
new.node.shape_env._maybe_evaluate_static(
|
||||
sympy.Eq(new.node.expr, old.node.expr)
|
||||
|
|
|
|||
|
|
@ -136,7 +136,9 @@ else:
|
|||
def may_get_constant_buffer_dtype(constant_buffer: sympy.Expr) -> Optional[torch.dtype]:
|
||||
assert isinstance(
|
||||
constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
|
||||
), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer"
|
||||
), (
|
||||
"get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer"
|
||||
)
|
||||
if isinstance(constant_buffer, sympy.core.numbers.Integer):
|
||||
return torch.int64
|
||||
|
||||
|
|
@ -308,9 +310,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
self.reuse_shape_env = True
|
||||
self._shape_env = shape_env
|
||||
# We're going to mutate ras_by_symbol as we finish generating them
|
||||
self.ras_by_symbol: dict[
|
||||
Optional[sympy.Symbol], list[RuntimeAssert]
|
||||
] = shape_env.deferred_runtime_asserts.copy()
|
||||
self.ras_by_symbol: dict[Optional[sympy.Symbol], list[RuntimeAssert]] = (
|
||||
shape_env.deferred_runtime_asserts.copy()
|
||||
)
|
||||
self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]()
|
||||
self.sizevars = SizeVarAllocator(shape_env)
|
||||
self.graph_input_names: list[str] = []
|
||||
|
|
@ -400,9 +402,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
|
||||
self.cache_linemap: list[
|
||||
tuple[int, str]
|
||||
] = (
|
||||
[]
|
||||
) # This is the linemap used by the profiler to mark custom compiled kernels getting run
|
||||
] = [] # This is the linemap used by the profiler to mark custom compiled kernels getting run
|
||||
# Used if lowering encounters cases where cudagraphs are not supported
|
||||
self.disable_cudagraphs_reason: Optional[str] = None
|
||||
|
||||
|
|
@ -1012,7 +1012,10 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
)
|
||||
|
||||
def placeholder(
|
||||
self, target: str, args: tuple[object], kwargs: dict[str, object] # type: ignore[override]
|
||||
self,
|
||||
target: str, # type: ignore[override]
|
||||
args: tuple[object], # type: ignore[override]
|
||||
kwargs: dict[str, object],
|
||||
) -> Union[Expr, TensorBox, None]:
|
||||
self.placeholder_idx += 1
|
||||
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
|
||||
|
|
@ -1118,9 +1121,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
return target(*args, **kwargs)
|
||||
|
||||
if target not in lowerings:
|
||||
assert isinstance(
|
||||
target, torch._ops.OpOverload
|
||||
), f"{target} is not an OpOverload"
|
||||
assert isinstance(target, torch._ops.OpOverload), (
|
||||
f"{target} is not an OpOverload"
|
||||
)
|
||||
base_name = target.name().split(".")[0]
|
||||
if base_name in FALLBACK_ALLOW_LIST:
|
||||
make_fallback(target, warn=False, override_decomp=True)
|
||||
|
|
@ -1189,7 +1192,10 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
return len(t.shape) == 1 and t.shape[0] <= 8
|
||||
|
||||
def get_attr(
|
||||
self, target: str, args: tuple[()], kwargs: dict[str, object] # type: ignore[override]
|
||||
self,
|
||||
target: str, # type: ignore[override]
|
||||
args: tuple[()], # type: ignore[override]
|
||||
kwargs: dict[str, object],
|
||||
) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
|
||||
# this is a constant
|
||||
value = getattr_recursive(self.module, target) # type: ignore[arg-type]
|
||||
|
|
@ -1241,7 +1247,10 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
raise AssertionError
|
||||
|
||||
def output(
|
||||
self, target: str, args: tuple[object], kwargs: dict[str, object] # type: ignore[override]
|
||||
self,
|
||||
target: str, # type: ignore[override]
|
||||
args: tuple[object], # type: ignore[override]
|
||||
kwargs: dict[str, object],
|
||||
) -> None:
|
||||
result = super().output(target, args, kwargs) # type: ignore[arg-type]
|
||||
if not isinstance(result, (tuple, list)):
|
||||
|
|
@ -1439,9 +1448,11 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
if is_call_function:
|
||||
args, kwargs = self.fetch_args_kwargs_from_env(n)
|
||||
origins |= gather_origins(args, kwargs)
|
||||
with ir.IRNode.current_origins(origins), self.set_current_node(
|
||||
n
|
||||
), V.set_current_node(n):
|
||||
with (
|
||||
ir.IRNode.current_origins(origins),
|
||||
self.set_current_node(n),
|
||||
V.set_current_node(n),
|
||||
):
|
||||
if (
|
||||
n.op == "call_function"
|
||||
and n.target is not operator.getitem
|
||||
|
|
@ -1454,7 +1465,8 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
):
|
||||
debug("fallback_handler")
|
||||
result = fallback_handler(n.target, add_to_fallback_set=False)(
|
||||
*args, **kwargs # type: ignore[possibly-undefined]
|
||||
*args, # type: ignore[possibly-undefined]
|
||||
**kwargs, # type: ignore[possibly-undefined]
|
||||
)
|
||||
elif (
|
||||
n.op == "call_function"
|
||||
|
|
@ -1833,9 +1845,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
wrapper_code_gen_cls = get_wrapper_codegen_for_device(
|
||||
self.device_type, self.cpp_wrapper
|
||||
)
|
||||
assert (
|
||||
wrapper_code_gen_cls is not None
|
||||
), f"Device {self.device_type} not supported"
|
||||
assert wrapper_code_gen_cls is not None, (
|
||||
f"Device {self.device_type} not supported"
|
||||
)
|
||||
self.wrapper_code = wrapper_code_gen_cls.create(
|
||||
is_subgraph,
|
||||
subgraph_name,
|
||||
|
|
@ -1866,7 +1878,7 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
compiled = self.compile_to_module().call
|
||||
|
||||
def materialize(
|
||||
x: Union[torch.SymInt, torch.SymFloat, torch.Tensor]
|
||||
x: Union[torch.SymInt, torch.SymFloat, torch.Tensor],
|
||||
) -> Union[int, float, torch.Tensor]:
|
||||
if x is None:
|
||||
return None
|
||||
|
|
@ -1876,9 +1888,9 @@ class GraphLowering(torch.fx.Interpreter):
|
|||
elif isinstance(x, FakeTensor):
|
||||
return defake(x)
|
||||
else:
|
||||
assert isinstance(
|
||||
x, torch.Tensor
|
||||
), "Unknown type when creating real inputs" + str(type(x))
|
||||
assert isinstance(x, torch.Tensor), (
|
||||
"Unknown type when creating real inputs" + str(type(x))
|
||||
)
|
||||
return x
|
||||
|
||||
tracing_context = torch._guards.TracingContext.try_get()
|
||||
|
|
|
|||
|
|
@ -5,21 +5,22 @@ propagation of sympy expressions downstream of ops.index_expr calls.
|
|||
|
||||
For example, say we have the IR:
|
||||
|
||||
tmp0 = ops.index_expr(x, torch.int32)
|
||||
tmp1 = ops.constant(2, torch.int32)
|
||||
tmp2 = ops.mul(tmp0, tmp1)
|
||||
tmp3 = ops.indirect_indexing(tmp2, x_size)
|
||||
tmp4 = ops.load("buf0", tmp3)
|
||||
tmp0 = ops.index_expr(x, torch.int32)
|
||||
tmp1 = ops.constant(2, torch.int32)
|
||||
tmp2 = ops.mul(tmp0, tmp1)
|
||||
tmp3 = ops.indirect_indexing(tmp2, x_size)
|
||||
tmp4 = ops.load("buf0", tmp3)
|
||||
|
||||
The underlying handler would just see:
|
||||
|
||||
ops.load("buf0", x * 2)
|
||||
ops.load("buf0", x * 2)
|
||||
|
||||
This is limited by the set of operators handled in the sympy expression
|
||||
printers. So simple operations like minimum and maximum cannot be translated to
|
||||
SymPy expressions yet, despite sympy.Min and sympy.Max existing.
|
||||
|
||||
"""
|
||||
|
||||
import itertools
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
|
|
@ -179,9 +180,9 @@ class IndexPropVar:
|
|||
return IndexPropVar(expr, is_symbolic=True)
|
||||
|
||||
def __post_init__(self):
|
||||
assert not self.is_symbolic or isinstance(
|
||||
self.value, TypedExpr
|
||||
), "Symbolic IndexPropVar must contain a TypedExpr"
|
||||
assert not self.is_symbolic or isinstance(self.value, TypedExpr), (
|
||||
"Symbolic IndexPropVar must contain a TypedExpr"
|
||||
)
|
||||
|
||||
|
||||
IndexPropResult: TypeAlias = Union[IndexPropVar, tuple["IndexPropResult", ...]]
|
||||
|
|
@ -251,14 +252,12 @@ class IndexPropagation(DefaultHandler):
|
|||
name: Literal["indirect_indexing"],
|
||||
args: Sequence[Any],
|
||||
kwargs: dict[str, Any],
|
||||
) -> IndexPropVar:
|
||||
...
|
||||
) -> IndexPropVar: ...
|
||||
|
||||
@overload
|
||||
def fallback(
|
||||
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
|
||||
) -> IndexPropResult:
|
||||
...
|
||||
) -> IndexPropResult: ...
|
||||
|
||||
def fallback(
|
||||
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
|
||||
|
|
@ -283,8 +282,7 @@ class IndexPropagation(DefaultHandler):
|
|||
is_valid_expr = new_expr is not NotImplemented and (
|
||||
# Inductor doesn't expect floating point in sympy expressions, but
|
||||
# allow floating point constants to be propagated
|
||||
new_expr.is_constant()
|
||||
or new_expr.expr.is_integer
|
||||
new_expr.is_constant() or new_expr.expr.is_integer
|
||||
)
|
||||
if not is_valid_expr:
|
||||
return self.fallback(name, args, kwargs)
|
||||
|
|
|
|||
|
|
@ -211,7 +211,9 @@ def validate_ir(node_or_nodes: Optional[_NodeOrNodes]) -> None:
|
|||
int,
|
||||
EffectfulKernel,
|
||||
),
|
||||
), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"
|
||||
), (
|
||||
f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"
|
||||
)
|
||||
|
||||
# Be picky about the accepted data structure (don't use pytree here)
|
||||
_check_tensorbox(node_or_nodes)
|
||||
|
|
@ -298,13 +300,11 @@ def get_stride_order(
|
|||
|
||||
|
||||
@overload
|
||||
def ir_node_to_tensor(x: Literal[None], guard_shape: bool = True) -> None:
|
||||
...
|
||||
def ir_node_to_tensor(x: Literal[None], guard_shape: bool = True) -> None: ...
|
||||
|
||||
|
||||
@overload
|
||||
def ir_node_to_tensor(x: IRNode, guard_shape: bool = True) -> torch.Tensor:
|
||||
...
|
||||
def ir_node_to_tensor(x: IRNode, guard_shape: bool = True) -> torch.Tensor: ...
|
||||
|
||||
|
||||
def ir_node_to_tensor(
|
||||
|
|
@ -346,7 +346,7 @@ def may_convert_to_optional(
|
|||
|
||||
|
||||
def get_device_type(
|
||||
x: Union[IRNode, OutputSpec, torch.device, None, str]
|
||||
x: Union[IRNode, OutputSpec, torch.device, None, str],
|
||||
) -> Optional[str]:
|
||||
if isinstance(x, str) or x is None:
|
||||
return x
|
||||
|
|
@ -698,8 +698,7 @@ class IRNode:
|
|||
if TYPE_CHECKING:
|
||||
|
||||
@property
|
||||
def dtype(self) -> torch.dtype:
|
||||
...
|
||||
def dtype(self) -> torch.dtype: ...
|
||||
|
||||
|
||||
@ir_dataclass(frozen=False)
|
||||
|
|
@ -839,8 +838,9 @@ class Loops(IRNode):
|
|||
@cache_on_self
|
||||
def inner_fn_opcount(self) -> OpCountResult:
|
||||
opcounter = OpCounterCSE(V.MockHandler())
|
||||
with V.set_ops_handler(opcounter), patch.object(
|
||||
FlexibleLayout, "allow_indexing", True
|
||||
with (
|
||||
V.set_ops_handler(opcounter),
|
||||
patch.object(FlexibleLayout, "allow_indexing", True),
|
||||
):
|
||||
self.inner_fn(*self.inner_fn_args())
|
||||
return opcounter.getvalue()
|
||||
|
|
@ -1364,9 +1364,9 @@ class Reduction(Loops):
|
|||
# "all" is desugared to `!any(!val)`
|
||||
}
|
||||
|
||||
assert (
|
||||
reduction_type in rtypes_to_inits.keys()
|
||||
), f"{reduction_type} not supported for zero-dimension tensors!"
|
||||
assert reduction_type in rtypes_to_inits.keys(), (
|
||||
f"{reduction_type} not supported for zero-dimension tensors!"
|
||||
)
|
||||
|
||||
def const_fn(index: int) -> OpsValue:
|
||||
return ops.constant(rtypes_to_inits[reduction_type], dst_dtype)
|
||||
|
|
@ -1575,9 +1575,9 @@ class Reduction(Loops):
|
|||
new_ranges: Sequence[Integer],
|
||||
new_reduction_ranges: Sequence[Integer],
|
||||
) -> Callable[[Sequence[sympy.Expr], Sequence[sympy.Expr]], OpsValue]:
|
||||
assert all(
|
||||
r == 1 for r in original_ranges
|
||||
), f"Only enabled for numel_hint == 1, found {original_ranges=}"
|
||||
assert all(r == 1 for r in original_ranges), (
|
||||
f"Only enabled for numel_hint == 1, found {original_ranges=}"
|
||||
)
|
||||
reindex = View.dynamic_reshape_indexer(
|
||||
original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges)
|
||||
)
|
||||
|
|
@ -1828,7 +1828,7 @@ class WelfordReduction(Reduction):
|
|||
if reduction_numel == 1:
|
||||
|
||||
def copy(
|
||||
loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue]
|
||||
loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue],
|
||||
) -> TensorBox:
|
||||
def inner_fn(idx: Sequence[Expr]) -> OpsValue:
|
||||
reduction_index = [sympy.S.Zero for _ in reduction_ranges]
|
||||
|
|
@ -2571,9 +2571,9 @@ class ExpandView(BaseView):
|
|||
# NB: new_size[i] == old_size[i] is expected to already be
|
||||
# guarded because the meta formula was expected to have taught
|
||||
# us this equality.
|
||||
assert (
|
||||
sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0
|
||||
), "Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}"
|
||||
assert sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0, (
|
||||
"Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}"
|
||||
)
|
||||
return new_size
|
||||
|
||||
@classmethod
|
||||
|
|
@ -3382,9 +3382,9 @@ class Layout(OutputSpec):
|
|||
)
|
||||
|
||||
def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
|
||||
assert (
|
||||
FlexibleLayout.allow_indexing
|
||||
), f"convert {type(self).__name__} to FixedLayout first"
|
||||
assert FlexibleLayout.allow_indexing, (
|
||||
f"convert {type(self).__name__} to FixedLayout first"
|
||||
)
|
||||
return self.as_fixed().make_indexer()
|
||||
|
||||
def __eq__(self, other) -> bool: # type: ignore[no-untyped-def]
|
||||
|
|
@ -3684,9 +3684,9 @@ class MutationLayoutSHOULDREMOVE(Layout):
|
|||
return target
|
||||
|
||||
result = unwrap_views(self.target)
|
||||
assert isinstance(
|
||||
result, Buffer
|
||||
), "MutationLayoutSHOULDREMOVE must refer to a buffer"
|
||||
assert isinstance(result, Buffer), (
|
||||
"MutationLayoutSHOULDREMOVE must refer to a buffer"
|
||||
)
|
||||
return result
|
||||
|
||||
def real_layout(self): # type: ignore[no-untyped-def]
|
||||
|
|
@ -3803,7 +3803,9 @@ class Buffer(IRNode):
|
|||
assert isinstance(self.layout, FlexibleLayout)
|
||||
self.layout = self.layout.as_same_order(stride)
|
||||
|
||||
def freeze_layout_with_exact_strides(self, exact_strides, allow_padding=False) -> None: # type: ignore[no-untyped-def]
|
||||
def freeze_layout_with_exact_strides( # type: ignore[no-untyped-def]
|
||||
self, exact_strides, allow_padding=False
|
||||
) -> None:
|
||||
assert isinstance(self.layout, FlexibleLayout)
|
||||
self.layout = self.layout.as_exact_strides(
|
||||
exact_strides, allow_padding=allow_padding
|
||||
|
|
@ -4365,9 +4367,9 @@ class TritonTemplateBuffer(TemplateBuffer):
|
|||
torch.ops.higher_order.flex_attention_backward,
|
||||
)
|
||||
current_node = V.graph.current_node.target
|
||||
assert (
|
||||
current_node in allowed_set
|
||||
), f"Mutated inputs are only allowed for {allowed_set} but got {current_node}"
|
||||
assert current_node in allowed_set, (
|
||||
f"Mutated inputs are only allowed for {allowed_set} but got {current_node}"
|
||||
)
|
||||
device = self.inputs[0].get_device()
|
||||
self.outputs += [
|
||||
MutationOutput(NoneLayout(device=device), buf, self)
|
||||
|
|
@ -5106,7 +5108,8 @@ class ExternKernel(InputsKernel):
|
|||
x_unwrap_view.freeze_layout()
|
||||
|
||||
index_args, var_ranges = dependencies.index_vars_squeeze(
|
||||
x.get_size(), prefix="r" # type: ignore[arg-type]
|
||||
x.get_size(),
|
||||
prefix="r", # type: ignore[arg-type]
|
||||
)
|
||||
range_vars = index_args[0]
|
||||
index = x.make_indexer()(range_vars)
|
||||
|
|
@ -5404,9 +5407,9 @@ class ExternKernel(InputsKernel):
|
|||
# pass in a list of const arg names for arg_properties lookup.
|
||||
name_to_arg_properties = None
|
||||
if names and self.arg_properties:
|
||||
assert len(self.constant_args) == len(
|
||||
names
|
||||
), "names passed to codegen_const_args does not match self.constant_args"
|
||||
assert len(self.constant_args) == len(names), (
|
||||
"names passed to codegen_const_args does not match self.constant_args"
|
||||
)
|
||||
name_to_arg_properties = {
|
||||
arg.get("name"): arg for arg in self.arg_properties
|
||||
}
|
||||
|
|
@ -5442,9 +5445,9 @@ class ExternKernel(InputsKernel):
|
|||
args = []
|
||||
for i, x in enumerate(inputs):
|
||||
if V.graph.cpp_wrapper:
|
||||
assert self.arg_properties and i < len(
|
||||
self.arg_properties
|
||||
), "Invalid access to ExternKernel.arg_properties"
|
||||
assert self.arg_properties and i < len(self.arg_properties), (
|
||||
"Invalid access to ExternKernel.arg_properties"
|
||||
)
|
||||
type_ = self.arg_properties[i].get("type")
|
||||
args.append(V.graph.wrapper_code.val_to_arg_str(x, type_))
|
||||
else:
|
||||
|
|
@ -5914,7 +5917,9 @@ class UserDefinedTritonKernel(ExternKernel):
|
|||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||
return OrderedSet()
|
||||
|
||||
def __init__(self, *, kernel_idx, grid, tma_descriptor_metadata, kernel_args) -> None: # type: ignore[no-untyped-def]
|
||||
def __init__( # type: ignore[no-untyped-def]
|
||||
self, *, kernel_idx, grid, tma_descriptor_metadata, kernel_args
|
||||
) -> None:
|
||||
inputs = []
|
||||
kwargs = {}
|
||||
constant_args = []
|
||||
|
|
@ -6835,9 +6840,9 @@ class FallbackKernel(ExternKernelAlloc):
|
|||
elif isinstance(output, torch.SymInt):
|
||||
return output.node.expr
|
||||
else:
|
||||
assert (
|
||||
output is None
|
||||
), f"FallbackKernel output type {type(output)} is not supported"
|
||||
assert output is None, (
|
||||
f"FallbackKernel output type {type(output)} is not supported"
|
||||
)
|
||||
return None
|
||||
|
||||
outputs = generate_output(example_output, [])
|
||||
|
|
@ -6919,7 +6924,12 @@ class MultiOutput(ExternKernel):
|
|||
)
|
||||
self.codegen_size_asserts(wrapper)
|
||||
|
||||
def __init__(self, layout: OutputSpec, input, indices: list[tuple[Any, ...]]) -> None: # type: ignore[no-untyped-def]
|
||||
def __init__( # type: ignore[no-untyped-def]
|
||||
self,
|
||||
layout: OutputSpec,
|
||||
input,
|
||||
indices: list[tuple[Any, ...]],
|
||||
) -> None:
|
||||
super().__init__(None, layout, [input], ())
|
||||
self.name = V.graph.register_buffer(self)
|
||||
V.graph.register_operation(self)
|
||||
|
|
@ -7496,9 +7506,9 @@ class WhileLoop(ExternKernel):
|
|||
assert p.get_dtype() == torch.bool, p
|
||||
assert len(p.get_size()) == 0, p
|
||||
|
||||
assert (
|
||||
len(all_inputs) > 0
|
||||
), "torch.while_loop is assumed to have at least one operand."
|
||||
assert len(all_inputs) > 0, (
|
||||
"torch.while_loop is assumed to have at least one operand."
|
||||
)
|
||||
|
||||
device = all_inputs[0].get_device()
|
||||
|
||||
|
|
@ -7669,9 +7679,9 @@ class _CollectiveKernel(FallbackKernel):
|
|||
# This is identical to FallbackKernel.set_cpp_kernel(), minus the
|
||||
# part that checks against input aliasing and mutation.
|
||||
def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None:
|
||||
assert (
|
||||
type(self.op_overload) is torch._ops.OpOverload
|
||||
), "Setting cpp kernel needs a valid op_overload"
|
||||
assert type(self.op_overload) is torch._ops.OpOverload, (
|
||||
"Setting cpp kernel needs a valid op_overload"
|
||||
)
|
||||
kernel = self.op_overload
|
||||
self.cpp_kernel_name = kernel._schema.name
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,5 @@
|
|||
# mypy: allow-untyped-defs
|
||||
""" Triton Implementation of the flex_attention Kernel"""
|
||||
"""Triton Implementation of the flex_attention Kernel"""
|
||||
|
||||
import copy
|
||||
import logging
|
||||
|
|
@ -60,9 +60,9 @@ def construct_strides(
|
|||
) -> Sequence[int]:
|
||||
"""From a list of sizes and a fill order, construct the strides of the permuted tensor."""
|
||||
# Initialize strides
|
||||
assert len(sizes) == len(
|
||||
fill_order
|
||||
), "Length of sizes must match the length of the fill order"
|
||||
assert len(sizes) == len(fill_order), (
|
||||
"Length of sizes must match the length of the fill order"
|
||||
)
|
||||
strides = [0] * len(sizes)
|
||||
|
||||
# Start with stride 1 for the innermost dimension
|
||||
|
|
@ -1151,10 +1151,14 @@ def lower_cpu(
|
|||
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
|
||||
assert V.graph.sizevars.evaluate_expr(
|
||||
sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE))
|
||||
), "Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask."
|
||||
), (
|
||||
"Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask."
|
||||
)
|
||||
assert V.graph.sizevars.evaluate_expr(
|
||||
sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE))
|
||||
), "KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask."
|
||||
), (
|
||||
"KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask."
|
||||
)
|
||||
CppFlexAttentionTemplate.add_choices(
|
||||
choices=_choices,
|
||||
input_nodes=input_nodes,
|
||||
|
|
@ -1364,15 +1368,15 @@ def flex_attention(
|
|||
|
||||
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
|
||||
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
|
||||
assert V.graph.sizevars.evaluate_expr(
|
||||
sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)
|
||||
), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
|
||||
assert V.graph.sizevars.evaluate_expr(
|
||||
sympy.Gt(seq_len_q, 0)
|
||||
), "Query length must be greater than 0"
|
||||
assert V.graph.sizevars.evaluate_expr(
|
||||
sympy.Gt(seq_len_kv, 0)
|
||||
), "Key length must be greater than 0"
|
||||
assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), (
|
||||
f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
|
||||
)
|
||||
assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_q, 0)), (
|
||||
"Query length must be greater than 0"
|
||||
)
|
||||
assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_kv, 0)), (
|
||||
"Key length must be greater than 0"
|
||||
)
|
||||
|
||||
B = Bq
|
||||
|
||||
|
|
@ -2291,9 +2295,9 @@ def process_joint_outputs(
|
|||
JointOutputResult containing processed buffers and gradients
|
||||
"""
|
||||
assert isinstance(all_joint_outputs, list)
|
||||
assert (
|
||||
all_joint_outputs[0] is not None
|
||||
), "joint_subgraph_buffer is None - this is a bug!"
|
||||
assert all_joint_outputs[0] is not None, (
|
||||
"joint_subgraph_buffer is None - this is a bug!"
|
||||
)
|
||||
|
||||
joint_buffer = all_joint_outputs[0]
|
||||
other_grads = all_joint_outputs[num_placeholders - 1 :]
|
||||
|
|
@ -2392,9 +2396,9 @@ def flex_attention_backward(*args, **kwargs):
|
|||
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
|
||||
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
|
||||
|
||||
assert V.graph.sizevars.evaluate_expr(
|
||||
sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)
|
||||
), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
|
||||
assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), (
|
||||
f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
|
||||
)
|
||||
|
||||
kernel_options = dict(kernel_options)
|
||||
# Mark symbols in custom kernel options as static shapes and add guards.
|
||||
|
|
@ -2639,9 +2643,11 @@ def flex_attention_backward(*args, **kwargs):
|
|||
grad_key = broadcasted_grad_key
|
||||
grad_value = broadcasted_grad_value
|
||||
else:
|
||||
assert V.graph.sizevars.evaluate_expr(
|
||||
sympy.Gt(Bq, 1) & sympy.Eq(Bkv, 1)
|
||||
), f"Bq and Bkv must broadcastable. Got Bq={V.graph.sizevars.evaluate_expr(Bq)} and Bkv={V.graph.sizevars.evaluate_expr(Bkv)}" # noqa: B950
|
||||
assert V.graph.sizevars.evaluate_expr(sympy.Gt(Bq, 1) & sympy.Eq(Bkv, 1)), (
|
||||
f"Bq and Bkv must broadcastable. "
|
||||
f"Got Bq={V.graph.sizevars.evaluate_expr(Bq)} "
|
||||
f"and Bkv={V.graph.sizevars.evaluate_expr(Bkv)}"
|
||||
)
|
||||
grad_key = lowerings[aten.sum](broadcasted_grad_key, axis=0, keepdims=True)
|
||||
grad_value = lowerings[aten.sum](broadcasted_grad_value, axis=0, keepdims=True)
|
||||
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
# mypy: allow-untyped-defs
|
||||
""" Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)"""
|
||||
"""Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import sympy
|
||||
|
|
@ -367,9 +368,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
|||
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
|
||||
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
|
||||
|
||||
assert V.graph.sizevars.evaluate_expr(
|
||||
sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)
|
||||
), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
|
||||
assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), (
|
||||
f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
|
||||
)
|
||||
|
||||
B = Bq
|
||||
kernel_options = dict(kernel_options)
|
||||
|
|
@ -481,7 +482,8 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
|||
max(
|
||||
next_power_of_2(
|
||||
V.graph.sizevars.size_hint(
|
||||
seq_len_q, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
||||
seq_len_q,
|
||||
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||
)
|
||||
* gqa_shared_heads
|
||||
),
|
||||
|
|
|
|||
|
|
@ -65,7 +65,8 @@ def filtered_configs(
|
|||
m = max(
|
||||
next_power_of_2(
|
||||
V.graph.sizevars.size_hint(
|
||||
m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
||||
m,
|
||||
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||
)
|
||||
),
|
||||
min_block_size,
|
||||
|
|
@ -73,7 +74,8 @@ def filtered_configs(
|
|||
n = max(
|
||||
next_power_of_2(
|
||||
V.graph.sizevars.size_hint(
|
||||
n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
||||
n,
|
||||
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||
)
|
||||
),
|
||||
min_block_size,
|
||||
|
|
@ -81,7 +83,8 @@ def filtered_configs(
|
|||
k = max(
|
||||
next_power_of_2(
|
||||
V.graph.sizevars.size_hint(
|
||||
k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
||||
k,
|
||||
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||
)
|
||||
),
|
||||
min_block_size_k,
|
||||
|
|
@ -467,8 +470,7 @@ def mm_options(config, sym_m, sym_n, sym_k, layout):
|
|||
"""
|
||||
even_k_symbolic = (
|
||||
# it isn't worth guarding on this
|
||||
sympy.gcd(sym_k, config.kwargs["BLOCK_K"])
|
||||
== config.kwargs["BLOCK_K"]
|
||||
sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"]
|
||||
)
|
||||
allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
|
||||
not inductor_config.force_same_precision
|
||||
|
|
|
|||
|
|
@ -194,11 +194,12 @@ class LoopBody:
|
|||
# There is indeed an issue due to symbol name conflicting.
|
||||
# y0 maybe reused for the y dimension later.
|
||||
(
|
||||
iter_vars,
|
||||
reduce_vars,
|
||||
), var_ranges = dependencies.index_vars_no_squeeze(
|
||||
iter_sizes, reduce_sizes, prefix="t"
|
||||
)
|
||||
(
|
||||
iter_vars,
|
||||
reduce_vars,
|
||||
),
|
||||
var_ranges,
|
||||
) = dependencies.index_vars_no_squeeze(iter_sizes, reduce_sizes, prefix="t")
|
||||
new_body = LoopBody(
|
||||
old_body,
|
||||
[iter_reindex(iter_vars), reduce_reindex(reduce_vars)],
|
||||
|
|
@ -234,7 +235,8 @@ class LoopBody:
|
|||
new_sizes = (new_iter_size, reduce_size)
|
||||
|
||||
(iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
|
||||
*new_sizes, prefix="t" # type: ignore[arg-type]
|
||||
*new_sizes,
|
||||
prefix="t", # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
inverse_order = {b: a for a, b in enumerate(new_order)}
|
||||
|
|
@ -254,7 +256,8 @@ class LoopBody:
|
|||
|
||||
# use the original symbol prefix so we can do multiple round of reordering
|
||||
(iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
|
||||
*new_sizes, prefix="p" # type: ignore[arg-type]
|
||||
*new_sizes,
|
||||
prefix="p", # type: ignore[arg-type]
|
||||
)
|
||||
new_body = LoopBody(
|
||||
loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2
|
||||
|
|
@ -385,9 +388,9 @@ class LoopBody:
|
|||
def indexing_from_args(self, indices):
|
||||
index = [*itertools.chain.from_iterable(indices)]
|
||||
assert len(index) == len(self.var_ranges), (index, self.var_ranges)
|
||||
assert all(
|
||||
v not in self.var_ranges for v in index
|
||||
), f"{self.var_ranges=}, {indices=}"
|
||||
assert all(v not in self.var_ranges for v in index), (
|
||||
f"{self.var_ranges=}, {indices=}"
|
||||
)
|
||||
replacements = dict(zip(self.var_ranges.keys(), index))
|
||||
return {
|
||||
name: sympy_subs(expr, replacements)
|
||||
|
|
|
|||
|
|
@ -346,7 +346,8 @@ def transform_args(
|
|||
# only consider tensor kwargs for promotion, for now
|
||||
promoting_args.extend(a for a in kwargs.values() if hasattr(a, "dtype"))
|
||||
dtype = get_promoted_dtype(
|
||||
*promoting_args, type_promotion_kind=type_promotion_kind # type: ignore[arg-type]
|
||||
*promoting_args,
|
||||
type_promotion_kind=type_promotion_kind, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
device = (
|
||||
|
|
@ -448,9 +449,9 @@ def _register_lowering(
|
|||
(fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn
|
||||
):
|
||||
# explicitly assert for "out=" ops for better error messages
|
||||
assert not any(
|
||||
x == "out" for x in kwargs.keys()
|
||||
), "out= ops aren't yet supported"
|
||||
assert not any(x == "out" for x in kwargs.keys()), (
|
||||
"out= ops aren't yet supported"
|
||||
)
|
||||
|
||||
args, kwargs = transform_args(
|
||||
args, kwargs, broadcast, type_promotion_kind, convert_input_to_bool
|
||||
|
|
@ -517,9 +518,9 @@ def broadcast_symbolic_shapes(a, b):
|
|||
|
||||
|
||||
def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None):
|
||||
assert (
|
||||
override_return_dtype is None or type_promotion_kind is None
|
||||
), "only one of override_return_dtype or type_promotion_kind may be given"
|
||||
assert override_return_dtype is None or type_promotion_kind is None, (
|
||||
"only one of override_return_dtype or type_promotion_kind may be given"
|
||||
)
|
||||
|
||||
if override_return_dtype is None and type_promotion_kind is None:
|
||||
type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
||||
|
|
@ -674,9 +675,9 @@ def make_foreach_pointwise(pw_fn, allow_alpha=False):
|
|||
if isinstance(input, (list, tuple)):
|
||||
a_list_input = input
|
||||
break
|
||||
assert (
|
||||
a_list_input is not None
|
||||
), "at least one input must be a list to a foreach op"
|
||||
assert a_list_input is not None, (
|
||||
"at least one input must be a list to a foreach op"
|
||||
)
|
||||
|
||||
# broadcast scalar inputs to match length of list inputs
|
||||
broadcast_inputs = []
|
||||
|
|
@ -1321,12 +1322,12 @@ def quantized_decomposed_quantize_per_channel(
|
|||
|
||||
if input.get_dtype() == torch.bfloat16:
|
||||
input = to_dtype(input, torch.float32)
|
||||
assert (
|
||||
input.get_dtype() == torch.float32
|
||||
), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
|
||||
assert axis < len(
|
||||
input.get_size()
|
||||
), f"Expecting axis to be < {len(input.get_size())}"
|
||||
assert input.get_dtype() == torch.float32, (
|
||||
f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
|
||||
)
|
||||
assert axis < len(input.get_size()), (
|
||||
f"Expecting axis to be < {len(input.get_size())}"
|
||||
)
|
||||
|
||||
input_loader = input.make_loader()
|
||||
scales_loader = scales.make_loader()
|
||||
|
|
@ -1373,12 +1374,12 @@ def quantized_decomposed_dequantize_per_channel(
|
|||
) -> TensorBox:
|
||||
assert len(scales.get_size()) == 1, "expect scales 1 dim"
|
||||
assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim"
|
||||
assert (
|
||||
input.get_dtype() == dtype
|
||||
), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
|
||||
assert axis < len(
|
||||
input.get_size()
|
||||
), f"Expecting axis to be < {len(input.get_size())}"
|
||||
assert input.get_dtype() == dtype, (
|
||||
f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
|
||||
)
|
||||
assert axis < len(input.get_size()), (
|
||||
f"Expecting axis to be < {len(input.get_size())}"
|
||||
)
|
||||
|
||||
if out_dtype is None:
|
||||
out_dtype = torch.float32
|
||||
|
|
@ -1423,9 +1424,9 @@ def quantized_decomposed_quantize_per_tensor_default(
|
|||
) -> TensorBox:
|
||||
if input.get_dtype() == torch.bfloat16:
|
||||
input = to_dtype(input, torch.float32)
|
||||
assert (
|
||||
input.get_dtype() == torch.float32
|
||||
), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
|
||||
assert input.get_dtype() == torch.float32, (
|
||||
f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
|
||||
)
|
||||
|
||||
input_loader = input.make_loader()
|
||||
|
||||
|
|
@ -1462,9 +1463,9 @@ def quantized_decomposed_dequantize_per_tensor_default(
|
|||
*,
|
||||
out_dtype: Optional[torch.dtype] = None,
|
||||
) -> TensorBox:
|
||||
assert (
|
||||
input.get_dtype() == dtype
|
||||
), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
|
||||
assert input.get_dtype() == dtype, (
|
||||
f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
|
||||
)
|
||||
|
||||
if out_dtype is None:
|
||||
out_dtype = torch.float32
|
||||
|
|
@ -1501,9 +1502,9 @@ def quantized_decomposed_quantize_per_tensor_tensor(
|
|||
) -> TensorBox:
|
||||
if input.get_dtype() == torch.bfloat16:
|
||||
input = to_dtype(input, torch.float32)
|
||||
assert (
|
||||
input.get_dtype() == torch.float32
|
||||
), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
|
||||
assert input.get_dtype() == torch.float32, (
|
||||
f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
|
||||
)
|
||||
assert len(scale.get_size()) == 0 or (
|
||||
len(scale.get_size()) == 1 and scale.get_size()[0] == 1
|
||||
), "expect scale as scalar tensor"
|
||||
|
|
@ -1555,9 +1556,9 @@ def quantized_decomposed_dequantize_per_tensor_tensor(
|
|||
assert len(zero_point.get_size()) == 0 or (
|
||||
len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1
|
||||
), "expect zero_point as scalar tensor"
|
||||
assert (
|
||||
input.get_dtype() == dtype
|
||||
), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
|
||||
assert input.get_dtype() == dtype, (
|
||||
f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
|
||||
)
|
||||
|
||||
if out_dtype is None:
|
||||
out_dtype = torch.float32
|
||||
|
|
@ -1973,9 +1974,9 @@ def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=
|
|||
|
||||
|
||||
def make_fallback(op, layout_constraint=None, warn=True, override_decomp=False):
|
||||
assert (
|
||||
op not in decompositions or override_decomp
|
||||
), f"both a fallback and a decomp for same op: {op}"
|
||||
assert op not in decompositions or override_decomp, (
|
||||
f"both a fallback and a decomp for same op: {op}"
|
||||
)
|
||||
if (
|
||||
warn
|
||||
and bool(os.getenv("CI"))
|
||||
|
|
@ -2086,9 +2087,9 @@ def native_dropout(x, p, train):
|
|||
|
||||
@register_lowering(aten.bernoulli_, type_promotion_kind=None)
|
||||
def bernoulli_(x, *args):
|
||||
assert config.fallback_random or x.get_device() == torch.device(
|
||||
"cpu"
|
||||
), "this should be handled in decomps unless config.fallback_random or the device is CPU"
|
||||
assert config.fallback_random or x.get_device() == torch.device("cpu"), (
|
||||
"this should be handled in decomps unless config.fallback_random or the device is CPU"
|
||||
)
|
||||
x.realize()
|
||||
op_overload = (
|
||||
aten.bernoulli_.float
|
||||
|
|
@ -2101,9 +2102,9 @@ def bernoulli_(x, *args):
|
|||
|
||||
@register_lowering(aten.bernoulli.p, type_promotion_kind=None)
|
||||
def bernoulli_p(x, *args):
|
||||
assert config.fallback_random or x.get_device() == torch.device(
|
||||
"cpu"
|
||||
), "this should be handled in decomps unless config.fallback_random or the device is CPU"
|
||||
assert config.fallback_random or x.get_device() == torch.device("cpu"), (
|
||||
"this should be handled in decomps unless config.fallback_random or the device is CPU"
|
||||
)
|
||||
return bernoulli_(clone(x), *args)
|
||||
|
||||
|
||||
|
|
@ -3376,7 +3377,9 @@ def check_and_broadcast_indices(indices, device):
|
|||
i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8)
|
||||
for i in indices
|
||||
if i is not None
|
||||
), f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}"
|
||||
), (
|
||||
f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}"
|
||||
)
|
||||
if any(
|
||||
i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None
|
||||
):
|
||||
|
|
@ -5668,7 +5671,8 @@ def make_reduction(reduction_type: ReductionType, override_return_dtype=None):
|
|||
)
|
||||
result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs)
|
||||
if isinstance(
|
||||
result.data.data, Reduction # type: ignore[attr-defined]
|
||||
result.data.data, # type: ignore[attr-defined]
|
||||
Reduction,
|
||||
): # Only realize if reduction isn't unrolled
|
||||
result.realize()
|
||||
return result
|
||||
|
|
@ -6008,8 +6012,9 @@ def get_constant_value(x: ir.IRNode) -> Optional[ir.Constant]:
|
|||
return None
|
||||
|
||||
handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device())
|
||||
with V.set_ops_handler(handler), patch.object(
|
||||
ir.FlexibleLayout, "allow_indexing", True
|
||||
with (
|
||||
V.set_ops_handler(handler),
|
||||
patch.object(ir.FlexibleLayout, "allow_indexing", True),
|
||||
):
|
||||
out = x.inner_fn(*x.inner_fn_args())
|
||||
|
||||
|
|
@ -6898,9 +6903,9 @@ def force_fallback(op: torch._ops.OpOverload):
|
|||
A context manager to force fallback an op. Used in unit test
|
||||
for FallbackKernel.
|
||||
"""
|
||||
assert isinstance(
|
||||
op, torch._ops.OpOverload
|
||||
), "Only OpOverload to make the clean up easier"
|
||||
assert isinstance(op, torch._ops.OpOverload), (
|
||||
"Only OpOverload to make the clean up easier"
|
||||
)
|
||||
old_handler = lowerings.get(op)
|
||||
try:
|
||||
register_lowering(op)(fallback_handler(op))
|
||||
|
|
|
|||
|
|
@ -35,9 +35,9 @@ class MemoryPlanningInfoForBuffer:
|
|||
class MemoryPlanningInfoForNode:
|
||||
index: int = 0
|
||||
size: int = 0
|
||||
pred_buffers: OrderedSet[
|
||||
Union[SchedulerBuffer, FreeableInputBuffer]
|
||||
] = dataclasses.field(default_factory=OrderedSet)
|
||||
pred_buffers: OrderedSet[Union[SchedulerBuffer, FreeableInputBuffer]] = (
|
||||
dataclasses.field(default_factory=OrderedSet)
|
||||
)
|
||||
pred_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field(
|
||||
default_factory=OrderedSet
|
||||
)
|
||||
|
|
@ -87,9 +87,9 @@ def get_freeable_input_buf(
|
|||
|
||||
# get freeable input buffers' successor nodes and their sizes
|
||||
# note that different deps can have the same name, so we use name as keys
|
||||
dep_name_to_succ_nodes: dict[
|
||||
str, OrderedSet[BaseSchedulerNode]
|
||||
] = collections.defaultdict(OrderedSet)
|
||||
dep_name_to_succ_nodes: dict[str, OrderedSet[BaseSchedulerNode]] = (
|
||||
collections.defaultdict(OrderedSet)
|
||||
)
|
||||
dep_name_to_size: dict[str, int] = dict()
|
||||
for node in nodes:
|
||||
for dep in node.read_writes.reads:
|
||||
|
|
@ -112,7 +112,7 @@ def get_freeable_input_buf(
|
|||
|
||||
|
||||
def compute_size_for_scheduler_buffer(
|
||||
name_to_buf: dict[str, SchedulerBuffer]
|
||||
name_to_buf: dict[str, SchedulerBuffer],
|
||||
) -> dict[str, tuple[int, int]]:
|
||||
"""
|
||||
Compute the size of each scheduler buffer, including (1) memory allocated when
|
||||
|
|
@ -187,9 +187,9 @@ def assign_memory_planning_info_for_scheduler_buffers(
|
|||
|
||||
# get buffer's successor nodes
|
||||
# note that different deps can have the same name, so we use name as keys
|
||||
dep_name_to_succ_nodes: dict[
|
||||
str, OrderedSet[BaseSchedulerNode]
|
||||
] = collections.defaultdict(OrderedSet)
|
||||
dep_name_to_succ_nodes: dict[str, OrderedSet[BaseSchedulerNode]] = (
|
||||
collections.defaultdict(OrderedSet)
|
||||
)
|
||||
for node in nodes:
|
||||
for dep in node.unmet_dependencies:
|
||||
dep_name_to_succ_nodes[dep.name].add(node)
|
||||
|
|
|
|||
|
|
@ -138,12 +138,12 @@ class MetricTable:
|
|||
return
|
||||
|
||||
row_dict = row_fn()
|
||||
assert len(self.column_names) == len(
|
||||
row_dict
|
||||
), f"{len(self.column_names)} v.s. {len(row_dict)}"
|
||||
assert OrderedSet(self.column_names) == OrderedSet(
|
||||
row_dict.keys()
|
||||
), f"{OrderedSet(self.column_names)} v.s. {OrderedSet(row_dict.keys())}"
|
||||
assert len(self.column_names) == len(row_dict), (
|
||||
f"{len(self.column_names)} v.s. {len(row_dict)}"
|
||||
)
|
||||
assert OrderedSet(self.column_names) == OrderedSet(row_dict.keys()), (
|
||||
f"{OrderedSet(self.column_names)} v.s. {OrderedSet(row_dict.keys())}"
|
||||
)
|
||||
|
||||
bn = get_benchmark_name()
|
||||
# assert bn is not None
|
||||
|
|
@ -433,9 +433,9 @@ def enabled_metric_tables_impl(config_str: str) -> OrderedSet[str]:
|
|||
name = name.strip()
|
||||
if not name:
|
||||
continue
|
||||
assert (
|
||||
name in REGISTERED_METRIC_TABLES
|
||||
), f"Metric table name {name} is not registered"
|
||||
assert name in REGISTERED_METRIC_TABLES, (
|
||||
f"Metric table name {name} is not registered"
|
||||
)
|
||||
enabled.add(name)
|
||||
return enabled
|
||||
|
||||
|
|
|
|||
|
|
@ -751,9 +751,9 @@ class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
|
|||
unary_algorithm,
|
||||
]
|
||||
|
||||
assert (
|
||||
binary_attr == "sum"
|
||||
), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E."
|
||||
assert binary_attr == "sum", (
|
||||
"For now, only post op sum is supported in QConvPointWiseBinaryPT2E."
|
||||
)
|
||||
|
||||
V.graph.mark_buffer_mutated(qaccum.get_name())
|
||||
packed = QConvPointWiseBinaryPT2E(
|
||||
|
|
|
|||
|
|
@ -575,9 +575,9 @@ def register_onednn_fusion_ops():
|
|||
algorithm,
|
||||
layout=None,
|
||||
):
|
||||
assert (
|
||||
packed_weight.get_dtype() is torch.int8
|
||||
), "Only int8 weights are supported by oneDNN qlinear."
|
||||
assert packed_weight.get_dtype() is torch.int8, (
|
||||
"Only int8 weights are supported by oneDNN qlinear."
|
||||
)
|
||||
x_size = x.get_size()
|
||||
if len(x_size) > 2:
|
||||
# GEMM template needs 2D input, normalize input shape here
|
||||
|
|
@ -928,9 +928,9 @@ def register_onednn_fusion_ops():
|
|||
# we will do accum dtype convertion here.
|
||||
x2 = to_dtype(x2, output_dtype)
|
||||
else:
|
||||
assert (
|
||||
x2.get_dtype() == output_dtype
|
||||
), "dtype of accum for qlinear post op sum should be the same as output"
|
||||
assert x2.get_dtype() == output_dtype, (
|
||||
"dtype of accum for qlinear post op sum should be the same as output"
|
||||
)
|
||||
x2_dtype = x2.get_dtype()
|
||||
bias_dtype = bias.get_dtype() if bias is not None else None
|
||||
choices: list[ChoiceCaller] = []
|
||||
|
|
|
|||
|
|
@ -806,8 +806,8 @@ class DefaultHandler(OpsHandler[Any]):
|
|||
assert self_arg == "self"
|
||||
code.write(
|
||||
f"""
|
||||
def {target}(self, {', '.join(args)}):
|
||||
return self._default({target!r}, ({', '.join(args)}, ), {{}})
|
||||
def {target}(self, {", ".join(args)}):
|
||||
return self._default({target!r}, ({", ".join(args)}, ), {{}})
|
||||
""".strip()
|
||||
)
|
||||
code.write("\n\n")
|
||||
|
|
@ -994,8 +994,9 @@ class KernelFormatterHandler(DefaultHandler):
|
|||
)
|
||||
formatter._output.writeline(f"{lhs} = {name}")
|
||||
|
||||
with V.set_ops_handler(formatter), patch.object(
|
||||
FlexibleLayout, "allow_indexing", True
|
||||
with (
|
||||
V.set_ops_handler(formatter),
|
||||
patch.object(FlexibleLayout, "allow_indexing", True),
|
||||
):
|
||||
result = ir_fn(*args)
|
||||
return formatter.getvalue(result)
|
||||
|
|
|
|||
|
|
@ -188,7 +188,9 @@ def package_aoti(
|
|||
) or (
|
||||
isinstance(archive_file, (str, os.PathLike))
|
||||
and os.fspath(archive_file).endswith(".pt2")
|
||||
), f"Expect archive file to be a file ending in .pt2, or is a buffer. Instead got {archive_file}"
|
||||
), (
|
||||
f"Expect archive file to be a file ending in .pt2, or is a buffer. Instead got {archive_file}"
|
||||
)
|
||||
|
||||
# Save using the PT2 packaging format
|
||||
# (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a)
|
||||
|
|
@ -285,9 +287,9 @@ class AOTICompiledModel:
|
|||
def load_package(path: FileLike, model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg]
|
||||
assert (
|
||||
isinstance(path, (io.IOBase, IO)) and path.readable() and path.seekable()
|
||||
) or (
|
||||
isinstance(path, (str, os.PathLike)) and os.fspath(path).endswith(".pt2")
|
||||
), f"Unable to load package. Path must be a buffer or a file ending in .pt2. Instead got {path}"
|
||||
) or (isinstance(path, (str, os.PathLike)) and os.fspath(path).endswith(".pt2")), (
|
||||
f"Unable to load package. Path must be a buffer or a file ending in .pt2. Instead got {path}"
|
||||
)
|
||||
|
||||
if isinstance(path, (io.IOBase, IO)):
|
||||
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
||||
|
|
|
|||
|
|
@ -90,20 +90,17 @@ NodeOrConstant = Union[Constant, torch.fx.Node]
|
|||
class SearchFn(Protocol):
|
||||
__name__: str
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
...
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
class ReplaceFn(Protocol):
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
...
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||
|
||||
|
||||
class TraceFn(Protocol):
|
||||
def __call__(
|
||||
self, fn: Union[SearchFn, ReplaceFn], *args: Any, **kwargs: Any
|
||||
) -> torch.fx.GraphModule:
|
||||
...
|
||||
) -> torch.fx.GraphModule: ...
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
|
@ -365,8 +362,7 @@ class PatternExpr(ABC):
|
|||
"""
|
||||
|
||||
@abstractmethod
|
||||
def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
|
||||
...
|
||||
def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: ...
|
||||
|
||||
def match(self, node: torch.fx.Node) -> MatchResult:
|
||||
try:
|
||||
|
|
@ -489,8 +485,7 @@ class _TargetExpr(PatternExpr):
|
|||
|
||||
@property
|
||||
@abstractmethod
|
||||
def op(self) -> str:
|
||||
...
|
||||
def op(self) -> str: ...
|
||||
|
||||
def fns_repr(self) -> str:
|
||||
first_repr = self.fns[0]
|
||||
|
|
@ -997,8 +992,9 @@ class PatternPrettyPrinter:
|
|||
|
||||
|
||||
class _PassDictsType(Protocol):
|
||||
def __getitem__(self, k: tuple[str, torch.fx.node.Target]) -> list[PatternEntry]:
|
||||
...
|
||||
def __getitem__(
|
||||
self, k: tuple[str, torch.fx.node.Target]
|
||||
) -> list[PatternEntry]: ...
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
|
|
@ -1925,7 +1921,10 @@ def fx_to_pattern(
|
|||
get_attr = _not_implemented
|
||||
|
||||
def placeholder(
|
||||
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
|
||||
self,
|
||||
target: str, # type: ignore[override]
|
||||
args: Sequence[Any],
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> Union[ExclusiveKeywordArg, KeywordArg]:
|
||||
n = next(argnum)
|
||||
if n < len(argnames):
|
||||
|
|
@ -1942,7 +1941,10 @@ def fx_to_pattern(
|
|||
return KeywordArg(name)
|
||||
|
||||
def call_function(
|
||||
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
|
||||
self,
|
||||
target: str, # type: ignore[override]
|
||||
args: Sequence[Any],
|
||||
kwargs: Mapping[str, Any],
|
||||
) -> PatternExpr:
|
||||
process_arg_fn = process_arg
|
||||
# Indexing is critical for matching getitem nodes, so we can't ignore int args here
|
||||
|
|
|
|||
|
|
@ -24,7 +24,7 @@ T = TypeVar("T")
|
|||
|
||||
|
||||
def time_and_count(
|
||||
fn: Callable[Concatenate[Any, P], T]
|
||||
fn: Callable[Concatenate[Any, P], T],
|
||||
) -> Callable[Concatenate[Any, P], T]:
|
||||
"""Wraps `fn` with `dynamo_timed` context, and increments the appropriate dynamo
|
||||
counters. It is expected that `fn` is a method of `Benchmarker` or one of its
|
||||
|
|
|
|||
|
|
@ -77,9 +77,9 @@ def validate_triton_config(cfg: Config) -> None:
|
|||
# right now, if a pre-hook is attached to the config, it will not be saved;
|
||||
# and then it won't be used when the config is loaded from cache.
|
||||
# So we assert - if we do get a pre_hook, it might get ignored after caching.
|
||||
assert (
|
||||
getattr(cfg, "pre_hook", None) is None
|
||||
), "triton configs with pre_hooks not supported"
|
||||
assert getattr(cfg, "pre_hook", None) is None, (
|
||||
"triton configs with pre_hooks not supported"
|
||||
)
|
||||
|
||||
|
||||
def create_bandwidth_info_str(
|
||||
|
|
|
|||
|
|
@ -450,9 +450,9 @@ class CachingAutotuner(KernelInterface):
|
|||
self.launchers = []
|
||||
|
||||
def __getstate__(self) -> dict[str, Any]:
|
||||
assert (
|
||||
not self.launchers
|
||||
), "pickle should not be called with after make_launchers()"
|
||||
assert not self.launchers, (
|
||||
"pickle should not be called with after make_launchers()"
|
||||
)
|
||||
return {
|
||||
**self.__dict__,
|
||||
"lock": None,
|
||||
|
|
@ -678,7 +678,9 @@ class CachingAutotuner(KernelInterface):
|
|||
assert isinstance(
|
||||
arg,
|
||||
torch.Tensor,
|
||||
), "self.reset_to_zero_arg_names should only contain valid argument names"
|
||||
), (
|
||||
"self.reset_to_zero_arg_names should only contain valid argument names"
|
||||
)
|
||||
arg.zero_()
|
||||
|
||||
for name, arg in kwargs.items():
|
||||
|
|
@ -686,7 +688,9 @@ class CachingAutotuner(KernelInterface):
|
|||
assert isinstance(
|
||||
arg,
|
||||
torch.Tensor,
|
||||
), "self.reset_to_zero_arg_names should only contain valid argument names"
|
||||
), (
|
||||
"self.reset_to_zero_arg_names should only contain valid argument names"
|
||||
)
|
||||
arg.zero_()
|
||||
|
||||
def maybe_clone_args(
|
||||
|
|
@ -866,7 +870,9 @@ class CachingAutotuner(KernelInterface):
|
|||
assert not (
|
||||
self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION
|
||||
and "R0_BLOCK" in launcher.config.kwargs
|
||||
), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK"
|
||||
), (
|
||||
"Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK"
|
||||
)
|
||||
start_time = time.time_ns()
|
||||
best_config = self.coordesc_tuner.autotune(
|
||||
benchmark_one_config, launcher.config, None
|
||||
|
|
@ -882,9 +888,7 @@ class CachingAutotuner(KernelInterface):
|
|||
)
|
||||
return config2launcher.get(best_config)
|
||||
|
||||
def run(
|
||||
self, *args, grid, stream, benchmark_run=False, **kwargs
|
||||
): # type:ignore[override]
|
||||
def run(self, *args, grid, stream, benchmark_run=False, **kwargs): # type:ignore[override]
|
||||
if self.triton_interpret:
|
||||
return self.fn[grid](
|
||||
*args,
|
||||
|
|
@ -1192,12 +1196,12 @@ class TritonCompileResult:
|
|||
|
||||
exec(
|
||||
f"""
|
||||
def launcher({', '.join(def_args)}, grid, stream):
|
||||
def launcher({", ".join(def_args)}, grid, stream):
|
||||
if callable(grid):
|
||||
grid_0, grid_1, grid_2 = grid(grid_meta)
|
||||
else:
|
||||
grid_0, grid_1, grid_2 = grid
|
||||
runner({', '.join(runner_args)})
|
||||
runner({", ".join(runner_args)})
|
||||
return bin
|
||||
""".lstrip(),
|
||||
scope,
|
||||
|
|
@ -1503,9 +1507,9 @@ def check_max_block(cfg: dict[str, int]):
|
|||
if block_suffix in var:
|
||||
prefix = var.removesuffix(block_suffix)
|
||||
max_block = TRITON_MAX_BLOCK[prefix]
|
||||
assert (
|
||||
val <= max_block
|
||||
), f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
|
||||
assert val <= max_block, (
|
||||
f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
|
||||
)
|
||||
|
||||
|
||||
def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False):
|
||||
|
|
@ -1657,20 +1661,20 @@ def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, in
|
|||
prefix = f"r{idx}_"
|
||||
max_size = min(size_hints[prefix], TRITON_MAX_BLOCK[prefix.upper()])
|
||||
dim = min(max_size, remaining)
|
||||
assert (
|
||||
remaining % dim == 0
|
||||
), f"Expected dimension '{dim}' to divide remaining size '{remaining}'"
|
||||
assert remaining % dim == 0, (
|
||||
f"Expected dimension '{dim}' to divide remaining size '{remaining}'"
|
||||
)
|
||||
rnumels[prefix] = dim
|
||||
remaining //= dim
|
||||
|
||||
# Sanity check the results.
|
||||
final_numel = conditional_product(*rnumels.values())
|
||||
assert (
|
||||
r == final_numel
|
||||
), f"Expected ND reduction size ({rnumels}) to have {r} elements."
|
||||
assert all(
|
||||
rnumels[prefix] <= size_hints[prefix] for prefix in rnumels
|
||||
), f"rnumels exceed size_hints. {rnumels} > {size_hints}"
|
||||
assert r == final_numel, (
|
||||
f"Expected ND reduction size ({rnumels}) to have {r} elements."
|
||||
)
|
||||
assert all(rnumels[prefix] <= size_hints[prefix] for prefix in rnumels), (
|
||||
f"rnumels exceed size_hints. {rnumels} > {size_hints}"
|
||||
)
|
||||
|
||||
return rnumels
|
||||
|
||||
|
|
@ -1967,9 +1971,9 @@ def cooperative_reduction(
|
|||
size_hints["x"] = 1
|
||||
|
||||
# Cooperative reductions currently only support a single reduction dimension.
|
||||
assert (
|
||||
len(size_hints) == 2
|
||||
), "Cooperative reductions don't support tiling reduction dims"
|
||||
assert len(size_hints) == 2, (
|
||||
"Cooperative reductions don't support tiling reduction dims"
|
||||
)
|
||||
xnumel, rnumel = size_hints["x"], size_hints["r0_"]
|
||||
|
||||
# TODO(jansel): we should base target on the SM count of the local GPU
|
||||
|
|
@ -2274,9 +2278,9 @@ def grid_combo_kernels(
|
|||
assert min_blocks_d is not None
|
||||
min_blocks = min_blocks_d
|
||||
else:
|
||||
assert (
|
||||
min_blocks_d is None or min_blocks == min_blocks_d
|
||||
), f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}"
|
||||
assert min_blocks_d is None or min_blocks == min_blocks_d, (
|
||||
f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}"
|
||||
)
|
||||
else:
|
||||
# sequential dispatch
|
||||
seq_numels = list(numels)
|
||||
|
|
|
|||
|
|
@ -200,9 +200,9 @@ class BaseSchedulerNode:
|
|||
|
||||
def __init__(self, scheduler: Scheduler) -> None:
|
||||
self.scheduler: Scheduler = scheduler
|
||||
self.debug_device_str: Callable[
|
||||
[BaseSchedulerNode], list[str]
|
||||
] = lambda *args, **kwargs: []
|
||||
self.debug_device_str: Callable[[BaseSchedulerNode], list[str]] = (
|
||||
lambda *args, **kwargs: []
|
||||
)
|
||||
|
||||
def _init_from_node(self, node: ir.Operation) -> None:
|
||||
self.node: Optional[ir.Operation] = node
|
||||
|
|
@ -232,7 +232,7 @@ class BaseSchedulerNode:
|
|||
buf = IndentedBuffer()
|
||||
buf.splice(
|
||||
f"""\
|
||||
{name}: {type(self).__name__}({type(getattr(self, 'node', None)).__name__})
|
||||
{name}: {type(self).__name__}({type(getattr(self, "node", None)).__name__})
|
||||
{name}.writes = {pformat(self.read_writes.writes)}
|
||||
{name}.unmet_dependencies = {pformat(self.unmet_dependencies)}
|
||||
{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}
|
||||
|
|
@ -525,9 +525,9 @@ class BaseSchedulerNode:
|
|||
V.kernel.mutations.add(input_buf.get_name())
|
||||
V.kernel.mutations.add(buf.get_name())
|
||||
|
||||
V.kernel.inplace_update_buffers[
|
||||
buf.get_name()
|
||||
] = input_buf.get_name()
|
||||
V.kernel.inplace_update_buffers[buf.get_name()] = (
|
||||
input_buf.get_name()
|
||||
)
|
||||
break
|
||||
|
||||
def codegen_originating_info(
|
||||
|
|
@ -693,7 +693,7 @@ class BaseSchedulerNode:
|
|||
continue
|
||||
|
||||
def get_buf_bytes(
|
||||
buf: Optional[Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject]]
|
||||
buf: Optional[Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject]],
|
||||
) -> int:
|
||||
if not buf:
|
||||
return 0
|
||||
|
|
@ -794,12 +794,11 @@ class BaseSchedulerNode:
|
|||
# runtime for that today
|
||||
return 0
|
||||
|
||||
with FakeTensorMode() as fake_mode, FlopCounterMode(
|
||||
display=False
|
||||
) as flop_counter_mode, V.set_current_node(
|
||||
self.node.fx_node
|
||||
), V.set_fake_mode(
|
||||
fake_mode
|
||||
with (
|
||||
FakeTensorMode() as fake_mode,
|
||||
FlopCounterMode(display=False) as flop_counter_mode,
|
||||
V.set_current_node(self.node.fx_node),
|
||||
V.set_fake_mode(fake_mode),
|
||||
):
|
||||
from .ir import ir_node_to_tensor
|
||||
|
||||
|
|
@ -1123,15 +1122,15 @@ class SchedulerNode(BaseSchedulerNode):
|
|||
return self._sizes
|
||||
|
||||
def is_reduction(self) -> bool:
|
||||
assert isinstance(
|
||||
self.node, (ir.ComputedBuffer, ir.TemplateBuffer)
|
||||
), f"{type(self.node)=}"
|
||||
assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)), (
|
||||
f"{type(self.node)=}"
|
||||
)
|
||||
return bool(self.node.get_reduction_type())
|
||||
|
||||
def is_split_scan(self) -> bool:
|
||||
assert isinstance(
|
||||
self.node, (ir.ComputedBuffer, ir.TemplateBuffer)
|
||||
), f"{type(self.node)=}"
|
||||
assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)), (
|
||||
f"{type(self.node)=}"
|
||||
)
|
||||
return isinstance(self.node, ir.ComputedBuffer) and isinstance(
|
||||
self.node.data, ir.SplitScan
|
||||
)
|
||||
|
|
@ -1163,9 +1162,10 @@ class SchedulerNode(BaseSchedulerNode):
|
|||
def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None:
|
||||
var_ranges = self.ranges_from_index_vars(index_vars)
|
||||
try:
|
||||
with V.set_ops_handler(
|
||||
SimplifyIndexing(V.get_ops_handler(), var_ranges)
|
||||
), V.kernel.set_current_node(self):
|
||||
with (
|
||||
V.set_ops_handler(SimplifyIndexing(V.get_ops_handler(), var_ranges)),
|
||||
V.kernel.set_current_node(self),
|
||||
):
|
||||
self._body(*index_vars)
|
||||
except Exception:
|
||||
log.fatal("Error in codegen for %s", self.node)
|
||||
|
|
@ -1231,7 +1231,7 @@ class SchedulerNode(BaseSchedulerNode):
|
|||
|
||||
|
||||
def refresh_group_node_dependencies(
|
||||
group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode]
|
||||
group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode],
|
||||
) -> None:
|
||||
snodes = group_snode.snodes
|
||||
group_snode.set_read_writes(
|
||||
|
|
@ -1754,7 +1754,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
|||
|
||||
@staticmethod
|
||||
def set_group_algorithm_for_combo_kernels(
|
||||
custom_group_algorithm: Callable[[Scheduler], list[list[BaseSchedulerNode]]]
|
||||
custom_group_algorithm: Callable[[Scheduler], list[list[BaseSchedulerNode]]],
|
||||
) -> None:
|
||||
ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels = (
|
||||
custom_group_algorithm
|
||||
|
|
@ -1975,9 +1975,9 @@ class Scheduler:
|
|||
for node in self.nodes:
|
||||
node.prune_deps()
|
||||
|
||||
self.name_to_donated_buffer: dict[
|
||||
str, SchedulerDonatedBuffer
|
||||
] = self.get_donated_buffers()
|
||||
self.name_to_donated_buffer: dict[str, SchedulerDonatedBuffer] = (
|
||||
self.get_donated_buffers()
|
||||
)
|
||||
self.name_to_node: dict[str, BaseSchedulerNode] = {
|
||||
n.get_name(): n for n in self.nodes
|
||||
}
|
||||
|
|
@ -2099,9 +2099,9 @@ class Scheduler:
|
|||
node.log_details()
|
||||
|
||||
def create_scheduler_node(self, node: ir.Operation) -> BaseSchedulerNode:
|
||||
assert (
|
||||
node.get_origins() is not None
|
||||
), "All nodes passed to scheduling must have an origin"
|
||||
assert node.get_origins() is not None, (
|
||||
"All nodes passed to scheduling must have an origin"
|
||||
)
|
||||
if node.is_no_op():
|
||||
return NopKernelSchedulerNode(self, node)
|
||||
elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)):
|
||||
|
|
@ -2260,9 +2260,9 @@ class Scheduler:
|
|||
)
|
||||
# if a kernel takes unbacked symints, register dependencies
|
||||
for s in unbacked_symbol_uses:
|
||||
assert (
|
||||
s in unbacked_symbol_to_origin_node
|
||||
), f"{s} not in {unbacked_symbol_to_origin_node}"
|
||||
assert s in unbacked_symbol_to_origin_node, (
|
||||
f"{s} not in {unbacked_symbol_to_origin_node}"
|
||||
)
|
||||
if (r := unbacked_symbol_to_origin_node[s]) is not None:
|
||||
for buf in self.name_to_node[r].get_outputs():
|
||||
node.add_fake_dep(StarDep(buf.get_name()))
|
||||
|
|
@ -2310,9 +2310,9 @@ class Scheduler:
|
|||
for alt_name in buf.get_mutations():
|
||||
self.mutation_renames[rename(alt_name)] = buf.get_name()
|
||||
self.mutation_renames[alt_name] = buf.get_name()
|
||||
self.mutation_real_name[
|
||||
buf.get_name()
|
||||
] = self.mutation_real_name.get(alt_name, alt_name)
|
||||
self.mutation_real_name[buf.get_name()] = (
|
||||
self.mutation_real_name.get(alt_name, alt_name)
|
||||
)
|
||||
|
||||
# make sure outputs aren't dead-code-eliminated
|
||||
for buf_name in V.graph.get_output_names():
|
||||
|
|
@ -2322,9 +2322,9 @@ class Scheduler:
|
|||
# make sure unbacked symints aren't dead-code-eliminated
|
||||
for out in V.graph.graph_outputs:
|
||||
for s in out.get_unbacked_symbol_uses():
|
||||
assert (
|
||||
s in unbacked_symbol_to_origin_node
|
||||
), f"{s} not in {unbacked_symbol_to_origin_node.keys()}"
|
||||
assert s in unbacked_symbol_to_origin_node, (
|
||||
f"{s} not in {unbacked_symbol_to_origin_node.keys()}"
|
||||
)
|
||||
if r := unbacked_symbol_to_origin_node[s]:
|
||||
for buf_name in self.name_to_node[r].get_buffer_names():
|
||||
log.debug(
|
||||
|
|
@ -3304,15 +3304,15 @@ class Scheduler:
|
|||
rhs_dep = node2_name2dep[buf_name]
|
||||
|
||||
if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep):
|
||||
reasons[
|
||||
buf_name
|
||||
] = f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}"
|
||||
reasons[buf_name] = (
|
||||
f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}"
|
||||
)
|
||||
continue
|
||||
|
||||
if lhs_dep.get_numel() != rhs_dep.get_numel():
|
||||
reasons[
|
||||
buf_name
|
||||
] = f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}"
|
||||
reasons[buf_name] = (
|
||||
f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}"
|
||||
)
|
||||
continue
|
||||
|
||||
# same numel but different MemoryDep.size. Should be broadcasting
|
||||
|
|
@ -3340,9 +3340,9 @@ class Scheduler:
|
|||
layout_str = ""
|
||||
if not isinstance(buf, ir.TorchBindObject):
|
||||
layout_str = f"Layout: {buf.layout}"
|
||||
reasons[
|
||||
buf_name
|
||||
] = f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. {layout_str}"
|
||||
reasons[buf_name] = (
|
||||
f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. {layout_str}"
|
||||
)
|
||||
|
||||
return str(reasons)
|
||||
|
||||
|
|
@ -3903,9 +3903,9 @@ class Scheduler:
|
|||
self.free_buffers()
|
||||
|
||||
def create_backend(self, device: torch.device) -> BaseScheduling:
|
||||
assert (
|
||||
not is_gpu(device.type) or device.index is not None
|
||||
), f"{device} should have been normalized in lowering"
|
||||
assert not is_gpu(device.type) or device.index is not None, (
|
||||
f"{device} should have been normalized in lowering"
|
||||
)
|
||||
V.graph.add_device_info(device)
|
||||
|
||||
device_scheduling = get_scheduling_for_device(device.type)
|
||||
|
|
@ -4135,9 +4135,9 @@ class Scheduler:
|
|||
partitions, signatures = self.graph_partition()
|
||||
|
||||
for partition, signature in zip(partitions, signatures):
|
||||
assert (
|
||||
len(partition) >= 1
|
||||
), f"Each partition must have at least one node but found {len(partition)}"
|
||||
assert len(partition) >= 1, (
|
||||
f"Each partition must have at least one node but found {len(partition)}"
|
||||
)
|
||||
|
||||
if signature.skip_cudagraph:
|
||||
self._codegen(partition)
|
||||
|
|
|
|||
|
|
@ -168,9 +168,9 @@ class PartialRender:
|
|||
)
|
||||
else:
|
||||
return
|
||||
assert (
|
||||
self.replacement_hooks[hook_key] is not None
|
||||
), "hook_key can only be called once"
|
||||
assert self.replacement_hooks[hook_key] is not None, (
|
||||
"hook_key can only be called once"
|
||||
)
|
||||
self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]())
|
||||
self.replacement_hooks[hook_key] = None
|
||||
|
||||
|
|
@ -257,9 +257,9 @@ class ModificationWrapper(V.WrapperHandler): # type: ignore[name-defined]
|
|||
This is used by flex_attention's backwards grad for captured buffers, see
|
||||
zeros_and_scatter lowering
|
||||
"""
|
||||
assert (
|
||||
self.mask is not None
|
||||
), "Mask is required for inner stores in modifications"
|
||||
assert self.mask is not None, (
|
||||
"Mask is required for inner stores in modifications"
|
||||
)
|
||||
assert mode == "atomic_add", "Only atomic_add is supported for inner stores"
|
||||
|
||||
buf_name = self._add_kernel_input(name)
|
||||
|
|
@ -573,12 +573,12 @@ class TritonTemplateKernel(TritonKernel):
|
|||
def _get_subgraph(self, subgraph_number: int):
|
||||
assert isinstance(subgraph_number, int)
|
||||
assert isinstance(self.subgraphs, list)
|
||||
assert subgraph_number < len(
|
||||
self.subgraphs
|
||||
), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}"
|
||||
assert (
|
||||
self.body.getvalue() == ""
|
||||
), "Body should be clear before adding a modification"
|
||||
assert subgraph_number < len(self.subgraphs), (
|
||||
f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}"
|
||||
)
|
||||
assert self.body.getvalue() == "", (
|
||||
"Body should be clear before adding a modification"
|
||||
)
|
||||
return self.subgraphs[subgraph_number]
|
||||
|
||||
def _handle_scatter_graph(self, scatter_graph):
|
||||
|
|
@ -587,9 +587,9 @@ class TritonTemplateKernel(TritonKernel):
|
|||
Args:
|
||||
scatter_graph: The scatter graph to process
|
||||
"""
|
||||
assert isinstance(
|
||||
scatter_graph, ir.ComputedBuffer
|
||||
), f"scatter_graph must be an instance of ComputeBuffer but got {type(scatter_graph)}"
|
||||
assert isinstance(scatter_graph, ir.ComputedBuffer), (
|
||||
f"scatter_graph must be an instance of ComputeBuffer but got {type(scatter_graph)}"
|
||||
)
|
||||
|
||||
def contiguous_strides(x):
|
||||
# We always create a fresh contiguous grad for scattering into
|
||||
|
|
@ -597,7 +597,9 @@ class TritonTemplateKernel(TritonKernel):
|
|||
x_i * stride for x_i, stride in zip(x, scatter_graph.get_stride())
|
||||
)
|
||||
|
||||
return scatter_graph.data.store_output(scatter_graph.name, contiguous_strides, []) # type: ignore[attr-defined]
|
||||
return scatter_graph.data.store_output( # type: ignore[attr-defined]
|
||||
scatter_graph.name, contiguous_strides, []
|
||||
)
|
||||
|
||||
def modification(
|
||||
self,
|
||||
|
|
@ -626,9 +628,9 @@ class TritonTemplateKernel(TritonKernel):
|
|||
self, subgraph_number, fixed_inputs, mask
|
||||
)
|
||||
with V.set_ops_handler(modification_handler):
|
||||
assert isinstance(
|
||||
subgraph, (ir.ComputedBuffer, list)
|
||||
), f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}"
|
||||
assert isinstance(subgraph, (ir.ComputedBuffer, list)), (
|
||||
f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}"
|
||||
)
|
||||
# Handle scatter stores
|
||||
if isinstance(subgraph, list):
|
||||
for scatter_graph in subgraph:
|
||||
|
|
@ -1123,15 +1125,17 @@ class TritonTemplate(KernelTemplate):
|
|||
"subgraphs": subgraphs,
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
V.graph, "get_dtype", self._fake_get_dtype(fake_out)
|
||||
), V.graph.set_current_device(layout.device), TritonTemplateKernel(
|
||||
kernel_name=kernel_name,
|
||||
output_node=fake_out,
|
||||
workspace_arg=workspace_arg,
|
||||
use_jit=False,
|
||||
**kernel_options,
|
||||
) as kernel:
|
||||
with (
|
||||
patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_out)),
|
||||
V.graph.set_current_device(layout.device),
|
||||
TritonTemplateKernel(
|
||||
kernel_name=kernel_name,
|
||||
output_node=fake_out,
|
||||
workspace_arg=workspace_arg,
|
||||
use_jit=False,
|
||||
**kernel_options,
|
||||
) as kernel,
|
||||
):
|
||||
try:
|
||||
template = kernel.render(self.template, kwargs)
|
||||
with kernel.set_subgraph_body("<STORE_OUTPUT>"):
|
||||
|
|
@ -1442,9 +1446,9 @@ class ExternKernelCaller(ChoiceCaller):
|
|||
|
||||
def output_node(self):
|
||||
if self.choice.use_fallback_kernel:
|
||||
assert (
|
||||
self.choice.op_overload is not None
|
||||
), "Please provide an op_overload to use ir.FallbackKernel"
|
||||
assert self.choice.op_overload is not None, (
|
||||
"Please provide an op_overload to use ir.FallbackKernel"
|
||||
)
|
||||
inner = ir.FallbackKernel.create(
|
||||
self.choice.op_overload, *self.input_nodes, **self.kwargs
|
||||
)
|
||||
|
|
@ -1979,7 +1983,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
input_gen_fns = {}
|
||||
|
||||
def get_inputs(
|
||||
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]]
|
||||
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]],
|
||||
) -> AutotuneArgs:
|
||||
# de-duplicate args
|
||||
unique_example_inputs = {
|
||||
|
|
@ -2099,7 +2103,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
return timings
|
||||
|
||||
def benchmark_in_sub_process(
|
||||
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]]
|
||||
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]],
|
||||
):
|
||||
from . import autotune_process
|
||||
|
||||
|
|
@ -2139,7 +2143,8 @@ class AlgorithmSelectorCache(PersistentCache):
|
|||
map(
|
||||
str,
|
||||
V.graph.sizevars.size_hints(
|
||||
n.get_size(), fallback=config.unbacked_symint_fallback # type: ignore[arg-type]
|
||||
n.get_size(),
|
||||
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||
),
|
||||
)
|
||||
)
|
||||
|
|
@ -2313,15 +2318,15 @@ def autotune_select_algorithm(*args, **kwargs):
|
|||
_ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache()
|
||||
|
||||
if "return_multi_template" not in kwargs:
|
||||
kwargs[
|
||||
"return_multi_template"
|
||||
] = torch._inductor.config.benchmark_epilogue_fusion
|
||||
kwargs["return_multi_template"] = (
|
||||
torch._inductor.config.benchmark_epilogue_fusion
|
||||
)
|
||||
|
||||
return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
|
||||
|
||||
|
||||
def add_feedback_saver(
|
||||
fn: Callable[[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None]
|
||||
fn: Callable[[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None],
|
||||
):
|
||||
global _ALGORITHM_SELECTOR_CACHE
|
||||
if _ALGORITHM_SELECTOR_CACHE is None:
|
||||
|
|
|
|||
|
|
@ -905,9 +905,9 @@ class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined]
|
|||
def __init__(self, inner, var_ranges: VarRanges) -> None:
|
||||
super().__init__(inner)
|
||||
self.name = "SimplifyIndexing"
|
||||
self._simplify: Callable[
|
||||
[Expr], Expr
|
||||
] = lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges)
|
||||
self._simplify: Callable[[Expr], Expr] = (
|
||||
lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges)
|
||||
)
|
||||
|
||||
def load(self, name: str, index: sympy.Expr):
|
||||
return self._inner.load(name, self._simplify(index))
|
||||
|
|
|
|||
|
|
@ -283,9 +283,9 @@ def ceildiv(
|
|||
# TODO: There is a bug in a call to this function, to repro:
|
||||
# python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy
|
||||
# --amp --only YituTechConvBert --dynamic-shapes
|
||||
assert isinstance(numer, int) and isinstance(
|
||||
denom, int
|
||||
), f"{numer}: {type(numer)}, {denom}: {type(denom)}"
|
||||
assert isinstance(numer, int) and isinstance(denom, int), (
|
||||
f"{numer}: {type(numer)}, {denom}: {type(denom)}"
|
||||
)
|
||||
return runtime_ceildiv(numer, denom)
|
||||
|
||||
|
||||
|
|
@ -325,7 +325,7 @@ def _type_of(key: Optional[torch.dtype]) -> str:
|
|||
|
||||
|
||||
def convert_shape_to_inductor(
|
||||
lst: Iterable[Union[int, torch.SymInt]]
|
||||
lst: Iterable[Union[int, torch.SymInt]],
|
||||
) -> list[sympy.Expr]:
|
||||
"""
|
||||
Gets the shape and stride of a tensor. For non-symbolic tensors, this is
|
||||
|
|
@ -502,11 +502,9 @@ RV = TypeVar("RV", covariant=True)
|
|||
|
||||
class CachedMethod(Protocol, Generic[P, RV]):
|
||||
@staticmethod
|
||||
def clear_cache(cache: Any) -> None:
|
||||
...
|
||||
def clear_cache(cache: Any) -> None: ...
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV:
|
||||
...
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: ...
|
||||
|
||||
|
||||
# See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature
|
||||
|
|
@ -1359,9 +1357,9 @@ def _rocm_native_device_arch_name(device: str) -> str:
|
|||
|
||||
|
||||
@functools.lru_cache(None)
|
||||
def try_import_ck_lib() -> (
|
||||
tuple[Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any]]
|
||||
):
|
||||
def try_import_ck_lib() -> tuple[
|
||||
Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any]
|
||||
]:
|
||||
try:
|
||||
import ck4inductor # type: ignore[import]
|
||||
from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import]
|
||||
|
|
@ -1610,9 +1608,12 @@ def get_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> list[str]:
|
|||
|
||||
return DummyModule()
|
||||
|
||||
with mock.patch.object(
|
||||
GraphLowering, "compile_to_module", patched_compile_to_module
|
||||
), mock.patch.object(GraphLowering, "save_output_code", save_output_code):
|
||||
with (
|
||||
mock.patch.object(
|
||||
GraphLowering, "compile_to_module", patched_compile_to_module
|
||||
),
|
||||
mock.patch.object(GraphLowering, "save_output_code", save_output_code),
|
||||
):
|
||||
torch._dynamo.reset()
|
||||
# Note the return here is None
|
||||
_ = fn(*args, **kwargs)
|
||||
|
|
@ -1623,18 +1624,18 @@ def get_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> list[str]:
|
|||
def get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> str:
|
||||
source_codes = get_code(fn, *args, **kwargs)
|
||||
# Can have two outputs if backwards was eagerly compiled
|
||||
assert (
|
||||
1 <= len(source_codes) <= 2
|
||||
), f"expected one or two code outputs got {len(source_codes)}"
|
||||
assert 1 <= len(source_codes) <= 2, (
|
||||
f"expected one or two code outputs got {len(source_codes)}"
|
||||
)
|
||||
return source_codes[0]
|
||||
|
||||
|
||||
def run_and_get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> str:
|
||||
_, source_codes = run_and_get_code(fn, *args, **kwargs)
|
||||
# Can have two outputs if backwards was eagerly compiled
|
||||
assert (
|
||||
1 <= len(source_codes) <= 2
|
||||
), f"expected one or two code outputs got {len(source_codes)}"
|
||||
assert 1 <= len(source_codes) <= 2, (
|
||||
f"expected one or two code outputs got {len(source_codes)}"
|
||||
)
|
||||
return source_codes[0]
|
||||
|
||||
|
||||
|
|
@ -1760,9 +1761,9 @@ def is_cpu_device(inputs: Sequence[torch.Tensor]) -> bool:
|
|||
|
||||
|
||||
def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype:
|
||||
assert isinstance(
|
||||
val, sympy.Expr
|
||||
), "only support sympy.Expr as input to get_sympy_Expr_dtype"
|
||||
assert isinstance(val, sympy.Expr), (
|
||||
"only support sympy.Expr as input to get_sympy_Expr_dtype"
|
||||
)
|
||||
if val.is_integer: # type: ignore[attr-defined]
|
||||
return torch.int64
|
||||
else:
|
||||
|
|
@ -1932,7 +1933,7 @@ def is_multi_outputs_template(input_buf: Optional[Union[Buffer, Operation]]) ->
|
|||
|
||||
|
||||
def is_output_of_multi_outputs_template(
|
||||
input_buf: Optional[Union[Buffer, Operation]]
|
||||
input_buf: Optional[Union[Buffer, Operation]],
|
||||
) -> bool:
|
||||
"""
|
||||
Check if input buffer is a output of multi-outputs template buffer
|
||||
|
|
@ -2633,7 +2634,8 @@ def set_kernel_post_grad_provenance_tracing(
|
|||
if node not in (EnableReduction, DisableReduction):
|
||||
if node.node is not None:
|
||||
V.debug._inductor_triton_kernel_to_post_grad_node_info[kernel_name] = [
|
||||
origin.name for origin in node.node.origins # type: ignore[attr-defined]
|
||||
origin.name
|
||||
for origin in node.node.origins # type: ignore[attr-defined]
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -314,9 +314,9 @@ class _V:
|
|||
KernelFormatterHandler = KernelFormatterHandler
|
||||
WrapperHandler = WrapperHandler
|
||||
|
||||
set_ops_handler: Callable[
|
||||
[OpsHandler[Any]], AbstractContextManager[None]
|
||||
] = _ops._set_handler
|
||||
set_ops_handler: Callable[[OpsHandler[Any]], AbstractContextManager[None]] = (
|
||||
_ops._set_handler
|
||||
)
|
||||
get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler
|
||||
set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler
|
||||
set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler
|
||||
|
|
|
|||
|
|
@ -14,8 +14,7 @@ from .runtime.runtime_utils import create_bandwidth_info_str, get_num_bytes
|
|||
|
||||
|
||||
class BenchmarkCallableType(Protocol):
|
||||
def __call__(self, times: int, repeat: int) -> float:
|
||||
...
|
||||
def __call__(self, times: int, repeat: int) -> float: ...
|
||||
|
||||
|
||||
_kernel_category_choices = [
|
||||
|
|
@ -138,9 +137,9 @@ def benchmark_all_kernels(
|
|||
)
|
||||
else:
|
||||
ms = benchmarker.benchmark_gpu(lambda: kernel_mod.call(args), rep=40)
|
||||
assert (
|
||||
len(triton_kernel.launchers) == 1
|
||||
), "Autotuner should have selected the best config"
|
||||
assert len(triton_kernel.launchers) == 1, (
|
||||
"Autotuner should have selected the best config"
|
||||
)
|
||||
launcher = triton_kernel.launchers[0]
|
||||
print(
|
||||
get_info_str(
|
||||
|
|
@ -256,9 +255,9 @@ def parse_profile_event_list(
|
|||
"triton_unknown",
|
||||
"unknown",
|
||||
]
|
||||
assert OrderedSet(all_events.keys()).issubset(
|
||||
OrderedSet(category_list)
|
||||
), f"{list(all_events.keys())}"
|
||||
assert OrderedSet(all_events.keys()).issubset(OrderedSet(category_list)), (
|
||||
f"{list(all_events.keys())}"
|
||||
)
|
||||
|
||||
per_category_wall_time = {}
|
||||
total_device_ms = 0.0
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user