mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[BE][PYFMT] migrate PYFMT for torch._inductor to ruff format (#144550)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144550 Approved by: https://github.com/jansel
This commit is contained in:
parent
34d726011f
commit
1cb4e2df65
|
|
@ -53,7 +53,6 @@ USE_BLACK_FILELIST = re.compile(
|
||||||
# torch/_[e-h]*/**
|
# torch/_[e-h]*/**
|
||||||
"torch/_[e-h]*/**",
|
"torch/_[e-h]*/**",
|
||||||
# torch/_i*/**
|
# torch/_i*/**
|
||||||
"torch/_i*/**",
|
|
||||||
# torch/_[j-z]*/**
|
# torch/_[j-z]*/**
|
||||||
"torch/_[j-z]*/**",
|
"torch/_[j-z]*/**",
|
||||||
# torch/[a-c]*/**
|
# torch/[a-c]*/**
|
||||||
|
|
|
||||||
|
|
@ -66,7 +66,9 @@ def aoti_compile_and_package(
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
ep = torch.export.export(M(), ...)
|
ep = torch.export.export(M(), ...)
|
||||||
aoti_file = torch._inductor.aoti_compile_and_package(ep, package_path="my_package.pt2")
|
aoti_file = torch._inductor.aoti_compile_and_package(
|
||||||
|
ep, package_path="my_package.pt2"
|
||||||
|
)
|
||||||
compiled_model = torch._inductor.aoti_load_package("my_package.pt2")
|
compiled_model = torch._inductor.aoti_load_package("my_package.pt2")
|
||||||
|
|
||||||
To compile and save multiple models into a single ``.pt2`` artifact, you can do
|
To compile and save multiple models into a single ``.pt2`` artifact, you can do
|
||||||
|
|
@ -75,11 +77,16 @@ def aoti_compile_and_package(
|
||||||
.. code-block:: python
|
.. code-block:: python
|
||||||
|
|
||||||
ep1 = torch.export.export(M1(), ...)
|
ep1 = torch.export.export(M1(), ...)
|
||||||
aoti_file1 = torch._inductor.aot_compile(ep1, ..., options={"aot_inductor.package": True})
|
aoti_file1 = torch._inductor.aot_compile(
|
||||||
|
ep1, ..., options={"aot_inductor.package": True}
|
||||||
|
)
|
||||||
ep2 = torch.export.export(M2(), ...)
|
ep2 = torch.export.export(M2(), ...)
|
||||||
aoti_file2 = torch._inductor.aot_compile(ep2, ..., options={"aot_inductor.package": True})
|
aoti_file2 = torch._inductor.aot_compile(
|
||||||
|
ep2, ..., options={"aot_inductor.package": True}
|
||||||
|
)
|
||||||
|
|
||||||
from torch._inductor.package import package_aoti, load_package
|
from torch._inductor.package import package_aoti, load_package
|
||||||
|
|
||||||
package_aoti("my_package.pt2", {"model1": aoti_file1, "model2": aoti_file2})
|
package_aoti("my_package.pt2", {"model1": aoti_file1, "model2": aoti_file2})
|
||||||
|
|
||||||
compiled_model1 = load_package("my_package.pt2", "model1")
|
compiled_model1 = load_package("my_package.pt2", "model1")
|
||||||
|
|
@ -123,7 +130,9 @@ def aoti_compile_and_package(
|
||||||
isinstance(package_path, (str, os.PathLike))
|
isinstance(package_path, (str, os.PathLike))
|
||||||
and os.fspath(package_path).endswith(".pt2")
|
and os.fspath(package_path).endswith(".pt2")
|
||||||
)
|
)
|
||||||
), f"Expect package path to be a file ending in .pt2, is None, or is a buffer. Instead got {package_path}"
|
), (
|
||||||
|
f"Expect package path to be a file ending in .pt2, is None, or is a buffer. Instead got {package_path}"
|
||||||
|
)
|
||||||
|
|
||||||
inductor_configs = inductor_configs or {}
|
inductor_configs = inductor_configs or {}
|
||||||
inductor_configs["aot_inductor.package"] = True
|
inductor_configs["aot_inductor.package"] = True
|
||||||
|
|
@ -168,9 +177,9 @@ def _aoti_compile_and_package_inner(
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if check_accuracy:
|
if check_accuracy:
|
||||||
assert (
|
assert kwargs is None or len(kwargs) == 0, (
|
||||||
kwargs is None or len(kwargs) == 0
|
"when checking for accuracy, the inputs must have been flattened and kwargs is None"
|
||||||
), "when checking for accuracy, the inputs must have been flattened and kwargs is None"
|
)
|
||||||
|
|
||||||
from .package import package_aoti
|
from .package import package_aoti
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -156,8 +156,9 @@ def can_codegen_without_upcasts(
|
||||||
low_prec_analysis = RecordLowPrecisionOps(disallow_fp32_ops)
|
low_prec_analysis = RecordLowPrecisionOps(disallow_fp32_ops)
|
||||||
|
|
||||||
# Need to turn off upcasting to do analysis of whether we can turn it off
|
# Need to turn off upcasting to do analysis of whether we can turn it off
|
||||||
with config.patch("triton.codegen_upcast_to_fp32", False), V.set_ops_handler(
|
with (
|
||||||
low_prec_analysis
|
config.patch("triton.codegen_upcast_to_fp32", False),
|
||||||
|
V.set_ops_handler(low_prec_analysis),
|
||||||
):
|
):
|
||||||
prologue._body(*prologue.get_ranges())
|
prologue._body(*prologue.get_ranges())
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -245,8 +245,7 @@ class AsyncCompile:
|
||||||
|
|
||||||
def use_process_pool(self):
|
def use_process_pool(self):
|
||||||
return (
|
return (
|
||||||
get_compile_threads() > 1
|
get_compile_threads() > 1 and self.process_pool().ready_future.done() # type: ignore[union-attr]
|
||||||
and self.process_pool().ready_future.done() # type: ignore[union-attr]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):
|
def triton(self, kernel_name: str, source_code: str, device_str: str = "cuda"):
|
||||||
|
|
|
||||||
|
|
@ -24,8 +24,7 @@ if TYPE_CHECKING:
|
||||||
class Sortable(typing.Protocol):
|
class Sortable(typing.Protocol):
|
||||||
"""Anything that can be used as a list.sort() key (int/tuple/etc)"""
|
"""Anything that can be used as a list.sort() key (int/tuple/etc)"""
|
||||||
|
|
||||||
def __lt__(self, other: typing.Self) -> bool:
|
def __lt__(self, other: typing.Self) -> bool: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class InductorChoices:
|
class InductorChoices:
|
||||||
|
|
@ -100,7 +99,9 @@ class InductorChoices:
|
||||||
# to pick the faster one.
|
# to pick the faster one.
|
||||||
if config.triton.multi_kernel:
|
if config.triton.multi_kernel:
|
||||||
threshold *= 16
|
threshold *= 16
|
||||||
return V.graph.sizevars.statically_known_leq(features.reduction_numel, threshold) # type: ignore[arg-types]
|
return V.graph.sizevars.statically_known_leq(
|
||||||
|
features.reduction_numel, threshold
|
||||||
|
) # type: ignore[arg-types]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def want_no_x_dim(features: SIMDKernelFeatures) -> bool:
|
def want_no_x_dim(features: SIMDKernelFeatures) -> bool:
|
||||||
|
|
|
||||||
|
|
@ -417,9 +417,9 @@ def write_atomic(
|
||||||
) -> None:
|
) -> None:
|
||||||
# Write into temporary file first to avoid conflicts between threads
|
# Write into temporary file first to avoid conflicts between threads
|
||||||
# Avoid using a named temporary file, as those have restricted permissions
|
# Avoid using a named temporary file, as those have restricted permissions
|
||||||
assert isinstance(
|
assert isinstance(content, (str, bytes)), (
|
||||||
content, (str, bytes)
|
"Only strings and byte arrays can be saved in the cache"
|
||||||
), "Only strings and byte arrays can be saved in the cache"
|
)
|
||||||
path = Path(path_)
|
path = Path(path_)
|
||||||
if make_dirs:
|
if make_dirs:
|
||||||
path.parent.mkdir(parents=True, exist_ok=True)
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
@ -975,9 +975,9 @@ class FxGraphCache:
|
||||||
symints = FxGraphCache._filter_backed_symints(example_inputs)
|
symints = FxGraphCache._filter_backed_symints(example_inputs)
|
||||||
hints = [hint_int(s) for s in symints]
|
hints = [hint_int(s) for s in symints]
|
||||||
|
|
||||||
def iterate_over_candidates() -> (
|
def iterate_over_candidates() -> Generator[
|
||||||
Generator[tuple[CompiledFxGraph, bytes], None, None]
|
tuple[CompiledFxGraph, bytes], None, None
|
||||||
):
|
]:
|
||||||
if local:
|
if local:
|
||||||
subdir = FxGraphCache._get_tmp_dir_for_key(key)
|
subdir = FxGraphCache._get_tmp_dir_for_key(key)
|
||||||
if os.path.exists(subdir):
|
if os.path.exists(subdir):
|
||||||
|
|
@ -1123,9 +1123,9 @@ class FxGraphCache:
|
||||||
"""
|
"""
|
||||||
from .compile_fx import CompiledFxGraph
|
from .compile_fx import CompiledFxGraph
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(compiled_graph, CompiledFxGraph), (
|
||||||
compiled_graph, CompiledFxGraph
|
f"serialization for {type(compiled_graph)} NYI"
|
||||||
), f"serialization for {type(compiled_graph)} NYI"
|
)
|
||||||
disk_compiled_graph = copy(compiled_graph)
|
disk_compiled_graph = copy(compiled_graph)
|
||||||
disk_compiled_graph.prepare_for_serialization()
|
disk_compiled_graph.prepare_for_serialization()
|
||||||
|
|
||||||
|
|
@ -1315,9 +1315,8 @@ class FxGraphCache:
|
||||||
"distributed_ephemeral_timeout_us", time_saved_ns // 1000
|
"distributed_ephemeral_timeout_us", time_saved_ns // 1000
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
ephemeral_increase := add_ephemeral_timeout_increase_for_distributed(
|
ephemeral_increase
|
||||||
time_saved_ns
|
:= add_ephemeral_timeout_increase_for_distributed(time_saved_ns)
|
||||||
)
|
|
||||||
) != 0:
|
) != 0:
|
||||||
cache_info["ephemeral_timeout_increase"] = ephemeral_increase
|
cache_info["ephemeral_timeout_increase"] = ephemeral_increase
|
||||||
else:
|
else:
|
||||||
|
|
@ -1556,9 +1555,9 @@ class AotCodeCompiler:
|
||||||
cpp_path_operator.with_name(f"{cpp_path_operator.stem}_metadata.json")
|
cpp_path_operator.with_name(f"{cpp_path_operator.stem}_metadata.json")
|
||||||
)
|
)
|
||||||
for k, v in config.aot_inductor.metadata.items():
|
for k, v in config.aot_inductor.metadata.items():
|
||||||
assert isinstance(k, str) and isinstance(
|
assert isinstance(k, str) and isinstance(v, (str)), (
|
||||||
v, (str)
|
"Metadata must only contain strings"
|
||||||
), "Metadata must only contain strings"
|
)
|
||||||
|
|
||||||
with open(meta_json, "w") as f:
|
with open(meta_json, "w") as f:
|
||||||
f.write(json.dumps(config.aot_inductor.metadata))
|
f.write(json.dumps(config.aot_inductor.metadata))
|
||||||
|
|
|
||||||
|
|
@ -341,7 +341,7 @@ class BackendFeature(Enum):
|
||||||
|
|
||||||
|
|
||||||
def get_backend_features(
|
def get_backend_features(
|
||||||
device: Union[torch.device, str, None]
|
device: Union[torch.device, str, None],
|
||||||
) -> OrderedSet[BackendFeature]:
|
) -> OrderedSet[BackendFeature]:
|
||||||
if device is None:
|
if device is None:
|
||||||
return OrderedSet()
|
return OrderedSet()
|
||||||
|
|
@ -986,9 +986,9 @@ class OpOverrides(BasicMathOpsMixin, OpDecompositions, OpsHandler[Any]):
|
||||||
if cls._is_unimplemented(funcname):
|
if cls._is_unimplemented(funcname):
|
||||||
setattr(cls, funcname, cls._unimplemented(funcname))
|
setattr(cls, funcname, cls._unimplemented(funcname))
|
||||||
else:
|
else:
|
||||||
assert (
|
assert funcname not in cls.__dict__, (
|
||||||
funcname not in cls.__dict__
|
f"multiple definitions of {funcname} on {cls.__name__}"
|
||||||
), f"multiple definitions of {funcname} on {cls.__name__}"
|
)
|
||||||
impl.__name__ = funcname
|
impl.__name__ = funcname
|
||||||
setattr(cls, funcname, staticmethod(impl))
|
setattr(cls, funcname, staticmethod(impl))
|
||||||
|
|
||||||
|
|
@ -2229,7 +2229,7 @@ class KernelTemplate:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _fake_get_dtype(
|
def _fake_get_dtype(
|
||||||
fake_outs: Union[list[Buffer], Buffer]
|
fake_outs: Union[list[Buffer], Buffer],
|
||||||
) -> Callable[[str], torch.dtype]:
|
) -> Callable[[str], torch.dtype]:
|
||||||
_get_dtype_real = V.graph.get_dtype
|
_get_dtype_real = V.graph.get_dtype
|
||||||
if isinstance(fake_outs, (list, tuple)):
|
if isinstance(fake_outs, (list, tuple)):
|
||||||
|
|
|
||||||
|
|
@ -483,9 +483,9 @@ class OuterLoopFusedSchedulerNode(FusedSchedulerNode):
|
||||||
outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]],
|
outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]],
|
||||||
outer_loop_fusion_depth,
|
outer_loop_fusion_depth,
|
||||||
):
|
):
|
||||||
self.outer_fused_nodes: list[
|
self.outer_fused_nodes: list[Union[FusedSchedulerNode, SchedulerNode]] = (
|
||||||
Union[FusedSchedulerNode, SchedulerNode]
|
outer_fused_nodes
|
||||||
] = outer_fused_nodes
|
)
|
||||||
self.outer_loop_fusion_depth = outer_loop_fusion_depth
|
self.outer_loop_fusion_depth = outer_loop_fusion_depth
|
||||||
flatten_snodes = []
|
flatten_snodes = []
|
||||||
for _node in self.outer_fused_nodes:
|
for _node in self.outer_fused_nodes:
|
||||||
|
|
@ -1361,9 +1361,9 @@ class CppVecOverrides(CppOverrides):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def remainder(a, b):
|
def remainder(a, b):
|
||||||
assert (
|
assert a.dtype == b.dtype, (
|
||||||
a.dtype == b.dtype
|
"remainder vec implementation expect the same inputs' dtype."
|
||||||
), "remainder vec implementation expect the same inputs' dtype."
|
)
|
||||||
return f"{a} - ({CppVecOverrides.floordiv(a, b)}) * {b}"
|
return f"{a} - ({CppVecOverrides.floordiv(a, b)}) * {b}"
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -1468,9 +1468,9 @@ class CppVecOverrides(CppOverrides):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def floordiv(a, b):
|
def floordiv(a, b):
|
||||||
if is_float_dtype(a.dtype):
|
if is_float_dtype(a.dtype):
|
||||||
assert (
|
assert a.dtype == b.dtype, (
|
||||||
a.dtype == b.dtype
|
"div_floor_floating_vec implementation expect the same inputs' dtype."
|
||||||
), "div_floor_floating_vec implementation expect the same inputs' dtype."
|
)
|
||||||
return f"div_floor_floating_vec({a}, {b})"
|
return f"div_floor_floating_vec({a}, {b})"
|
||||||
else:
|
else:
|
||||||
assert all(is_integer_dtype(item.dtype) for item in [a, b])
|
assert all(is_integer_dtype(item.dtype) for item in [a, b])
|
||||||
|
|
@ -1629,9 +1629,9 @@ class CppVecOverrides(CppOverrides):
|
||||||
assert isinstance(other_vec_var, CppCSEVariable), other_vec_var
|
assert isinstance(other_vec_var, CppCSEVariable), other_vec_var
|
||||||
body_vec_var.dtype = dtype
|
body_vec_var.dtype = dtype
|
||||||
other_vec_var.dtype = dtype
|
other_vec_var.dtype = dtype
|
||||||
overrides: type[
|
overrides: type[Union[CppOverrides, CppVecOverrides]] = (
|
||||||
Union[CppOverrides, CppVecOverrides]
|
V.kernel.overrides
|
||||||
] = V.kernel.overrides # type: ignore[has-type]
|
) # type: ignore[has-type]
|
||||||
code.writeline(
|
code.writeline(
|
||||||
f"return {overrides.where(new_mask, body_vec_var, other_vec_var)};"
|
f"return {overrides.where(new_mask, body_vec_var, other_vec_var)};"
|
||||||
)
|
)
|
||||||
|
|
@ -2108,9 +2108,9 @@ class CppKernel(Kernel):
|
||||||
|
|
||||||
def set_ranges(self, lengths, reduction_lengths):
|
def set_ranges(self, lengths, reduction_lengths):
|
||||||
if self.call_ranges:
|
if self.call_ranges:
|
||||||
assert self.call_ranges == tuple(lengths) + tuple(
|
assert self.call_ranges == tuple(lengths) + tuple(reduction_lengths), (
|
||||||
reduction_lengths
|
f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}"
|
||||||
), f"{self.call_ranges} == {tuple(lengths)} + {tuple(reduction_lengths)}"
|
)
|
||||||
assert self.reduction_depth == len(lengths)
|
assert self.reduction_depth == len(lengths)
|
||||||
else:
|
else:
|
||||||
self.call_ranges = tuple(lengths) + tuple(reduction_lengths)
|
self.call_ranges = tuple(lengths) + tuple(reduction_lengths)
|
||||||
|
|
@ -2787,9 +2787,9 @@ class CppVecKernel(CppKernel):
|
||||||
self.weight_recps_val = self.weight_recps_cse.generate(
|
self.weight_recps_val = self.weight_recps_cse.generate(
|
||||||
self.compute, f"reduction {self.weight_recp_vec_range}", write=False
|
self.compute, f"reduction {self.weight_recp_vec_range}", write=False
|
||||||
)
|
)
|
||||||
self.weight_recps_cse.reduction_cache[
|
self.weight_recps_cse.reduction_cache[self.weight_recp_vec_range] = (
|
||||||
self.weight_recp_vec_range
|
self.weight_recps_val
|
||||||
] = self.weight_recps_val
|
)
|
||||||
self.non_parallel_reduction_prefix.writeline(
|
self.non_parallel_reduction_prefix.writeline(
|
||||||
self.welford_weight_reciprocal_vec(dtype)
|
self.welford_weight_reciprocal_vec(dtype)
|
||||||
)
|
)
|
||||||
|
|
@ -4969,9 +4969,9 @@ class CppScheduling(BaseScheduling):
|
||||||
]
|
]
|
||||||
|
|
||||||
counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes)
|
counters["inductor"]["cpp_epilogue_fusion_counter"] += len(epilogue_nodes)
|
||||||
assert self.is_cpp_template(
|
assert self.is_cpp_template(template_node), (
|
||||||
template_node
|
"Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer"
|
||||||
), "Template node passed to CppScheduler.codegen_template must be a SchedulerNode that wraps a CppTemplateBuffer"
|
)
|
||||||
template_node = cast(SchedulerNode, template_node)
|
template_node = cast(SchedulerNode, template_node)
|
||||||
_, (_, rnumel) = template_node.group
|
_, (_, rnumel) = template_node.group
|
||||||
assert rnumel == ()
|
assert rnumel == ()
|
||||||
|
|
@ -4979,9 +4979,9 @@ class CppScheduling(BaseScheduling):
|
||||||
epilogue_ir_nodes: list[Optional[ir.Operation]] = [
|
epilogue_ir_nodes: list[Optional[ir.Operation]] = [
|
||||||
n.node for n in epilogue_nodes
|
n.node for n in epilogue_nodes
|
||||||
]
|
]
|
||||||
assert all(
|
assert all(isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes), (
|
||||||
isinstance(n, ir.ComputedBuffer) for n in epilogue_ir_nodes
|
"Epilogue nodes must all be instances of ir.ComputedBuffer"
|
||||||
), "Epilogue nodes must all be instances of ir.ComputedBuffer"
|
)
|
||||||
|
|
||||||
def template_buffer_has_other_users(
|
def template_buffer_has_other_users(
|
||||||
template_buffer, outputs_by_name, epilogue_nodes
|
template_buffer, outputs_by_name, epilogue_nodes
|
||||||
|
|
@ -5019,16 +5019,16 @@ class CppScheduling(BaseScheduling):
|
||||||
if is_multi_outputs_template(template_node.node):
|
if is_multi_outputs_template(template_node.node):
|
||||||
# For multi outputs template, allocate buffers for each output after the epilogue
|
# For multi outputs template, allocate buffers for each output after the epilogue
|
||||||
# codegen to which determines if the buffer has been removed.
|
# codegen to which determines if the buffer has been removed.
|
||||||
assert (
|
assert len(template_node.outputs) == 1, (
|
||||||
len(template_node.outputs) == 1
|
"Multi outputs template should be with 1 output template buffer of MultiOutputLayout"
|
||||||
), "Multi outputs template should be with 1 output template buffer of MultiOutputLayout"
|
)
|
||||||
for user in template_node.outputs[0].users:
|
for user in template_node.outputs[0].users:
|
||||||
assert isinstance(
|
assert isinstance(user.node, ExternKernelSchedulerNode), (
|
||||||
user.node, ExternKernelSchedulerNode
|
"Multi outputs template should be with ExternKernelSchedulerNode"
|
||||||
), "Multi outputs template should be with ExternKernelSchedulerNode"
|
)
|
||||||
assert isinstance(
|
assert isinstance(user.node.node, ir.MultiOutput), (
|
||||||
user.node.node, ir.MultiOutput
|
"Multi outputs template has multi users with MultiOutput"
|
||||||
), "Multi outputs template has multi users with MultiOutput"
|
)
|
||||||
user.node.mark_run()
|
user.node.mark_run()
|
||||||
|
|
||||||
kernel.call_kernel(kernel_name, ctb)
|
kernel.call_kernel(kernel_name, ctb)
|
||||||
|
|
@ -5347,9 +5347,9 @@ class LoopNest:
|
||||||
return self.loops is not None and self.loops[0].is_reduction
|
return self.loops is not None and self.loops[0].is_reduction
|
||||||
|
|
||||||
def mark_parallel(self, par_depth):
|
def mark_parallel(self, par_depth):
|
||||||
assert (
|
assert par_depth <= self.max_parallel_depth(), (
|
||||||
par_depth <= self.max_parallel_depth()
|
"Parallel depth cannot exceed the maximal allowed parallel depth"
|
||||||
), "Parallel depth cannot exceed the maximal allowed parallel depth"
|
)
|
||||||
assert self.loops is not None
|
assert self.loops is not None
|
||||||
assert len(self.loops) >= par_depth
|
assert len(self.loops) >= par_depth
|
||||||
loop = self.loops[0]
|
loop = self.loops[0]
|
||||||
|
|
|
||||||
|
|
@ -862,7 +862,9 @@ class CppFlexAttentionTemplate(CppTemplate):
|
||||||
assert all(
|
assert all(
|
||||||
mem.buffer_name in kernel_group.args.input_buffers
|
mem.buffer_name in kernel_group.args.input_buffers
|
||||||
for mem in body.memory_usage[MemoryUsageType.LOAD]
|
for mem in body.memory_usage[MemoryUsageType.LOAD]
|
||||||
), "All the buffers in the score and mask subgraph should be in kernel_group.args.input_buffers"
|
), (
|
||||||
|
"All the buffers in the score and mask subgraph should be in kernel_group.args.input_buffers"
|
||||||
|
)
|
||||||
|
|
||||||
bodies.append(body)
|
bodies.append(body)
|
||||||
var_sizes_list.append((var_sizes, ()))
|
var_sizes_list.append((var_sizes, ()))
|
||||||
|
|
|
||||||
|
|
@ -557,9 +557,9 @@ class CppGemmTemplate(CppTemplate):
|
||||||
thread_block_m = math.ceil(m_blocks / m_factor)
|
thread_block_m = math.ceil(m_blocks / m_factor)
|
||||||
return GemmBlocking(thread_block_m, thread_block_n, thread_block_k)
|
return GemmBlocking(thread_block_m, thread_block_n, thread_block_k)
|
||||||
|
|
||||||
assert (
|
assert not self.is_dynamic_M, (
|
||||||
not self.is_dynamic_M
|
"Unable to determine thread blocking for dynamic M."
|
||||||
), "Unable to determine thread blocking for dynamic M."
|
)
|
||||||
register_blocking = self.register_blocking
|
register_blocking = self.register_blocking
|
||||||
m_blocks = math.ceil(self.m / register_blocking.block_m)
|
m_blocks = math.ceil(self.m / register_blocking.block_m)
|
||||||
n_blocks = math.ceil(self.n / register_blocking.block_n)
|
n_blocks = math.ceil(self.n / register_blocking.block_n)
|
||||||
|
|
@ -673,17 +673,17 @@ class CppGemmTemplate(CppTemplate):
|
||||||
L1_cache_size = (
|
L1_cache_size = (
|
||||||
torch._C._cpu._L1d_cache_size()
|
torch._C._cpu._L1d_cache_size()
|
||||||
) # per core cache size in Bytes
|
) # per core cache size in Bytes
|
||||||
assert (
|
assert L1_cache_size > 0, (
|
||||||
L1_cache_size > 0
|
f"Expect L1_cache_size > 0 but got {L1_cache_size}"
|
||||||
), f"Expect L1_cache_size > 0 but got {L1_cache_size}"
|
)
|
||||||
L1 = L1_cache_size * L1_limit_factor
|
L1 = L1_cache_size * L1_limit_factor
|
||||||
|
|
||||||
L2_cache_size = (
|
L2_cache_size = (
|
||||||
torch._C._cpu._L2_cache_size()
|
torch._C._cpu._L2_cache_size()
|
||||||
) # per core cache size in Bytes
|
) # per core cache size in Bytes
|
||||||
assert (
|
assert L2_cache_size > 0, (
|
||||||
L2_cache_size > 0
|
f"Expect L2_cache_size > 0 but got {L2_cache_size}"
|
||||||
), f"Expect L2_cache_size > 0 but got {L2_cache_size}"
|
)
|
||||||
L2 = L2_cache_size * L2_limit_factor
|
L2 = L2_cache_size * L2_limit_factor
|
||||||
|
|
||||||
def get_num_byte(dtype):
|
def get_num_byte(dtype):
|
||||||
|
|
@ -744,9 +744,9 @@ class CppGemmTemplate(CppTemplate):
|
||||||
|
|
||||||
return Mc_blocks, Nc_blocks, Kc_blocks
|
return Mc_blocks, Nc_blocks, Kc_blocks
|
||||||
|
|
||||||
assert (
|
assert not self.is_dynamic_M, (
|
||||||
not self.is_dynamic_M
|
"Unable to determine cache blocking for dynamic M."
|
||||||
), "Unable to determine cache blocking for dynamic M."
|
)
|
||||||
register_blocking = self.register_blocking
|
register_blocking = self.register_blocking
|
||||||
thread_blocking = self.thread_blocking(num_threads)
|
thread_blocking = self.thread_blocking(num_threads)
|
||||||
|
|
||||||
|
|
@ -1114,9 +1114,9 @@ class CppGemmTemplate(CppTemplate):
|
||||||
LayoutType.VNNI4,
|
LayoutType.VNNI4,
|
||||||
], f"We only support {layout_str} for now"
|
], f"We only support {layout_str} for now"
|
||||||
vnni_size = 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2
|
vnni_size = 4 if micro_gemm.get_b_layout() == LayoutType.VNNI4 else 2
|
||||||
assert (
|
assert k % vnni_size == 0, (
|
||||||
k % vnni_size == 0
|
f"k should be divisible by vnni_size for {layout_str} layout"
|
||||||
), f"k should be divisible by vnni_size for {layout_str} layout"
|
)
|
||||||
vnni_view_size = list(new_size)
|
vnni_view_size = list(new_size)
|
||||||
vnni_view_size[-2] = k // vnni_size
|
vnni_view_size[-2] = k // vnni_size
|
||||||
vnni_view_size.insert(-1, vnni_size)
|
vnni_view_size.insert(-1, vnni_size)
|
||||||
|
|
|
||||||
|
|
@ -309,9 +309,9 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||||
for W_node in W_nodes:
|
for W_node in W_nodes:
|
||||||
assert W_node.get_name() in V.graph.constants
|
assert W_node.get_name() in V.graph.constants
|
||||||
W_tensor.append(V.graph.constants[W_node.get_name()])
|
W_tensor.append(V.graph.constants[W_node.get_name()])
|
||||||
new_input_nodes[
|
new_input_nodes[wgt_start_idx : wgt_start_idx + gemm_grouped_num] = (
|
||||||
wgt_start_idx : wgt_start_idx + gemm_grouped_num
|
W_tensor # type: ignore[assignment]
|
||||||
] = W_tensor # type: ignore[assignment]
|
)
|
||||||
new_input_nodes, _ = pack_weight(
|
new_input_nodes, _ = pack_weight(
|
||||||
*normalize_shapes(*maybe_to_dense(new_input_nodes, layout))
|
*normalize_shapes(*maybe_to_dense(new_input_nodes, layout))
|
||||||
)
|
)
|
||||||
|
|
@ -321,9 +321,9 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||||
W_packed = new_input_nodes[idx]
|
W_packed = new_input_nodes[idx]
|
||||||
assert isinstance(W_packed, torch.Tensor)
|
assert isinstance(W_packed, torch.Tensor)
|
||||||
W_packed_constant = V.graph.add_tensor_constant(W_packed)
|
W_packed_constant = V.graph.add_tensor_constant(W_packed)
|
||||||
template_buffer.inputs[
|
template_buffer.inputs[idx] = (
|
||||||
idx
|
ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
|
||||||
] = ir.InputsKernel.unwrap_storage_for_input(W_packed_constant)
|
)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
template = DataProcessorTemplateWrapper(
|
template = DataProcessorTemplateWrapper(
|
||||||
|
|
@ -419,9 +419,9 @@ class CppGroupedGemmTemplate(CppGemmTemplate):
|
||||||
ir.Buffer(name=gemm_output_name, layout=template_buffer.layout)
|
ir.Buffer(name=gemm_output_name, layout=template_buffer.layout)
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert not self.epilogue_creator, (
|
||||||
not self.epilogue_creator
|
"epilogue_creator is not supported yet in Grouped GEMM Template"
|
||||||
), "epilogue_creator is not supported yet in Grouped GEMM Template"
|
)
|
||||||
|
|
||||||
kernel_args: dict[str, Optional[ir.IRNode]] = {}
|
kernel_args: dict[str, Optional[ir.IRNode]] = {}
|
||||||
for x_idx in range(wgt_start_idx):
|
for x_idx in range(wgt_start_idx):
|
||||||
|
|
|
||||||
|
|
@ -231,9 +231,9 @@ micro_gemm_configs: dict[type[CppMicroGemm], list[CppMicroGemmConfig]] = {}
|
||||||
|
|
||||||
def register_micro_gemm(*configs):
|
def register_micro_gemm(*configs):
|
||||||
def inner(cls):
|
def inner(cls):
|
||||||
assert (
|
assert cls not in micro_gemm_configs, (
|
||||||
cls not in micro_gemm_configs
|
f"Duplicate micro_gemm registration for {cls}"
|
||||||
), f"Duplicate micro_gemm registration for {cls}"
|
)
|
||||||
assert len(configs) > 0, f"No micro_gemm configs provided for {cls}"
|
assert len(configs) > 0, f"No micro_gemm configs provided for {cls}"
|
||||||
micro_gemm_configs[cls] = list(configs)
|
micro_gemm_configs[cls] = list(configs)
|
||||||
return cls
|
return cls
|
||||||
|
|
|
||||||
|
|
@ -44,11 +44,13 @@ class CppTemplate(KernelTemplate):
|
||||||
|
|
||||||
def generate(self, **kwargs):
|
def generate(self, **kwargs):
|
||||||
kernel_name = f"cpp_{self.name}"
|
kernel_name = f"cpp_{self.name}"
|
||||||
with patch.object(
|
with (
|
||||||
V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
|
patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)),
|
||||||
), patch.object(ir.FlexibleLayout, "allow_indexing", True), CppTemplateKernel(
|
patch.object(ir.FlexibleLayout, "allow_indexing", True),
|
||||||
kernel_name=kernel_name, num_threads=self.num_threads
|
CppTemplateKernel(
|
||||||
) as kernel:
|
kernel_name=kernel_name, num_threads=self.num_threads
|
||||||
|
) as kernel,
|
||||||
|
):
|
||||||
code = kernel.render(self, **kwargs)
|
code = kernel.render(self, **kwargs)
|
||||||
_, call_args, _, _ = kernel.args.python_argdefs()
|
_, call_args, _, _ = kernel.args.python_argdefs()
|
||||||
log.debug("Generated Code:\n%s", code)
|
log.debug("Generated Code:\n%s", code)
|
||||||
|
|
|
||||||
|
|
@ -377,7 +377,10 @@ class CppTemplateKernel(CppKernel):
|
||||||
)
|
)
|
||||||
epilogue_nodes = scope.localize_nodes(epilogue_nodes)
|
epilogue_nodes = scope.localize_nodes(epilogue_nodes)
|
||||||
return self.store_pointwise_nodes(
|
return self.store_pointwise_nodes(
|
||||||
dst, epilogue_nodes, offsets, reindexers # type: ignore[arg-type]
|
dst,
|
||||||
|
epilogue_nodes, # type: ignore[arg-type]
|
||||||
|
offsets,
|
||||||
|
reindexers,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if dst.get_name() != src.get_name():
|
if dst.get_name() != src.get_name():
|
||||||
|
|
|
||||||
|
|
@ -110,9 +110,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
Only valid when cuda == True.
|
Only valid when cuda == True.
|
||||||
"""
|
"""
|
||||||
assert not gpu, "CppWrapperCpu.generate_kernel_call does not support GPU"
|
assert not gpu, "CppWrapperCpu.generate_kernel_call does not support GPU"
|
||||||
assert arg_types is not None and len(call_args) == len(
|
assert arg_types is not None and len(call_args) == len(arg_types), (
|
||||||
arg_types
|
"Mismatch call_args and arg_types in generate_kernel_call"
|
||||||
), "Mismatch call_args and arg_types in generate_kernel_call"
|
)
|
||||||
new_args = []
|
new_args = []
|
||||||
for idx, arg in enumerate(call_args):
|
for idx, arg in enumerate(call_args):
|
||||||
if "*" in arg_types[idx]:
|
if "*" in arg_types[idx]:
|
||||||
|
|
@ -506,9 +506,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
dtype = may_get_constant_buffer_dtype(
|
dtype = may_get_constant_buffer_dtype(
|
||||||
V.graph.graph_inputs[input_key] # type: ignore[arg-type]
|
V.graph.graph_inputs[input_key] # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
assert (
|
assert dtype is not None, (
|
||||||
dtype is not None
|
"Fails to get the dtype of the sympy.Expr"
|
||||||
), "Fails to get the dtype of the sympy.Expr"
|
)
|
||||||
self.codegen_tensor_item(
|
self.codegen_tensor_item(
|
||||||
dtype, f"inputs[{idx}]", input_key, self.prefix
|
dtype, f"inputs[{idx}]", input_key, self.prefix
|
||||||
)
|
)
|
||||||
|
|
@ -555,8 +555,7 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
def codegen_tensor_dtype_var_decl(self, code: IndentedBuffer, name):
|
def codegen_tensor_dtype_var_decl(self, code: IndentedBuffer, name):
|
||||||
code.writeline(f"int32_t {name}_dtype;")
|
code.writeline(f"int32_t {name}_dtype;")
|
||||||
code.writeline(
|
code.writeline(
|
||||||
"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype"
|
f"AOTI_TORCH_ERROR_CODE_CHECK(aoti_torch_get_dtype({name}, &{name}_dtype));"
|
||||||
f"({name}, &{name}_dtype));"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def codegen_input_size_var_decl(self, code: IndentedBuffer, name):
|
def codegen_input_size_var_decl(self, code: IndentedBuffer, name):
|
||||||
|
|
@ -570,9 +569,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
|
|
||||||
# Tell compiler we need to link with the non-mangled symbols
|
# Tell compiler we need to link with the non-mangled symbols
|
||||||
for kernel in self.initialized_kernels.values():
|
for kernel in self.initialized_kernels.values():
|
||||||
assert hasattr(
|
assert hasattr(kernel, "get_signature"), (
|
||||||
kernel, "get_signature"
|
f"{kernel} must have get_signature implemented"
|
||||||
), f"{kernel} must have get_signature implemented"
|
)
|
||||||
signature = kernel.get_signature()
|
signature = kernel.get_signature()
|
||||||
self.prefix.writeline(f'extern "C" {signature};')
|
self.prefix.writeline(f'extern "C" {signature};')
|
||||||
|
|
||||||
|
|
@ -597,9 +596,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
for name, kernel in self.initialized_kernels.items():
|
for name, kernel in self.initialized_kernels.items():
|
||||||
assert hasattr(
|
assert hasattr(kernel, "get_signature"), (
|
||||||
kernel, "get_signature"
|
f"{kernel} must have get_signature implemented"
|
||||||
), f"{kernel} must have get_signature implemented"
|
)
|
||||||
kernel_ptr = f"(*{name})"
|
kernel_ptr = f"(*{name})"
|
||||||
signature = kernel.get_signature().replace(name, kernel_ptr)
|
signature = kernel.get_signature().replace(name, kernel_ptr)
|
||||||
self.prefix.writeline(f" {signature} = torch::aot_inductor::{name};")
|
self.prefix.writeline(f" {signature} = torch::aot_inductor::{name};")
|
||||||
|
|
@ -645,9 +644,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
|
|
||||||
with self.prefix.indent():
|
with self.prefix.indent():
|
||||||
for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()):
|
for idx, (name, inp) in enumerate(V.graph.graph_inputs.items()):
|
||||||
assert not isinstance(
|
assert not isinstance(inp, sympy.Expr), (
|
||||||
inp, sympy.Expr
|
f"input {name=} cannot be symbolic"
|
||||||
), f"input {name=} cannot be symbolic"
|
)
|
||||||
self.write_input_output_info("inputs_info_", idx, name)
|
self.write_input_output_info("inputs_info_", idx, name)
|
||||||
|
|
||||||
all_cuda = all(
|
all_cuda = all(
|
||||||
|
|
@ -718,9 +717,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
opaque_metadata_tensor = torch.ops.mkldnn._get_mkldnn_serialized_md(
|
opaque_metadata_tensor = torch.ops.mkldnn._get_mkldnn_serialized_md(
|
||||||
tensor
|
tensor
|
||||||
)
|
)
|
||||||
assert (
|
assert opaque_metadata_tensor.dim() == 1, (
|
||||||
opaque_metadata_tensor.dim() == 1
|
"Expect opaque_metadata_tensor to be 1-D"
|
||||||
), "Expect opaque_metadata_tensor to be 1-D"
|
)
|
||||||
|
|
||||||
opaque_metadata_list = opaque_metadata_tensor.tolist()
|
opaque_metadata_list = opaque_metadata_tensor.tolist()
|
||||||
opaque_metadata_str = self.codegen_shape_tuple(opaque_metadata_list)
|
opaque_metadata_str = self.codegen_shape_tuple(opaque_metadata_list)
|
||||||
|
|
@ -757,9 +756,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx, output in enumerate(V.graph.graph_outputs):
|
for idx, output in enumerate(V.graph.graph_outputs):
|
||||||
assert not isinstance(
|
assert not isinstance(output, sympy.Expr), (
|
||||||
output, sympy.Expr
|
f"output {name=} cannot be symbolic"
|
||||||
), f"output {name=} cannot be symbolic"
|
)
|
||||||
name = f"output{idx}"
|
name = f"output{idx}"
|
||||||
self.write_input_output_info("outputs_info_", idx, name)
|
self.write_input_output_info("outputs_info_", idx, name)
|
||||||
|
|
||||||
|
|
@ -816,9 +815,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
for idx, (name, _) in enumerate(V.graph.constants.items()):
|
for idx, (name, _) in enumerate(V.graph.constants.items()):
|
||||||
if name in V.graph.const_output_index:
|
if name in V.graph.const_output_index:
|
||||||
const_index_mapping[V.graph.const_output_index[name]] = (idx, name) # type: ignore[call-overload]
|
const_index_mapping[V.graph.const_output_index[name]] = (idx, name) # type: ignore[call-overload]
|
||||||
assert (
|
assert None not in const_index_mapping, (
|
||||||
None not in const_index_mapping
|
"Not all constant gets mapped for constant folding graph."
|
||||||
), "Not all constant gets mapped for constant folding graph."
|
)
|
||||||
|
|
||||||
self.prefix.writeline(
|
self.prefix.writeline(
|
||||||
f"""
|
f"""
|
||||||
|
|
@ -1117,9 +1116,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
name = f"{output.get_name()}"
|
name = f"{output.get_name()}"
|
||||||
output_handle_name = f"{name}_handle"
|
output_handle_name = f"{name}_handle"
|
||||||
if output.indices:
|
if output.indices:
|
||||||
assert (
|
assert output.indices[0][1] == idx, (
|
||||||
output.indices[0][1] == idx
|
f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}"
|
||||||
), f"expected {output.indices[0][1]=} == {idx=} for {output_name_base=}"
|
)
|
||||||
self.writeline(f"AtenTensorHandle {output_handle_name};")
|
self.writeline(f"AtenTensorHandle {output_handle_name};")
|
||||||
output_args.append(f"&{output_handle_name}")
|
output_args.append(f"&{output_handle_name}")
|
||||||
output_raii_handles.append(
|
output_raii_handles.append(
|
||||||
|
|
@ -1140,7 +1139,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
args = args + output_args
|
args = args + output_args
|
||||||
device = d.type if (d := fallback_kernel.get_device()) else self.device
|
device = d.type if (d := fallback_kernel.get_device()) else self.device
|
||||||
self.generate_c_shim_extern_kernel_call(
|
self.generate_c_shim_extern_kernel_call(
|
||||||
fallback_kernel.cpp_kernel_name, args, device # type: ignore[arg-type]
|
fallback_kernel.cpp_kernel_name, # type: ignore[arg-type]
|
||||||
|
args,
|
||||||
|
device,
|
||||||
)
|
)
|
||||||
for raii_handle in output_raii_handles:
|
for raii_handle in output_raii_handles:
|
||||||
self.writeline(raii_handle)
|
self.writeline(raii_handle)
|
||||||
|
|
@ -1189,9 +1190,9 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
if reduce:
|
if reduce:
|
||||||
line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
|
line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
|
||||||
else:
|
else:
|
||||||
assert (
|
assert reduce is None, (
|
||||||
reduce is None
|
"Expect reduce to be None for aten.scatter_ with scalar src"
|
||||||
), "Expect reduce to be None for aten.scatter_ with scalar src"
|
)
|
||||||
line += ");"
|
line += ");"
|
||||||
self.writeline(line)
|
self.writeline(line)
|
||||||
|
|
||||||
|
|
@ -1841,18 +1842,24 @@ class CppWrapperCpu(PythonWrapperCodegen):
|
||||||
# Only treat int Scalar as dynamic
|
# Only treat int Scalar as dynamic
|
||||||
is_int_type = [isinstance(a, int) for a in arg]
|
is_int_type = [isinstance(a, int) for a in arg]
|
||||||
if any(is_int_type):
|
if any(is_int_type):
|
||||||
assert all(
|
assert all(is_int_type), (
|
||||||
is_int_type
|
"AOTInductor only supports int scalars of the same type"
|
||||||
), "AOTInductor only supports int scalars of the same type"
|
)
|
||||||
new_int_args.extend([str(a) for a in arg])
|
new_int_args.extend([str(a) for a in arg])
|
||||||
else:
|
else:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
arg_type.getElementType(), static_arg_types # type: ignore[arg-type]
|
arg_type.getElementType(),
|
||||||
), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
|
static_arg_types, # type: ignore[arg-type]
|
||||||
|
), (
|
||||||
|
f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
arg_type, static_arg_types # type: ignore[arg-type]
|
arg_type,
|
||||||
), f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
|
static_arg_types, # type: ignore[arg-type]
|
||||||
|
), (
|
||||||
|
f"Fall through arguments must be one of static_arg_types, got {type(arg_type)}"
|
||||||
|
)
|
||||||
|
|
||||||
for arg, arg_type in zip(raw_args, arg_types):
|
for arg, arg_type in zip(raw_args, arg_types):
|
||||||
if arg is not None:
|
if arg is not None:
|
||||||
|
|
@ -2378,9 +2385,9 @@ if (custom_op_wrapper.get() == NULL) {
|
||||||
return f"&{var_name}"
|
return f"&{var_name}"
|
||||||
|
|
||||||
if isinstance(type_, torch.ListType):
|
if isinstance(type_, torch.ListType):
|
||||||
assert isinstance(
|
assert isinstance(val, (list, tuple)), (
|
||||||
val, (list, tuple)
|
f"{val} does not match with arg type {type_}"
|
||||||
), f"{val} does not match with arg type {type_}"
|
)
|
||||||
element_type = type_.getElementType()
|
element_type = type_.getElementType()
|
||||||
var_name = f"var_array_{next(self.var_array_id)}"
|
var_name = f"var_array_{next(self.var_array_id)}"
|
||||||
if len(val) == 0:
|
if len(val) == 0:
|
||||||
|
|
|
||||||
|
|
@ -56,9 +56,9 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
|
||||||
self.cached_output_id = count()
|
self.cached_output_id = count()
|
||||||
self.scalar_to_tensor_id = count()
|
self.scalar_to_tensor_id = count()
|
||||||
self.custom_op_wrapper_loaded = False
|
self.custom_op_wrapper_loaded = False
|
||||||
self.allow_stack_allocation: Optional[
|
self.allow_stack_allocation: Optional[bool] = (
|
||||||
bool
|
config.aot_inductor.allow_stack_allocation
|
||||||
] = config.aot_inductor.allow_stack_allocation
|
)
|
||||||
self.stack_allocated_buffers: dict[BufferName, BufferLike] = {}
|
self.stack_allocated_buffers: dict[BufferName, BufferLike] = {}
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -126,12 +126,12 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
|
||||||
Otherwise it uses the CUDA language for codegen.
|
Otherwise it uses the CUDA language for codegen.
|
||||||
Only valid when cuda == True.
|
Only valid when cuda == True.
|
||||||
"""
|
"""
|
||||||
assert (
|
assert not gpu, (
|
||||||
not gpu
|
"CppWrapperCpuArrayRef.generate_kernel_call does not support GPU"
|
||||||
), "CppWrapperCpuArrayRef.generate_kernel_call does not support GPU"
|
)
|
||||||
assert arg_types is not None and len(call_args) == len(
|
assert arg_types is not None and len(call_args) == len(arg_types), (
|
||||||
arg_types
|
"Mismatch call_args and arg_types in generate_kernel_call"
|
||||||
), "Mismatch call_args and arg_types in generate_kernel_call"
|
)
|
||||||
new_args = []
|
new_args = []
|
||||||
for idx, arg in enumerate(call_args):
|
for idx, arg in enumerate(call_args):
|
||||||
if "*" in arg_types[idx]:
|
if "*" in arg_types[idx]:
|
||||||
|
|
@ -328,9 +328,9 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
|
||||||
dtype = may_get_constant_buffer_dtype(
|
dtype = may_get_constant_buffer_dtype(
|
||||||
V.graph.graph_inputs[input_key] # type: ignore[arg-type]
|
V.graph.graph_inputs[input_key] # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
assert (
|
assert dtype is not None, (
|
||||||
dtype is not None
|
"Fails to get the dtype of the sympy.Expr"
|
||||||
), "Fails to get the dtype of the sympy.Expr"
|
)
|
||||||
self.codegen_tensor_item(
|
self.codegen_tensor_item(
|
||||||
dtype, f"inputs[{idx}]", input_key, self.prefix
|
dtype, f"inputs[{idx}]", input_key, self.prefix
|
||||||
)
|
)
|
||||||
|
|
@ -724,9 +724,9 @@ class CppWrapperCpuArrayRef(CppWrapperCpu):
|
||||||
if reduce:
|
if reduce:
|
||||||
line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
|
line += f", {V.graph.wrapper_code.val_to_arg_str(reduce)}"
|
||||||
else:
|
else:
|
||||||
assert (
|
assert reduce is None, (
|
||||||
reduce is None
|
"Expect reduce to be None for aten.scatter_ with scalar src"
|
||||||
), "Expect reduce to be None for aten.scatter_ with scalar src"
|
)
|
||||||
line += ");"
|
line += ");"
|
||||||
self.writeline(line)
|
self.writeline(line)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -60,13 +60,13 @@ class DeferredGpuKernelLine(DeferredLineBase):
|
||||||
# MultiKernel will select one kernel after running the autotune block
|
# MultiKernel will select one kernel after running the autotune block
|
||||||
self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
|
self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
|
||||||
params = CudaKernelParamCache.get(self.kernel_name)
|
params = CudaKernelParamCache.get(self.kernel_name)
|
||||||
assert (
|
assert params is not None, (
|
||||||
params is not None
|
f"{self.kernel_name} not found in CudaKernelParamCache"
|
||||||
), f"{self.kernel_name} not found in CudaKernelParamCache"
|
)
|
||||||
for key in self.keys:
|
for key in self.keys:
|
||||||
assert (
|
assert key in params, (
|
||||||
key in params
|
f"{key} not found in CudaKernelParamCache[{self.kernel_name}]"
|
||||||
), f"{key} not found in CudaKernelParamCache[{self.kernel_name}]"
|
)
|
||||||
if key == get_cpp_wrapper_cubin_path_name():
|
if key == get_cpp_wrapper_cubin_path_name():
|
||||||
assert os.path.exists(params[key]), f"{params[key]} does not exist"
|
assert os.path.exists(params[key]), f"{params[key]} does not exist"
|
||||||
self.additional_files.append(params[key])
|
self.additional_files.append(params[key])
|
||||||
|
|
@ -122,9 +122,9 @@ class DeferredGpuDefaultGrid:
|
||||||
grid_fn = self.grid_callable(*grid, **self.grid_extra_kwargs)
|
grid_fn = self.grid_callable(*grid, **self.grid_extra_kwargs)
|
||||||
|
|
||||||
params = CudaKernelParamCache.get(self.kernel_name)
|
params = CudaKernelParamCache.get(self.kernel_name)
|
||||||
assert (
|
assert params is not None, (
|
||||||
params is not None
|
f"{self.kernel_name} not found in CudaKernelParamCache"
|
||||||
), f"{self.kernel_name} not found in CudaKernelParamCache"
|
)
|
||||||
return grid_fn(params["meta"])
|
return grid_fn(params["meta"])
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -153,9 +153,9 @@ class DeferredGpuGridLine(DeferredLineBase):
|
||||||
self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
|
self.kernel_name = MultiKernelCall.lookup_choice(self.kernel_name)
|
||||||
|
|
||||||
params = CudaKernelParamCache.get(self.kernel_name)
|
params = CudaKernelParamCache.get(self.kernel_name)
|
||||||
assert (
|
assert params is not None, (
|
||||||
params is not None
|
f"{self.kernel_name} not found in CudaKernelParamCache"
|
||||||
), f"{self.kernel_name} not found in CudaKernelParamCache"
|
)
|
||||||
|
|
||||||
if self.autotune_configs is not None:
|
if self.autotune_configs is not None:
|
||||||
# This indicates the Triton kernel is a user-defined one.
|
# This indicates the Triton kernel is a user-defined one.
|
||||||
|
|
@ -248,13 +248,13 @@ class CppWrapperGpu(CppWrapperCpu):
|
||||||
if V.graph.aot_mode and V.graph.inputs_to_check:
|
if V.graph.aot_mode and V.graph.inputs_to_check:
|
||||||
for idx in V.graph.inputs_to_check:
|
for idx in V.graph.inputs_to_check:
|
||||||
input_name = V.graph.graph_input_names[idx]
|
input_name = V.graph.graph_input_names[idx]
|
||||||
assert (
|
assert input_name in V.graph.graph_inputs, (
|
||||||
input_name in V.graph.graph_inputs
|
f"{input_name} not found in graph inputs"
|
||||||
), f"{input_name} not found in graph inputs"
|
)
|
||||||
value = V.graph.graph_inputs[input_name]
|
value = V.graph.graph_inputs[input_name]
|
||||||
assert isinstance(
|
assert isinstance(value, TensorBox), (
|
||||||
value, TensorBox
|
f"{input_name} is expected to be tensor but found as {type(value)}"
|
||||||
), f"{input_name} is expected to be tensor but found as {type(value)}"
|
)
|
||||||
warn_msg = (
|
warn_msg = (
|
||||||
f"Input {idx} was compiled as {GPU_ALIGN_BYTES}-bytes aligned, "
|
f"Input {idx} was compiled as {GPU_ALIGN_BYTES}-bytes aligned, "
|
||||||
"but it is not aligned at run time. Copying to an aligned tensor "
|
"but it is not aligned at run time. Copying to an aligned tensor "
|
||||||
|
|
|
||||||
|
|
@ -87,9 +87,9 @@ class CUDACPPScheduling(BaseScheduling):
|
||||||
Codegen a CUDA template, possibly with fused epilogues
|
Codegen a CUDA template, possibly with fused epilogues
|
||||||
"""
|
"""
|
||||||
counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes)
|
counters["inductor"]["cuda_epilogue_fusion_counter"] += len(epilogue_nodes)
|
||||||
assert self.is_cuda_cpp_template(
|
assert self.is_cuda_cpp_template(template_node), (
|
||||||
template_node
|
"Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer"
|
||||||
), "Template node passed to CUDAScheduler.codegen_template must be a SchedulerNode that wraps a CUDATemplateBuffer"
|
)
|
||||||
template_node = cast(SchedulerNode, template_node)
|
template_node = cast(SchedulerNode, template_node)
|
||||||
_, (_numel, rnumel) = template_node.group
|
_, (_numel, rnumel) = template_node.group
|
||||||
assert rnumel == 1
|
assert rnumel == 1
|
||||||
|
|
|
||||||
|
|
@ -496,7 +496,9 @@ class CUDATemplateCaller(ChoiceCaller):
|
||||||
make_kernel_render: Callable[[CUDATemplateBuffer, Optional[list[IRNode]]], str],
|
make_kernel_render: Callable[[CUDATemplateBuffer, Optional[list[IRNode]]], str],
|
||||||
bmreq: CUDABenchmarkRequest,
|
bmreq: CUDABenchmarkRequest,
|
||||||
template: "CUDATemplate", # type: ignore[name-defined]
|
template: "CUDATemplate", # type: ignore[name-defined]
|
||||||
info_kwargs: Optional[dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]], # type: ignore[type-arg]
|
info_kwargs: Optional[
|
||||||
|
dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]
|
||||||
|
], # type: ignore[type-arg]
|
||||||
description: str,
|
description: str,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(name, input_nodes, layout, description)
|
super().__init__(name, input_nodes, layout, description)
|
||||||
|
|
|
||||||
|
|
@ -71,13 +71,14 @@ class CUDATemplate(KernelTemplate):
|
||||||
A CUDATemplateCaller object representing the generated CUDA template caller.
|
A CUDATemplateCaller object representing the generated CUDA template caller.
|
||||||
"""
|
"""
|
||||||
kernel_name = f"cuda_{self.name}"
|
kernel_name = f"cuda_{self.name}"
|
||||||
with patch.object(
|
with (
|
||||||
V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
|
patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)),
|
||||||
), CUDATemplateKernel(
|
CUDATemplateKernel(
|
||||||
kernel_name=kernel_name,
|
kernel_name=kernel_name,
|
||||||
runtime_arg_info=self.get_runtime_arg_info(),
|
runtime_arg_info=self.get_runtime_arg_info(),
|
||||||
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
|
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
|
||||||
) as kernel:
|
) as kernel,
|
||||||
|
):
|
||||||
code = self.render(kernel=kernel, **kwargs)
|
code = self.render(kernel=kernel, **kwargs)
|
||||||
_, call_args, _, _ = kernel.args.python_argdefs()
|
_, call_args, _, _ = kernel.args.python_argdefs()
|
||||||
autotuning_log.debug("Generated Code:\n%s", code)
|
autotuning_log.debug("Generated Code:\n%s", code)
|
||||||
|
|
|
||||||
|
|
@ -147,7 +147,9 @@ if try_import_cutlass():
|
||||||
"element_d": DataTypeTag[operation.D.element], # type: ignore[name-defined]
|
"element_d": DataTypeTag[operation.D.element], # type: ignore[name-defined]
|
||||||
"layout_d": LayoutTag[instance_layout_D], # type: ignore[name-defined]
|
"layout_d": LayoutTag[instance_layout_D], # type: ignore[name-defined]
|
||||||
"element_accumulator": DataTypeTag[operation.accumulator_type()], # type: ignore[name-defined]
|
"element_accumulator": DataTypeTag[operation.accumulator_type()], # type: ignore[name-defined]
|
||||||
"opcode_class": OpcodeClassTag[operation.tile_description.math_instruction.opcode_class], # type: ignore[name-defined] # noqa: B950
|
"opcode_class": OpcodeClassTag[ # type: ignore[name-defined]
|
||||||
|
operation.tile_description.math_instruction.opcode_class
|
||||||
|
],
|
||||||
"arch": f"cutlass::arch::Sm{operation.arch:d}",
|
"arch": f"cutlass::arch::Sm{operation.arch:d}",
|
||||||
"tile_shape_m": str(operation.tile_description.tile_shape[0]),
|
"tile_shape_m": str(operation.tile_description.tile_shape[0]),
|
||||||
"tile_shape_n": str(operation.tile_description.tile_shape[1]),
|
"tile_shape_n": str(operation.tile_description.tile_shape[1]),
|
||||||
|
|
@ -168,7 +170,9 @@ if try_import_cutlass():
|
||||||
operation.tile_description.math_instruction.instruction_shape[2]
|
operation.tile_description.math_instruction.instruction_shape[2]
|
||||||
),
|
),
|
||||||
"kernel_schedule": str(KernelScheduleTag[operation.kernel_schedule]), # type: ignore[name-defined]
|
"kernel_schedule": str(KernelScheduleTag[operation.kernel_schedule]), # type: ignore[name-defined]
|
||||||
"epilogue_schedule": str(EpilogueScheduleTag[operation.epilogue_schedule]), # type: ignore[name-defined]
|
"epilogue_schedule": str(
|
||||||
|
EpilogueScheduleTag[operation.epilogue_schedule] # type: ignore[name-defined]
|
||||||
|
),
|
||||||
"epilogue_functor": epilogue_functor,
|
"epilogue_functor": epilogue_functor,
|
||||||
"stages": stage_count_string,
|
"stages": stage_count_string,
|
||||||
"align_a": str(operation.A.alignment),
|
"align_a": str(operation.A.alignment),
|
||||||
|
|
|
||||||
|
|
@ -56,9 +56,9 @@ def try_import_cutlass() -> bool:
|
||||||
"Found cutlass_library in python search path, overriding config.cuda.cutlass_dir"
|
"Found cutlass_library in python search path, overriding config.cuda.cutlass_dir"
|
||||||
)
|
)
|
||||||
cutlass_library_dir = os.path.dirname(cutlass_library.__file__)
|
cutlass_library_dir = os.path.dirname(cutlass_library.__file__)
|
||||||
assert os.path.isdir(
|
assert os.path.isdir(cutlass_library_dir), (
|
||||||
cutlass_library_dir
|
f"{cutlass_library_dir} is not a directory"
|
||||||
), f"{cutlass_library_dir} is not a directory"
|
)
|
||||||
config.cuda.cutlass_dir = os.path.abspath(
|
config.cuda.cutlass_dir = os.path.abspath(
|
||||||
os.path.join(
|
os.path.join(
|
||||||
cutlass_library_dir,
|
cutlass_library_dir,
|
||||||
|
|
@ -86,9 +86,9 @@ def try_import_cutlass() -> bool:
|
||||||
if os.path.isdir(cutlass_py_full_path):
|
if os.path.isdir(cutlass_py_full_path):
|
||||||
if tmp_cutlass_py_full_path not in sys.path:
|
if tmp_cutlass_py_full_path not in sys.path:
|
||||||
if os.path.exists(dst_link):
|
if os.path.exists(dst_link):
|
||||||
assert os.path.islink(
|
assert os.path.islink(dst_link), (
|
||||||
dst_link
|
f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
|
||||||
), f"{dst_link} is not a symlink. Try to remove {dst_link} manually and try again."
|
)
|
||||||
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
|
assert os.path.realpath(os.readlink(dst_link)) == os.path.realpath(
|
||||||
cutlass_py_full_path
|
cutlass_py_full_path
|
||||||
), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
|
), f"Symlink at {dst_link} does not point to {cutlass_py_full_path}"
|
||||||
|
|
|
||||||
|
|
@ -949,9 +949,9 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
||||||
import cutlass_library.gemm_operation as cutlass_gemm_op
|
import cutlass_library.gemm_operation as cutlass_gemm_op
|
||||||
import cutlass_library.library as cutlass_lib
|
import cutlass_library.library as cutlass_lib
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(op, cutlass_gemm_op.GemmOperation), (
|
||||||
op, cutlass_gemm_op.GemmOperation
|
"op argument is required and has to be an instance of GemmOperation"
|
||||||
), "op argument is required and has to be an instance of GemmOperation"
|
)
|
||||||
|
|
||||||
assert len(self.input_nodes) >= 2 and self.output_node is not None
|
assert len(self.input_nodes) >= 2 and self.output_node is not None
|
||||||
X, W = self.input_nodes[0], self.input_nodes[1]
|
X, W = self.input_nodes[0], self.input_nodes[1]
|
||||||
|
|
@ -977,7 +977,10 @@ class CUTLASSGemmTemplate(CUTLASSTemplate, ABC):
|
||||||
else:
|
else:
|
||||||
input_reorder = None
|
input_reorder = None
|
||||||
kernel_call_signature = kernel.def_kernel(
|
kernel_call_signature = kernel.def_kernel(
|
||||||
inputs=inputs, outputs=[Y], names_str=names_str, input_reorder=input_reorder # type: ignore[arg-type]
|
inputs=inputs, # type: ignore[arg-type]
|
||||||
|
outputs=[Y],
|
||||||
|
names_str=names_str,
|
||||||
|
input_reorder=input_reorder,
|
||||||
)
|
)
|
||||||
test_call_statement = self.test_call_statement(kernel, inputs, names_str)
|
test_call_statement = self.test_call_statement(kernel, inputs, names_str)
|
||||||
# The layouts might have changed between autotuning and this call if they were FlexibleLayout
|
# The layouts might have changed between autotuning and this call if they were FlexibleLayout
|
||||||
|
|
|
||||||
|
|
@ -198,7 +198,7 @@ class HalidePrinter(PythonPrinter):
|
||||||
val, n = expr.args
|
val, n = expr.args
|
||||||
val = self._print(val)
|
val = self._print(val)
|
||||||
n = int(n)
|
n = int(n)
|
||||||
return f"hl.f32({10.**(-n)!r})*hl.round(({val})*hl.f32({10.**n!r}))"
|
return f"hl.f32({10.0 ** (-n)!r})*hl.round(({val})*hl.f32({10.0**n!r}))"
|
||||||
|
|
||||||
|
|
||||||
texpr = HalidePrinter().doprint
|
texpr = HalidePrinter().doprint
|
||||||
|
|
@ -856,11 +856,11 @@ class HalideKernel(SIMDKernel):
|
||||||
for sym, size in added_sym_size:
|
for sym, size in added_sym_size:
|
||||||
full_index += stride * sym
|
full_index += stride * sym
|
||||||
stride *= size
|
stride *= size
|
||||||
self.index_replacements[
|
self.index_replacements[node.symbol()] = (
|
||||||
node.symbol()
|
V.graph.sizevars.simplify_with_ranges(
|
||||||
] = V.graph.sizevars.simplify_with_ranges(
|
ModularIndexing(full_index, node.divisor, node.length),
|
||||||
ModularIndexing(full_index, node.divisor, node.length),
|
self.halide_vars, # type: ignore[arg-type]
|
||||||
self.halide_vars, # type: ignore[arg-type]
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# codegen the variable definitions
|
# codegen the variable definitions
|
||||||
|
|
@ -1183,9 +1183,9 @@ class HalideKernel(SIMDKernel):
|
||||||
|
|
||||||
if isinstance(value, tuple):
|
if isinstance(value, tuple):
|
||||||
assert reduction_type == "welford_combine"
|
assert reduction_type == "welford_combine"
|
||||||
self.cse.reduction_cache[
|
self.cse.reduction_cache[cache_key] = result_tuple = (
|
||||||
cache_key
|
self.welford_combine_impl(*value)
|
||||||
] = result_tuple = self.welford_combine_impl(*value)
|
)
|
||||||
return result_tuple
|
return result_tuple
|
||||||
|
|
||||||
assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
|
assert isinstance(value, HalideCSEVariable) and value.used_dims is not None
|
||||||
|
|
@ -1304,9 +1304,9 @@ class HalideKernel(SIMDKernel):
|
||||||
scan = f"{scan_dom}.x"
|
scan = f"{scan_dom}.x"
|
||||||
self.body.writeline(f"{scan_dom} = hl.RDom([hl.Range(1, {length})])")
|
self.body.writeline(f"{scan_dom} = hl.RDom([hl.Range(1, {length})])")
|
||||||
|
|
||||||
assert (
|
assert len(self.reduction_renames) == 1, (
|
||||||
len(self.reduction_renames) == 1
|
"multi-dimensional scan not implemented"
|
||||||
), "multi-dimensional scan not implemented"
|
)
|
||||||
(scan_var,) = [*self.reduction_renames] # type: ignore[misc]
|
(scan_var,) = [*self.reduction_renames] # type: ignore[misc]
|
||||||
scan_renames_cur = {scan_var: sympy_index_symbol(scan)}
|
scan_renames_cur = {scan_var: sympy_index_symbol(scan)}
|
||||||
scan_renames_pri = {scan_var: sympy_index_symbol(scan) - 1}
|
scan_renames_pri = {scan_var: sympy_index_symbol(scan) - 1}
|
||||||
|
|
|
||||||
|
|
@ -214,8 +214,7 @@ class MemorySplitProtocol(Protocol):
|
||||||
get_size_hint: CachedMethod[[], int]
|
get_size_hint: CachedMethod[[], int]
|
||||||
get_symbolic_size: CachedMethod[[], sympy.Expr]
|
get_symbolic_size: CachedMethod[[], sympy.Expr]
|
||||||
|
|
||||||
def _allocate(self, block: Allocation, is_last: bool) -> bool:
|
def _allocate(self, block: Allocation, is_last: bool) -> bool: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class ClearCacheOnAllocateMixin(MemorySplitProtocol):
|
class ClearCacheOnAllocateMixin(MemorySplitProtocol):
|
||||||
|
|
|
||||||
|
|
@ -560,7 +560,10 @@ class MetalKernel(SIMDKernel):
|
||||||
threads = [self.pexpr(v.numel) for v in self.active_range_trees()] # type: ignore[misc]
|
threads = [self.pexpr(v.numel) for v in self.active_range_trees()] # type: ignore[misc]
|
||||||
args += [f"threads=[{', '.join(threads)}]"]
|
args += [f"threads=[{', '.join(threads)}]"]
|
||||||
if self.inside_reduction:
|
if self.inside_reduction:
|
||||||
threads = [self.pexpr(v.numel) if v.is_reduction else "1" for v in self.active_range_trees()] # type: ignore[misc]
|
threads = [
|
||||||
|
self.pexpr(v.numel) if v.is_reduction else "1" # type: ignore[misc]
|
||||||
|
for v in self.active_range_trees()
|
||||||
|
]
|
||||||
args += [f"group_size=[{', '.join(threads)}]"]
|
args += [f"group_size=[{', '.join(threads)}]"]
|
||||||
|
|
||||||
wrapper.generate_kernel_call(
|
wrapper.generate_kernel_call(
|
||||||
|
|
|
||||||
|
|
@ -33,9 +33,9 @@ def _get_all_args(args_list, arg_types_list=None):
|
||||||
all_args = max(args_list, key=len)[:]
|
all_args = max(args_list, key=len)[:]
|
||||||
arg_types = max(arg_types_list, key=len)[:] if arg_types_list is not None else None
|
arg_types = max(arg_types_list, key=len)[:] if arg_types_list is not None else None
|
||||||
for args in args_list:
|
for args in args_list:
|
||||||
assert OrderedSet(args).issubset(
|
assert OrderedSet(args).issubset(OrderedSet(all_args)), (
|
||||||
OrderedSet(all_args)
|
f"{args} v.s. {all_args}"
|
||||||
), f"{args} v.s. {all_args}"
|
)
|
||||||
|
|
||||||
return all_args, arg_types
|
return all_args, arg_types
|
||||||
|
|
||||||
|
|
@ -149,7 +149,9 @@ class MultiKernel:
|
||||||
Assume we do codegen for a MultiKernel encapsulating kernel1 and kernel2.
|
Assume we do codegen for a MultiKernel encapsulating kernel1 and kernel2.
|
||||||
The generated definition for the multi-kernel will looks like:
|
The generated definition for the multi-kernel will looks like:
|
||||||
```
|
```
|
||||||
multi_kernel_kernel1 = MultiKernelCall([kernel1, kernel2], multi_kernel_definition_code)
|
multi_kernel_kernel1 = MultiKernelCall(
|
||||||
|
[kernel1, kernel2], multi_kernel_definition_code
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
Here is an concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39
|
Here is an concrete example: https://gist.github.com/shunting314/d9f3fb6bc6cee3dbae005825ca196d39
|
||||||
|
|
|
||||||
|
|
@ -516,7 +516,12 @@ class CKGroupedConvFwdTemplate(CKTemplate):
|
||||||
template_params=(",\n" + 12 * " ").join(template_params),
|
template_params=(",\n" + 12 * " ").join(template_params),
|
||||||
), self._template_from_string(template_type).render(operation_name=op.name())
|
), self._template_from_string(template_type).render(operation_name=op.name())
|
||||||
|
|
||||||
def render(self, kernel: ROCmTemplateKernel, op: "CKGroupedConvFwdOp", **kwargs) -> str: # type: ignore[override, name-defined]
|
def render( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
kernel: ROCmTemplateKernel,
|
||||||
|
op: "CKGroupedConvFwdOp", # type: ignore[name-defined]
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
template_buffer_node = kwargs.get("template_buffer_node", None)
|
template_buffer_node = kwargs.get("template_buffer_node", None)
|
||||||
if template_buffer_node is not None:
|
if template_buffer_node is not None:
|
||||||
self.output_node = template_buffer_node
|
self.output_node = template_buffer_node
|
||||||
|
|
|
||||||
|
|
@ -602,7 +602,12 @@ class CKGemmTemplate(CKTemplate):
|
||||||
operation_name=operation_name
|
operation_name=operation_name
|
||||||
)
|
)
|
||||||
|
|
||||||
def render(self, kernel: ROCmTemplateKernel, op: "CKGemmOperation", **kwargs) -> str: # type: ignore[override]
|
def render( # type: ignore[override]
|
||||||
|
self,
|
||||||
|
kernel: ROCmTemplateKernel,
|
||||||
|
op: "CKGemmOperation",
|
||||||
|
**kwargs,
|
||||||
|
) -> str:
|
||||||
"""
|
"""
|
||||||
The primary entry point for the code rendering process used in this template.
|
The primary entry point for the code rendering process used in this template.
|
||||||
"""
|
"""
|
||||||
|
|
@ -706,7 +711,7 @@ class CKGemmTemplate(CKTemplate):
|
||||||
* Template instance {op}
|
* Template instance {op}
|
||||||
*
|
*
|
||||||
* {torch.__version__=}
|
* {torch.__version__=}
|
||||||
* torch.version.git_version={getattr(torch.version, 'git_version', 'None')}
|
* torch.version.git_version={getattr(torch.version, "git_version", "None")}
|
||||||
*/
|
*/
|
||||||
"""
|
"""
|
||||||
epilogue = None
|
epilogue = None
|
||||||
|
|
|
||||||
|
|
@ -79,9 +79,9 @@ class ROCmCPPScheduling(BaseScheduling):
|
||||||
"""
|
"""
|
||||||
Codegen a ROCm template, possibly with fused epilogues
|
Codegen a ROCm template, possibly with fused epilogues
|
||||||
"""
|
"""
|
||||||
assert self.is_rocm_cpp_template(
|
assert self.is_rocm_cpp_template(template_node), (
|
||||||
template_node
|
"Template node passed to ROCmScheduler.codegen_template must be a SchedulerNode that wraps a ROCmTemplateBuffer"
|
||||||
), "Template node passed to ROCmScheduler.codegen_template must be a SchedulerNode that wraps a ROCmTemplateBuffer"
|
)
|
||||||
template_node = cast(SchedulerNode, template_node)
|
template_node = cast(SchedulerNode, template_node)
|
||||||
_, (_numel, rnumel) = template_node.group
|
_, (_numel, rnumel) = template_node.group
|
||||||
assert rnumel == 1
|
assert rnumel == 1
|
||||||
|
|
|
||||||
|
|
@ -232,7 +232,9 @@ class ROCmTemplateCaller(ChoiceCaller):
|
||||||
],
|
],
|
||||||
bmreq: ROCmBenchmarkRequest,
|
bmreq: ROCmBenchmarkRequest,
|
||||||
template: "ROCmTemplate", # type: ignore[name-defined]
|
template: "ROCmTemplate", # type: ignore[name-defined]
|
||||||
info_kwargs: Optional[dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]], # type: ignore[type-arg]
|
info_kwargs: Optional[
|
||||||
|
dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]
|
||||||
|
], # type: ignore[type-arg]
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(name, input_nodes, layout, description="")
|
super().__init__(name, input_nodes, layout, description="")
|
||||||
self.category = category
|
self.category = category
|
||||||
|
|
|
||||||
|
|
@ -70,13 +70,14 @@ class ROCmTemplate(KernelTemplate):
|
||||||
"""
|
"""
|
||||||
kernel_name = f"rocm_{self.name}"
|
kernel_name = f"rocm_{self.name}"
|
||||||
kernel_hash_name = f"rocm_{self.name}_{next(self.index_counter)}"
|
kernel_hash_name = f"rocm_{self.name}_{next(self.index_counter)}"
|
||||||
with patch.object(
|
with (
|
||||||
V.graph, "get_dtype", self._fake_get_dtype(self.output_node)
|
patch.object(V.graph, "get_dtype", self._fake_get_dtype(self.output_node)),
|
||||||
), ROCmTemplateKernel(
|
ROCmTemplateKernel(
|
||||||
kernel_name=kernel_name,
|
kernel_name=kernel_name,
|
||||||
runtime_arg_info=self.get_runtime_arg_info(),
|
runtime_arg_info=self.get_runtime_arg_info(),
|
||||||
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
|
runtime_arg_values=self.get_runtime_arg_values(**kwargs),
|
||||||
) as kernel:
|
) as kernel,
|
||||||
|
):
|
||||||
code = self.render(kernel=kernel, **kwargs)
|
code = self.render(kernel=kernel, **kwargs)
|
||||||
_, call_args, _, _ = kernel.args.python_argdefs()
|
_, call_args, _, _ = kernel.args.python_argdefs()
|
||||||
log.debug("Autotune key: %s, Generated Code:\n%s", kernel_hash_name, code)
|
log.debug("Autotune key: %s, Generated Code:\n%s", kernel_hash_name, code)
|
||||||
|
|
|
||||||
|
|
@ -638,7 +638,8 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
while current_group < len(remaining) and sv.statically_known_equals(
|
while current_group < len(remaining) and sv.statically_known_equals(
|
||||||
remaining[current_group], 1 # type: ignore[arg-type]
|
remaining[current_group],
|
||||||
|
1, # type: ignore[arg-type]
|
||||||
):
|
):
|
||||||
# scroll to next group with remaining elements
|
# scroll to next group with remaining elements
|
||||||
current_group += 1
|
current_group += 1
|
||||||
|
|
@ -666,9 +667,9 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||||
)
|
)
|
||||||
return_getters_groups.append(return_getters)
|
return_getters_groups.append(return_getters)
|
||||||
|
|
||||||
assert all(
|
assert all(V.graph.sizevars.size_hint(s) == 1 for s in remaining), (
|
||||||
V.graph.sizevars.size_hint(s) == 1 for s in remaining
|
f"failed to set ranges {remaining} {lengths}"
|
||||||
), f"failed to set ranges {remaining} {lengths}"
|
)
|
||||||
|
|
||||||
return new_ranges, return_getters_groups
|
return new_ranges, return_getters_groups
|
||||||
|
|
||||||
|
|
@ -836,7 +837,8 @@ class SIMDKernel(Kernel[CSEVariableType], Generic[CSEVariableType]):
|
||||||
replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps)
|
replacements[ps] = V.graph.sizevars.lookup_precomputed_size(ps)
|
||||||
if len(replacements) > 0:
|
if len(replacements) > 0:
|
||||||
self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index]
|
self.range_tree_nodes[sym].expr = sympy_subs( # type: ignore[index]
|
||||||
self.range_tree_nodes[sym].expr, replacements # type: ignore[index]
|
self.range_tree_nodes[sym].expr,
|
||||||
|
replacements, # type: ignore[index]
|
||||||
)
|
)
|
||||||
self.range_tree_nodes[sym].codegen() # type: ignore[index]
|
self.range_tree_nodes[sym].codegen() # type: ignore[index]
|
||||||
return expr
|
return expr
|
||||||
|
|
@ -2071,9 +2073,10 @@ class SIMDScheduling(BaseScheduling):
|
||||||
features=SIMDKernelFeatures(node_schedule, numel, rnumel),
|
features=SIMDKernelFeatures(node_schedule, numel, rnumel),
|
||||||
)
|
)
|
||||||
self.codegen_node_schedule_with_kernel(node_schedule, kernel)
|
self.codegen_node_schedule_with_kernel(node_schedule, kernel)
|
||||||
with config.patch(
|
with (
|
||||||
"benchmark_kernel", benchmark_kernel
|
config.patch("benchmark_kernel", benchmark_kernel),
|
||||||
), V.set_kernel_handler(kernel):
|
V.set_kernel_handler(kernel),
|
||||||
|
):
|
||||||
src_code = kernel.codegen_kernel()
|
src_code = kernel.codegen_kernel()
|
||||||
else:
|
else:
|
||||||
prologue, template, epilogue = nodes[0].get_prologue_template_epilogue(
|
prologue, template, epilogue = nodes[0].get_prologue_template_epilogue(
|
||||||
|
|
|
||||||
|
|
@ -1579,9 +1579,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||||
self.block_ptr_id = itertools.count()
|
self.block_ptr_id = itertools.count()
|
||||||
self.block_ptr_to_buffer = dict[str, str]()
|
self.block_ptr_to_buffer = dict[str, str]()
|
||||||
self.helper_functions = HelperFunctions()
|
self.helper_functions = HelperFunctions()
|
||||||
self.pointer_advancements: dict[
|
self.pointer_advancements: dict[SymT, dict[str, list[sympy.Expr]]] = (
|
||||||
SymT, dict[str, list[sympy.Expr]]
|
collections.defaultdict(dict)
|
||||||
] = collections.defaultdict(dict)
|
)
|
||||||
self._load_counts: collections.Counter[str] = collections.Counter()
|
self._load_counts: collections.Counter[str] = collections.Counter()
|
||||||
|
|
||||||
# A set of autotuning hints to pass as part of triton_meta
|
# A set of autotuning hints to pass as part of triton_meta
|
||||||
|
|
@ -2053,9 +2053,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||||
continue
|
continue
|
||||||
|
|
||||||
advancements = self.pointer_advancements[symt]
|
advancements = self.pointer_advancements[symt]
|
||||||
assert (
|
assert block_ptr not in advancements, (
|
||||||
block_ptr not in advancements
|
"duplicate advancement for pointer '{block_ptr}' at type '{symt}'"
|
||||||
), "duplicate advancement for pointer '{block_ptr}' at type '{symt}'"
|
)
|
||||||
advancements[block_ptr] = advance_offsets
|
advancements[block_ptr] = advance_offsets
|
||||||
else:
|
else:
|
||||||
block_ptr = indexing.format(var)
|
block_ptr = indexing.format(var)
|
||||||
|
|
@ -2476,7 +2476,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||||
buffer.splice(
|
buffer.splice(
|
||||||
f"""\
|
f"""\
|
||||||
{result_var}_val, {result_var}_idx = triton_helpers.{root_op}_with_index({value}, {index}, {dim})
|
{result_var}_val, {result_var}_idx = triton_helpers.{root_op}_with_index({value}, {index}, {dim})
|
||||||
{result_var} = {self.reduction_resize(f'{result_var}_idx')}
|
{result_var} = {self.reduction_resize(f"{result_var}_idx")}
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -2576,8 +2576,8 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||||
{accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index(
|
{accumulator}_next, {accumulator_index}_next = triton_helpers.{root_op}imum_with_index(
|
||||||
{accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index
|
{accumulator}, {accumulator_index}, {value}, {reduction_range_prefix}index
|
||||||
)
|
)
|
||||||
{accumulator} = {where_cond(f'{accumulator}_next', accumulator)}
|
{accumulator} = {where_cond(f"{accumulator}_next", accumulator)}
|
||||||
{accumulator_index} = {where_cond(f'{accumulator_index}_next', accumulator_index)}
|
{accumulator_index} = {where_cond(f"{accumulator_index}_next", accumulator_index)}
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
final_argreduce(
|
final_argreduce(
|
||||||
|
|
@ -2751,9 +2751,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||||
)
|
)
|
||||||
self.compute.splice(
|
self.compute.splice(
|
||||||
f"""\
|
f"""\
|
||||||
{accumulator} = {where_cond(f'{accumulator}_next', accumulator)}
|
{accumulator} = {where_cond(f"{accumulator}_next", accumulator)}
|
||||||
{accumulator_m2} = {where_cond(f'{accumulator_m2}_next', accumulator_m2)}
|
{accumulator_m2} = {where_cond(f"{accumulator_m2}_next", accumulator_m2)}
|
||||||
{accumulator_weight} = {where_cond(f'{accumulator_weight}_next', accumulator_weight)}
|
{accumulator_weight} = {where_cond(f"{accumulator_weight}_next", accumulator_weight)}
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
result_mean = result_var
|
result_mean = result_var
|
||||||
|
|
@ -3040,9 +3040,9 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||||
self.filter_masks(masks)
|
self.filter_masks(masks)
|
||||||
masks = sorted(masks)
|
masks = sorted(masks)
|
||||||
assert not self._load_mask, "ops.sort not supported inside ops.masked"
|
assert not self._load_mask, "ops.sort not supported inside ops.masked"
|
||||||
assert (
|
assert self.persistent_reduction, (
|
||||||
self.persistent_reduction
|
"ops.sort is only supported in persistent reductions"
|
||||||
), "ops.sort is only supported in persistent reductions"
|
)
|
||||||
|
|
||||||
cse_compute = functools.partial(self.cse.generate, self.compute)
|
cse_compute = functools.partial(self.cse.generate, self.compute)
|
||||||
dim = self.triton_tensor_ndim() - self.num_reduction_dims
|
dim = self.triton_tensor_ndim() - self.num_reduction_dims
|
||||||
|
|
@ -3302,9 +3302,7 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||||
{}
|
{}
|
||||||
import torch
|
import torch
|
||||||
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
|
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid
|
||||||
""".format(
|
""".format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream"))
|
||||||
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_heuristic(self):
|
def _get_heuristic(self):
|
||||||
|
|
@ -3344,19 +3342,19 @@ class TritonKernel(SIMDKernel[TritonCSEVariable]):
|
||||||
inductor_meta["profile_bandwidth"] = config.profile_bandwidth
|
inductor_meta["profile_bandwidth"] = config.profile_bandwidth
|
||||||
inductor_meta["profile_bandwidth_regex"] = config.profile_bandwidth_regex
|
inductor_meta["profile_bandwidth_regex"] = config.profile_bandwidth_regex
|
||||||
inductor_meta["profile_bandwidth_output"] = config.profile_bandwidth_output
|
inductor_meta["profile_bandwidth_output"] = config.profile_bandwidth_output
|
||||||
inductor_meta[
|
inductor_meta["profile_bandwidth_with_do_bench_using_profiling"] = (
|
||||||
"profile_bandwidth_with_do_bench_using_profiling"
|
config.profile_bandwidth_with_do_bench_using_profiling
|
||||||
] = config.profile_bandwidth_with_do_bench_using_profiling
|
)
|
||||||
if config.coordinate_descent_tuning:
|
if config.coordinate_descent_tuning:
|
||||||
inductor_meta[
|
inductor_meta["coordinate_descent_tuning"] = (
|
||||||
"coordinate_descent_tuning"
|
config.coordinate_descent_tuning
|
||||||
] = config.coordinate_descent_tuning
|
)
|
||||||
inductor_meta[
|
inductor_meta["coordinate_descent_search_radius"] = (
|
||||||
"coordinate_descent_search_radius"
|
config.coordinate_descent_search_radius
|
||||||
] = config.coordinate_descent_search_radius
|
)
|
||||||
inductor_meta[
|
inductor_meta["coordinate_descent_check_all_directions"] = (
|
||||||
"coordinate_descent_check_all_directions"
|
config.coordinate_descent_check_all_directions
|
||||||
] = config.coordinate_descent_check_all_directions
|
)
|
||||||
return inductor_meta
|
return inductor_meta
|
||||||
|
|
||||||
def codegen_kernel(self, name=None):
|
def codegen_kernel(self, name=None):
|
||||||
|
|
@ -4046,9 +4044,10 @@ class TritonScheduling(SIMDScheduling):
|
||||||
) -> tuple[float, str]:
|
) -> tuple[float, str]:
|
||||||
"""Benchmark an already compiled module"""
|
"""Benchmark an already compiled module"""
|
||||||
device_interface = get_interface_for_device(V.graph.device_type)
|
device_interface = get_interface_for_device(V.graph.device_type)
|
||||||
with preserve_rng_state(), device_interface.device(
|
with (
|
||||||
V.graph.get_current_device_or_throw()
|
preserve_rng_state(),
|
||||||
): # type: ignore[attr-defined]
|
device_interface.device(V.graph.get_current_device_or_throw()), # type: ignore[attr-defined]
|
||||||
|
):
|
||||||
ms = None
|
ms = None
|
||||||
|
|
||||||
def cache_file_path():
|
def cache_file_path():
|
||||||
|
|
@ -4322,9 +4321,9 @@ def debug_triton_code(node: BaseSchedulerNode) -> list[str]:
|
||||||
device = node.get_device()
|
device = node.get_device()
|
||||||
assert device is not None
|
assert device is not None
|
||||||
backend = node.scheduler.get_backend(device)
|
backend = node.scheduler.get_backend(device)
|
||||||
assert isinstance(
|
assert isinstance(backend, (SIMDScheduling, CUDACombinedScheduling)), (
|
||||||
backend, (SIMDScheduling, CUDACombinedScheduling)
|
f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}"
|
||||||
), f"Scheduling backend should be SIMD or CUDACombined when generating debug Triton strings, got: {type(backend)}"
|
)
|
||||||
|
|
||||||
with V.graph.set_current_device(device):
|
with V.graph.set_current_device(device):
|
||||||
# Don't increment kernel count when generating debug string.
|
# Don't increment kernel count when generating debug string.
|
||||||
|
|
|
||||||
|
|
@ -86,7 +86,9 @@ def _default_custom_combo_kernel_horizontal_partition(
|
||||||
# rnumel > 2048 usually has long execution time
|
# rnumel > 2048 usually has long execution time
|
||||||
# BaseSchedulerNode.group[-1][-1] is rnumel for reduction nodes
|
# BaseSchedulerNode.group[-1][-1] is rnumel for reduction nodes
|
||||||
long_reduction = [
|
long_reduction = [
|
||||||
n for n in reduction if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 # type: ignore[arg-type]
|
n
|
||||||
|
for n in reduction
|
||||||
|
if V.graph.sizevars.size_hint(n.group[-1][-1]) > 2048 # type: ignore[arg-type]
|
||||||
]
|
]
|
||||||
short_reduction = [n for n in reduction if n not in long_reduction]
|
short_reduction = [n for n in reduction if n not in long_reduction]
|
||||||
if long_reduction:
|
if long_reduction:
|
||||||
|
|
@ -138,7 +140,7 @@ def set_custom_combo_kernel_horizontal_partition(
|
||||||
dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]],
|
dict[BaseSchedulerNode, tuple[Any, Any, Any, Any]],
|
||||||
],
|
],
|
||||||
list[list[BaseSchedulerNode]],
|
list[list[BaseSchedulerNode]],
|
||||||
]
|
],
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Sets the algorithm used to partition nodes into horizontal partitions. Nodes in different partitions
|
"""Sets the algorithm used to partition nodes into horizontal partitions. Nodes in different partitions
|
||||||
are implemented in different combo kernels. Nodes in the same partition are likely to be implemented
|
are implemented in different combo kernels. Nodes in the same partition are likely to be implemented
|
||||||
|
|
@ -593,9 +595,9 @@ class ComboKernel(Kernel):
|
||||||
num_persistent_reduction = len(
|
num_persistent_reduction = len(
|
||||||
[e for e in heuristics_list if e == "persistent_reduction"]
|
[e for e in heuristics_list if e == "persistent_reduction"]
|
||||||
)
|
)
|
||||||
assert (
|
assert num_reduction == 0, (
|
||||||
num_reduction == 0
|
"combining pointwise and reduction are not supported yet."
|
||||||
), "combining pointwise and reduction are not supported yet."
|
)
|
||||||
heuristics = (
|
heuristics = (
|
||||||
"pointwise_with_reduction"
|
"pointwise_with_reduction"
|
||||||
if num_persistent_reduction > 0
|
if num_persistent_reduction > 0
|
||||||
|
|
@ -784,13 +786,13 @@ class ComboKernel(Kernel):
|
||||||
name, tree, suffix=str(num)
|
name, tree, suffix=str(num)
|
||||||
)
|
)
|
||||||
if not tree.is_reduction:
|
if not tree.is_reduction:
|
||||||
assert isinstance(
|
assert isinstance(grid[i][num], str), (
|
||||||
grid[i][num], str
|
f"Grid {grid[i][num]} should be a dynamic shape."
|
||||||
), f"Grid {grid[i][num]} should be a dynamic shape."
|
)
|
||||||
numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else ""
|
numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else ""
|
||||||
assert (
|
assert grid[i][num] == numel_sign + numel_name, (
|
||||||
grid[i][num] == numel_sign + numel_name
|
f"numel args mismatch: {grid[i][num]} vs {numel_name}"
|
||||||
), f"numel args mismatch: {grid[i][num]} vs {numel_name}"
|
)
|
||||||
grid[i][num] = -expr if numel_sign == "-" else expr
|
grid[i][num] = -expr if numel_sign == "-" else expr
|
||||||
|
|
||||||
if not tree.is_reduction or sub_kernel.inside_reduction:
|
if not tree.is_reduction or sub_kernel.inside_reduction:
|
||||||
|
|
@ -807,13 +809,13 @@ class ComboKernel(Kernel):
|
||||||
continue
|
continue
|
||||||
expr = V.graph.sizevars.size_hint(tree.numel)
|
expr = V.graph.sizevars.size_hint(tree.numel)
|
||||||
if not tree.is_reduction:
|
if not tree.is_reduction:
|
||||||
assert isinstance(
|
assert isinstance(grid[i][num], str), (
|
||||||
grid[i][num], str
|
f"Grid {grid[i][num]} should be a dynamic shape."
|
||||||
), f"Grid {grid[i][num]} should be a dynamic shape."
|
)
|
||||||
numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else ""
|
numel_sign = grid[i][num][0] if grid[i][num][0] == "-" else ""
|
||||||
assert (
|
assert grid[i][num] == numel_sign + numel_name, (
|
||||||
grid[i][num] == numel_sign + numel_name
|
f"grid mismatch: {grid[i][num]} vs {numel_name}"
|
||||||
), f"grid mismatch: {grid[i][num]} vs {numel_name}"
|
)
|
||||||
grid[i][num] = -expr if numel_sign == "-" else expr
|
grid[i][num] = -expr if numel_sign == "-" else expr
|
||||||
if not tree.is_reduction or sub_kernel.inside_reduction:
|
if not tree.is_reduction or sub_kernel.inside_reduction:
|
||||||
extra_args.append(expr)
|
extra_args.append(expr)
|
||||||
|
|
@ -1015,9 +1017,7 @@ class ComboKernel(Kernel):
|
||||||
{}
|
{}
|
||||||
import torch
|
import torch
|
||||||
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels
|
from torch._inductor.runtime.triton_heuristics import grid, split_scan_grid, grid_combo_kernels
|
||||||
""".format(
|
""".format(V.graph.device_ops.import_get_raw_stream_as("get_raw_stream"))
|
||||||
V.graph.device_ops.import_get_raw_stream_as("get_raw_stream")
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def uniquify_block_sizes(
|
def uniquify_block_sizes(
|
||||||
|
|
|
||||||
|
|
@ -57,9 +57,9 @@ class TritonSplitScanKernel(TritonKernel):
|
||||||
|
|
||||||
def initialize_range_tree(self, pid_cache):
|
def initialize_range_tree(self, pid_cache):
|
||||||
prefixes = ["y", "x", "r0_"]
|
prefixes = ["y", "x", "r0_"]
|
||||||
assert len(self.numels) <= len(
|
assert len(self.numels) <= len(prefixes), (
|
||||||
prefixes
|
"z dimension not supported for split scan"
|
||||||
), "z dimension not supported for split scan"
|
)
|
||||||
active_prefixes = prefixes[len(prefixes) - len(self.numels) :]
|
active_prefixes = prefixes[len(prefixes) - len(self.numels) :]
|
||||||
|
|
||||||
grid_dims = {"r0_": 0, "x": 1, "y": 2}
|
grid_dims = {"r0_": 0, "x": 1, "y": 2}
|
||||||
|
|
|
||||||
|
|
@ -184,7 +184,8 @@ def config_of(
|
||||||
if isinstance(x, TensorArg):
|
if isinstance(x, TensorArg):
|
||||||
if include_tensor:
|
if include_tensor:
|
||||||
offset_aligned = V.graph.sizevars.statically_known_multiple_of(
|
offset_aligned = V.graph.sizevars.statically_known_multiple_of(
|
||||||
x.offset * x.dtype.itemsize, alignment # type: ignore[arg-type]
|
x.offset * x.dtype.itemsize,
|
||||||
|
alignment, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
return offset_aligned and not is_unaligned_buffer(x)
|
return offset_aligned and not is_unaligned_buffer(x)
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -104,8 +104,7 @@ def can_match_buffer_size(input_buf: BufferLike, output_buf: BufferLike):
|
||||||
# NB: this is symbolic so that we don't try to reuse a buffer
|
# NB: this is symbolic so that we don't try to reuse a buffer
|
||||||
# for s0 for s1, just because they happen to share the same
|
# for s0 for s1, just because they happen to share the same
|
||||||
# size hint
|
# size hint
|
||||||
sympy_str(input_size)
|
sympy_str(input_size) == sympy_str(output_size)
|
||||||
== sympy_str(output_size)
|
|
||||||
) or (
|
) or (
|
||||||
# statically known that 0.95 * input_size <= output_size <= input_size
|
# statically known that 0.95 * input_size <= output_size <= input_size
|
||||||
V.graph.sizevars.statically_known_geq(output_size, 0.95 * input_size)
|
V.graph.sizevars.statically_known_geq(output_size, 0.95 * input_size)
|
||||||
|
|
@ -138,9 +137,9 @@ def convert_arg_type(arg: torch.Argument) -> str:
|
||||||
container_match = re.findall(py_container + r"\[([a-zA-Z_]+)]", python_type)
|
container_match = re.findall(py_container + r"\[([a-zA-Z_]+)]", python_type)
|
||||||
if len(container_match) == 1:
|
if len(container_match) == 1:
|
||||||
contained_type = container_match[0]
|
contained_type = container_match[0]
|
||||||
assert (
|
assert contained_type in PYTHON_TO_CPP, (
|
||||||
contained_type in PYTHON_TO_CPP
|
f"unsupported {py_container} type in convert_arg_type: {contained_type}"
|
||||||
), f"unsupported {py_container} type in convert_arg_type: {contained_type}"
|
)
|
||||||
cpp_contained_type = PYTHON_TO_CPP[contained_type]
|
cpp_contained_type = PYTHON_TO_CPP[contained_type]
|
||||||
return f"{cpp_container}<{cpp_contained_type}>"
|
return f"{cpp_container}<{cpp_contained_type}>"
|
||||||
|
|
||||||
|
|
@ -367,9 +366,9 @@ class SymbolicCallArg:
|
||||||
class MemoryPlanningState:
|
class MemoryPlanningState:
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.reuse_pool: dict[
|
self.reuse_pool: dict[ReuseKey, list[FreeIfNotReusedLine]] = (
|
||||||
ReuseKey, list[FreeIfNotReusedLine]
|
collections.defaultdict(list)
|
||||||
] = collections.defaultdict(list)
|
)
|
||||||
self.total_allocated_buffer_size: int = 0
|
self.total_allocated_buffer_size: int = 0
|
||||||
|
|
||||||
def __contains__(self, key: ReuseKey) -> bool:
|
def __contains__(self, key: ReuseKey) -> bool:
|
||||||
|
|
@ -431,9 +430,9 @@ class EnterDeviceContextManagerLine(WrapperLine):
|
||||||
f"{V.graph.device_ops.cpp_aoti_stream_guard()} stream_guard(stream, this->device_idx_);"
|
f"{V.graph.device_ops.cpp_aoti_stream_guard()} stream_guard(stream, this->device_idx_);"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert self.last_seen_device_guard_index == self.device_idx, (
|
||||||
self.last_seen_device_guard_index == self.device_idx
|
"AOTInductor only supports running on one CUDA device"
|
||||||
), "AOTInductor only supports running on one CUDA device"
|
)
|
||||||
else:
|
else:
|
||||||
if self.last_seen_device_guard_index is None:
|
if self.last_seen_device_guard_index is None:
|
||||||
code.writeline(
|
code.writeline(
|
||||||
|
|
@ -1794,7 +1793,8 @@ class PythonWrapperCodegen(CodeGen):
|
||||||
equals_1 = isinstance(
|
equals_1 = isinstance(
|
||||||
arg, (int, sympy.Integer)
|
arg, (int, sympy.Integer)
|
||||||
) and V.graph.sizevars.statically_known_equals(
|
) and V.graph.sizevars.statically_known_equals(
|
||||||
arg, 1 # type: ignore[arg-type]
|
arg,
|
||||||
|
1, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
add_arg(idx, SizeArg(key, arg), equals_1=equals_1)
|
add_arg(idx, SizeArg(key, arg), equals_1=equals_1)
|
||||||
|
|
||||||
|
|
@ -2052,9 +2052,9 @@ class PythonWrapperCodegen(CodeGen):
|
||||||
buf_name = arg
|
buf_name = arg
|
||||||
buf = V.graph.get_buffer(arg)
|
buf = V.graph.get_buffer(arg)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert raw_arg is not None, (
|
||||||
raw_arg is not None
|
"V.graph.get_buffer(arg) and raw_arg can't be None at the same time"
|
||||||
), "V.graph.get_buffer(arg) and raw_arg can't be None at the same time"
|
)
|
||||||
buf_name = f"tmp_arg_{index}"
|
buf_name = f"tmp_arg_{index}"
|
||||||
buf = raw_arg
|
buf = raw_arg
|
||||||
|
|
||||||
|
|
@ -2181,9 +2181,9 @@ class PythonWrapperCodegen(CodeGen):
|
||||||
and kernel_name not in self.kernel_autotune_names
|
and kernel_name not in self.kernel_autotune_names
|
||||||
):
|
):
|
||||||
# Create example args for autotune in a separate epilogue
|
# Create example args for autotune in a separate epilogue
|
||||||
assert arg_types is not None and len(call_args) == len(
|
assert arg_types is not None and len(call_args) == len(arg_types), (
|
||||||
arg_types
|
"call_args and arg_types do not match"
|
||||||
), "call_args and arg_types do not match"
|
)
|
||||||
|
|
||||||
tensor_args = {}
|
tensor_args = {}
|
||||||
all_args = []
|
all_args = []
|
||||||
|
|
@ -2191,9 +2191,9 @@ class PythonWrapperCodegen(CodeGen):
|
||||||
# create a dummy raw_args for uniform behavior in the following loop
|
# create a dummy raw_args for uniform behavior in the following loop
|
||||||
raw_args = [None] * len(call_args)
|
raw_args = [None] * len(call_args)
|
||||||
else:
|
else:
|
||||||
assert len(raw_args) == len(
|
assert len(raw_args) == len(call_args), (
|
||||||
call_args
|
"call_args and raw_args do not match"
|
||||||
), "call_args and raw_args do not match"
|
)
|
||||||
|
|
||||||
for i, (arg, arg_type, raw_arg) in enumerate(
|
for i, (arg, arg_type, raw_arg) in enumerate(
|
||||||
zip(call_args, arg_types, raw_args)
|
zip(call_args, arg_types, raw_args)
|
||||||
|
|
@ -2411,9 +2411,9 @@ class PythonWrapperCodegen(CodeGen):
|
||||||
if isinstance(layout, ir.NoneLayout):
|
if isinstance(layout, ir.NoneLayout):
|
||||||
return
|
return
|
||||||
if isinstance(layout, ir.NonOwningLayout):
|
if isinstance(layout, ir.NonOwningLayout):
|
||||||
assert isinstance(
|
assert isinstance(layout.view, ir.ReinterpretView), (
|
||||||
layout.view, ir.ReinterpretView
|
f"unexpected {type(layout.view)}: {layout.view}"
|
||||||
), f"unexpected {type(layout.view)}: {layout.view}"
|
)
|
||||||
assert isinstance(layout.view.data, ir.StorageBox), type(layout.view.data)
|
assert isinstance(layout.view.data, ir.StorageBox), type(layout.view.data)
|
||||||
assert isinstance(layout.view.data.data, ir.Buffer), type(layout.view.data)
|
assert isinstance(layout.view.data.data, ir.Buffer), type(layout.view.data)
|
||||||
self.codegen_allocation(layout.view.data.data)
|
self.codegen_allocation(layout.view.data.data)
|
||||||
|
|
@ -2535,9 +2535,9 @@ class PythonWrapperCodegen(CodeGen):
|
||||||
def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs):
|
def codegen_subgraph_prefix(self, subgraph, outer_inputs, outer_outputs):
|
||||||
# All inputs of hops must be explicitly passed in.
|
# All inputs of hops must be explicitly passed in.
|
||||||
# Free tensors and basic symbols should have been explicitly lifted as inputs in dynamo.
|
# Free tensors and basic symbols should have been explicitly lifted as inputs in dynamo.
|
||||||
assert len(outer_inputs) == len(
|
assert len(outer_inputs) == len(subgraph.graph.graph_input_names), (
|
||||||
subgraph.graph.graph_input_names
|
f"graph_input_names:{subgraph.graph.graph_input_names}, outer_inputs: {outer_inputs}"
|
||||||
), f"graph_input_names:{subgraph.graph.graph_input_names}, outer_inputs: {outer_inputs}"
|
)
|
||||||
for inner_input, outer_input in zip(
|
for inner_input, outer_input in zip(
|
||||||
subgraph.graph.graph_input_names, outer_inputs
|
subgraph.graph.graph_input_names, outer_inputs
|
||||||
):
|
):
|
||||||
|
|
|
||||||
|
|
@ -219,8 +219,7 @@ def _schedule_for_comm(
|
||||||
|
|
||||||
for snode, deps in unmet_deps.items():
|
for snode, deps in unmet_deps.items():
|
||||||
assert len(deps) == 0, (
|
assert len(deps) == 0, (
|
||||||
"Detected unscheduled nodes. "
|
f"Detected unscheduled nodes. Nodes with unmet dependencies: {unmet_deps}"
|
||||||
f"Nodes with unmet dependencies: {unmet_deps}"
|
|
||||||
)
|
)
|
||||||
return scheduled
|
return scheduled
|
||||||
|
|
||||||
|
|
@ -354,9 +353,7 @@ def remove_fsdp2_unsharded_param_graph_input_usage(graph: torch.fx.Graph):
|
||||||
node.op == "call_function"
|
node.op == "call_function"
|
||||||
and node.target == torch.ops.inductor.resize_storage_bytes_.default
|
and node.target == torch.ops.inductor.resize_storage_bytes_.default
|
||||||
):
|
):
|
||||||
assert (
|
assert node.args[0].op == "placeholder", f"""\
|
||||||
node.args[0].op == "placeholder"
|
|
||||||
), f"""\
|
|
||||||
Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]}
|
Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]}
|
||||||
"""
|
"""
|
||||||
graph_input = node.args[0]
|
graph_input = node.args[0]
|
||||||
|
|
@ -408,9 +405,7 @@ Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass for that
|
||||||
if node.op == "call_function" and node.target == torch.ops.fsdp.copy_.default:
|
if node.op == "call_function" and node.target == torch.ops.fsdp.copy_.default:
|
||||||
fsdp_copy_node = node
|
fsdp_copy_node = node
|
||||||
unsharded_param = node.args[0]
|
unsharded_param = node.args[0]
|
||||||
assert (
|
assert unsharded_param.op == "placeholder", f"""
|
||||||
unsharded_param.op == "placeholder"
|
|
||||||
), f"""
|
|
||||||
Assumed all FSDP2 `unsharded_param`s to be graph input, but it's not true!
|
Assumed all FSDP2 `unsharded_param`s to be graph input, but it's not true!
|
||||||
Offending node: {unsharded_param}. Graph: {graph}
|
Offending node: {unsharded_param}. Graph: {graph}
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
|
|
@ -281,9 +281,9 @@ def _unlift_graph(
|
||||||
elif node_name in graph_signature.inputs_to_buffers:
|
elif node_name in graph_signature.inputs_to_buffers:
|
||||||
buffer_name = graph_signature.inputs_to_buffers[node_name]
|
buffer_name = graph_signature.inputs_to_buffers[node_name]
|
||||||
lifted_inputs.append(buffer_name)
|
lifted_inputs.append(buffer_name)
|
||||||
gm.meta[
|
gm.meta[get_cloned_parameter_buffer_name(buffer_name)] = (
|
||||||
get_cloned_parameter_buffer_name(buffer_name)
|
clone_preserve_strides(state_dict[buffer_name])
|
||||||
] = clone_preserve_strides(state_dict[buffer_name])
|
)
|
||||||
else:
|
else:
|
||||||
assert node_name in graph_signature.user_inputs
|
assert node_name in graph_signature.user_inputs
|
||||||
lifted_inputs.append(None)
|
lifted_inputs.append(None)
|
||||||
|
|
@ -542,7 +542,7 @@ def fake_tensor_prop(
|
||||||
|
|
||||||
# pass config dict back to user
|
# pass config dict back to user
|
||||||
def get_patched_config_dict(
|
def get_patched_config_dict(
|
||||||
config_patches: Optional[Union[str, dict[str, Any]]] = None
|
config_patches: Optional[Union[str, dict[str, Any]]] = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
with config.patch(config_patches):
|
with config.patch(config_patches):
|
||||||
return config.get_config_copy()
|
return config.get_config_copy()
|
||||||
|
|
@ -579,8 +579,7 @@ class _CompileFxCallable(Protocol):
|
||||||
gm: GraphModule,
|
gm: GraphModule,
|
||||||
example_inputs: Sequence[InputType],
|
example_inputs: Sequence[InputType],
|
||||||
**kwargs: Unpack[_CompileFxKwargs],
|
**kwargs: Unpack[_CompileFxKwargs],
|
||||||
) -> OutputCode:
|
) -> OutputCode: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def compile_fx_inner(
|
def compile_fx_inner(
|
||||||
|
|
@ -662,9 +661,9 @@ def _compile_fx_inner(
|
||||||
static_inputs_log.debug("static input idxs compile_fx_inner: %s", static_input_idxs)
|
static_inputs_log.debug("static input idxs compile_fx_inner: %s", static_input_idxs)
|
||||||
inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs)
|
inputs_to_check = get_input_idxs_to_check(example_inputs, static_input_idxs)
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)), (
|
||||||
next(iter(reversed(gm.graph.nodes))).args[0], (tuple, list)
|
f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}"
|
||||||
), f"inductor can only compile FX graphs which return a tuple/list, but got {gm.graph}"
|
)
|
||||||
|
|
||||||
if (cudagraphs := graph_kwargs.get("cudagraphs")) is None:
|
if (cudagraphs := graph_kwargs.get("cudagraphs")) is None:
|
||||||
graph_kwargs["cudagraphs"] = cudagraphs = BoxedBool(config.triton.cudagraphs)
|
graph_kwargs["cudagraphs"] = cudagraphs = BoxedBool(config.triton.cudagraphs)
|
||||||
|
|
@ -679,9 +678,10 @@ def _compile_fx_inner(
|
||||||
|
|
||||||
fx_graph_remote_cache = should_use_remote_fx_graph_cache()
|
fx_graph_remote_cache = should_use_remote_fx_graph_cache()
|
||||||
|
|
||||||
with _WaitCounter(
|
with (
|
||||||
"pytorch.wait_counter.fx_codegen_and_compile"
|
_WaitCounter("pytorch.wait_counter.fx_codegen_and_compile").guard() as _,
|
||||||
).guard() as _, _WaitCounter("pytorch.wait_counter.all_compilation_types").guard():
|
_WaitCounter("pytorch.wait_counter.all_compilation_types").guard(),
|
||||||
|
):
|
||||||
use_cache = (
|
use_cache = (
|
||||||
not config.force_disable_caches
|
not config.force_disable_caches
|
||||||
and (config.fx_graph_cache or fx_graph_remote_cache)
|
and (config.fx_graph_cache or fx_graph_remote_cache)
|
||||||
|
|
@ -865,8 +865,7 @@ class FxCompile(ABC):
|
||||||
example_inputs: Sequence[InputType],
|
example_inputs: Sequence[InputType],
|
||||||
inputs_to_check: Sequence[int],
|
inputs_to_check: Sequence[int],
|
||||||
graph_kwargs: _CompileFxKwargs,
|
graph_kwargs: _CompileFxKwargs,
|
||||||
) -> OutputCode:
|
) -> OutputCode: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class _InProcessFxCompile(FxCompile):
|
class _InProcessFxCompile(FxCompile):
|
||||||
|
|
@ -890,16 +889,17 @@ class _InProcessFxCompile(FxCompile):
|
||||||
cpp_wrapper: bool = graph_kwargs.get("cpp_wrapper", False)
|
cpp_wrapper: bool = graph_kwargs.get("cpp_wrapper", False)
|
||||||
aot_mode: bool = V.aot_compilation
|
aot_mode: bool = V.aot_compilation
|
||||||
is_inference: bool = graph_kwargs.get("is_inference", False)
|
is_inference: bool = graph_kwargs.get("is_inference", False)
|
||||||
extern_node_serializer: Optional[
|
extern_node_serializer: Optional[Callable[[list[ExternKernelNode]], Any]] = (
|
||||||
Callable[[list[ExternKernelNode]], Any]
|
graph_kwargs.get("extern_node_serializer", None)
|
||||||
] = graph_kwargs.get("extern_node_serializer", None)
|
)
|
||||||
boxed_forward_device_index: Optional[BoxedDeviceIndex] = graph_kwargs.get(
|
boxed_forward_device_index: Optional[BoxedDeviceIndex] = graph_kwargs.get(
|
||||||
"boxed_forward_device_index", None
|
"boxed_forward_device_index", None
|
||||||
)
|
)
|
||||||
|
|
||||||
with _WaitCounter(
|
with (
|
||||||
"pytorch.wait_counter.actual_codegen_and_compile"
|
_WaitCounter("pytorch.wait_counter.actual_codegen_and_compile").guard(),
|
||||||
).guard(), dynamo_utils.preserve_rng_state():
|
dynamo_utils.preserve_rng_state(),
|
||||||
|
):
|
||||||
if (sleep_sec := config.sleep_sec_TESTING_ONLY) is not None:
|
if (sleep_sec := config.sleep_sec_TESTING_ONLY) is not None:
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
|
@ -1038,9 +1038,11 @@ class _InProcessFxCompile(FxCompile):
|
||||||
# See details in vllm/compilation/pass_manager.py.
|
# See details in vllm/compilation/pass_manager.py.
|
||||||
log.warning("failed to log pt2_configs")
|
log.warning("failed to log pt2_configs")
|
||||||
|
|
||||||
with V.set_fake_mode(fake_mode), maybe_disable_comprehensive_padding(
|
with (
|
||||||
example_inputs
|
V.set_fake_mode(fake_mode),
|
||||||
), maybe_disable_graph_partition(cpp_wrapper, aot_mode):
|
maybe_disable_comprehensive_padding(example_inputs),
|
||||||
|
maybe_disable_graph_partition(cpp_wrapper, aot_mode),
|
||||||
|
):
|
||||||
const_output_index = None
|
const_output_index = None
|
||||||
const_graph = None
|
const_graph = None
|
||||||
const_code = None
|
const_code = None
|
||||||
|
|
@ -1123,9 +1125,9 @@ class _InProcessFxCompile(FxCompile):
|
||||||
if graph.aot_mode:
|
if graph.aot_mode:
|
||||||
from .codecache import AotCodeCompiler
|
from .codecache import AotCodeCompiler
|
||||||
|
|
||||||
assert (
|
assert graph.cpp_wrapper, (
|
||||||
graph.cpp_wrapper
|
"AOT mode only supports C++ wrapper"
|
||||||
), "AOT mode only supports C++ wrapper"
|
)
|
||||||
code, linemap = graph.codegen_with_cpp_wrapper()
|
code, linemap = graph.codegen_with_cpp_wrapper()
|
||||||
output_code_log.debug("Output code: \n%s", code)
|
output_code_log.debug("Output code: \n%s", code)
|
||||||
|
|
||||||
|
|
@ -1509,10 +1511,13 @@ def cudagraphify(
|
||||||
def run(new_inputs: Sequence[InputType]) -> Any:
|
def run(new_inputs: Sequence[InputType]) -> Any:
|
||||||
nonlocal compiled_fn
|
nonlocal compiled_fn
|
||||||
if compiled_fn is None:
|
if compiled_fn is None:
|
||||||
with dynamo_utils.dynamo_timed(
|
with (
|
||||||
"cudagraphify",
|
dynamo_utils.dynamo_timed(
|
||||||
log_pt2_compile_event=True,
|
"cudagraphify",
|
||||||
), dynamo_utils.preserve_rng_state():
|
log_pt2_compile_event=True,
|
||||||
|
),
|
||||||
|
dynamo_utils.preserve_rng_state(),
|
||||||
|
):
|
||||||
compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs)
|
compiled_fn = cudagraphify_fn(model, new_inputs, static_input_idxs)
|
||||||
return compiled_fn(new_inputs)
|
return compiled_fn(new_inputs)
|
||||||
|
|
||||||
|
|
@ -1669,13 +1674,16 @@ def compile_fx_aot(
|
||||||
extern_node_serializer = config_patches.pop("extern_node_serializer", None)
|
extern_node_serializer = config_patches.pop("extern_node_serializer", None)
|
||||||
saved_compile_id = model_.meta.get("dynamo_compile_id", None)
|
saved_compile_id = model_.meta.get("dynamo_compile_id", None)
|
||||||
saved_compile_context = torch._guards.CompileContext(saved_compile_id)
|
saved_compile_context = torch._guards.CompileContext(saved_compile_id)
|
||||||
with V.set_aot_compilation(True), torch._guards.compile_context(
|
with (
|
||||||
saved_compile_context
|
V.set_aot_compilation(True),
|
||||||
), chromium_event_timed(
|
torch._guards.compile_context(saved_compile_context),
|
||||||
"compile_fx_aot",
|
chromium_event_timed(
|
||||||
log_pt2_compile_event=True,
|
"compile_fx_aot",
|
||||||
reset_event_log_on_exit=True,
|
log_pt2_compile_event=True,
|
||||||
), get_metrics_context():
|
reset_event_log_on_exit=True,
|
||||||
|
),
|
||||||
|
get_metrics_context(),
|
||||||
|
):
|
||||||
compiled_artifacts = compile_fx(
|
compiled_artifacts = compile_fx(
|
||||||
model_,
|
model_,
|
||||||
example_inputs_,
|
example_inputs_,
|
||||||
|
|
@ -1875,12 +1883,15 @@ def compile_fx(
|
||||||
|
|
||||||
# TODO: This probably shouldn't be a recursive call
|
# TODO: This probably shouldn't be a recursive call
|
||||||
if config.cpp_wrapper:
|
if config.cpp_wrapper:
|
||||||
with config.patch(
|
with (
|
||||||
{
|
config.patch(
|
||||||
"cpp_wrapper": False, # reset to break recursive call to compile_fx
|
{
|
||||||
**get_cpp_wrapper_config(),
|
"cpp_wrapper": False, # reset to break recursive call to compile_fx
|
||||||
}
|
**get_cpp_wrapper_config(),
|
||||||
), V.set_real_inputs(example_inputs_):
|
}
|
||||||
|
),
|
||||||
|
V.set_real_inputs(example_inputs_),
|
||||||
|
):
|
||||||
inputs_: Sequence[InputType] = example_inputs_
|
inputs_: Sequence[InputType] = example_inputs_
|
||||||
|
|
||||||
if isinstance(model_, GraphModule):
|
if isinstance(model_, GraphModule):
|
||||||
|
|
@ -1940,10 +1951,10 @@ def compile_fx(
|
||||||
|
|
||||||
# Do the actual work
|
# Do the actual work
|
||||||
|
|
||||||
with _use_lazy_graph_module(
|
with (
|
||||||
dynamo_config.use_lazy_graph_module
|
_use_lazy_graph_module(dynamo_config.use_lazy_graph_module),
|
||||||
), enable_python_dispatcher(), torch.fx.traceback.preserve_node_meta(
|
enable_python_dispatcher(),
|
||||||
config.trace.enabled
|
torch.fx.traceback.preserve_node_meta(config.trace.enabled),
|
||||||
):
|
):
|
||||||
# Pre-grad passes cannot be run if we weren't given a GraphModule.
|
# Pre-grad passes cannot be run if we weren't given a GraphModule.
|
||||||
# Dynamo will always produce a GraphModule, but this handles cases
|
# Dynamo will always produce a GraphModule, but this handles cases
|
||||||
|
|
@ -2085,9 +2096,9 @@ def compile_fx(
|
||||||
boxed_forward_device_index=forward_device,
|
boxed_forward_device_index=forward_device,
|
||||||
)
|
)
|
||||||
|
|
||||||
fw_compiler: Callable[
|
fw_compiler: Callable[[GraphModule, Sequence[InputType]], OutputCode] = (
|
||||||
[GraphModule, Sequence[InputType]], OutputCode
|
functools.partial(fw_compiler_base, is_inference=False)
|
||||||
] = functools.partial(fw_compiler_base, is_inference=False)
|
)
|
||||||
fw_compiler = SerializableAOTDispatchCompiler(OutputCode, fw_compiler)
|
fw_compiler = SerializableAOTDispatchCompiler(OutputCode, fw_compiler)
|
||||||
|
|
||||||
if config.freezing and not torch.is_grad_enabled():
|
if config.freezing and not torch.is_grad_enabled():
|
||||||
|
|
@ -2124,9 +2135,10 @@ def compile_fx(
|
||||||
) -> OutputCode:
|
) -> OutputCode:
|
||||||
from torch._dynamo.convert_frame import compile_lock
|
from torch._dynamo.convert_frame import compile_lock
|
||||||
|
|
||||||
with dynamo_utils.dynamo_timed(
|
with (
|
||||||
"compile_fx.<locals>.bw_compiler"
|
dynamo_utils.dynamo_timed("compile_fx.<locals>.bw_compiler"),
|
||||||
), compile_lock:
|
compile_lock,
|
||||||
|
):
|
||||||
model_outputs_node = output_node(gm)
|
model_outputs_node = output_node(gm)
|
||||||
if config.bw_outputs_user_visible:
|
if config.bw_outputs_user_visible:
|
||||||
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
model_outputs = pytree.arg_tree_leaves(*model_outputs_node.args)
|
||||||
|
|
@ -2194,10 +2206,11 @@ def compile_fx(
|
||||||
with V.set_fake_mode(fake_mode), compiled_autograd._disable(), context():
|
with V.set_fake_mode(fake_mode), compiled_autograd._disable(), context():
|
||||||
return inference_compiler(unlifted_gm, example_inputs_)
|
return inference_compiler(unlifted_gm, example_inputs_)
|
||||||
|
|
||||||
with V.set_fake_mode(fake_mode), torch._guards.tracing(
|
with (
|
||||||
tracing_context
|
V.set_fake_mode(fake_mode),
|
||||||
), compiled_autograd._disable(), functorch_config.patch(
|
torch._guards.tracing(tracing_context),
|
||||||
unlift_effect_tokens=True
|
compiled_autograd._disable(),
|
||||||
|
functorch_config.patch(unlift_effect_tokens=True),
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
return aot_autograd(
|
return aot_autograd(
|
||||||
|
|
|
||||||
|
|
@ -530,7 +530,8 @@ class CompilerBisector:
|
||||||
)
|
)
|
||||||
if result:
|
if result:
|
||||||
curr_subsystem = cls.get_subsystem_object(
|
curr_subsystem = cls.get_subsystem_object(
|
||||||
curr_backend, cls.get_subsystem() # type: ignore[arg-type]
|
curr_backend,
|
||||||
|
cls.get_subsystem(), # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(curr_subsystem, BinarySubsystem):
|
if isinstance(curr_subsystem, BinarySubsystem):
|
||||||
|
|
|
||||||
|
|
@ -80,9 +80,9 @@ fx_graph_cache: bool = Config(
|
||||||
fx_graph_remote_cache: Optional[bool] = fx_graph_remote_cache_default()
|
fx_graph_remote_cache: Optional[bool] = fx_graph_remote_cache_default()
|
||||||
|
|
||||||
# should we bundle triton caching into fx graph cache
|
# should we bundle triton caching into fx graph cache
|
||||||
bundle_triton_into_fx_graph_cache: Optional[
|
bundle_triton_into_fx_graph_cache: Optional[bool] = (
|
||||||
bool
|
bundle_triton_into_fx_graph_cache_default()
|
||||||
] = bundle_triton_into_fx_graph_cache_default()
|
)
|
||||||
|
|
||||||
# Enable autotune local cache.
|
# Enable autotune local cache.
|
||||||
#
|
#
|
||||||
|
|
@ -1390,12 +1390,12 @@ class halide:
|
||||||
|
|
||||||
# Halide autoscheduler to use, choices are:
|
# Halide autoscheduler to use, choices are:
|
||||||
# "Anderson2021" (gpu-only), "Li2018", "Adams2019" (cpu-only), or "Mullapudi2016" (cpu-only)
|
# "Anderson2021" (gpu-only), "Li2018", "Adams2019" (cpu-only), or "Mullapudi2016" (cpu-only)
|
||||||
scheduler_cuda: Literal[
|
scheduler_cuda: Literal["Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"] = (
|
||||||
"Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"
|
"Anderson2021"
|
||||||
] = "Anderson2021"
|
)
|
||||||
scheduler_cpu: Literal[
|
scheduler_cpu: Literal["Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"] = (
|
||||||
"Anderson2021", "Li2018", "Adams2019", "Mullapudi2016"
|
"Adams2019"
|
||||||
] = "Adams2019"
|
)
|
||||||
|
|
||||||
# Controls `no_asserts` flag passed to Halide target (warning: can false positive)
|
# Controls `no_asserts` flag passed to Halide target (warning: can false positive)
|
||||||
asserts = False
|
asserts = False
|
||||||
|
|
|
||||||
|
|
@ -125,7 +125,8 @@ class ConstantFolder(torch.fx.Interpreter):
|
||||||
and is_woq_int8_pattern(next(iter(node.users)))
|
and is_woq_int8_pattern(next(iter(node.users)))
|
||||||
)
|
)
|
||||||
) and is_const_source(
|
) and is_const_source(
|
||||||
node.args[0], self.lifted_constant_names # type: ignore[arg-type]
|
node.args[0], # type: ignore[arg-type]
|
||||||
|
self.lifted_constant_names,
|
||||||
):
|
):
|
||||||
# Case 1: int8_weight -> dq -> bf16_weight
|
# Case 1: int8_weight -> dq -> bf16_weight
|
||||||
# Case 2: int8_weight -> permute -> dq -> bf16_weight
|
# Case 2: int8_weight -> permute -> dq -> bf16_weight
|
||||||
|
|
|
||||||
|
|
@ -1633,8 +1633,8 @@ class CppBuilder:
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
assert os.path.exists(
|
assert os.path.exists(cmake_path), (
|
||||||
cmake_path
|
f"save_link_cmd_to_cmakefile expects {cmake_path} to already exist"
|
||||||
), f"save_link_cmd_to_cmakefile expects {cmake_path} to already exist"
|
)
|
||||||
with open(cmake_path, "a") as f:
|
with open(cmake_path, "a") as f:
|
||||||
f.write(contents)
|
f.write(contents)
|
||||||
|
|
|
||||||
|
|
@ -119,6 +119,7 @@ from . import config
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class GraphID:
|
class GraphID:
|
||||||
"Unique counter of a cuda graph recording"
|
"Unique counter of a cuda graph recording"
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -622,11 +623,15 @@ class CUDAWarmupNode:
|
||||||
refs = list(self.path_live_weakrefs())
|
refs = list(self.path_live_weakrefs())
|
||||||
check_memory_pool(self.device_index, self.cuda_graphs_pool, refs)
|
check_memory_pool(self.device_index, self.cuda_graphs_pool, refs)
|
||||||
|
|
||||||
with torch.cuda.device(
|
with (
|
||||||
self.device_index
|
torch.cuda.device(self.device_index),
|
||||||
), disable_conv_cache_emptying(), clear_cublas_manager(), _use_cuda_memory_pool_manager(
|
disable_conv_cache_emptying(),
|
||||||
self.device_index, self.cuda_graphs_pool, self.stream
|
clear_cublas_manager(),
|
||||||
), get_history_recording():
|
_use_cuda_memory_pool_manager(
|
||||||
|
self.device_index, self.cuda_graphs_pool, self.stream
|
||||||
|
),
|
||||||
|
get_history_recording(),
|
||||||
|
):
|
||||||
out = self.wrapped_function.model(new_inputs)
|
out = self.wrapped_function.model(new_inputs)
|
||||||
|
|
||||||
# We need to know which outputs are allocated within the cudagraph pool
|
# We need to know which outputs are allocated within the cudagraph pool
|
||||||
|
|
@ -713,6 +718,7 @@ UnaliasedStorage = _UnaliasedStorage()
|
||||||
|
|
||||||
class AliasesPriorGraphOutput(OutputAliasInfo):
|
class AliasesPriorGraphOutput(OutputAliasInfo):
|
||||||
"Marks that the graph output aliases an output of a prior graph"
|
"Marks that the graph output aliases an output of a prior graph"
|
||||||
|
|
||||||
__slots__ = ["index"]
|
__slots__ = ["index"]
|
||||||
|
|
||||||
index: PathOutputIndex
|
index: PathOutputIndex
|
||||||
|
|
@ -1200,14 +1206,18 @@ class CUDAGraphNode:
|
||||||
]
|
]
|
||||||
check_memory_pool(self.device, self.cuda_graphs_pool, memory)
|
check_memory_pool(self.device, self.cuda_graphs_pool, memory)
|
||||||
|
|
||||||
with preserve_rng_state(), torch.cuda.device(
|
with (
|
||||||
self.device
|
preserve_rng_state(),
|
||||||
), clear_cublas_manager(), torch.cuda.graph(
|
torch.cuda.device(self.device),
|
||||||
self.graph,
|
clear_cublas_manager(),
|
||||||
stream=self.stream,
|
torch.cuda.graph(
|
||||||
pool=self.cuda_graphs_pool,
|
self.graph,
|
||||||
capture_error_mode="thread_local",
|
stream=self.stream,
|
||||||
), get_history_recording():
|
pool=self.cuda_graphs_pool,
|
||||||
|
capture_error_mode="thread_local",
|
||||||
|
),
|
||||||
|
get_history_recording(),
|
||||||
|
):
|
||||||
static_outputs = model(inputs)
|
static_outputs = model(inputs)
|
||||||
|
|
||||||
# running model should reclaim memory
|
# running model should reclaim memory
|
||||||
|
|
@ -1247,11 +1257,13 @@ class CUDAGraphNode:
|
||||||
self.output_storage_alias.append(UnaliasedStorage)
|
self.output_storage_alias.append(UnaliasedStorage)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
torch._check(
|
(
|
||||||
o.is_cuda or o.untyped_storage().data_ptr() == 0,
|
torch._check(
|
||||||
lambda: (
|
o.is_cuda or o.untyped_storage().data_ptr() == 0,
|
||||||
"Expected all cuda outputs in cuda graph recording. Non cuda output "
|
lambda: (
|
||||||
f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}"
|
"Expected all cuda outputs in cuda graph recording. Non cuda output "
|
||||||
|
f"from {self.stack_traces[i] if self.stack_traces else '(unknown)'}"
|
||||||
|
),
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1291,9 +1303,9 @@ class CUDAGraphNode:
|
||||||
if self.stack_traces is None:
|
if self.stack_traces is None:
|
||||||
self.stack_traces = [None for _ in range(len(outputs))]
|
self.stack_traces = [None for _ in range(len(outputs))]
|
||||||
else:
|
else:
|
||||||
assert len(self.stack_traces) == len(
|
assert len(self.stack_traces) == len(outputs), (
|
||||||
outputs
|
"Wrong number of stack traces passed in"
|
||||||
), "Wrong number of stack traces passed in"
|
)
|
||||||
|
|
||||||
assert not self.outputs_weakrefs
|
assert not self.outputs_weakrefs
|
||||||
for out, static_output_tensor in zip(outputs, self.static_output_tensors):
|
for out, static_output_tensor in zip(outputs, self.static_output_tensors):
|
||||||
|
|
@ -1599,12 +1611,14 @@ class CUDAGraphNode:
|
||||||
self.stream.wait_stream(torch.cuda.current_stream())
|
self.stream.wait_stream(torch.cuda.current_stream())
|
||||||
recording_inputs: list[InputType] = []
|
recording_inputs: list[InputType] = []
|
||||||
|
|
||||||
with warnings.catch_warnings(record=True), torch.cuda.device(
|
with (
|
||||||
self.device
|
warnings.catch_warnings(record=True),
|
||||||
), _use_cuda_memory_pool_manager(
|
torch.cuda.device(self.device),
|
||||||
self.device,
|
_use_cuda_memory_pool_manager(
|
||||||
mem_pool=self.cuda_graphs_pool,
|
self.device,
|
||||||
stream=self.stream,
|
mem_pool=self.cuda_graphs_pool,
|
||||||
|
stream=self.stream,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
for i, inp in enumerate(inputs):
|
for i, inp in enumerate(inputs):
|
||||||
if not isinstance(inp, torch.Tensor):
|
if not isinstance(inp, torch.Tensor):
|
||||||
|
|
@ -1736,12 +1750,8 @@ def check_memory_pool(
|
||||||
pool_id: tuple[int, int],
|
pool_id: tuple[int, int],
|
||||||
live_storages_ptrs: list[StorageWeakRefWrapper],
|
live_storages_ptrs: list[StorageWeakRefWrapper],
|
||||||
) -> None:
|
) -> None:
|
||||||
assert all(
|
assert all(isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs) # noqa: C419
|
||||||
isinstance(elem, StorageWeakRefWrapper) for elem in live_storages_ptrs
|
unique_storages = {stor.data_ptr() for stor in live_storages_ptrs if stor()} # noqa: set_linter
|
||||||
) # noqa: C419
|
|
||||||
unique_storages = {
|
|
||||||
stor.data_ptr() for stor in live_storages_ptrs if stor()
|
|
||||||
} # noqa: set_linter
|
|
||||||
|
|
||||||
# check if there is a divergence first, then do the expensive snapshot call after
|
# check if there is a divergence first, then do the expensive snapshot call after
|
||||||
# we know it will error
|
# we know it will error
|
||||||
|
|
@ -1864,11 +1874,14 @@ class CUDAGraphTreeManager:
|
||||||
self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph()
|
self.graph: Optional[torch.cuda.CUDAGraph] = torch.cuda.CUDAGraph()
|
||||||
self.cuda_graphs_thread_pool = torch.cuda.graph_pool_handle()
|
self.cuda_graphs_thread_pool = torch.cuda.graph_pool_handle()
|
||||||
|
|
||||||
with warnings.catch_warnings(record=True), torch.cuda.graph(
|
with (
|
||||||
self.graph,
|
warnings.catch_warnings(record=True),
|
||||||
pool=self.cuda_graphs_thread_pool,
|
torch.cuda.graph(
|
||||||
stream=self.stream,
|
self.graph,
|
||||||
capture_error_mode="thread_local",
|
pool=self.cuda_graphs_thread_pool,
|
||||||
|
stream=self.stream,
|
||||||
|
capture_error_mode="thread_local",
|
||||||
|
),
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
@ -2230,7 +2243,10 @@ class CUDAGraphTreeManager:
|
||||||
constants: tuple[torch.Tensor, ...],
|
constants: tuple[torch.Tensor, ...],
|
||||||
placeholders: tuple[PlaceholderInfo, ...],
|
placeholders: tuple[PlaceholderInfo, ...],
|
||||||
mutated_input_idxs: tuple[int, ...],
|
mutated_input_idxs: tuple[int, ...],
|
||||||
) -> tuple[ModelType, OutputType,]:
|
) -> tuple[
|
||||||
|
ModelType,
|
||||||
|
OutputType,
|
||||||
|
]:
|
||||||
id = self.new_func_id()
|
id = self.new_func_id()
|
||||||
self.ids_to_stack_traces[id] = stack_traces
|
self.ids_to_stack_traces[id] = stack_traces
|
||||||
self.ids_to_funcs[id] = WrappedFunction(
|
self.ids_to_funcs[id] = WrappedFunction(
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ ModelType = Callable[[list[InputType]], OutputType]
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class FunctionID:
|
class FunctionID:
|
||||||
"Unique counter of a function wrapped in cudagraphify_impl"
|
"Unique counter of a function wrapped in cudagraphify_impl"
|
||||||
|
|
||||||
id: int
|
id: int
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -164,7 +165,7 @@ def _get_use_stack_trace(node: torch.fx.Node) -> Optional[str]:
|
||||||
|
|
||||||
|
|
||||||
def check_multiple_devices_or_any_cpu_nodes(
|
def check_multiple_devices_or_any_cpu_nodes(
|
||||||
device_node_mapping: dict[torch.device, torch.fx.Node]
|
device_node_mapping: dict[torch.device, torch.fx.Node],
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
if cpu_node := device_node_mapping.get(torch.device("cpu")):
|
if cpu_node := device_node_mapping.get(torch.device("cpu")):
|
||||||
msg = f"cpu device ({cpu_node.name})"
|
msg = f"cpu device ({cpu_node.name})"
|
||||||
|
|
@ -184,7 +185,7 @@ def check_multiple_devices_or_any_cpu_nodes(
|
||||||
|
|
||||||
|
|
||||||
def check_lowering_disable_cudagraph(
|
def check_lowering_disable_cudagraph(
|
||||||
device_node_mapping: dict[torch.device, torch.fx.Node]
|
device_node_mapping: dict[torch.device, torch.fx.Node],
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
return check_multiple_devices_or_any_cpu_nodes(device_node_mapping)
|
return check_multiple_devices_or_any_cpu_nodes(device_node_mapping)
|
||||||
|
|
||||||
|
|
@ -276,9 +277,9 @@ def log_data_ptr_mismatch(
|
||||||
Logs the mismatch between input data pointers and recorded data pointers.
|
Logs the mismatch between input data pointers and recorded data pointers.
|
||||||
This checks only idxs in target_idxs.
|
This checks only idxs in target_idxs.
|
||||||
"""
|
"""
|
||||||
assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len(
|
assert len(inputs) == len(recorded_data_ptr) and len(inputs) == len(placeholders), (
|
||||||
placeholders
|
"length mismatch between inputs, recorded_data_ptr, and placeholders"
|
||||||
), "length mismatch between inputs, recorded_data_ptr, and placeholders"
|
)
|
||||||
|
|
||||||
t_tensors = [inputs[i] for i in target_idxs]
|
t_tensors = [inputs[i] for i in target_idxs]
|
||||||
t_data_ptrs = [recorded_data_ptr[i] for i in target_idxs]
|
t_data_ptrs = [recorded_data_ptr[i] for i in target_idxs]
|
||||||
|
|
|
||||||
|
|
@ -240,7 +240,7 @@ def update_orig_fx_node_name_to_buf_name(
|
||||||
|
|
||||||
|
|
||||||
def get_node_name_to_buf_meta(
|
def get_node_name_to_buf_meta(
|
||||||
node_name_to_buf_name: dict[str, str]
|
node_name_to_buf_name: dict[str, str],
|
||||||
) -> dict[str, BufMeta]:
|
) -> dict[str, BufMeta]:
|
||||||
buf_name_to_n_node = {}
|
buf_name_to_n_node = {}
|
||||||
for node_name, buf_name in node_name_to_buf_name.items():
|
for node_name, buf_name in node_name_to_buf_name.items():
|
||||||
|
|
|
||||||
|
|
@ -123,7 +123,7 @@ remove_decompositions(decompositions, decomps_to_exclude)
|
||||||
|
|
||||||
|
|
||||||
def register_decomposition(
|
def register_decomposition(
|
||||||
ops: list[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]]
|
ops: list[Union[torch._ops.OperatorBase, torch._ops.OpOverloadPacket]],
|
||||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||||
for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined]
|
for op in [ops] if callable(ops) else ops: # type: ignore[attr-defined]
|
||||||
if op in decompositions:
|
if op in decompositions:
|
||||||
|
|
|
||||||
|
|
@ -194,7 +194,9 @@ class MemoryDep(Dep):
|
||||||
)
|
)
|
||||||
new_index = sympy_subs(sympy.expand(self.index), replacement) # type: ignore[arg-type] # next PR
|
new_index = sympy_subs(sympy.expand(self.index), replacement) # type: ignore[arg-type] # next PR
|
||||||
|
|
||||||
out = MemoryDep(self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values())) # type: ignore[arg-type]
|
out = MemoryDep(
|
||||||
|
self.name, new_index, tuple(var_ranges.keys()), tuple(var_ranges.values())
|
||||||
|
) # type: ignore[arg-type]
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|
@ -649,11 +651,16 @@ def extract_loop_body_with_args(
|
||||||
inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) # type: ignore[arg-type]
|
inner.load_seed(entry.buffer_name, int(name_to_index[entry.index_name])) # type: ignore[arg-type]
|
||||||
for entry in fn.memory_usage[MemoryUsageType.STORE]:
|
for entry in fn.memory_usage[MemoryUsageType.STORE]:
|
||||||
inner.store(
|
inner.store(
|
||||||
entry.buffer_name, name_to_index[entry.index_name], None, entry.mode # type: ignore[arg-type]
|
entry.buffer_name,
|
||||||
|
name_to_index[entry.index_name],
|
||||||
|
None, # type: ignore[arg-type]
|
||||||
|
entry.mode,
|
||||||
)
|
)
|
||||||
for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]:
|
for entry in fn.memory_usage[MemoryUsageType.STORE_REDUCTION]:
|
||||||
inner.store_reduction(
|
inner.store_reduction(
|
||||||
entry.buffer_name, name_to_index[entry.index_name], None # type: ignore[arg-type]
|
entry.buffer_name,
|
||||||
|
name_to_index[entry.index_name],
|
||||||
|
None, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]:
|
for entry in fn.memory_usage[MemoryUsageType.INDEX_EXPR]:
|
||||||
inner.index_expr(name_to_index[entry.index_name], None)
|
inner.index_expr(name_to_index[entry.index_name], None)
|
||||||
|
|
@ -661,7 +668,11 @@ def extract_loop_body_with_args(
|
||||||
# All that matters is that we record the buffer name, so place it in the
|
# All that matters is that we record the buffer name, so place it in the
|
||||||
# "boundaries" name position to ensure that it's recorded.
|
# "boundaries" name position to ensure that it's recorded.
|
||||||
inner.bucketize(
|
inner.bucketize(
|
||||||
None, (entry.buffer_name, None, None, None), None, None, None # type: ignore[arg-type]
|
None,
|
||||||
|
(entry.buffer_name, None, None, None),
|
||||||
|
None,
|
||||||
|
None, # type: ignore[arg-type]
|
||||||
|
None, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
# fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped
|
# fn.memory_usage[MemoryUsageType.CHECK_BOUNDS] intentionally skipped
|
||||||
return inner
|
return inner
|
||||||
|
|
@ -801,8 +812,9 @@ def extract_free_unbacked_symbols(
|
||||||
handler = FreeUnbackedSymbolsOpsHandler()
|
handler = FreeUnbackedSymbolsOpsHandler()
|
||||||
# NB: I cargo culted the allow_indexing patch here, I don't understand why
|
# NB: I cargo culted the allow_indexing patch here, I don't understand why
|
||||||
# people do this all over
|
# people do this all over
|
||||||
with V.set_ops_handler(handler), patch.object(
|
with (
|
||||||
FlexibleLayout, "allow_indexing", True
|
V.set_ops_handler(handler),
|
||||||
|
patch.object(FlexibleLayout, "allow_indexing", True),
|
||||||
):
|
):
|
||||||
fn(*args)
|
fn(*args)
|
||||||
return handler.symbols
|
return handler.symbols
|
||||||
|
|
|
||||||
|
|
@ -19,8 +19,7 @@ T = TypeVar("T")
|
||||||
|
|
||||||
class DTypeVar(Protocol):
|
class DTypeVar(Protocol):
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
DTypeArg = Union[DTypeVar, torch.types.Number, str, OpsValue]
|
DTypeArg = Union[DTypeVar, torch.types.Number, str, OpsValue]
|
||||||
|
|
|
||||||
|
|
@ -526,6 +526,7 @@ class ConfigFuzzer:
|
||||||
```python
|
```python
|
||||||
import torch._inductor.config as cfg
|
import torch._inductor.config as cfg
|
||||||
|
|
||||||
|
|
||||||
def create_simple_test_model_gpu() -> FactoryOutputType:
|
def create_simple_test_model_gpu() -> FactoryOutputType:
|
||||||
batch_size = 32
|
batch_size = 32
|
||||||
seq_length = 50
|
seq_length = 50
|
||||||
|
|
@ -539,6 +540,8 @@ class ConfigFuzzer:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
return test_fn
|
return test_fn
|
||||||
|
|
||||||
|
|
||||||
fuzzer = ConfigFuzzer(cfg, create_simple_test_model_gpu, seed=2)
|
fuzzer = ConfigFuzzer(cfg, create_simple_test_model_gpu, seed=2)
|
||||||
|
|
||||||
# Test every pair of configs:
|
# Test every pair of configs:
|
||||||
|
|
@ -550,7 +553,9 @@ class ConfigFuzzer:
|
||||||
ret = fuzzer.bisect(num_attempts=10)
|
ret = fuzzer.bisect(num_attempts=10)
|
||||||
|
|
||||||
# reproduce a failing config
|
# reproduce a failing config
|
||||||
fuzzer.reproduce([{"triton.autotune_pointwise": ..., "coordinate_descent_tuning": ...}])
|
fuzzer.reproduce(
|
||||||
|
[{"triton.autotune_pointwise": ..., "coordinate_descent_tuning": ...}]
|
||||||
|
)
|
||||||
```
|
```
|
||||||
|
|
||||||
The list of known failures on inductor config are:
|
The list of known failures on inductor config are:
|
||||||
|
|
|
||||||
|
|
@ -531,7 +531,11 @@ def tuned_b2b_gemm(
|
||||||
A.realize()
|
A.realize()
|
||||||
B.realize()
|
B.realize()
|
||||||
C.realize()
|
C.realize()
|
||||||
layout = FixedLayout(A.get_device_or_error(), A.get_dtype(), [A.shape[0], C.shape[1]]) # type: ignore[index]
|
layout = FixedLayout(
|
||||||
|
A.get_device_or_error(),
|
||||||
|
A.get_dtype(),
|
||||||
|
[A.shape[0], C.shape[1]], # type: ignore[index]
|
||||||
|
)
|
||||||
subgraph_buffer = build_subgraph_buffer(
|
subgraph_buffer = build_subgraph_buffer(
|
||||||
[create_placeholder("inner_mm", A.get_dtype(), A.get_device_or_error())],
|
[create_placeholder("inner_mm", A.get_dtype(), A.get_device_or_error())],
|
||||||
subgraph,
|
subgraph,
|
||||||
|
|
|
||||||
|
|
@ -545,9 +545,9 @@ def schedule_comm_wait(graph: fx.Graph) -> None:
|
||||||
node_indices = {node: i for i, node in enumerate(graph.nodes)}
|
node_indices = {node: i for i, node in enumerate(graph.nodes)}
|
||||||
for allreduce in comm_blocks:
|
for allreduce in comm_blocks:
|
||||||
# Find the earliest/first user -- target_node.
|
# Find the earliest/first user -- target_node.
|
||||||
assert (
|
assert len(allreduce.outputs) >= 1, (
|
||||||
len(allreduce.outputs) >= 1
|
f"Found a allreduce that has zero outputs/users -- {allreduce}."
|
||||||
), f"Found a allreduce that has zero outputs/users -- {allreduce}."
|
)
|
||||||
# Initialize the target node to avoid typing issues.
|
# Initialize the target node to avoid typing issues.
|
||||||
target_node = next(iter(next(iter(allreduce.outputs)).users))
|
target_node = next(iter(next(iter(allreduce.outputs)).users))
|
||||||
target_node_index = 2**31
|
target_node_index = 2**31
|
||||||
|
|
|
||||||
|
|
@ -380,7 +380,9 @@ def efficient_conv_bn_eval_graph_transform(match: Match, *args, **kwargs):
|
||||||
# argument. `graph.get_attr` and
|
# argument. `graph.get_attr` and
|
||||||
# `graph.call_function` does not allow the `name` argument.
|
# `graph.call_function` does not allow the `name` argument.
|
||||||
conv_get_node = graph.create_node(
|
conv_get_node = graph.create_node(
|
||||||
op="get_attr", target=conv_node.target, name="get_conv" # type: ignore[union-attr]
|
op="get_attr",
|
||||||
|
target=conv_node.target, # type: ignore[union-attr]
|
||||||
|
name="get_conv",
|
||||||
)
|
)
|
||||||
bn_get_node = graph.create_node(
|
bn_get_node = graph.create_node(
|
||||||
op="get_attr", target=bn_node.target, name="get_bn"
|
op="get_attr", target=bn_node.target, name="get_bn"
|
||||||
|
|
|
||||||
|
|
@ -866,15 +866,18 @@ def _get_sfdp_patterns():
|
||||||
name += "_bs1"
|
name += "_bs1"
|
||||||
|
|
||||||
training_name = name + "_training"
|
training_name = name + "_training"
|
||||||
yield training_name, {
|
yield (
|
||||||
"search_fn": pattern,
|
training_name,
|
||||||
"replace_fn": replacement,
|
{
|
||||||
"example_inputs": args,
|
"search_fn": pattern,
|
||||||
"trace_fn": joint_fwd_bwd,
|
"replace_fn": replacement,
|
||||||
"pass_dicts": patterns,
|
"example_inputs": args,
|
||||||
"extra_check": extra_check,
|
"trace_fn": joint_fwd_bwd,
|
||||||
"scalar_workaround": workaround,
|
"pass_dicts": patterns,
|
||||||
}
|
"extra_check": extra_check,
|
||||||
|
"scalar_workaround": workaround,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if workaround:
|
if workaround:
|
||||||
assert len(workaround) == 1 and "dropout_p" in workaround
|
assert len(workaround) == 1 and "dropout_p" in workaround
|
||||||
|
|
@ -886,18 +889,21 @@ def _get_sfdp_patterns():
|
||||||
workaround = {}
|
workaround = {}
|
||||||
|
|
||||||
inference_name = name + "_inference"
|
inference_name = name + "_inference"
|
||||||
yield inference_name, {
|
yield (
|
||||||
"search_fn": pattern,
|
inference_name,
|
||||||
"replace_fn": replacement,
|
{
|
||||||
"example_inputs": args,
|
"search_fn": pattern,
|
||||||
"trace_fn": fwd_only,
|
"replace_fn": replacement,
|
||||||
"pass_dicts": patterns,
|
"example_inputs": args,
|
||||||
"extra_check": extra_check,
|
"trace_fn": fwd_only,
|
||||||
"scalar_workaround": workaround,
|
"pass_dicts": patterns,
|
||||||
# with dropout turned into clone, we end up with a number of
|
"extra_check": extra_check,
|
||||||
# semantically identical graphs
|
"scalar_workaround": workaround,
|
||||||
"skip_duplicates": True,
|
# with dropout turned into clone, we end up with a number of
|
||||||
}
|
# semantically identical graphs
|
||||||
|
"skip_duplicates": True,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
|
|
|
||||||
|
|
@ -271,7 +271,9 @@ class PostGradBatchLinearFusion(BatchFusion):
|
||||||
args=(batch_biases[i],),
|
args=(batch_biases[i],),
|
||||||
kwargs={"size": broadcast_shape},
|
kwargs={"size": broadcast_shape},
|
||||||
)
|
)
|
||||||
broadcast_bias.meta["val"] = aten.broadcast_to(batch_biases_meta[i]["val"], broadcast_shape) # type: ignore[assignment]
|
broadcast_bias.meta["val"] = aten.broadcast_to(
|
||||||
|
batch_biases_meta[i]["val"], broadcast_shape
|
||||||
|
) # type: ignore[assignment]
|
||||||
new_bias_add = graph.call_function( # type: ignore[operator]
|
new_bias_add = graph.call_function( # type: ignore[operator]
|
||||||
aten.add.Tensor, args=((broadcast_bias, new_mm))
|
aten.add.Tensor, args=((broadcast_bias, new_mm))
|
||||||
)
|
)
|
||||||
|
|
@ -803,9 +805,9 @@ class BatchLayernormFusion(BatchFusion):
|
||||||
group_biases = None # type: ignore[assignment]
|
group_biases = None # type: ignore[assignment]
|
||||||
if all(weight is None for weight in group_weights):
|
if all(weight is None for weight in group_weights):
|
||||||
group_weights = None # type: ignore[assignment]
|
group_weights = None # type: ignore[assignment]
|
||||||
assert all(
|
assert all(eps == group_epss[0] for eps in group_epss), (
|
||||||
eps == group_epss[0] for eps in group_epss
|
"all epsilon values must be equal"
|
||||||
), "all epsilon values must be equal"
|
)
|
||||||
|
|
||||||
with graph.inserting_before(subset[0]): # type: ignore[operator]
|
with graph.inserting_before(subset[0]): # type: ignore[operator]
|
||||||
stack_input = graph.call_function( # type: ignore[operator]
|
stack_input = graph.call_function( # type: ignore[operator]
|
||||||
|
|
@ -996,7 +998,11 @@ class BatchPointwiseOpsPostGradFusion(BatchPointwiseOpsFusionFactory):
|
||||||
# for relu op, we also use the inplace to construct the key
|
# for relu op, we also use the inplace to construct the key
|
||||||
# we batch the ops with same parent to enable followup split cat
|
# we batch the ops with same parent to enable followup split cat
|
||||||
parent = node.args[0]
|
parent = node.args[0]
|
||||||
parent = parent.target if self.graph_search_options.get("fuse_nodes_with_same_parent", False) else "" # type: ignore[union-attr]
|
parent = (
|
||||||
|
parent.target # type: ignore[union-attr]
|
||||||
|
if self.graph_search_options.get("fuse_nodes_with_same_parent", False)
|
||||||
|
else ""
|
||||||
|
)
|
||||||
group_key = (
|
group_key = (
|
||||||
"batch_aten_" + self.op.__name__.lower().split(".")[0],
|
"batch_aten_" + self.op.__name__.lower().split(".")[0],
|
||||||
str(input.meta["val"].shape),
|
str(input.meta["val"].shape),
|
||||||
|
|
@ -1293,9 +1299,9 @@ def get_fusion_candidates(
|
||||||
"""
|
"""
|
||||||
q: collections.deque[tuple[int, torch.fx.Node]] = collections.deque()
|
q: collections.deque[tuple[int, torch.fx.Node]] = collections.deque()
|
||||||
|
|
||||||
candidate_dict: collections.defaultdict[
|
candidate_dict: collections.defaultdict[Any, list[torch.fx.Node]] = (
|
||||||
Any, list[torch.fx.Node]
|
collections.defaultdict(list)
|
||||||
] = collections.defaultdict(list)
|
)
|
||||||
|
|
||||||
if root_node.target in SEARCH_EXCLUSIONS:
|
if root_node.target in SEARCH_EXCLUSIONS:
|
||||||
return candidate_dict
|
return candidate_dict
|
||||||
|
|
|
||||||
|
|
@ -763,9 +763,7 @@ def _get_node_to_ancestors(
|
||||||
"""
|
"""
|
||||||
Compute the ancestors for all nodes in a graph.
|
Compute the ancestors for all nodes in a graph.
|
||||||
"""
|
"""
|
||||||
node_to_ancestors = defaultdict(
|
node_to_ancestors = defaultdict(OrderedSet[torch.fx.Node]) # type: ignore[var-annotated]
|
||||||
OrderedSet[torch.fx.Node]
|
|
||||||
) # type: ignore[var-annotated]
|
|
||||||
for node in graph.nodes:
|
for node in graph.nodes:
|
||||||
node_to_ancestors[node] = OrderedSet(node.all_input_nodes)
|
node_to_ancestors[node] = OrderedSet(node.all_input_nodes)
|
||||||
for dep in node.all_input_nodes:
|
for dep in node.all_input_nodes:
|
||||||
|
|
|
||||||
|
|
@ -558,9 +558,9 @@ if torch._C._has_mkldnn:
|
||||||
binary_nodes = filter_nodes(match.nodes, binary_op)
|
binary_nodes = filter_nodes(match.nodes, binary_op)
|
||||||
|
|
||||||
def _get_compute_node(_binary_node, _other_index):
|
def _get_compute_node(_binary_node, _other_index):
|
||||||
assert (
|
assert len(_binary_node.all_input_nodes) == 2, (
|
||||||
len(_binary_node.all_input_nodes) == 2
|
"Binary node should have 2 input nodes."
|
||||||
), "Binary node should have 2 input nodes."
|
)
|
||||||
_compute_index = 1 if (_other_index == 0) else 0
|
_compute_index = 1 if (_other_index == 0) else 0
|
||||||
return _binary_node.args[_compute_index]
|
return _binary_node.args[_compute_index]
|
||||||
|
|
||||||
|
|
@ -614,9 +614,9 @@ if torch._C._has_mkldnn:
|
||||||
else:
|
else:
|
||||||
computation_args += [1.0, None, [], None]
|
computation_args += [1.0, None, [], None]
|
||||||
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1
|
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1
|
||||||
counters["inductor"][
|
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"] += (
|
||||||
"mkldnn_conv_binary_unary_fusion_matcher_nodes"
|
len(match.nodes)
|
||||||
] += len(match.nodes)
|
)
|
||||||
return L[fusion_op](*computation_args)
|
return L[fusion_op](*computation_args)
|
||||||
|
|
||||||
return fn
|
return fn
|
||||||
|
|
@ -659,9 +659,9 @@ if torch._C._has_mkldnn:
|
||||||
else:
|
else:
|
||||||
computation_args += [1.0, None, [], None]
|
computation_args += [1.0, None, [], None]
|
||||||
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1
|
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_count"] += 1
|
||||||
counters["inductor"][
|
counters["inductor"]["mkldnn_conv_binary_unary_fusion_matcher_nodes"] += (
|
||||||
"mkldnn_conv_binary_unary_fusion_matcher_nodes"
|
len(match.nodes)
|
||||||
] += len(match.nodes)
|
)
|
||||||
# Make sure the other is not an alias or mutation(fx side doesn't has such info).
|
# Make sure the other is not an alias or mutation(fx side doesn't has such info).
|
||||||
other.realize()
|
other.realize()
|
||||||
if not _can_be_inplace(other) or other.data.shape != list(
|
if not _can_be_inplace(other) or other.data.shape != list(
|
||||||
|
|
@ -1310,9 +1310,9 @@ if torch._C._has_mkldnn:
|
||||||
)
|
)
|
||||||
batch_size = input.meta.get("val").shape[0]
|
batch_size = input.meta.get("val").shape[0]
|
||||||
if has_free_symbols(batch_size):
|
if has_free_symbols(batch_size):
|
||||||
assert (
|
assert is_lp_weight or mkldnn._is_mkldnn_acl_supported(), (
|
||||||
is_lp_weight or mkldnn._is_mkldnn_acl_supported()
|
f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
|
||||||
), f"only bf16/fp16 weight prepacking supports dynamic shape inputs but got {weight_dtype}"
|
)
|
||||||
# For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance.
|
# For bfloat16 dynamic shape path, using input size hint to pack weight for a better performance.
|
||||||
packed_weight_inputs = (
|
packed_weight_inputs = (
|
||||||
transpose_weight_node,
|
transpose_weight_node,
|
||||||
|
|
|
||||||
|
|
@ -437,7 +437,7 @@ def _should_pad_bench(
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def realize_symbols(
|
def realize_symbols(
|
||||||
ds: Union[torch.Size, tuple[torch.SymInt, ...]]
|
ds: Union[torch.Size, tuple[torch.SymInt, ...]],
|
||||||
) -> list[int]:
|
) -> list[int]:
|
||||||
return [d if isinstance(d, int) else d.node.hint for d in ds]
|
return [d if isinstance(d, int) else d.node.hint for d in ds]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -137,9 +137,9 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
||||||
pattern_matcher_pass.apply
|
pattern_matcher_pass.apply
|
||||||
)
|
)
|
||||||
if not is_same_dict(counters["inductor"], inductor_before_change):
|
if not is_same_dict(counters["inductor"], inductor_before_change):
|
||||||
optimus_scuba_log[
|
optimus_scuba_log[f"{pattern_matcher_pass.pass_name}_post_grad"] = (
|
||||||
f"{pattern_matcher_pass.pass_name}_post_grad"
|
upload_graph(gm.graph)
|
||||||
] = upload_graph(gm.graph)
|
)
|
||||||
if config.b2b_gemm_pass:
|
if config.b2b_gemm_pass:
|
||||||
B2B_GEMM_PASS.apply(gm.graph) # type: ignore[arg-type]
|
B2B_GEMM_PASS.apply(gm.graph) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -277,9 +277,9 @@ def pre_grad_passes(
|
||||||
for _ in range(counter):
|
for _ in range(counter):
|
||||||
pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
|
pattern_matcher_pass.apply(gm.graph) # type: ignore[arg-type]
|
||||||
if not is_same_dict(counters["inductor"], inductor_before_change):
|
if not is_same_dict(counters["inductor"], inductor_before_change):
|
||||||
optimus_scuba_log[
|
optimus_scuba_log[f"{pattern_matcher_pass.pass_name}_pre_grad"] = (
|
||||||
f"{pattern_matcher_pass.pass_name}_pre_grad"
|
upload_graph(gm.graph)
|
||||||
] = upload_graph(gm.graph)
|
)
|
||||||
# TODO: move efficient_conv_bn_eval_pass to the fusions dict too.
|
# TODO: move efficient_conv_bn_eval_pass to the fusions dict too.
|
||||||
efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type]
|
efficient_conv_bn_eval_pass.apply(gm.graph) # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -763,9 +763,9 @@ def _register_quantized_conv_binary_lowering(
|
||||||
accum.realize()
|
accum.realize()
|
||||||
from .mkldnn_fusion import _can_be_inplace
|
from .mkldnn_fusion import _can_be_inplace
|
||||||
|
|
||||||
assert _can_be_inplace(
|
assert _can_be_inplace(accum), (
|
||||||
accum
|
"QConv Binary Inplace Fusion requires accum is not an alias or mutation."
|
||||||
), "QConv Binary Inplace Fusion requires accum is not an alias or mutation."
|
)
|
||||||
|
|
||||||
computation_args = (
|
computation_args = (
|
||||||
x,
|
x,
|
||||||
|
|
@ -1307,9 +1307,9 @@ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
|
||||||
def clone_to_new_node(graph, source_node, user_node):
|
def clone_to_new_node(graph, source_node, user_node):
|
||||||
# Clone the source_node to a new node
|
# Clone the source_node to a new node
|
||||||
# Replace user_node's input from source_node to new_node
|
# Replace user_node's input from source_node to new_node
|
||||||
assert (
|
assert source_node.op == "call_function", (
|
||||||
source_node.op == "call_function"
|
"clone_to_new_node only support node.op call_function"
|
||||||
), "clone_to_new_node only support node.op call_function"
|
)
|
||||||
with graph.inserting_before(user_node):
|
with graph.inserting_before(user_node):
|
||||||
new_node = graph.call_function(
|
new_node = graph.call_function(
|
||||||
source_node.target,
|
source_node.target,
|
||||||
|
|
@ -1343,9 +1343,9 @@ def _register_dequant_promotion_pass(pattern, pass_number, dtype=torch.float32):
|
||||||
# For a dequant pattern, we expect the start node is a dequantize_per_tensor node
|
# For a dequant pattern, we expect the start node is a dequantize_per_tensor node
|
||||||
return _node
|
return _node
|
||||||
else:
|
else:
|
||||||
assert (
|
assert len(_node.args) >= 1, (
|
||||||
len(_node.args) >= 1
|
"In in dequant pattern, each node should have more than 1 arg."
|
||||||
), "In in dequant pattern, each node should have more than 1 arg."
|
)
|
||||||
return _find_first_node_in_dequant_pattern(_node.args[0])
|
return _find_first_node_in_dequant_pattern(_node.args[0])
|
||||||
|
|
||||||
dequant_pattern_start_node = _find_first_node_in_dequant_pattern(
|
dequant_pattern_start_node = _find_first_node_in_dequant_pattern(
|
||||||
|
|
|
||||||
|
|
@ -616,7 +616,8 @@ def merge_splits(
|
||||||
dim=first_split_dim,
|
dim=first_split_dim,
|
||||||
)
|
)
|
||||||
first_split_num_to_user = {
|
first_split_num_to_user = {
|
||||||
user.args[1]: user for user in first_split.users.keys() # type: ignore[union-attr]
|
user.args[1]: user
|
||||||
|
for user in first_split.users.keys() # type: ignore[union-attr]
|
||||||
}
|
}
|
||||||
|
|
||||||
new_split_num = 0
|
new_split_num = 0
|
||||||
|
|
@ -706,7 +707,11 @@ class SplitCatSimplifier:
|
||||||
graph, split_node, split_sections, user_inputs_list, simplified_split_ranges
|
graph, split_node, split_sections, user_inputs_list, simplified_split_ranges
|
||||||
)
|
)
|
||||||
self.replace_cat(
|
self.replace_cat(
|
||||||
graph, split_node, next_users, user_inputs_list_new, transform_params_list # type: ignore[arg-type]
|
graph,
|
||||||
|
split_node,
|
||||||
|
next_users,
|
||||||
|
user_inputs_list_new,
|
||||||
|
transform_params_list, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type]
|
self.erase_old_nodes(graph, split_node, next_users) # type: ignore[arg-type]
|
||||||
counters["inductor"]["unbind_stack_pass"] += 1
|
counters["inductor"]["unbind_stack_pass"] += 1
|
||||||
|
|
@ -913,7 +918,9 @@ class SplitCatSimplifier:
|
||||||
)
|
)
|
||||||
if is_node_meta_valid(split_input): # type: ignore[arg-type, union-attr]
|
if is_node_meta_valid(split_input): # type: ignore[arg-type, union-attr]
|
||||||
new_split.meta["example_value"] = torch.split(
|
new_split.meta["example_value"] = torch.split(
|
||||||
split_input.meta["example_value"], [r[1] - r[0] for r in split_ranges], dim=split_dim # type: ignore[union-attr]
|
split_input.meta["example_value"], # type: ignore[union-attr]
|
||||||
|
[r[1] - r[0] for r in split_ranges],
|
||||||
|
dim=split_dim,
|
||||||
)
|
)
|
||||||
counters["inductor"]["scmerge_split_added"] += 1
|
counters["inductor"]["scmerge_split_added"] += 1
|
||||||
split_items = []
|
split_items = []
|
||||||
|
|
@ -1005,7 +1012,10 @@ class SplitCatSimplifier:
|
||||||
stacked_input = graph.call_function(
|
stacked_input = graph.call_function(
|
||||||
torch.stack, args=(to_stack,), kwargs={"dim": stack_dim}
|
torch.stack, args=(to_stack,), kwargs={"dim": stack_dim}
|
||||||
)
|
)
|
||||||
stacked_input.meta["example_value"] = torch.stack(to_stack_meta, dim=stack_dim) # type: ignore[arg-type, union-attr]
|
stacked_input.meta["example_value"] = torch.stack( # type: ignore[arg-type]
|
||||||
|
to_stack_meta,
|
||||||
|
dim=stack_dim, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
to_stack, to_stack_meta = [], []
|
to_stack, to_stack_meta = [], []
|
||||||
stack_dim = None
|
stack_dim = None
|
||||||
user_inputs_new_transformed.append(stacked_input)
|
user_inputs_new_transformed.append(stacked_input)
|
||||||
|
|
@ -1023,19 +1033,28 @@ class SplitCatSimplifier:
|
||||||
user_input_new = graph.call_function(
|
user_input_new = graph.call_function(
|
||||||
torch.unflatten, args=(user_input_new, *unflatten_params)
|
torch.unflatten, args=(user_input_new, *unflatten_params)
|
||||||
)
|
)
|
||||||
user_input_new.meta["example_value"] = torch.unflatten(user_input_new_meta, *unflatten_params) # type: ignore[arg-type, possibly-undefined, union-attr]
|
user_input_new.meta["example_value"] = torch.unflatten( # type: ignore[arg-type]
|
||||||
|
user_input_new_meta, # type: ignore[arg-type]
|
||||||
|
*unflatten_params, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
if movedim_params:
|
if movedim_params:
|
||||||
user_input_new_meta = user_input_new.meta["example_value"]
|
user_input_new_meta = user_input_new.meta["example_value"]
|
||||||
user_input_new = graph.call_function(
|
user_input_new = graph.call_function(
|
||||||
torch.movedim, args=(user_input_new, *movedim_params)
|
torch.movedim, args=(user_input_new, *movedim_params)
|
||||||
)
|
)
|
||||||
user_input_new.meta["example_value"] = torch.movedim(user_input_new_meta, *movedim_params) # type: ignore[arg-type, possibly-undefined, union-attr]
|
user_input_new.meta["example_value"] = torch.movedim( # type: ignore[arg-type]
|
||||||
|
user_input_new_meta, # type: ignore[arg-type]
|
||||||
|
*movedim_params, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
if flatten_params:
|
if flatten_params:
|
||||||
user_input_new_meta = user_input_new.meta["example_value"]
|
user_input_new_meta = user_input_new.meta["example_value"]
|
||||||
user_input_new = graph.call_function(
|
user_input_new = graph.call_function(
|
||||||
torch.flatten, args=(user_input_new, *flatten_params)
|
torch.flatten, args=(user_input_new, *flatten_params)
|
||||||
)
|
)
|
||||||
user_input_new.meta["example_value"] = torch.flatten(user_input_new_meta, *flatten_params) # type: ignore[arg-type, possibly-undefined, union-attr]
|
user_input_new.meta["example_value"] = torch.flatten( # type: ignore[arg-type]
|
||||||
|
user_input_new_meta,
|
||||||
|
*flatten_params, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
user_inputs_new_transformed.append(user_input_new)
|
user_inputs_new_transformed.append(user_input_new)
|
||||||
user_inputs_new_transformed_meta.append(
|
user_inputs_new_transformed_meta.append(
|
||||||
user_input_new.meta["example_value"]
|
user_input_new.meta["example_value"]
|
||||||
|
|
@ -1044,7 +1063,10 @@ class SplitCatSimplifier:
|
||||||
stacked_input = graph.call_function(
|
stacked_input = graph.call_function(
|
||||||
torch.stack, args=(to_stack,), kwargs={"dim": stack_dim}
|
torch.stack, args=(to_stack,), kwargs={"dim": stack_dim}
|
||||||
)
|
)
|
||||||
stacked_input.meta["example_value"] = torch.stack(to_stack_meta, dim=stack_dim) # type: ignore[arg-type, union-attr]
|
stacked_input.meta["example_value"] = torch.stack( # type: ignore[arg-type]
|
||||||
|
to_stack_meta,
|
||||||
|
dim=stack_dim, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
user_inputs_new_transformed.append(stacked_input)
|
user_inputs_new_transformed.append(stacked_input)
|
||||||
user_inputs_new_transformed_meta.append(
|
user_inputs_new_transformed_meta.append(
|
||||||
stacked_input.meta["example_value"]
|
stacked_input.meta["example_value"]
|
||||||
|
|
@ -1058,14 +1080,15 @@ class SplitCatSimplifier:
|
||||||
kwargs={"dim": cat_dim},
|
kwargs={"dim": cat_dim},
|
||||||
)
|
)
|
||||||
new_cat_node.meta["example_value"] = torch.cat(
|
new_cat_node.meta["example_value"] = torch.cat(
|
||||||
user_inputs_new_transformed_meta, dim=cat_dim
|
user_inputs_new_transformed_meta,
|
||||||
|
dim=cat_dim,
|
||||||
)
|
)
|
||||||
counters["inductor"]["scmerge_cat_added"] += 1
|
counters["inductor"]["scmerge_cat_added"] += 1
|
||||||
else:
|
else:
|
||||||
new_cat_node = user_inputs_new_transformed[-1]
|
new_cat_node = user_inputs_new_transformed[-1]
|
||||||
new_cat_node.meta[
|
new_cat_node.meta["example_value"] = (
|
||||||
"example_value"
|
user_inputs_new_transformed_meta[-1]
|
||||||
] = user_inputs_new_transformed_meta[-1]
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
user_node.target == torch.cat
|
user_node.target == torch.cat
|
||||||
|
|
@ -1077,7 +1100,11 @@ class SplitCatSimplifier:
|
||||||
new_cat_node = graph.call_function(
|
new_cat_node = graph.call_function(
|
||||||
torch.flatten, args=(new_cat_node, cat_dim, cat_dim + 1)
|
torch.flatten, args=(new_cat_node, cat_dim, cat_dim + 1)
|
||||||
)
|
)
|
||||||
new_cat_node.meta["example_value"] = torch.flatten(new_cat_node_meta, cat_dim, cat_dim + 1) # type: ignore[possibly-undefined, union-attr]
|
new_cat_node.meta["example_value"] = torch.flatten(
|
||||||
|
new_cat_node_meta,
|
||||||
|
cat_dim,
|
||||||
|
cat_dim + 1,
|
||||||
|
)
|
||||||
user_node.replace_all_uses_with(new_cat_node)
|
user_node.replace_all_uses_with(new_cat_node)
|
||||||
new_cats.append(new_cat_node)
|
new_cats.append(new_cat_node)
|
||||||
|
|
||||||
|
|
@ -1123,9 +1150,7 @@ class UnbindCatRemover(SplitCatSimplifier):
|
||||||
]
|
]
|
||||||
if not is_sorted_and_consecutive(getitem_indices) or len( # type: ignore[arg-type]
|
if not is_sorted_and_consecutive(getitem_indices) or len( # type: ignore[arg-type]
|
||||||
getitem_indices
|
getitem_indices
|
||||||
) != len(
|
) != len(unbind_node.meta["example_value"]):
|
||||||
unbind_node.meta["example_value"]
|
|
||||||
):
|
|
||||||
return
|
return
|
||||||
num_unbind = len(getitem_indices)
|
num_unbind = len(getitem_indices)
|
||||||
split_sections = [1 for _ in range(num_unbind)] # type: ignore[operator, arg-type]
|
split_sections = [1 for _ in range(num_unbind)] # type: ignore[operator, arg-type]
|
||||||
|
|
@ -1510,7 +1535,8 @@ def merge_getitem_cat(match: Match, split_sections: list[int], dim: int):
|
||||||
fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index]
|
fused_tensor_size += split_node.args[1][i] # type: ignore[operator, assignment, index]
|
||||||
# update the split sections
|
# update the split sections
|
||||||
split_sections[indices[0]] = calculate_fused_tensor_size( # type: ignore[index]
|
split_sections[indices[0]] = calculate_fused_tensor_size( # type: ignore[index]
|
||||||
split_node, indices # type: ignore[arg-type]
|
split_node,
|
||||||
|
indices, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
# padding others with zeros to keep the same dict size
|
# padding others with zeros to keep the same dict size
|
||||||
for i in indices[1:]:
|
for i in indices[1:]:
|
||||||
|
|
@ -1613,10 +1639,12 @@ def mutate_cat_node(match: Match, split_sections: list[int], dim: int):
|
||||||
elif is_node_meta_valid(split_node.args[0]): # type: ignore[arg-type]
|
elif is_node_meta_valid(split_node.args[0]): # type: ignore[arg-type]
|
||||||
# check the split dim, and construct the slice tuple
|
# check the split dim, and construct the slice tuple
|
||||||
start_fused_size = calculate_fused_tensor_size(
|
start_fused_size = calculate_fused_tensor_size(
|
||||||
split_node, list(range(indices[0])) # type: ignore[arg-type]
|
split_node,
|
||||||
|
list(range(indices[0])), # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
end_fused_size = start_fused_size + calculate_fused_tensor_size(
|
end_fused_size = start_fused_size + calculate_fused_tensor_size(
|
||||||
split_node, indices # type: ignore[arg-type]
|
split_node,
|
||||||
|
indices, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
slice_list = []
|
slice_list = []
|
||||||
for i in range(len(split_node.args[0].meta["example_value"].shape)): # type: ignore[union-attr]
|
for i in range(len(split_node.args[0].meta["example_value"].shape)): # type: ignore[union-attr]
|
||||||
|
|
@ -1714,7 +1742,10 @@ def merge_split_cat_aten(match: Match, *args, **kwargs):
|
||||||
continue
|
continue
|
||||||
# check the cat node has consecutive indices
|
# check the cat node has consecutive indices
|
||||||
indices = [arg.args[1] for arg in cat_node.args[0]] # type: ignore[union-attr]
|
indices = [arg.args[1] for arg in cat_node.args[0]] # type: ignore[union-attr]
|
||||||
if not is_sorted_and_consecutive(indices) and len(getitem_nodes) != len(cat_inputs): # type: ignore[arg-type]
|
if (
|
||||||
|
not is_sorted_and_consecutive(indices) # type: ignore[arg-type]
|
||||||
|
and len(getitem_nodes) != len(cat_inputs)
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
# replace the users of the cat node to be the input of the split node
|
# replace the users of the cat node to be the input of the split node
|
||||||
cat_node.replace_all_uses_with(split_input)
|
cat_node.replace_all_uses_with(split_input)
|
||||||
|
|
@ -1764,7 +1795,10 @@ def merge_select_cat_aten(match: Match, *args, **kwargs):
|
||||||
continue
|
continue
|
||||||
# check the cat node has consecutive indices
|
# check the cat node has consecutive indices
|
||||||
indices = [select.args[2] for select in cat_node.args[0]] # type: ignore[union-attr]
|
indices = [select.args[2] for select in cat_node.args[0]] # type: ignore[union-attr]
|
||||||
if not is_sorted_and_consecutive(indices) or len(select_nodes) != len(cat_inputs): # type: ignore[arg-type]
|
if (
|
||||||
|
not is_sorted_and_consecutive(indices) # type: ignore[arg-type]
|
||||||
|
or len(select_nodes) != len(cat_inputs)
|
||||||
|
):
|
||||||
continue
|
continue
|
||||||
# check all the select nodes can be merged to the cat node input
|
# check all the select nodes can be merged to the cat node input
|
||||||
if len(indices) != select_nodes[0].args[0].meta["val"].shape[cat_dim]: # type: ignore[union-attr]
|
if len(indices) != select_nodes[0].args[0].meta["val"].shape[cat_dim]: # type: ignore[union-attr]
|
||||||
|
|
@ -2318,7 +2352,9 @@ def unbind_cat_to_view(match: Match, unbind_input: torch.fx.Node, dim: int):
|
||||||
args=(new_cat_args,),
|
args=(new_cat_args,),
|
||||||
kwargs={"dim": cat_dim},
|
kwargs={"dim": cat_dim},
|
||||||
)
|
)
|
||||||
new_cat_node.meta["example_value"] = torch.cat(new_cat_args_meta, dim=cat_dim) # type: ignore[arg-type]
|
new_cat_node.meta["example_value"] = torch.cat(
|
||||||
|
new_cat_args_meta, dim=cat_dim
|
||||||
|
) # type: ignore[arg-type]
|
||||||
cat_node.replace_all_uses_with(new_cat_node)
|
cat_node.replace_all_uses_with(new_cat_node)
|
||||||
new_cat_node.meta.update(cat_node.meta)
|
new_cat_node.meta.update(cat_node.meta)
|
||||||
# remove inputs of cat_node if they have no users
|
# remove inputs of cat_node if they have no users
|
||||||
|
|
@ -2411,7 +2447,8 @@ def convert_reshape_cat_arg_to_stack(
|
||||||
args=(permute_node, tuple(stack_node_shape)), # type: ignore[arg-type]
|
args=(permute_node, tuple(stack_node_shape)), # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
reshape_node.meta["example_value"] = torch.Tensor.view(
|
reshape_node.meta["example_value"] = torch.Tensor.view(
|
||||||
permute_node.meta["example_value"], tuple(stack_node_shape) # type: ignore[arg-type]
|
permute_node.meta["example_value"],
|
||||||
|
tuple(stack_node_shape), # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
return reshape_node
|
return reshape_node
|
||||||
|
|
||||||
|
|
@ -2687,7 +2724,9 @@ def move_reshape_out_of_split_stack(match: Match, *args, **kwargs):
|
||||||
cat_inputs.append(decomposed_stack_node)
|
cat_inputs.append(decomposed_stack_node)
|
||||||
# cat_arg must be the split input
|
# cat_arg must be the split input
|
||||||
view_shape_list = get_view_shape_list(cat_arg, stack_dim)
|
view_shape_list = get_view_shape_list(cat_arg, stack_dim)
|
||||||
stack_node_shape = torch.reshape(cat_arg.meta["example_value"], tuple(view_shape_list)).shape # type: ignore[union-attr]
|
stack_node_shape = torch.reshape(
|
||||||
|
cat_arg.meta["example_value"], tuple(view_shape_list)
|
||||||
|
).shape # type: ignore[union-attr]
|
||||||
cat_inputs.append(
|
cat_inputs.append(
|
||||||
convert_reshape_cat_arg_to_stack(
|
convert_reshape_cat_arg_to_stack(
|
||||||
graph,
|
graph,
|
||||||
|
|
|
||||||
|
|
@ -105,9 +105,9 @@ class FakeTensorUpdater:
|
||||||
if new is None:
|
if new is None:
|
||||||
return old is None
|
return old is None
|
||||||
if not isinstance(new, torch.Tensor):
|
if not isinstance(new, torch.Tensor):
|
||||||
assert isinstance(
|
assert isinstance(new, (torch.SymInt, torch.SymBool, torch.SymFloat)), (
|
||||||
new, (torch.SymInt, torch.SymBool, torch.SymFloat)
|
f"Unknown type {type(new)} in {self.graph}"
|
||||||
), f"Unknown type {type(new)} in {self.graph}"
|
)
|
||||||
return (
|
return (
|
||||||
new.node.shape_env._maybe_evaluate_static(
|
new.node.shape_env._maybe_evaluate_static(
|
||||||
sympy.Eq(new.node.expr, old.node.expr)
|
sympy.Eq(new.node.expr, old.node.expr)
|
||||||
|
|
|
||||||
|
|
@ -136,7 +136,9 @@ else:
|
||||||
def may_get_constant_buffer_dtype(constant_buffer: sympy.Expr) -> Optional[torch.dtype]:
|
def may_get_constant_buffer_dtype(constant_buffer: sympy.Expr) -> Optional[torch.dtype]:
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
|
constant_buffer, (sympy.Symbol, sympy.Expr, sympy.core.numbers.Integer)
|
||||||
), "get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer"
|
), (
|
||||||
|
"get_constant_buffer_dtype only supports input of sympy.Symbol, sympy.Expr or sympy.core.numbers.Integer"
|
||||||
|
)
|
||||||
if isinstance(constant_buffer, sympy.core.numbers.Integer):
|
if isinstance(constant_buffer, sympy.core.numbers.Integer):
|
||||||
return torch.int64
|
return torch.int64
|
||||||
|
|
||||||
|
|
@ -308,9 +310,9 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
self.reuse_shape_env = True
|
self.reuse_shape_env = True
|
||||||
self._shape_env = shape_env
|
self._shape_env = shape_env
|
||||||
# We're going to mutate ras_by_symbol as we finish generating them
|
# We're going to mutate ras_by_symbol as we finish generating them
|
||||||
self.ras_by_symbol: dict[
|
self.ras_by_symbol: dict[Optional[sympy.Symbol], list[RuntimeAssert]] = (
|
||||||
Optional[sympy.Symbol], list[RuntimeAssert]
|
shape_env.deferred_runtime_asserts.copy()
|
||||||
] = shape_env.deferred_runtime_asserts.copy()
|
)
|
||||||
self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]()
|
self.bound_unbacked_symbols = OrderedSet[sympy.Symbol]()
|
||||||
self.sizevars = SizeVarAllocator(shape_env)
|
self.sizevars = SizeVarAllocator(shape_env)
|
||||||
self.graph_input_names: list[str] = []
|
self.graph_input_names: list[str] = []
|
||||||
|
|
@ -400,9 +402,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
|
self.cache_path: str = "" # This is the path in the filesystem where the compiled artifact is stored
|
||||||
self.cache_linemap: list[
|
self.cache_linemap: list[
|
||||||
tuple[int, str]
|
tuple[int, str]
|
||||||
] = (
|
] = [] # This is the linemap used by the profiler to mark custom compiled kernels getting run
|
||||||
[]
|
|
||||||
) # This is the linemap used by the profiler to mark custom compiled kernels getting run
|
|
||||||
# Used if lowering encounters cases where cudagraphs are not supported
|
# Used if lowering encounters cases where cudagraphs are not supported
|
||||||
self.disable_cudagraphs_reason: Optional[str] = None
|
self.disable_cudagraphs_reason: Optional[str] = None
|
||||||
|
|
||||||
|
|
@ -1012,7 +1012,10 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
)
|
)
|
||||||
|
|
||||||
def placeholder(
|
def placeholder(
|
||||||
self, target: str, args: tuple[object], kwargs: dict[str, object] # type: ignore[override]
|
self,
|
||||||
|
target: str, # type: ignore[override]
|
||||||
|
args: tuple[object], # type: ignore[override]
|
||||||
|
kwargs: dict[str, object],
|
||||||
) -> Union[Expr, TensorBox, None]:
|
) -> Union[Expr, TensorBox, None]:
|
||||||
self.placeholder_idx += 1
|
self.placeholder_idx += 1
|
||||||
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
|
example = super().placeholder(target, args, kwargs) # type: ignore[arg-type]
|
||||||
|
|
@ -1118,9 +1121,9 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
return target(*args, **kwargs)
|
return target(*args, **kwargs)
|
||||||
|
|
||||||
if target not in lowerings:
|
if target not in lowerings:
|
||||||
assert isinstance(
|
assert isinstance(target, torch._ops.OpOverload), (
|
||||||
target, torch._ops.OpOverload
|
f"{target} is not an OpOverload"
|
||||||
), f"{target} is not an OpOverload"
|
)
|
||||||
base_name = target.name().split(".")[0]
|
base_name = target.name().split(".")[0]
|
||||||
if base_name in FALLBACK_ALLOW_LIST:
|
if base_name in FALLBACK_ALLOW_LIST:
|
||||||
make_fallback(target, warn=False, override_decomp=True)
|
make_fallback(target, warn=False, override_decomp=True)
|
||||||
|
|
@ -1189,7 +1192,10 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
return len(t.shape) == 1 and t.shape[0] <= 8
|
return len(t.shape) == 1 and t.shape[0] <= 8
|
||||||
|
|
||||||
def get_attr(
|
def get_attr(
|
||||||
self, target: str, args: tuple[()], kwargs: dict[str, object] # type: ignore[override]
|
self,
|
||||||
|
target: str, # type: ignore[override]
|
||||||
|
args: tuple[()], # type: ignore[override]
|
||||||
|
kwargs: dict[str, object],
|
||||||
) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
|
) -> Union[Constant, TensorBox, ir.Subgraph, TorchBindObject]:
|
||||||
# this is a constant
|
# this is a constant
|
||||||
value = getattr_recursive(self.module, target) # type: ignore[arg-type]
|
value = getattr_recursive(self.module, target) # type: ignore[arg-type]
|
||||||
|
|
@ -1241,7 +1247,10 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
raise AssertionError
|
raise AssertionError
|
||||||
|
|
||||||
def output(
|
def output(
|
||||||
self, target: str, args: tuple[object], kwargs: dict[str, object] # type: ignore[override]
|
self,
|
||||||
|
target: str, # type: ignore[override]
|
||||||
|
args: tuple[object], # type: ignore[override]
|
||||||
|
kwargs: dict[str, object],
|
||||||
) -> None:
|
) -> None:
|
||||||
result = super().output(target, args, kwargs) # type: ignore[arg-type]
|
result = super().output(target, args, kwargs) # type: ignore[arg-type]
|
||||||
if not isinstance(result, (tuple, list)):
|
if not isinstance(result, (tuple, list)):
|
||||||
|
|
@ -1439,9 +1448,11 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
if is_call_function:
|
if is_call_function:
|
||||||
args, kwargs = self.fetch_args_kwargs_from_env(n)
|
args, kwargs = self.fetch_args_kwargs_from_env(n)
|
||||||
origins |= gather_origins(args, kwargs)
|
origins |= gather_origins(args, kwargs)
|
||||||
with ir.IRNode.current_origins(origins), self.set_current_node(
|
with (
|
||||||
n
|
ir.IRNode.current_origins(origins),
|
||||||
), V.set_current_node(n):
|
self.set_current_node(n),
|
||||||
|
V.set_current_node(n),
|
||||||
|
):
|
||||||
if (
|
if (
|
||||||
n.op == "call_function"
|
n.op == "call_function"
|
||||||
and n.target is not operator.getitem
|
and n.target is not operator.getitem
|
||||||
|
|
@ -1454,7 +1465,8 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
):
|
):
|
||||||
debug("fallback_handler")
|
debug("fallback_handler")
|
||||||
result = fallback_handler(n.target, add_to_fallback_set=False)(
|
result = fallback_handler(n.target, add_to_fallback_set=False)(
|
||||||
*args, **kwargs # type: ignore[possibly-undefined]
|
*args, # type: ignore[possibly-undefined]
|
||||||
|
**kwargs, # type: ignore[possibly-undefined]
|
||||||
)
|
)
|
||||||
elif (
|
elif (
|
||||||
n.op == "call_function"
|
n.op == "call_function"
|
||||||
|
|
@ -1833,9 +1845,9 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
wrapper_code_gen_cls = get_wrapper_codegen_for_device(
|
wrapper_code_gen_cls = get_wrapper_codegen_for_device(
|
||||||
self.device_type, self.cpp_wrapper
|
self.device_type, self.cpp_wrapper
|
||||||
)
|
)
|
||||||
assert (
|
assert wrapper_code_gen_cls is not None, (
|
||||||
wrapper_code_gen_cls is not None
|
f"Device {self.device_type} not supported"
|
||||||
), f"Device {self.device_type} not supported"
|
)
|
||||||
self.wrapper_code = wrapper_code_gen_cls.create(
|
self.wrapper_code = wrapper_code_gen_cls.create(
|
||||||
is_subgraph,
|
is_subgraph,
|
||||||
subgraph_name,
|
subgraph_name,
|
||||||
|
|
@ -1866,7 +1878,7 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
compiled = self.compile_to_module().call
|
compiled = self.compile_to_module().call
|
||||||
|
|
||||||
def materialize(
|
def materialize(
|
||||||
x: Union[torch.SymInt, torch.SymFloat, torch.Tensor]
|
x: Union[torch.SymInt, torch.SymFloat, torch.Tensor],
|
||||||
) -> Union[int, float, torch.Tensor]:
|
) -> Union[int, float, torch.Tensor]:
|
||||||
if x is None:
|
if x is None:
|
||||||
return None
|
return None
|
||||||
|
|
@ -1876,9 +1888,9 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
elif isinstance(x, FakeTensor):
|
elif isinstance(x, FakeTensor):
|
||||||
return defake(x)
|
return defake(x)
|
||||||
else:
|
else:
|
||||||
assert isinstance(
|
assert isinstance(x, torch.Tensor), (
|
||||||
x, torch.Tensor
|
"Unknown type when creating real inputs" + str(type(x))
|
||||||
), "Unknown type when creating real inputs" + str(type(x))
|
)
|
||||||
return x
|
return x
|
||||||
|
|
||||||
tracing_context = torch._guards.TracingContext.try_get()
|
tracing_context = torch._guards.TracingContext.try_get()
|
||||||
|
|
|
||||||
|
|
@ -5,21 +5,22 @@ propagation of sympy expressions downstream of ops.index_expr calls.
|
||||||
|
|
||||||
For example, say we have the IR:
|
For example, say we have the IR:
|
||||||
|
|
||||||
tmp0 = ops.index_expr(x, torch.int32)
|
tmp0 = ops.index_expr(x, torch.int32)
|
||||||
tmp1 = ops.constant(2, torch.int32)
|
tmp1 = ops.constant(2, torch.int32)
|
||||||
tmp2 = ops.mul(tmp0, tmp1)
|
tmp2 = ops.mul(tmp0, tmp1)
|
||||||
tmp3 = ops.indirect_indexing(tmp2, x_size)
|
tmp3 = ops.indirect_indexing(tmp2, x_size)
|
||||||
tmp4 = ops.load("buf0", tmp3)
|
tmp4 = ops.load("buf0", tmp3)
|
||||||
|
|
||||||
The underlying handler would just see:
|
The underlying handler would just see:
|
||||||
|
|
||||||
ops.load("buf0", x * 2)
|
ops.load("buf0", x * 2)
|
||||||
|
|
||||||
This is limited by the set of operators handled in the sympy expression
|
This is limited by the set of operators handled in the sympy expression
|
||||||
printers. So simple operations like minimum and maximum cannot be translated to
|
printers. So simple operations like minimum and maximum cannot be translated to
|
||||||
SymPy expressions yet, despite sympy.Min and sympy.Max existing.
|
SymPy expressions yet, despite sympy.Min and sympy.Max existing.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import itertools
|
import itertools
|
||||||
from collections.abc import Sequence
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
@ -179,9 +180,9 @@ class IndexPropVar:
|
||||||
return IndexPropVar(expr, is_symbolic=True)
|
return IndexPropVar(expr, is_symbolic=True)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
assert not self.is_symbolic or isinstance(
|
assert not self.is_symbolic or isinstance(self.value, TypedExpr), (
|
||||||
self.value, TypedExpr
|
"Symbolic IndexPropVar must contain a TypedExpr"
|
||||||
), "Symbolic IndexPropVar must contain a TypedExpr"
|
)
|
||||||
|
|
||||||
|
|
||||||
IndexPropResult: TypeAlias = Union[IndexPropVar, tuple["IndexPropResult", ...]]
|
IndexPropResult: TypeAlias = Union[IndexPropVar, tuple["IndexPropResult", ...]]
|
||||||
|
|
@ -251,14 +252,12 @@ class IndexPropagation(DefaultHandler):
|
||||||
name: Literal["indirect_indexing"],
|
name: Literal["indirect_indexing"],
|
||||||
args: Sequence[Any],
|
args: Sequence[Any],
|
||||||
kwargs: dict[str, Any],
|
kwargs: dict[str, Any],
|
||||||
) -> IndexPropVar:
|
) -> IndexPropVar: ...
|
||||||
...
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def fallback(
|
def fallback(
|
||||||
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
|
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
|
||||||
) -> IndexPropResult:
|
) -> IndexPropResult: ...
|
||||||
...
|
|
||||||
|
|
||||||
def fallback(
|
def fallback(
|
||||||
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
|
self, name: str, args: Sequence[Any], kwargs: dict[str, Any]
|
||||||
|
|
@ -283,8 +282,7 @@ class IndexPropagation(DefaultHandler):
|
||||||
is_valid_expr = new_expr is not NotImplemented and (
|
is_valid_expr = new_expr is not NotImplemented and (
|
||||||
# Inductor doesn't expect floating point in sympy expressions, but
|
# Inductor doesn't expect floating point in sympy expressions, but
|
||||||
# allow floating point constants to be propagated
|
# allow floating point constants to be propagated
|
||||||
new_expr.is_constant()
|
new_expr.is_constant() or new_expr.expr.is_integer
|
||||||
or new_expr.expr.is_integer
|
|
||||||
)
|
)
|
||||||
if not is_valid_expr:
|
if not is_valid_expr:
|
||||||
return self.fallback(name, args, kwargs)
|
return self.fallback(name, args, kwargs)
|
||||||
|
|
|
||||||
|
|
@ -211,7 +211,9 @@ def validate_ir(node_or_nodes: Optional[_NodeOrNodes]) -> None:
|
||||||
int,
|
int,
|
||||||
EffectfulKernel,
|
EffectfulKernel,
|
||||||
),
|
),
|
||||||
), f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"
|
), (
|
||||||
|
f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"
|
||||||
|
)
|
||||||
|
|
||||||
# Be picky about the accepted data structure (don't use pytree here)
|
# Be picky about the accepted data structure (don't use pytree here)
|
||||||
_check_tensorbox(node_or_nodes)
|
_check_tensorbox(node_or_nodes)
|
||||||
|
|
@ -298,13 +300,11 @@ def get_stride_order(
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def ir_node_to_tensor(x: Literal[None], guard_shape: bool = True) -> None:
|
def ir_node_to_tensor(x: Literal[None], guard_shape: bool = True) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@overload
|
@overload
|
||||||
def ir_node_to_tensor(x: IRNode, guard_shape: bool = True) -> torch.Tensor:
|
def ir_node_to_tensor(x: IRNode, guard_shape: bool = True) -> torch.Tensor: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def ir_node_to_tensor(
|
def ir_node_to_tensor(
|
||||||
|
|
@ -346,7 +346,7 @@ def may_convert_to_optional(
|
||||||
|
|
||||||
|
|
||||||
def get_device_type(
|
def get_device_type(
|
||||||
x: Union[IRNode, OutputSpec, torch.device, None, str]
|
x: Union[IRNode, OutputSpec, torch.device, None, str],
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
if isinstance(x, str) or x is None:
|
if isinstance(x, str) or x is None:
|
||||||
return x
|
return x
|
||||||
|
|
@ -698,8 +698,7 @@ class IRNode:
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dtype(self) -> torch.dtype:
|
def dtype(self) -> torch.dtype: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
@ir_dataclass(frozen=False)
|
@ir_dataclass(frozen=False)
|
||||||
|
|
@ -839,8 +838,9 @@ class Loops(IRNode):
|
||||||
@cache_on_self
|
@cache_on_self
|
||||||
def inner_fn_opcount(self) -> OpCountResult:
|
def inner_fn_opcount(self) -> OpCountResult:
|
||||||
opcounter = OpCounterCSE(V.MockHandler())
|
opcounter = OpCounterCSE(V.MockHandler())
|
||||||
with V.set_ops_handler(opcounter), patch.object(
|
with (
|
||||||
FlexibleLayout, "allow_indexing", True
|
V.set_ops_handler(opcounter),
|
||||||
|
patch.object(FlexibleLayout, "allow_indexing", True),
|
||||||
):
|
):
|
||||||
self.inner_fn(*self.inner_fn_args())
|
self.inner_fn(*self.inner_fn_args())
|
||||||
return opcounter.getvalue()
|
return opcounter.getvalue()
|
||||||
|
|
@ -1364,9 +1364,9 @@ class Reduction(Loops):
|
||||||
# "all" is desugared to `!any(!val)`
|
# "all" is desugared to `!any(!val)`
|
||||||
}
|
}
|
||||||
|
|
||||||
assert (
|
assert reduction_type in rtypes_to_inits.keys(), (
|
||||||
reduction_type in rtypes_to_inits.keys()
|
f"{reduction_type} not supported for zero-dimension tensors!"
|
||||||
), f"{reduction_type} not supported for zero-dimension tensors!"
|
)
|
||||||
|
|
||||||
def const_fn(index: int) -> OpsValue:
|
def const_fn(index: int) -> OpsValue:
|
||||||
return ops.constant(rtypes_to_inits[reduction_type], dst_dtype)
|
return ops.constant(rtypes_to_inits[reduction_type], dst_dtype)
|
||||||
|
|
@ -1575,9 +1575,9 @@ class Reduction(Loops):
|
||||||
new_ranges: Sequence[Integer],
|
new_ranges: Sequence[Integer],
|
||||||
new_reduction_ranges: Sequence[Integer],
|
new_reduction_ranges: Sequence[Integer],
|
||||||
) -> Callable[[Sequence[sympy.Expr], Sequence[sympy.Expr]], OpsValue]:
|
) -> Callable[[Sequence[sympy.Expr], Sequence[sympy.Expr]], OpsValue]:
|
||||||
assert all(
|
assert all(r == 1 for r in original_ranges), (
|
||||||
r == 1 for r in original_ranges
|
f"Only enabled for numel_hint == 1, found {original_ranges=}"
|
||||||
), f"Only enabled for numel_hint == 1, found {original_ranges=}"
|
)
|
||||||
reindex = View.dynamic_reshape_indexer(
|
reindex = View.dynamic_reshape_indexer(
|
||||||
original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges)
|
original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges)
|
||||||
)
|
)
|
||||||
|
|
@ -1828,7 +1828,7 @@ class WelfordReduction(Reduction):
|
||||||
if reduction_numel == 1:
|
if reduction_numel == 1:
|
||||||
|
|
||||||
def copy(
|
def copy(
|
||||||
loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue]
|
loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue],
|
||||||
) -> TensorBox:
|
) -> TensorBox:
|
||||||
def inner_fn(idx: Sequence[Expr]) -> OpsValue:
|
def inner_fn(idx: Sequence[Expr]) -> OpsValue:
|
||||||
reduction_index = [sympy.S.Zero for _ in reduction_ranges]
|
reduction_index = [sympy.S.Zero for _ in reduction_ranges]
|
||||||
|
|
@ -2571,9 +2571,9 @@ class ExpandView(BaseView):
|
||||||
# NB: new_size[i] == old_size[i] is expected to already be
|
# NB: new_size[i] == old_size[i] is expected to already be
|
||||||
# guarded because the meta formula was expected to have taught
|
# guarded because the meta formula was expected to have taught
|
||||||
# us this equality.
|
# us this equality.
|
||||||
assert (
|
assert sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0, (
|
||||||
sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0
|
"Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}"
|
||||||
), "Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}"
|
)
|
||||||
return new_size
|
return new_size
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
@ -3382,9 +3382,9 @@ class Layout(OutputSpec):
|
||||||
)
|
)
|
||||||
|
|
||||||
def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
|
def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
|
||||||
assert (
|
assert FlexibleLayout.allow_indexing, (
|
||||||
FlexibleLayout.allow_indexing
|
f"convert {type(self).__name__} to FixedLayout first"
|
||||||
), f"convert {type(self).__name__} to FixedLayout first"
|
)
|
||||||
return self.as_fixed().make_indexer()
|
return self.as_fixed().make_indexer()
|
||||||
|
|
||||||
def __eq__(self, other) -> bool: # type: ignore[no-untyped-def]
|
def __eq__(self, other) -> bool: # type: ignore[no-untyped-def]
|
||||||
|
|
@ -3684,9 +3684,9 @@ class MutationLayoutSHOULDREMOVE(Layout):
|
||||||
return target
|
return target
|
||||||
|
|
||||||
result = unwrap_views(self.target)
|
result = unwrap_views(self.target)
|
||||||
assert isinstance(
|
assert isinstance(result, Buffer), (
|
||||||
result, Buffer
|
"MutationLayoutSHOULDREMOVE must refer to a buffer"
|
||||||
), "MutationLayoutSHOULDREMOVE must refer to a buffer"
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def real_layout(self): # type: ignore[no-untyped-def]
|
def real_layout(self): # type: ignore[no-untyped-def]
|
||||||
|
|
@ -3803,7 +3803,9 @@ class Buffer(IRNode):
|
||||||
assert isinstance(self.layout, FlexibleLayout)
|
assert isinstance(self.layout, FlexibleLayout)
|
||||||
self.layout = self.layout.as_same_order(stride)
|
self.layout = self.layout.as_same_order(stride)
|
||||||
|
|
||||||
def freeze_layout_with_exact_strides(self, exact_strides, allow_padding=False) -> None: # type: ignore[no-untyped-def]
|
def freeze_layout_with_exact_strides( # type: ignore[no-untyped-def]
|
||||||
|
self, exact_strides, allow_padding=False
|
||||||
|
) -> None:
|
||||||
assert isinstance(self.layout, FlexibleLayout)
|
assert isinstance(self.layout, FlexibleLayout)
|
||||||
self.layout = self.layout.as_exact_strides(
|
self.layout = self.layout.as_exact_strides(
|
||||||
exact_strides, allow_padding=allow_padding
|
exact_strides, allow_padding=allow_padding
|
||||||
|
|
@ -4365,9 +4367,9 @@ class TritonTemplateBuffer(TemplateBuffer):
|
||||||
torch.ops.higher_order.flex_attention_backward,
|
torch.ops.higher_order.flex_attention_backward,
|
||||||
)
|
)
|
||||||
current_node = V.graph.current_node.target
|
current_node = V.graph.current_node.target
|
||||||
assert (
|
assert current_node in allowed_set, (
|
||||||
current_node in allowed_set
|
f"Mutated inputs are only allowed for {allowed_set} but got {current_node}"
|
||||||
), f"Mutated inputs are only allowed for {allowed_set} but got {current_node}"
|
)
|
||||||
device = self.inputs[0].get_device()
|
device = self.inputs[0].get_device()
|
||||||
self.outputs += [
|
self.outputs += [
|
||||||
MutationOutput(NoneLayout(device=device), buf, self)
|
MutationOutput(NoneLayout(device=device), buf, self)
|
||||||
|
|
@ -5106,7 +5108,8 @@ class ExternKernel(InputsKernel):
|
||||||
x_unwrap_view.freeze_layout()
|
x_unwrap_view.freeze_layout()
|
||||||
|
|
||||||
index_args, var_ranges = dependencies.index_vars_squeeze(
|
index_args, var_ranges = dependencies.index_vars_squeeze(
|
||||||
x.get_size(), prefix="r" # type: ignore[arg-type]
|
x.get_size(),
|
||||||
|
prefix="r", # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
range_vars = index_args[0]
|
range_vars = index_args[0]
|
||||||
index = x.make_indexer()(range_vars)
|
index = x.make_indexer()(range_vars)
|
||||||
|
|
@ -5404,9 +5407,9 @@ class ExternKernel(InputsKernel):
|
||||||
# pass in a list of const arg names for arg_properties lookup.
|
# pass in a list of const arg names for arg_properties lookup.
|
||||||
name_to_arg_properties = None
|
name_to_arg_properties = None
|
||||||
if names and self.arg_properties:
|
if names and self.arg_properties:
|
||||||
assert len(self.constant_args) == len(
|
assert len(self.constant_args) == len(names), (
|
||||||
names
|
"names passed to codegen_const_args does not match self.constant_args"
|
||||||
), "names passed to codegen_const_args does not match self.constant_args"
|
)
|
||||||
name_to_arg_properties = {
|
name_to_arg_properties = {
|
||||||
arg.get("name"): arg for arg in self.arg_properties
|
arg.get("name"): arg for arg in self.arg_properties
|
||||||
}
|
}
|
||||||
|
|
@ -5442,9 +5445,9 @@ class ExternKernel(InputsKernel):
|
||||||
args = []
|
args = []
|
||||||
for i, x in enumerate(inputs):
|
for i, x in enumerate(inputs):
|
||||||
if V.graph.cpp_wrapper:
|
if V.graph.cpp_wrapper:
|
||||||
assert self.arg_properties and i < len(
|
assert self.arg_properties and i < len(self.arg_properties), (
|
||||||
self.arg_properties
|
"Invalid access to ExternKernel.arg_properties"
|
||||||
), "Invalid access to ExternKernel.arg_properties"
|
)
|
||||||
type_ = self.arg_properties[i].get("type")
|
type_ = self.arg_properties[i].get("type")
|
||||||
args.append(V.graph.wrapper_code.val_to_arg_str(x, type_))
|
args.append(V.graph.wrapper_code.val_to_arg_str(x, type_))
|
||||||
else:
|
else:
|
||||||
|
|
@ -5914,7 +5917,9 @@ class UserDefinedTritonKernel(ExternKernel):
|
||||||
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
|
||||||
return OrderedSet()
|
return OrderedSet()
|
||||||
|
|
||||||
def __init__(self, *, kernel_idx, grid, tma_descriptor_metadata, kernel_args) -> None: # type: ignore[no-untyped-def]
|
def __init__( # type: ignore[no-untyped-def]
|
||||||
|
self, *, kernel_idx, grid, tma_descriptor_metadata, kernel_args
|
||||||
|
) -> None:
|
||||||
inputs = []
|
inputs = []
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
constant_args = []
|
constant_args = []
|
||||||
|
|
@ -6835,9 +6840,9 @@ class FallbackKernel(ExternKernelAlloc):
|
||||||
elif isinstance(output, torch.SymInt):
|
elif isinstance(output, torch.SymInt):
|
||||||
return output.node.expr
|
return output.node.expr
|
||||||
else:
|
else:
|
||||||
assert (
|
assert output is None, (
|
||||||
output is None
|
f"FallbackKernel output type {type(output)} is not supported"
|
||||||
), f"FallbackKernel output type {type(output)} is not supported"
|
)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
outputs = generate_output(example_output, [])
|
outputs = generate_output(example_output, [])
|
||||||
|
|
@ -6919,7 +6924,12 @@ class MultiOutput(ExternKernel):
|
||||||
)
|
)
|
||||||
self.codegen_size_asserts(wrapper)
|
self.codegen_size_asserts(wrapper)
|
||||||
|
|
||||||
def __init__(self, layout: OutputSpec, input, indices: list[tuple[Any, ...]]) -> None: # type: ignore[no-untyped-def]
|
def __init__( # type: ignore[no-untyped-def]
|
||||||
|
self,
|
||||||
|
layout: OutputSpec,
|
||||||
|
input,
|
||||||
|
indices: list[tuple[Any, ...]],
|
||||||
|
) -> None:
|
||||||
super().__init__(None, layout, [input], ())
|
super().__init__(None, layout, [input], ())
|
||||||
self.name = V.graph.register_buffer(self)
|
self.name = V.graph.register_buffer(self)
|
||||||
V.graph.register_operation(self)
|
V.graph.register_operation(self)
|
||||||
|
|
@ -7496,9 +7506,9 @@ class WhileLoop(ExternKernel):
|
||||||
assert p.get_dtype() == torch.bool, p
|
assert p.get_dtype() == torch.bool, p
|
||||||
assert len(p.get_size()) == 0, p
|
assert len(p.get_size()) == 0, p
|
||||||
|
|
||||||
assert (
|
assert len(all_inputs) > 0, (
|
||||||
len(all_inputs) > 0
|
"torch.while_loop is assumed to have at least one operand."
|
||||||
), "torch.while_loop is assumed to have at least one operand."
|
)
|
||||||
|
|
||||||
device = all_inputs[0].get_device()
|
device = all_inputs[0].get_device()
|
||||||
|
|
||||||
|
|
@ -7669,9 +7679,9 @@ class _CollectiveKernel(FallbackKernel):
|
||||||
# This is identical to FallbackKernel.set_cpp_kernel(), minus the
|
# This is identical to FallbackKernel.set_cpp_kernel(), minus the
|
||||||
# part that checks against input aliasing and mutation.
|
# part that checks against input aliasing and mutation.
|
||||||
def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None:
|
def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None:
|
||||||
assert (
|
assert type(self.op_overload) is torch._ops.OpOverload, (
|
||||||
type(self.op_overload) is torch._ops.OpOverload
|
"Setting cpp kernel needs a valid op_overload"
|
||||||
), "Setting cpp kernel needs a valid op_overload"
|
)
|
||||||
kernel = self.op_overload
|
kernel = self.op_overload
|
||||||
self.cpp_kernel_name = kernel._schema.name
|
self.cpp_kernel_name = kernel._schema.name
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
""" Triton Implementation of the flex_attention Kernel"""
|
"""Triton Implementation of the flex_attention Kernel"""
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import logging
|
import logging
|
||||||
|
|
@ -60,9 +60,9 @@ def construct_strides(
|
||||||
) -> Sequence[int]:
|
) -> Sequence[int]:
|
||||||
"""From a list of sizes and a fill order, construct the strides of the permuted tensor."""
|
"""From a list of sizes and a fill order, construct the strides of the permuted tensor."""
|
||||||
# Initialize strides
|
# Initialize strides
|
||||||
assert len(sizes) == len(
|
assert len(sizes) == len(fill_order), (
|
||||||
fill_order
|
"Length of sizes must match the length of the fill order"
|
||||||
), "Length of sizes must match the length of the fill order"
|
)
|
||||||
strides = [0] * len(sizes)
|
strides = [0] * len(sizes)
|
||||||
|
|
||||||
# Start with stride 1 for the innermost dimension
|
# Start with stride 1 for the innermost dimension
|
||||||
|
|
@ -1151,10 +1151,14 @@ def lower_cpu(
|
||||||
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
|
SPARSE_Q_BLOCK_SIZE = V.graph.sizevars.evaluate_static_shape(SPARSE_Q_BLOCK_SIZE)
|
||||||
assert V.graph.sizevars.evaluate_expr(
|
assert V.graph.sizevars.evaluate_expr(
|
||||||
sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE))
|
sympy.Le(seq_len_q, sympy.Mul(kv_indices.get_size()[-2], SPARSE_Q_BLOCK_SIZE))
|
||||||
), "Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask."
|
), (
|
||||||
|
"Q seqlen must be smaller than the block_mask size in the Q dimension, considering pass a larger block_mask."
|
||||||
|
)
|
||||||
assert V.graph.sizevars.evaluate_expr(
|
assert V.graph.sizevars.evaluate_expr(
|
||||||
sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE))
|
sympy.Le(seq_len_kv, sympy.Mul(kv_indices.get_size()[-1], SPARSE_KV_BLOCK_SIZE))
|
||||||
), "KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask."
|
), (
|
||||||
|
"KV seqlen must be smaller than the block_mask size in the KV dimension, considering pass a larger block_mask."
|
||||||
|
)
|
||||||
CppFlexAttentionTemplate.add_choices(
|
CppFlexAttentionTemplate.add_choices(
|
||||||
choices=_choices,
|
choices=_choices,
|
||||||
input_nodes=input_nodes,
|
input_nodes=input_nodes,
|
||||||
|
|
@ -1364,15 +1368,15 @@ def flex_attention(
|
||||||
|
|
||||||
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
|
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
|
||||||
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
|
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
|
||||||
assert V.graph.sizevars.evaluate_expr(
|
assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), (
|
||||||
sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)
|
f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
|
||||||
), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
|
)
|
||||||
assert V.graph.sizevars.evaluate_expr(
|
assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_q, 0)), (
|
||||||
sympy.Gt(seq_len_q, 0)
|
"Query length must be greater than 0"
|
||||||
), "Query length must be greater than 0"
|
)
|
||||||
assert V.graph.sizevars.evaluate_expr(
|
assert V.graph.sizevars.evaluate_expr(sympy.Gt(seq_len_kv, 0)), (
|
||||||
sympy.Gt(seq_len_kv, 0)
|
"Key length must be greater than 0"
|
||||||
), "Key length must be greater than 0"
|
)
|
||||||
|
|
||||||
B = Bq
|
B = Bq
|
||||||
|
|
||||||
|
|
@ -2291,9 +2295,9 @@ def process_joint_outputs(
|
||||||
JointOutputResult containing processed buffers and gradients
|
JointOutputResult containing processed buffers and gradients
|
||||||
"""
|
"""
|
||||||
assert isinstance(all_joint_outputs, list)
|
assert isinstance(all_joint_outputs, list)
|
||||||
assert (
|
assert all_joint_outputs[0] is not None, (
|
||||||
all_joint_outputs[0] is not None
|
"joint_subgraph_buffer is None - this is a bug!"
|
||||||
), "joint_subgraph_buffer is None - this is a bug!"
|
)
|
||||||
|
|
||||||
joint_buffer = all_joint_outputs[0]
|
joint_buffer = all_joint_outputs[0]
|
||||||
other_grads = all_joint_outputs[num_placeholders - 1 :]
|
other_grads = all_joint_outputs[num_placeholders - 1 :]
|
||||||
|
|
@ -2392,9 +2396,9 @@ def flex_attention_backward(*args, **kwargs):
|
||||||
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
|
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
|
||||||
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
|
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
|
||||||
|
|
||||||
assert V.graph.sizevars.evaluate_expr(
|
assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), (
|
||||||
sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)
|
f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
|
||||||
), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
|
)
|
||||||
|
|
||||||
kernel_options = dict(kernel_options)
|
kernel_options = dict(kernel_options)
|
||||||
# Mark symbols in custom kernel options as static shapes and add guards.
|
# Mark symbols in custom kernel options as static shapes and add guards.
|
||||||
|
|
@ -2639,9 +2643,11 @@ def flex_attention_backward(*args, **kwargs):
|
||||||
grad_key = broadcasted_grad_key
|
grad_key = broadcasted_grad_key
|
||||||
grad_value = broadcasted_grad_value
|
grad_value = broadcasted_grad_value
|
||||||
else:
|
else:
|
||||||
assert V.graph.sizevars.evaluate_expr(
|
assert V.graph.sizevars.evaluate_expr(sympy.Gt(Bq, 1) & sympy.Eq(Bkv, 1)), (
|
||||||
sympy.Gt(Bq, 1) & sympy.Eq(Bkv, 1)
|
f"Bq and Bkv must broadcastable. "
|
||||||
), f"Bq and Bkv must broadcastable. Got Bq={V.graph.sizevars.evaluate_expr(Bq)} and Bkv={V.graph.sizevars.evaluate_expr(Bkv)}" # noqa: B950
|
f"Got Bq={V.graph.sizevars.evaluate_expr(Bq)} "
|
||||||
|
f"and Bkv={V.graph.sizevars.evaluate_expr(Bkv)}"
|
||||||
|
)
|
||||||
grad_key = lowerings[aten.sum](broadcasted_grad_key, axis=0, keepdims=True)
|
grad_key = lowerings[aten.sum](broadcasted_grad_key, axis=0, keepdims=True)
|
||||||
grad_value = lowerings[aten.sum](broadcasted_grad_value, axis=0, keepdims=True)
|
grad_value = lowerings[aten.sum](broadcasted_grad_value, axis=0, keepdims=True)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
""" Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)"""
|
"""Triton Implementation of the flex_attention Kernel for short query length (FlexDecoding)"""
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import sympy
|
import sympy
|
||||||
|
|
@ -367,9 +368,9 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
||||||
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
|
Bq, Hq, seq_len_q, qk_head_dim = query.get_size()
|
||||||
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
|
Bkv, Hkv, seq_len_kv, v_head_dim = value.get_size()
|
||||||
|
|
||||||
assert V.graph.sizevars.evaluate_expr(
|
assert V.graph.sizevars.evaluate_expr(sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)), (
|
||||||
sympy.Eq(Bq, Bkv) | sympy.Eq(Bkv, 1)
|
f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
|
||||||
), f"Bq and Bkv must broadcastable. Got Bq={Bq} and Bkv={Bkv}"
|
)
|
||||||
|
|
||||||
B = Bq
|
B = Bq
|
||||||
kernel_options = dict(kernel_options)
|
kernel_options = dict(kernel_options)
|
||||||
|
|
@ -481,7 +482,8 @@ def create_flex_decoding_kernel(*args, **kwargs):
|
||||||
max(
|
max(
|
||||||
next_power_of_2(
|
next_power_of_2(
|
||||||
V.graph.sizevars.size_hint(
|
V.graph.sizevars.size_hint(
|
||||||
seq_len_q, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
seq_len_q,
|
||||||
|
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
* gqa_shared_heads
|
* gqa_shared_heads
|
||||||
),
|
),
|
||||||
|
|
|
||||||
|
|
@ -65,7 +65,8 @@ def filtered_configs(
|
||||||
m = max(
|
m = max(
|
||||||
next_power_of_2(
|
next_power_of_2(
|
||||||
V.graph.sizevars.size_hint(
|
V.graph.sizevars.size_hint(
|
||||||
m, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
m,
|
||||||
|
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
min_block_size,
|
min_block_size,
|
||||||
|
|
@ -73,7 +74,8 @@ def filtered_configs(
|
||||||
n = max(
|
n = max(
|
||||||
next_power_of_2(
|
next_power_of_2(
|
||||||
V.graph.sizevars.size_hint(
|
V.graph.sizevars.size_hint(
|
||||||
n, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
n,
|
||||||
|
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
min_block_size,
|
min_block_size,
|
||||||
|
|
@ -81,7 +83,8 @@ def filtered_configs(
|
||||||
k = max(
|
k = max(
|
||||||
next_power_of_2(
|
next_power_of_2(
|
||||||
V.graph.sizevars.size_hint(
|
V.graph.sizevars.size_hint(
|
||||||
k, fallback=torch._inductor.config.unbacked_symint_fallback # type: ignore[arg-type]
|
k,
|
||||||
|
fallback=torch._inductor.config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
),
|
),
|
||||||
min_block_size_k,
|
min_block_size_k,
|
||||||
|
|
@ -467,8 +470,7 @@ def mm_options(config, sym_m, sym_n, sym_k, layout):
|
||||||
"""
|
"""
|
||||||
even_k_symbolic = (
|
even_k_symbolic = (
|
||||||
# it isn't worth guarding on this
|
# it isn't worth guarding on this
|
||||||
sympy.gcd(sym_k, config.kwargs["BLOCK_K"])
|
sympy.gcd(sym_k, config.kwargs["BLOCK_K"]) == config.kwargs["BLOCK_K"]
|
||||||
== config.kwargs["BLOCK_K"]
|
|
||||||
)
|
)
|
||||||
allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
|
allow_tf32 = torch.backends.cuda.matmul.allow_tf32 and (
|
||||||
not inductor_config.force_same_precision
|
not inductor_config.force_same_precision
|
||||||
|
|
|
||||||
|
|
@ -194,11 +194,12 @@ class LoopBody:
|
||||||
# There is indeed an issue due to symbol name conflicting.
|
# There is indeed an issue due to symbol name conflicting.
|
||||||
# y0 maybe reused for the y dimension later.
|
# y0 maybe reused for the y dimension later.
|
||||||
(
|
(
|
||||||
iter_vars,
|
(
|
||||||
reduce_vars,
|
iter_vars,
|
||||||
), var_ranges = dependencies.index_vars_no_squeeze(
|
reduce_vars,
|
||||||
iter_sizes, reduce_sizes, prefix="t"
|
),
|
||||||
)
|
var_ranges,
|
||||||
|
) = dependencies.index_vars_no_squeeze(iter_sizes, reduce_sizes, prefix="t")
|
||||||
new_body = LoopBody(
|
new_body = LoopBody(
|
||||||
old_body,
|
old_body,
|
||||||
[iter_reindex(iter_vars), reduce_reindex(reduce_vars)],
|
[iter_reindex(iter_vars), reduce_reindex(reduce_vars)],
|
||||||
|
|
@ -234,7 +235,8 @@ class LoopBody:
|
||||||
new_sizes = (new_iter_size, reduce_size)
|
new_sizes = (new_iter_size, reduce_size)
|
||||||
|
|
||||||
(iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
|
(iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
|
||||||
*new_sizes, prefix="t" # type: ignore[arg-type]
|
*new_sizes,
|
||||||
|
prefix="t", # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
inverse_order = {b: a for a, b in enumerate(new_order)}
|
inverse_order = {b: a for a, b in enumerate(new_order)}
|
||||||
|
|
@ -254,7 +256,8 @@ class LoopBody:
|
||||||
|
|
||||||
# use the original symbol prefix so we can do multiple round of reordering
|
# use the original symbol prefix so we can do multiple round of reordering
|
||||||
(iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
|
(iter_vars2, reduce_vars2), var_ranges2 = dependencies.index_vars_no_squeeze(
|
||||||
*new_sizes, prefix="p" # type: ignore[arg-type]
|
*new_sizes,
|
||||||
|
prefix="p", # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
new_body = LoopBody(
|
new_body = LoopBody(
|
||||||
loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2
|
loop_body, (iter_vars2, reduce_vars2), var_ranges2, iter_vars2, reduce_vars2
|
||||||
|
|
@ -385,9 +388,9 @@ class LoopBody:
|
||||||
def indexing_from_args(self, indices):
|
def indexing_from_args(self, indices):
|
||||||
index = [*itertools.chain.from_iterable(indices)]
|
index = [*itertools.chain.from_iterable(indices)]
|
||||||
assert len(index) == len(self.var_ranges), (index, self.var_ranges)
|
assert len(index) == len(self.var_ranges), (index, self.var_ranges)
|
||||||
assert all(
|
assert all(v not in self.var_ranges for v in index), (
|
||||||
v not in self.var_ranges for v in index
|
f"{self.var_ranges=}, {indices=}"
|
||||||
), f"{self.var_ranges=}, {indices=}"
|
)
|
||||||
replacements = dict(zip(self.var_ranges.keys(), index))
|
replacements = dict(zip(self.var_ranges.keys(), index))
|
||||||
return {
|
return {
|
||||||
name: sympy_subs(expr, replacements)
|
name: sympy_subs(expr, replacements)
|
||||||
|
|
|
||||||
|
|
@ -346,7 +346,8 @@ def transform_args(
|
||||||
# only consider tensor kwargs for promotion, for now
|
# only consider tensor kwargs for promotion, for now
|
||||||
promoting_args.extend(a for a in kwargs.values() if hasattr(a, "dtype"))
|
promoting_args.extend(a for a in kwargs.values() if hasattr(a, "dtype"))
|
||||||
dtype = get_promoted_dtype(
|
dtype = get_promoted_dtype(
|
||||||
*promoting_args, type_promotion_kind=type_promotion_kind # type: ignore[arg-type]
|
*promoting_args,
|
||||||
|
type_promotion_kind=type_promotion_kind, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
|
|
||||||
device = (
|
device = (
|
||||||
|
|
@ -448,9 +449,9 @@ def _register_lowering(
|
||||||
(fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn
|
(fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn
|
||||||
):
|
):
|
||||||
# explicitly assert for "out=" ops for better error messages
|
# explicitly assert for "out=" ops for better error messages
|
||||||
assert not any(
|
assert not any(x == "out" for x in kwargs.keys()), (
|
||||||
x == "out" for x in kwargs.keys()
|
"out= ops aren't yet supported"
|
||||||
), "out= ops aren't yet supported"
|
)
|
||||||
|
|
||||||
args, kwargs = transform_args(
|
args, kwargs = transform_args(
|
||||||
args, kwargs, broadcast, type_promotion_kind, convert_input_to_bool
|
args, kwargs, broadcast, type_promotion_kind, convert_input_to_bool
|
||||||
|
|
@ -517,9 +518,9 @@ def broadcast_symbolic_shapes(a, b):
|
||||||
|
|
||||||
|
|
||||||
def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None):
|
def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None):
|
||||||
assert (
|
assert override_return_dtype is None or type_promotion_kind is None, (
|
||||||
override_return_dtype is None or type_promotion_kind is None
|
"only one of override_return_dtype or type_promotion_kind may be given"
|
||||||
), "only one of override_return_dtype or type_promotion_kind may be given"
|
)
|
||||||
|
|
||||||
if override_return_dtype is None and type_promotion_kind is None:
|
if override_return_dtype is None and type_promotion_kind is None:
|
||||||
type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
|
||||||
|
|
@ -674,9 +675,9 @@ def make_foreach_pointwise(pw_fn, allow_alpha=False):
|
||||||
if isinstance(input, (list, tuple)):
|
if isinstance(input, (list, tuple)):
|
||||||
a_list_input = input
|
a_list_input = input
|
||||||
break
|
break
|
||||||
assert (
|
assert a_list_input is not None, (
|
||||||
a_list_input is not None
|
"at least one input must be a list to a foreach op"
|
||||||
), "at least one input must be a list to a foreach op"
|
)
|
||||||
|
|
||||||
# broadcast scalar inputs to match length of list inputs
|
# broadcast scalar inputs to match length of list inputs
|
||||||
broadcast_inputs = []
|
broadcast_inputs = []
|
||||||
|
|
@ -1321,12 +1322,12 @@ def quantized_decomposed_quantize_per_channel(
|
||||||
|
|
||||||
if input.get_dtype() == torch.bfloat16:
|
if input.get_dtype() == torch.bfloat16:
|
||||||
input = to_dtype(input, torch.float32)
|
input = to_dtype(input, torch.float32)
|
||||||
assert (
|
assert input.get_dtype() == torch.float32, (
|
||||||
input.get_dtype() == torch.float32
|
f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
|
||||||
), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
|
)
|
||||||
assert axis < len(
|
assert axis < len(input.get_size()), (
|
||||||
input.get_size()
|
f"Expecting axis to be < {len(input.get_size())}"
|
||||||
), f"Expecting axis to be < {len(input.get_size())}"
|
)
|
||||||
|
|
||||||
input_loader = input.make_loader()
|
input_loader = input.make_loader()
|
||||||
scales_loader = scales.make_loader()
|
scales_loader = scales.make_loader()
|
||||||
|
|
@ -1373,12 +1374,12 @@ def quantized_decomposed_dequantize_per_channel(
|
||||||
) -> TensorBox:
|
) -> TensorBox:
|
||||||
assert len(scales.get_size()) == 1, "expect scales 1 dim"
|
assert len(scales.get_size()) == 1, "expect scales 1 dim"
|
||||||
assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim"
|
assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim"
|
||||||
assert (
|
assert input.get_dtype() == dtype, (
|
||||||
input.get_dtype() == dtype
|
f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
|
||||||
), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
|
)
|
||||||
assert axis < len(
|
assert axis < len(input.get_size()), (
|
||||||
input.get_size()
|
f"Expecting axis to be < {len(input.get_size())}"
|
||||||
), f"Expecting axis to be < {len(input.get_size())}"
|
)
|
||||||
|
|
||||||
if out_dtype is None:
|
if out_dtype is None:
|
||||||
out_dtype = torch.float32
|
out_dtype = torch.float32
|
||||||
|
|
@ -1423,9 +1424,9 @@ def quantized_decomposed_quantize_per_tensor_default(
|
||||||
) -> TensorBox:
|
) -> TensorBox:
|
||||||
if input.get_dtype() == torch.bfloat16:
|
if input.get_dtype() == torch.bfloat16:
|
||||||
input = to_dtype(input, torch.float32)
|
input = to_dtype(input, torch.float32)
|
||||||
assert (
|
assert input.get_dtype() == torch.float32, (
|
||||||
input.get_dtype() == torch.float32
|
f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
|
||||||
), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
|
)
|
||||||
|
|
||||||
input_loader = input.make_loader()
|
input_loader = input.make_loader()
|
||||||
|
|
||||||
|
|
@ -1462,9 +1463,9 @@ def quantized_decomposed_dequantize_per_tensor_default(
|
||||||
*,
|
*,
|
||||||
out_dtype: Optional[torch.dtype] = None,
|
out_dtype: Optional[torch.dtype] = None,
|
||||||
) -> TensorBox:
|
) -> TensorBox:
|
||||||
assert (
|
assert input.get_dtype() == dtype, (
|
||||||
input.get_dtype() == dtype
|
f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
|
||||||
), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
|
)
|
||||||
|
|
||||||
if out_dtype is None:
|
if out_dtype is None:
|
||||||
out_dtype = torch.float32
|
out_dtype = torch.float32
|
||||||
|
|
@ -1501,9 +1502,9 @@ def quantized_decomposed_quantize_per_tensor_tensor(
|
||||||
) -> TensorBox:
|
) -> TensorBox:
|
||||||
if input.get_dtype() == torch.bfloat16:
|
if input.get_dtype() == torch.bfloat16:
|
||||||
input = to_dtype(input, torch.float32)
|
input = to_dtype(input, torch.float32)
|
||||||
assert (
|
assert input.get_dtype() == torch.float32, (
|
||||||
input.get_dtype() == torch.float32
|
f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
|
||||||
), f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
|
)
|
||||||
assert len(scale.get_size()) == 0 or (
|
assert len(scale.get_size()) == 0 or (
|
||||||
len(scale.get_size()) == 1 and scale.get_size()[0] == 1
|
len(scale.get_size()) == 1 and scale.get_size()[0] == 1
|
||||||
), "expect scale as scalar tensor"
|
), "expect scale as scalar tensor"
|
||||||
|
|
@ -1555,9 +1556,9 @@ def quantized_decomposed_dequantize_per_tensor_tensor(
|
||||||
assert len(zero_point.get_size()) == 0 or (
|
assert len(zero_point.get_size()) == 0 or (
|
||||||
len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1
|
len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1
|
||||||
), "expect zero_point as scalar tensor"
|
), "expect zero_point as scalar tensor"
|
||||||
assert (
|
assert input.get_dtype() == dtype, (
|
||||||
input.get_dtype() == dtype
|
f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
|
||||||
), f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
|
)
|
||||||
|
|
||||||
if out_dtype is None:
|
if out_dtype is None:
|
||||||
out_dtype = torch.float32
|
out_dtype = torch.float32
|
||||||
|
|
@ -1973,9 +1974,9 @@ def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=
|
||||||
|
|
||||||
|
|
||||||
def make_fallback(op, layout_constraint=None, warn=True, override_decomp=False):
|
def make_fallback(op, layout_constraint=None, warn=True, override_decomp=False):
|
||||||
assert (
|
assert op not in decompositions or override_decomp, (
|
||||||
op not in decompositions or override_decomp
|
f"both a fallback and a decomp for same op: {op}"
|
||||||
), f"both a fallback and a decomp for same op: {op}"
|
)
|
||||||
if (
|
if (
|
||||||
warn
|
warn
|
||||||
and bool(os.getenv("CI"))
|
and bool(os.getenv("CI"))
|
||||||
|
|
@ -2086,9 +2087,9 @@ def native_dropout(x, p, train):
|
||||||
|
|
||||||
@register_lowering(aten.bernoulli_, type_promotion_kind=None)
|
@register_lowering(aten.bernoulli_, type_promotion_kind=None)
|
||||||
def bernoulli_(x, *args):
|
def bernoulli_(x, *args):
|
||||||
assert config.fallback_random or x.get_device() == torch.device(
|
assert config.fallback_random or x.get_device() == torch.device("cpu"), (
|
||||||
"cpu"
|
"this should be handled in decomps unless config.fallback_random or the device is CPU"
|
||||||
), "this should be handled in decomps unless config.fallback_random or the device is CPU"
|
)
|
||||||
x.realize()
|
x.realize()
|
||||||
op_overload = (
|
op_overload = (
|
||||||
aten.bernoulli_.float
|
aten.bernoulli_.float
|
||||||
|
|
@ -2101,9 +2102,9 @@ def bernoulli_(x, *args):
|
||||||
|
|
||||||
@register_lowering(aten.bernoulli.p, type_promotion_kind=None)
|
@register_lowering(aten.bernoulli.p, type_promotion_kind=None)
|
||||||
def bernoulli_p(x, *args):
|
def bernoulli_p(x, *args):
|
||||||
assert config.fallback_random or x.get_device() == torch.device(
|
assert config.fallback_random or x.get_device() == torch.device("cpu"), (
|
||||||
"cpu"
|
"this should be handled in decomps unless config.fallback_random or the device is CPU"
|
||||||
), "this should be handled in decomps unless config.fallback_random or the device is CPU"
|
)
|
||||||
return bernoulli_(clone(x), *args)
|
return bernoulli_(clone(x), *args)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -3376,7 +3377,9 @@ def check_and_broadcast_indices(indices, device):
|
||||||
i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8)
|
i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8)
|
||||||
for i in indices
|
for i in indices
|
||||||
if i is not None
|
if i is not None
|
||||||
), f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}"
|
), (
|
||||||
|
f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}"
|
||||||
|
)
|
||||||
if any(
|
if any(
|
||||||
i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None
|
i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None
|
||||||
):
|
):
|
||||||
|
|
@ -5668,7 +5671,8 @@ def make_reduction(reduction_type: ReductionType, override_return_dtype=None):
|
||||||
)
|
)
|
||||||
result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs)
|
result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs)
|
||||||
if isinstance(
|
if isinstance(
|
||||||
result.data.data, Reduction # type: ignore[attr-defined]
|
result.data.data, # type: ignore[attr-defined]
|
||||||
|
Reduction,
|
||||||
): # Only realize if reduction isn't unrolled
|
): # Only realize if reduction isn't unrolled
|
||||||
result.realize()
|
result.realize()
|
||||||
return result
|
return result
|
||||||
|
|
@ -6008,8 +6012,9 @@ def get_constant_value(x: ir.IRNode) -> Optional[ir.Constant]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device())
|
handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device())
|
||||||
with V.set_ops_handler(handler), patch.object(
|
with (
|
||||||
ir.FlexibleLayout, "allow_indexing", True
|
V.set_ops_handler(handler),
|
||||||
|
patch.object(ir.FlexibleLayout, "allow_indexing", True),
|
||||||
):
|
):
|
||||||
out = x.inner_fn(*x.inner_fn_args())
|
out = x.inner_fn(*x.inner_fn_args())
|
||||||
|
|
||||||
|
|
@ -6898,9 +6903,9 @@ def force_fallback(op: torch._ops.OpOverload):
|
||||||
A context manager to force fallback an op. Used in unit test
|
A context manager to force fallback an op. Used in unit test
|
||||||
for FallbackKernel.
|
for FallbackKernel.
|
||||||
"""
|
"""
|
||||||
assert isinstance(
|
assert isinstance(op, torch._ops.OpOverload), (
|
||||||
op, torch._ops.OpOverload
|
"Only OpOverload to make the clean up easier"
|
||||||
), "Only OpOverload to make the clean up easier"
|
)
|
||||||
old_handler = lowerings.get(op)
|
old_handler = lowerings.get(op)
|
||||||
try:
|
try:
|
||||||
register_lowering(op)(fallback_handler(op))
|
register_lowering(op)(fallback_handler(op))
|
||||||
|
|
|
||||||
|
|
@ -35,9 +35,9 @@ class MemoryPlanningInfoForBuffer:
|
||||||
class MemoryPlanningInfoForNode:
|
class MemoryPlanningInfoForNode:
|
||||||
index: int = 0
|
index: int = 0
|
||||||
size: int = 0
|
size: int = 0
|
||||||
pred_buffers: OrderedSet[
|
pred_buffers: OrderedSet[Union[SchedulerBuffer, FreeableInputBuffer]] = (
|
||||||
Union[SchedulerBuffer, FreeableInputBuffer]
|
dataclasses.field(default_factory=OrderedSet)
|
||||||
] = dataclasses.field(default_factory=OrderedSet)
|
)
|
||||||
pred_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field(
|
pred_nodes: OrderedSet[BaseSchedulerNode] = dataclasses.field(
|
||||||
default_factory=OrderedSet
|
default_factory=OrderedSet
|
||||||
)
|
)
|
||||||
|
|
@ -87,9 +87,9 @@ def get_freeable_input_buf(
|
||||||
|
|
||||||
# get freeable input buffers' successor nodes and their sizes
|
# get freeable input buffers' successor nodes and their sizes
|
||||||
# note that different deps can have the same name, so we use name as keys
|
# note that different deps can have the same name, so we use name as keys
|
||||||
dep_name_to_succ_nodes: dict[
|
dep_name_to_succ_nodes: dict[str, OrderedSet[BaseSchedulerNode]] = (
|
||||||
str, OrderedSet[BaseSchedulerNode]
|
collections.defaultdict(OrderedSet)
|
||||||
] = collections.defaultdict(OrderedSet)
|
)
|
||||||
dep_name_to_size: dict[str, int] = dict()
|
dep_name_to_size: dict[str, int] = dict()
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
for dep in node.read_writes.reads:
|
for dep in node.read_writes.reads:
|
||||||
|
|
@ -112,7 +112,7 @@ def get_freeable_input_buf(
|
||||||
|
|
||||||
|
|
||||||
def compute_size_for_scheduler_buffer(
|
def compute_size_for_scheduler_buffer(
|
||||||
name_to_buf: dict[str, SchedulerBuffer]
|
name_to_buf: dict[str, SchedulerBuffer],
|
||||||
) -> dict[str, tuple[int, int]]:
|
) -> dict[str, tuple[int, int]]:
|
||||||
"""
|
"""
|
||||||
Compute the size of each scheduler buffer, including (1) memory allocated when
|
Compute the size of each scheduler buffer, including (1) memory allocated when
|
||||||
|
|
@ -187,9 +187,9 @@ def assign_memory_planning_info_for_scheduler_buffers(
|
||||||
|
|
||||||
# get buffer's successor nodes
|
# get buffer's successor nodes
|
||||||
# note that different deps can have the same name, so we use name as keys
|
# note that different deps can have the same name, so we use name as keys
|
||||||
dep_name_to_succ_nodes: dict[
|
dep_name_to_succ_nodes: dict[str, OrderedSet[BaseSchedulerNode]] = (
|
||||||
str, OrderedSet[BaseSchedulerNode]
|
collections.defaultdict(OrderedSet)
|
||||||
] = collections.defaultdict(OrderedSet)
|
)
|
||||||
for node in nodes:
|
for node in nodes:
|
||||||
for dep in node.unmet_dependencies:
|
for dep in node.unmet_dependencies:
|
||||||
dep_name_to_succ_nodes[dep.name].add(node)
|
dep_name_to_succ_nodes[dep.name].add(node)
|
||||||
|
|
|
||||||
|
|
@ -138,12 +138,12 @@ class MetricTable:
|
||||||
return
|
return
|
||||||
|
|
||||||
row_dict = row_fn()
|
row_dict = row_fn()
|
||||||
assert len(self.column_names) == len(
|
assert len(self.column_names) == len(row_dict), (
|
||||||
row_dict
|
f"{len(self.column_names)} v.s. {len(row_dict)}"
|
||||||
), f"{len(self.column_names)} v.s. {len(row_dict)}"
|
)
|
||||||
assert OrderedSet(self.column_names) == OrderedSet(
|
assert OrderedSet(self.column_names) == OrderedSet(row_dict.keys()), (
|
||||||
row_dict.keys()
|
f"{OrderedSet(self.column_names)} v.s. {OrderedSet(row_dict.keys())}"
|
||||||
), f"{OrderedSet(self.column_names)} v.s. {OrderedSet(row_dict.keys())}"
|
)
|
||||||
|
|
||||||
bn = get_benchmark_name()
|
bn = get_benchmark_name()
|
||||||
# assert bn is not None
|
# assert bn is not None
|
||||||
|
|
@ -433,9 +433,9 @@ def enabled_metric_tables_impl(config_str: str) -> OrderedSet[str]:
|
||||||
name = name.strip()
|
name = name.strip()
|
||||||
if not name:
|
if not name:
|
||||||
continue
|
continue
|
||||||
assert (
|
assert name in REGISTERED_METRIC_TABLES, (
|
||||||
name in REGISTERED_METRIC_TABLES
|
f"Metric table name {name} is not registered"
|
||||||
), f"Metric table name {name} is not registered"
|
)
|
||||||
enabled.add(name)
|
enabled.add(name)
|
||||||
return enabled
|
return enabled
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -751,9 +751,9 @@ class QConvPointWiseBinaryPT2E(ExternKernelAlloc):
|
||||||
unary_algorithm,
|
unary_algorithm,
|
||||||
]
|
]
|
||||||
|
|
||||||
assert (
|
assert binary_attr == "sum", (
|
||||||
binary_attr == "sum"
|
"For now, only post op sum is supported in QConvPointWiseBinaryPT2E."
|
||||||
), "For now, only post op sum is supported in QConvPointWiseBinaryPT2E."
|
)
|
||||||
|
|
||||||
V.graph.mark_buffer_mutated(qaccum.get_name())
|
V.graph.mark_buffer_mutated(qaccum.get_name())
|
||||||
packed = QConvPointWiseBinaryPT2E(
|
packed = QConvPointWiseBinaryPT2E(
|
||||||
|
|
|
||||||
|
|
@ -575,9 +575,9 @@ def register_onednn_fusion_ops():
|
||||||
algorithm,
|
algorithm,
|
||||||
layout=None,
|
layout=None,
|
||||||
):
|
):
|
||||||
assert (
|
assert packed_weight.get_dtype() is torch.int8, (
|
||||||
packed_weight.get_dtype() is torch.int8
|
"Only int8 weights are supported by oneDNN qlinear."
|
||||||
), "Only int8 weights are supported by oneDNN qlinear."
|
)
|
||||||
x_size = x.get_size()
|
x_size = x.get_size()
|
||||||
if len(x_size) > 2:
|
if len(x_size) > 2:
|
||||||
# GEMM template needs 2D input, normalize input shape here
|
# GEMM template needs 2D input, normalize input shape here
|
||||||
|
|
@ -928,9 +928,9 @@ def register_onednn_fusion_ops():
|
||||||
# we will do accum dtype convertion here.
|
# we will do accum dtype convertion here.
|
||||||
x2 = to_dtype(x2, output_dtype)
|
x2 = to_dtype(x2, output_dtype)
|
||||||
else:
|
else:
|
||||||
assert (
|
assert x2.get_dtype() == output_dtype, (
|
||||||
x2.get_dtype() == output_dtype
|
"dtype of accum for qlinear post op sum should be the same as output"
|
||||||
), "dtype of accum for qlinear post op sum should be the same as output"
|
)
|
||||||
x2_dtype = x2.get_dtype()
|
x2_dtype = x2.get_dtype()
|
||||||
bias_dtype = bias.get_dtype() if bias is not None else None
|
bias_dtype = bias.get_dtype() if bias is not None else None
|
||||||
choices: list[ChoiceCaller] = []
|
choices: list[ChoiceCaller] = []
|
||||||
|
|
|
||||||
|
|
@ -806,8 +806,8 @@ class DefaultHandler(OpsHandler[Any]):
|
||||||
assert self_arg == "self"
|
assert self_arg == "self"
|
||||||
code.write(
|
code.write(
|
||||||
f"""
|
f"""
|
||||||
def {target}(self, {', '.join(args)}):
|
def {target}(self, {", ".join(args)}):
|
||||||
return self._default({target!r}, ({', '.join(args)}, ), {{}})
|
return self._default({target!r}, ({", ".join(args)}, ), {{}})
|
||||||
""".strip()
|
""".strip()
|
||||||
)
|
)
|
||||||
code.write("\n\n")
|
code.write("\n\n")
|
||||||
|
|
@ -994,8 +994,9 @@ class KernelFormatterHandler(DefaultHandler):
|
||||||
)
|
)
|
||||||
formatter._output.writeline(f"{lhs} = {name}")
|
formatter._output.writeline(f"{lhs} = {name}")
|
||||||
|
|
||||||
with V.set_ops_handler(formatter), patch.object(
|
with (
|
||||||
FlexibleLayout, "allow_indexing", True
|
V.set_ops_handler(formatter),
|
||||||
|
patch.object(FlexibleLayout, "allow_indexing", True),
|
||||||
):
|
):
|
||||||
result = ir_fn(*args)
|
result = ir_fn(*args)
|
||||||
return formatter.getvalue(result)
|
return formatter.getvalue(result)
|
||||||
|
|
|
||||||
|
|
@ -188,7 +188,9 @@ def package_aoti(
|
||||||
) or (
|
) or (
|
||||||
isinstance(archive_file, (str, os.PathLike))
|
isinstance(archive_file, (str, os.PathLike))
|
||||||
and os.fspath(archive_file).endswith(".pt2")
|
and os.fspath(archive_file).endswith(".pt2")
|
||||||
), f"Expect archive file to be a file ending in .pt2, or is a buffer. Instead got {archive_file}"
|
), (
|
||||||
|
f"Expect archive file to be a file ending in .pt2, or is a buffer. Instead got {archive_file}"
|
||||||
|
)
|
||||||
|
|
||||||
# Save using the PT2 packaging format
|
# Save using the PT2 packaging format
|
||||||
# (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a)
|
# (https://docs.google.com/document/d/1jLPp8MN8Whs0-VW9PmJ93Yg02W85tpujvHrTa1pc5x8/edit#heading=h.v2y2jgnwc56a)
|
||||||
|
|
@ -285,9 +287,9 @@ class AOTICompiledModel:
|
||||||
def load_package(path: FileLike, model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg]
|
def load_package(path: FileLike, model_name: str = "model") -> AOTICompiledModel: # type: ignore[type-arg]
|
||||||
assert (
|
assert (
|
||||||
isinstance(path, (io.IOBase, IO)) and path.readable() and path.seekable()
|
isinstance(path, (io.IOBase, IO)) and path.readable() and path.seekable()
|
||||||
) or (
|
) or (isinstance(path, (str, os.PathLike)) and os.fspath(path).endswith(".pt2")), (
|
||||||
isinstance(path, (str, os.PathLike)) and os.fspath(path).endswith(".pt2")
|
f"Unable to load package. Path must be a buffer or a file ending in .pt2. Instead got {path}"
|
||||||
), f"Unable to load package. Path must be a buffer or a file ending in .pt2. Instead got {path}"
|
)
|
||||||
|
|
||||||
if isinstance(path, (io.IOBase, IO)):
|
if isinstance(path, (io.IOBase, IO)):
|
||||||
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
with tempfile.NamedTemporaryFile(suffix=".pt2") as f:
|
||||||
|
|
|
||||||
|
|
@ -90,20 +90,17 @@ NodeOrConstant = Union[Constant, torch.fx.Node]
|
||||||
class SearchFn(Protocol):
|
class SearchFn(Protocol):
|
||||||
__name__: str
|
__name__: str
|
||||||
|
|
||||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class ReplaceFn(Protocol):
|
class ReplaceFn(Protocol):
|
||||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
def __call__(self, *args: Any, **kwargs: Any) -> Any: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
class TraceFn(Protocol):
|
class TraceFn(Protocol):
|
||||||
def __call__(
|
def __call__(
|
||||||
self, fn: Union[SearchFn, ReplaceFn], *args: Any, **kwargs: Any
|
self, fn: Union[SearchFn, ReplaceFn], *args: Any, **kwargs: Any
|
||||||
) -> torch.fx.GraphModule:
|
) -> torch.fx.GraphModule: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
|
@ -365,8 +362,7 @@ class PatternExpr(ABC):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult:
|
def _match(self, node: torch.fx.Node, ctx: MatchContext) -> MatchResult: ...
|
||||||
...
|
|
||||||
|
|
||||||
def match(self, node: torch.fx.Node) -> MatchResult:
|
def match(self, node: torch.fx.Node) -> MatchResult:
|
||||||
try:
|
try:
|
||||||
|
|
@ -489,8 +485,7 @@ class _TargetExpr(PatternExpr):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def op(self) -> str:
|
def op(self) -> str: ...
|
||||||
...
|
|
||||||
|
|
||||||
def fns_repr(self) -> str:
|
def fns_repr(self) -> str:
|
||||||
first_repr = self.fns[0]
|
first_repr = self.fns[0]
|
||||||
|
|
@ -997,8 +992,9 @@ class PatternPrettyPrinter:
|
||||||
|
|
||||||
|
|
||||||
class _PassDictsType(Protocol):
|
class _PassDictsType(Protocol):
|
||||||
def __getitem__(self, k: tuple[str, torch.fx.node.Target]) -> list[PatternEntry]:
|
def __getitem__(
|
||||||
...
|
self, k: tuple[str, torch.fx.node.Target]
|
||||||
|
) -> list[PatternEntry]: ...
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass
|
@dataclasses.dataclass
|
||||||
|
|
@ -1925,7 +1921,10 @@ def fx_to_pattern(
|
||||||
get_attr = _not_implemented
|
get_attr = _not_implemented
|
||||||
|
|
||||||
def placeholder(
|
def placeholder(
|
||||||
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
|
self,
|
||||||
|
target: str, # type: ignore[override]
|
||||||
|
args: Sequence[Any],
|
||||||
|
kwargs: Mapping[str, Any],
|
||||||
) -> Union[ExclusiveKeywordArg, KeywordArg]:
|
) -> Union[ExclusiveKeywordArg, KeywordArg]:
|
||||||
n = next(argnum)
|
n = next(argnum)
|
||||||
if n < len(argnames):
|
if n < len(argnames):
|
||||||
|
|
@ -1942,7 +1941,10 @@ def fx_to_pattern(
|
||||||
return KeywordArg(name)
|
return KeywordArg(name)
|
||||||
|
|
||||||
def call_function(
|
def call_function(
|
||||||
self, target: str, args: Sequence[Any], kwargs: Mapping[str, Any] # type: ignore[override]
|
self,
|
||||||
|
target: str, # type: ignore[override]
|
||||||
|
args: Sequence[Any],
|
||||||
|
kwargs: Mapping[str, Any],
|
||||||
) -> PatternExpr:
|
) -> PatternExpr:
|
||||||
process_arg_fn = process_arg
|
process_arg_fn = process_arg
|
||||||
# Indexing is critical for matching getitem nodes, so we can't ignore int args here
|
# Indexing is critical for matching getitem nodes, so we can't ignore int args here
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,7 @@ T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
def time_and_count(
|
def time_and_count(
|
||||||
fn: Callable[Concatenate[Any, P], T]
|
fn: Callable[Concatenate[Any, P], T],
|
||||||
) -> Callable[Concatenate[Any, P], T]:
|
) -> Callable[Concatenate[Any, P], T]:
|
||||||
"""Wraps `fn` with `dynamo_timed` context, and increments the appropriate dynamo
|
"""Wraps `fn` with `dynamo_timed` context, and increments the appropriate dynamo
|
||||||
counters. It is expected that `fn` is a method of `Benchmarker` or one of its
|
counters. It is expected that `fn` is a method of `Benchmarker` or one of its
|
||||||
|
|
|
||||||
|
|
@ -77,9 +77,9 @@ def validate_triton_config(cfg: Config) -> None:
|
||||||
# right now, if a pre-hook is attached to the config, it will not be saved;
|
# right now, if a pre-hook is attached to the config, it will not be saved;
|
||||||
# and then it won't be used when the config is loaded from cache.
|
# and then it won't be used when the config is loaded from cache.
|
||||||
# So we assert - if we do get a pre_hook, it might get ignored after caching.
|
# So we assert - if we do get a pre_hook, it might get ignored after caching.
|
||||||
assert (
|
assert getattr(cfg, "pre_hook", None) is None, (
|
||||||
getattr(cfg, "pre_hook", None) is None
|
"triton configs with pre_hooks not supported"
|
||||||
), "triton configs with pre_hooks not supported"
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_bandwidth_info_str(
|
def create_bandwidth_info_str(
|
||||||
|
|
|
||||||
|
|
@ -450,9 +450,9 @@ class CachingAutotuner(KernelInterface):
|
||||||
self.launchers = []
|
self.launchers = []
|
||||||
|
|
||||||
def __getstate__(self) -> dict[str, Any]:
|
def __getstate__(self) -> dict[str, Any]:
|
||||||
assert (
|
assert not self.launchers, (
|
||||||
not self.launchers
|
"pickle should not be called with after make_launchers()"
|
||||||
), "pickle should not be called with after make_launchers()"
|
)
|
||||||
return {
|
return {
|
||||||
**self.__dict__,
|
**self.__dict__,
|
||||||
"lock": None,
|
"lock": None,
|
||||||
|
|
@ -678,7 +678,9 @@ class CachingAutotuner(KernelInterface):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
arg,
|
arg,
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
), "self.reset_to_zero_arg_names should only contain valid argument names"
|
), (
|
||||||
|
"self.reset_to_zero_arg_names should only contain valid argument names"
|
||||||
|
)
|
||||||
arg.zero_()
|
arg.zero_()
|
||||||
|
|
||||||
for name, arg in kwargs.items():
|
for name, arg in kwargs.items():
|
||||||
|
|
@ -686,7 +688,9 @@ class CachingAutotuner(KernelInterface):
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
arg,
|
arg,
|
||||||
torch.Tensor,
|
torch.Tensor,
|
||||||
), "self.reset_to_zero_arg_names should only contain valid argument names"
|
), (
|
||||||
|
"self.reset_to_zero_arg_names should only contain valid argument names"
|
||||||
|
)
|
||||||
arg.zero_()
|
arg.zero_()
|
||||||
|
|
||||||
def maybe_clone_args(
|
def maybe_clone_args(
|
||||||
|
|
@ -866,7 +870,9 @@ class CachingAutotuner(KernelInterface):
|
||||||
assert not (
|
assert not (
|
||||||
self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION
|
self.heuristic_type == HeuristicType.PERSISTENT_REDUCTION
|
||||||
and "R0_BLOCK" in launcher.config.kwargs
|
and "R0_BLOCK" in launcher.config.kwargs
|
||||||
), "Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK"
|
), (
|
||||||
|
"Coordinate descent tuner relies on the assumption that persistent reduction's triton config does not have R0_BLOCK"
|
||||||
|
)
|
||||||
start_time = time.time_ns()
|
start_time = time.time_ns()
|
||||||
best_config = self.coordesc_tuner.autotune(
|
best_config = self.coordesc_tuner.autotune(
|
||||||
benchmark_one_config, launcher.config, None
|
benchmark_one_config, launcher.config, None
|
||||||
|
|
@ -882,9 +888,7 @@ class CachingAutotuner(KernelInterface):
|
||||||
)
|
)
|
||||||
return config2launcher.get(best_config)
|
return config2launcher.get(best_config)
|
||||||
|
|
||||||
def run(
|
def run(self, *args, grid, stream, benchmark_run=False, **kwargs): # type:ignore[override]
|
||||||
self, *args, grid, stream, benchmark_run=False, **kwargs
|
|
||||||
): # type:ignore[override]
|
|
||||||
if self.triton_interpret:
|
if self.triton_interpret:
|
||||||
return self.fn[grid](
|
return self.fn[grid](
|
||||||
*args,
|
*args,
|
||||||
|
|
@ -1192,12 +1196,12 @@ class TritonCompileResult:
|
||||||
|
|
||||||
exec(
|
exec(
|
||||||
f"""
|
f"""
|
||||||
def launcher({', '.join(def_args)}, grid, stream):
|
def launcher({", ".join(def_args)}, grid, stream):
|
||||||
if callable(grid):
|
if callable(grid):
|
||||||
grid_0, grid_1, grid_2 = grid(grid_meta)
|
grid_0, grid_1, grid_2 = grid(grid_meta)
|
||||||
else:
|
else:
|
||||||
grid_0, grid_1, grid_2 = grid
|
grid_0, grid_1, grid_2 = grid
|
||||||
runner({', '.join(runner_args)})
|
runner({", ".join(runner_args)})
|
||||||
return bin
|
return bin
|
||||||
""".lstrip(),
|
""".lstrip(),
|
||||||
scope,
|
scope,
|
||||||
|
|
@ -1503,9 +1507,9 @@ def check_max_block(cfg: dict[str, int]):
|
||||||
if block_suffix in var:
|
if block_suffix in var:
|
||||||
prefix = var.removesuffix(block_suffix)
|
prefix = var.removesuffix(block_suffix)
|
||||||
max_block = TRITON_MAX_BLOCK[prefix]
|
max_block = TRITON_MAX_BLOCK[prefix]
|
||||||
assert (
|
assert val <= max_block, (
|
||||||
val <= max_block
|
f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
|
||||||
), f"'{var}' too large. Maximum: {max_block}. Actual: {val}."
|
)
|
||||||
|
|
||||||
|
|
||||||
def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False):
|
def _num_warps(num_warps, max_num_warps=8, min_num_warps=2, register_intensive=False):
|
||||||
|
|
@ -1657,20 +1661,20 @@ def _get_nd_reduction_numels(r: int, size_hints: dict[str, int]) -> dict[str, in
|
||||||
prefix = f"r{idx}_"
|
prefix = f"r{idx}_"
|
||||||
max_size = min(size_hints[prefix], TRITON_MAX_BLOCK[prefix.upper()])
|
max_size = min(size_hints[prefix], TRITON_MAX_BLOCK[prefix.upper()])
|
||||||
dim = min(max_size, remaining)
|
dim = min(max_size, remaining)
|
||||||
assert (
|
assert remaining % dim == 0, (
|
||||||
remaining % dim == 0
|
f"Expected dimension '{dim}' to divide remaining size '{remaining}'"
|
||||||
), f"Expected dimension '{dim}' to divide remaining size '{remaining}'"
|
)
|
||||||
rnumels[prefix] = dim
|
rnumels[prefix] = dim
|
||||||
remaining //= dim
|
remaining //= dim
|
||||||
|
|
||||||
# Sanity check the results.
|
# Sanity check the results.
|
||||||
final_numel = conditional_product(*rnumels.values())
|
final_numel = conditional_product(*rnumels.values())
|
||||||
assert (
|
assert r == final_numel, (
|
||||||
r == final_numel
|
f"Expected ND reduction size ({rnumels}) to have {r} elements."
|
||||||
), f"Expected ND reduction size ({rnumels}) to have {r} elements."
|
)
|
||||||
assert all(
|
assert all(rnumels[prefix] <= size_hints[prefix] for prefix in rnumels), (
|
||||||
rnumels[prefix] <= size_hints[prefix] for prefix in rnumels
|
f"rnumels exceed size_hints. {rnumels} > {size_hints}"
|
||||||
), f"rnumels exceed size_hints. {rnumels} > {size_hints}"
|
)
|
||||||
|
|
||||||
return rnumels
|
return rnumels
|
||||||
|
|
||||||
|
|
@ -1967,9 +1971,9 @@ def cooperative_reduction(
|
||||||
size_hints["x"] = 1
|
size_hints["x"] = 1
|
||||||
|
|
||||||
# Cooperative reductions currently only support a single reduction dimension.
|
# Cooperative reductions currently only support a single reduction dimension.
|
||||||
assert (
|
assert len(size_hints) == 2, (
|
||||||
len(size_hints) == 2
|
"Cooperative reductions don't support tiling reduction dims"
|
||||||
), "Cooperative reductions don't support tiling reduction dims"
|
)
|
||||||
xnumel, rnumel = size_hints["x"], size_hints["r0_"]
|
xnumel, rnumel = size_hints["x"], size_hints["r0_"]
|
||||||
|
|
||||||
# TODO(jansel): we should base target on the SM count of the local GPU
|
# TODO(jansel): we should base target on the SM count of the local GPU
|
||||||
|
|
@ -2274,9 +2278,9 @@ def grid_combo_kernels(
|
||||||
assert min_blocks_d is not None
|
assert min_blocks_d is not None
|
||||||
min_blocks = min_blocks_d
|
min_blocks = min_blocks_d
|
||||||
else:
|
else:
|
||||||
assert (
|
assert min_blocks_d is None or min_blocks == min_blocks_d, (
|
||||||
min_blocks_d is None or min_blocks == min_blocks_d
|
f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}"
|
||||||
), f"inconsistent min_blocks {min_blocks} vs x grid {numels[-1]}"
|
)
|
||||||
else:
|
else:
|
||||||
# sequential dispatch
|
# sequential dispatch
|
||||||
seq_numels = list(numels)
|
seq_numels = list(numels)
|
||||||
|
|
|
||||||
|
|
@ -200,9 +200,9 @@ class BaseSchedulerNode:
|
||||||
|
|
||||||
def __init__(self, scheduler: Scheduler) -> None:
|
def __init__(self, scheduler: Scheduler) -> None:
|
||||||
self.scheduler: Scheduler = scheduler
|
self.scheduler: Scheduler = scheduler
|
||||||
self.debug_device_str: Callable[
|
self.debug_device_str: Callable[[BaseSchedulerNode], list[str]] = (
|
||||||
[BaseSchedulerNode], list[str]
|
lambda *args, **kwargs: []
|
||||||
] = lambda *args, **kwargs: []
|
)
|
||||||
|
|
||||||
def _init_from_node(self, node: ir.Operation) -> None:
|
def _init_from_node(self, node: ir.Operation) -> None:
|
||||||
self.node: Optional[ir.Operation] = node
|
self.node: Optional[ir.Operation] = node
|
||||||
|
|
@ -232,7 +232,7 @@ class BaseSchedulerNode:
|
||||||
buf = IndentedBuffer()
|
buf = IndentedBuffer()
|
||||||
buf.splice(
|
buf.splice(
|
||||||
f"""\
|
f"""\
|
||||||
{name}: {type(self).__name__}({type(getattr(self, 'node', None)).__name__})
|
{name}: {type(self).__name__}({type(getattr(self, "node", None)).__name__})
|
||||||
{name}.writes = {pformat(self.read_writes.writes)}
|
{name}.writes = {pformat(self.read_writes.writes)}
|
||||||
{name}.unmet_dependencies = {pformat(self.unmet_dependencies)}
|
{name}.unmet_dependencies = {pformat(self.unmet_dependencies)}
|
||||||
{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}
|
{name}.met_dependencies = {pformat(self.read_writes.reads - self.unmet_dependencies)}
|
||||||
|
|
@ -525,9 +525,9 @@ class BaseSchedulerNode:
|
||||||
V.kernel.mutations.add(input_buf.get_name())
|
V.kernel.mutations.add(input_buf.get_name())
|
||||||
V.kernel.mutations.add(buf.get_name())
|
V.kernel.mutations.add(buf.get_name())
|
||||||
|
|
||||||
V.kernel.inplace_update_buffers[
|
V.kernel.inplace_update_buffers[buf.get_name()] = (
|
||||||
buf.get_name()
|
input_buf.get_name()
|
||||||
] = input_buf.get_name()
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
def codegen_originating_info(
|
def codegen_originating_info(
|
||||||
|
|
@ -693,7 +693,7 @@ class BaseSchedulerNode:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
def get_buf_bytes(
|
def get_buf_bytes(
|
||||||
buf: Optional[Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject]]
|
buf: Optional[Union[ir.Buffer, ir.TensorBox, ir.TorchBindObject]],
|
||||||
) -> int:
|
) -> int:
|
||||||
if not buf:
|
if not buf:
|
||||||
return 0
|
return 0
|
||||||
|
|
@ -794,12 +794,11 @@ class BaseSchedulerNode:
|
||||||
# runtime for that today
|
# runtime for that today
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
with FakeTensorMode() as fake_mode, FlopCounterMode(
|
with (
|
||||||
display=False
|
FakeTensorMode() as fake_mode,
|
||||||
) as flop_counter_mode, V.set_current_node(
|
FlopCounterMode(display=False) as flop_counter_mode,
|
||||||
self.node.fx_node
|
V.set_current_node(self.node.fx_node),
|
||||||
), V.set_fake_mode(
|
V.set_fake_mode(fake_mode),
|
||||||
fake_mode
|
|
||||||
):
|
):
|
||||||
from .ir import ir_node_to_tensor
|
from .ir import ir_node_to_tensor
|
||||||
|
|
||||||
|
|
@ -1123,15 +1122,15 @@ class SchedulerNode(BaseSchedulerNode):
|
||||||
return self._sizes
|
return self._sizes
|
||||||
|
|
||||||
def is_reduction(self) -> bool:
|
def is_reduction(self) -> bool:
|
||||||
assert isinstance(
|
assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)), (
|
||||||
self.node, (ir.ComputedBuffer, ir.TemplateBuffer)
|
f"{type(self.node)=}"
|
||||||
), f"{type(self.node)=}"
|
)
|
||||||
return bool(self.node.get_reduction_type())
|
return bool(self.node.get_reduction_type())
|
||||||
|
|
||||||
def is_split_scan(self) -> bool:
|
def is_split_scan(self) -> bool:
|
||||||
assert isinstance(
|
assert isinstance(self.node, (ir.ComputedBuffer, ir.TemplateBuffer)), (
|
||||||
self.node, (ir.ComputedBuffer, ir.TemplateBuffer)
|
f"{type(self.node)=}"
|
||||||
), f"{type(self.node)=}"
|
)
|
||||||
return isinstance(self.node, ir.ComputedBuffer) and isinstance(
|
return isinstance(self.node, ir.ComputedBuffer) and isinstance(
|
||||||
self.node.data, ir.SplitScan
|
self.node.data, ir.SplitScan
|
||||||
)
|
)
|
||||||
|
|
@ -1163,9 +1162,10 @@ class SchedulerNode(BaseSchedulerNode):
|
||||||
def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None:
|
def codegen(self, index_vars: Sequence[Sequence[sympy.Expr]]) -> None:
|
||||||
var_ranges = self.ranges_from_index_vars(index_vars)
|
var_ranges = self.ranges_from_index_vars(index_vars)
|
||||||
try:
|
try:
|
||||||
with V.set_ops_handler(
|
with (
|
||||||
SimplifyIndexing(V.get_ops_handler(), var_ranges)
|
V.set_ops_handler(SimplifyIndexing(V.get_ops_handler(), var_ranges)),
|
||||||
), V.kernel.set_current_node(self):
|
V.kernel.set_current_node(self),
|
||||||
|
):
|
||||||
self._body(*index_vars)
|
self._body(*index_vars)
|
||||||
except Exception:
|
except Exception:
|
||||||
log.fatal("Error in codegen for %s", self.node)
|
log.fatal("Error in codegen for %s", self.node)
|
||||||
|
|
@ -1231,7 +1231,7 @@ class SchedulerNode(BaseSchedulerNode):
|
||||||
|
|
||||||
|
|
||||||
def refresh_group_node_dependencies(
|
def refresh_group_node_dependencies(
|
||||||
group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode]
|
group_snode: Union[FusedSchedulerNode, GroupedSchedulerNode],
|
||||||
) -> None:
|
) -> None:
|
||||||
snodes = group_snode.snodes
|
snodes = group_snode.snodes
|
||||||
group_snode.set_read_writes(
|
group_snode.set_read_writes(
|
||||||
|
|
@ -1754,7 +1754,7 @@ class ForeachKernelSchedulerNode(FusedSchedulerNode):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def set_group_algorithm_for_combo_kernels(
|
def set_group_algorithm_for_combo_kernels(
|
||||||
custom_group_algorithm: Callable[[Scheduler], list[list[BaseSchedulerNode]]]
|
custom_group_algorithm: Callable[[Scheduler], list[list[BaseSchedulerNode]]],
|
||||||
) -> None:
|
) -> None:
|
||||||
ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels = (
|
ForeachKernelSchedulerNode.group_algorithm_for_combo_kernels = (
|
||||||
custom_group_algorithm
|
custom_group_algorithm
|
||||||
|
|
@ -1975,9 +1975,9 @@ class Scheduler:
|
||||||
for node in self.nodes:
|
for node in self.nodes:
|
||||||
node.prune_deps()
|
node.prune_deps()
|
||||||
|
|
||||||
self.name_to_donated_buffer: dict[
|
self.name_to_donated_buffer: dict[str, SchedulerDonatedBuffer] = (
|
||||||
str, SchedulerDonatedBuffer
|
self.get_donated_buffers()
|
||||||
] = self.get_donated_buffers()
|
)
|
||||||
self.name_to_node: dict[str, BaseSchedulerNode] = {
|
self.name_to_node: dict[str, BaseSchedulerNode] = {
|
||||||
n.get_name(): n for n in self.nodes
|
n.get_name(): n for n in self.nodes
|
||||||
}
|
}
|
||||||
|
|
@ -2099,9 +2099,9 @@ class Scheduler:
|
||||||
node.log_details()
|
node.log_details()
|
||||||
|
|
||||||
def create_scheduler_node(self, node: ir.Operation) -> BaseSchedulerNode:
|
def create_scheduler_node(self, node: ir.Operation) -> BaseSchedulerNode:
|
||||||
assert (
|
assert node.get_origins() is not None, (
|
||||||
node.get_origins() is not None
|
"All nodes passed to scheduling must have an origin"
|
||||||
), "All nodes passed to scheduling must have an origin"
|
)
|
||||||
if node.is_no_op():
|
if node.is_no_op():
|
||||||
return NopKernelSchedulerNode(self, node)
|
return NopKernelSchedulerNode(self, node)
|
||||||
elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)):
|
elif isinstance(node, (ir.ComputedBuffer, ir.TemplateBuffer)):
|
||||||
|
|
@ -2260,9 +2260,9 @@ class Scheduler:
|
||||||
)
|
)
|
||||||
# if a kernel takes unbacked symints, register dependencies
|
# if a kernel takes unbacked symints, register dependencies
|
||||||
for s in unbacked_symbol_uses:
|
for s in unbacked_symbol_uses:
|
||||||
assert (
|
assert s in unbacked_symbol_to_origin_node, (
|
||||||
s in unbacked_symbol_to_origin_node
|
f"{s} not in {unbacked_symbol_to_origin_node}"
|
||||||
), f"{s} not in {unbacked_symbol_to_origin_node}"
|
)
|
||||||
if (r := unbacked_symbol_to_origin_node[s]) is not None:
|
if (r := unbacked_symbol_to_origin_node[s]) is not None:
|
||||||
for buf in self.name_to_node[r].get_outputs():
|
for buf in self.name_to_node[r].get_outputs():
|
||||||
node.add_fake_dep(StarDep(buf.get_name()))
|
node.add_fake_dep(StarDep(buf.get_name()))
|
||||||
|
|
@ -2310,9 +2310,9 @@ class Scheduler:
|
||||||
for alt_name in buf.get_mutations():
|
for alt_name in buf.get_mutations():
|
||||||
self.mutation_renames[rename(alt_name)] = buf.get_name()
|
self.mutation_renames[rename(alt_name)] = buf.get_name()
|
||||||
self.mutation_renames[alt_name] = buf.get_name()
|
self.mutation_renames[alt_name] = buf.get_name()
|
||||||
self.mutation_real_name[
|
self.mutation_real_name[buf.get_name()] = (
|
||||||
buf.get_name()
|
self.mutation_real_name.get(alt_name, alt_name)
|
||||||
] = self.mutation_real_name.get(alt_name, alt_name)
|
)
|
||||||
|
|
||||||
# make sure outputs aren't dead-code-eliminated
|
# make sure outputs aren't dead-code-eliminated
|
||||||
for buf_name in V.graph.get_output_names():
|
for buf_name in V.graph.get_output_names():
|
||||||
|
|
@ -2322,9 +2322,9 @@ class Scheduler:
|
||||||
# make sure unbacked symints aren't dead-code-eliminated
|
# make sure unbacked symints aren't dead-code-eliminated
|
||||||
for out in V.graph.graph_outputs:
|
for out in V.graph.graph_outputs:
|
||||||
for s in out.get_unbacked_symbol_uses():
|
for s in out.get_unbacked_symbol_uses():
|
||||||
assert (
|
assert s in unbacked_symbol_to_origin_node, (
|
||||||
s in unbacked_symbol_to_origin_node
|
f"{s} not in {unbacked_symbol_to_origin_node.keys()}"
|
||||||
), f"{s} not in {unbacked_symbol_to_origin_node.keys()}"
|
)
|
||||||
if r := unbacked_symbol_to_origin_node[s]:
|
if r := unbacked_symbol_to_origin_node[s]:
|
||||||
for buf_name in self.name_to_node[r].get_buffer_names():
|
for buf_name in self.name_to_node[r].get_buffer_names():
|
||||||
log.debug(
|
log.debug(
|
||||||
|
|
@ -3304,15 +3304,15 @@ class Scheduler:
|
||||||
rhs_dep = node2_name2dep[buf_name]
|
rhs_dep = node2_name2dep[buf_name]
|
||||||
|
|
||||||
if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep):
|
if not isinstance(lhs_dep, MemoryDep) or not isinstance(rhs_dep, MemoryDep):
|
||||||
reasons[
|
reasons[buf_name] = (
|
||||||
buf_name
|
f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}"
|
||||||
] = f"not MemoryDep: {type(lhs_dep)} v.s. {type(rhs_dep)}"
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
if lhs_dep.get_numel() != rhs_dep.get_numel():
|
if lhs_dep.get_numel() != rhs_dep.get_numel():
|
||||||
reasons[
|
reasons[buf_name] = (
|
||||||
buf_name
|
f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}"
|
||||||
] = f"different numel: {lhs_dep.get_numel()} v.s. {rhs_dep.get_numel()}"
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# same numel but different MemoryDep.size. Should be broadcasting
|
# same numel but different MemoryDep.size. Should be broadcasting
|
||||||
|
|
@ -3340,9 +3340,9 @@ class Scheduler:
|
||||||
layout_str = ""
|
layout_str = ""
|
||||||
if not isinstance(buf, ir.TorchBindObject):
|
if not isinstance(buf, ir.TorchBindObject):
|
||||||
layout_str = f"Layout: {buf.layout}"
|
layout_str = f"Layout: {buf.layout}"
|
||||||
reasons[
|
reasons[buf_name] = (
|
||||||
buf_name
|
f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. {layout_str}"
|
||||||
] = f"Unknown reason: {lhs_dep} v.s. {rhs_dep}. {layout_str}"
|
)
|
||||||
|
|
||||||
return str(reasons)
|
return str(reasons)
|
||||||
|
|
||||||
|
|
@ -3903,9 +3903,9 @@ class Scheduler:
|
||||||
self.free_buffers()
|
self.free_buffers()
|
||||||
|
|
||||||
def create_backend(self, device: torch.device) -> BaseScheduling:
|
def create_backend(self, device: torch.device) -> BaseScheduling:
|
||||||
assert (
|
assert not is_gpu(device.type) or device.index is not None, (
|
||||||
not is_gpu(device.type) or device.index is not None
|
f"{device} should have been normalized in lowering"
|
||||||
), f"{device} should have been normalized in lowering"
|
)
|
||||||
V.graph.add_device_info(device)
|
V.graph.add_device_info(device)
|
||||||
|
|
||||||
device_scheduling = get_scheduling_for_device(device.type)
|
device_scheduling = get_scheduling_for_device(device.type)
|
||||||
|
|
@ -4135,9 +4135,9 @@ class Scheduler:
|
||||||
partitions, signatures = self.graph_partition()
|
partitions, signatures = self.graph_partition()
|
||||||
|
|
||||||
for partition, signature in zip(partitions, signatures):
|
for partition, signature in zip(partitions, signatures):
|
||||||
assert (
|
assert len(partition) >= 1, (
|
||||||
len(partition) >= 1
|
f"Each partition must have at least one node but found {len(partition)}"
|
||||||
), f"Each partition must have at least one node but found {len(partition)}"
|
)
|
||||||
|
|
||||||
if signature.skip_cudagraph:
|
if signature.skip_cudagraph:
|
||||||
self._codegen(partition)
|
self._codegen(partition)
|
||||||
|
|
|
||||||
|
|
@ -168,9 +168,9 @@ class PartialRender:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
assert (
|
assert self.replacement_hooks[hook_key] is not None, (
|
||||||
self.replacement_hooks[hook_key] is not None
|
"hook_key can only be called once"
|
||||||
), "hook_key can only be called once"
|
)
|
||||||
self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]())
|
self.code = self.code.replace(hook_key, self.replacement_hooks[hook_key]())
|
||||||
self.replacement_hooks[hook_key] = None
|
self.replacement_hooks[hook_key] = None
|
||||||
|
|
||||||
|
|
@ -257,9 +257,9 @@ class ModificationWrapper(V.WrapperHandler): # type: ignore[name-defined]
|
||||||
This is used by flex_attention's backwards grad for captured buffers, see
|
This is used by flex_attention's backwards grad for captured buffers, see
|
||||||
zeros_and_scatter lowering
|
zeros_and_scatter lowering
|
||||||
"""
|
"""
|
||||||
assert (
|
assert self.mask is not None, (
|
||||||
self.mask is not None
|
"Mask is required for inner stores in modifications"
|
||||||
), "Mask is required for inner stores in modifications"
|
)
|
||||||
assert mode == "atomic_add", "Only atomic_add is supported for inner stores"
|
assert mode == "atomic_add", "Only atomic_add is supported for inner stores"
|
||||||
|
|
||||||
buf_name = self._add_kernel_input(name)
|
buf_name = self._add_kernel_input(name)
|
||||||
|
|
@ -573,12 +573,12 @@ class TritonTemplateKernel(TritonKernel):
|
||||||
def _get_subgraph(self, subgraph_number: int):
|
def _get_subgraph(self, subgraph_number: int):
|
||||||
assert isinstance(subgraph_number, int)
|
assert isinstance(subgraph_number, int)
|
||||||
assert isinstance(self.subgraphs, list)
|
assert isinstance(self.subgraphs, list)
|
||||||
assert subgraph_number < len(
|
assert subgraph_number < len(self.subgraphs), (
|
||||||
self.subgraphs
|
f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}"
|
||||||
), f"Invalid subgraph number provided to create_modification, {subgraph_number} must be < {len(self.subgraphs)}"
|
)
|
||||||
assert (
|
assert self.body.getvalue() == "", (
|
||||||
self.body.getvalue() == ""
|
"Body should be clear before adding a modification"
|
||||||
), "Body should be clear before adding a modification"
|
)
|
||||||
return self.subgraphs[subgraph_number]
|
return self.subgraphs[subgraph_number]
|
||||||
|
|
||||||
def _handle_scatter_graph(self, scatter_graph):
|
def _handle_scatter_graph(self, scatter_graph):
|
||||||
|
|
@ -587,9 +587,9 @@ class TritonTemplateKernel(TritonKernel):
|
||||||
Args:
|
Args:
|
||||||
scatter_graph: The scatter graph to process
|
scatter_graph: The scatter graph to process
|
||||||
"""
|
"""
|
||||||
assert isinstance(
|
assert isinstance(scatter_graph, ir.ComputedBuffer), (
|
||||||
scatter_graph, ir.ComputedBuffer
|
f"scatter_graph must be an instance of ComputeBuffer but got {type(scatter_graph)}"
|
||||||
), f"scatter_graph must be an instance of ComputeBuffer but got {type(scatter_graph)}"
|
)
|
||||||
|
|
||||||
def contiguous_strides(x):
|
def contiguous_strides(x):
|
||||||
# We always create a fresh contiguous grad for scattering into
|
# We always create a fresh contiguous grad for scattering into
|
||||||
|
|
@ -597,7 +597,9 @@ class TritonTemplateKernel(TritonKernel):
|
||||||
x_i * stride for x_i, stride in zip(x, scatter_graph.get_stride())
|
x_i * stride for x_i, stride in zip(x, scatter_graph.get_stride())
|
||||||
)
|
)
|
||||||
|
|
||||||
return scatter_graph.data.store_output(scatter_graph.name, contiguous_strides, []) # type: ignore[attr-defined]
|
return scatter_graph.data.store_output( # type: ignore[attr-defined]
|
||||||
|
scatter_graph.name, contiguous_strides, []
|
||||||
|
)
|
||||||
|
|
||||||
def modification(
|
def modification(
|
||||||
self,
|
self,
|
||||||
|
|
@ -626,9 +628,9 @@ class TritonTemplateKernel(TritonKernel):
|
||||||
self, subgraph_number, fixed_inputs, mask
|
self, subgraph_number, fixed_inputs, mask
|
||||||
)
|
)
|
||||||
with V.set_ops_handler(modification_handler):
|
with V.set_ops_handler(modification_handler):
|
||||||
assert isinstance(
|
assert isinstance(subgraph, (ir.ComputedBuffer, list)), (
|
||||||
subgraph, (ir.ComputedBuffer, list)
|
f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}"
|
||||||
), f"Expected the subgraph to be a ComputedBuffer or a List[ComputedBuffer], got {type(subgraph)}"
|
)
|
||||||
# Handle scatter stores
|
# Handle scatter stores
|
||||||
if isinstance(subgraph, list):
|
if isinstance(subgraph, list):
|
||||||
for scatter_graph in subgraph:
|
for scatter_graph in subgraph:
|
||||||
|
|
@ -1123,15 +1125,17 @@ class TritonTemplate(KernelTemplate):
|
||||||
"subgraphs": subgraphs,
|
"subgraphs": subgraphs,
|
||||||
}
|
}
|
||||||
|
|
||||||
with patch.object(
|
with (
|
||||||
V.graph, "get_dtype", self._fake_get_dtype(fake_out)
|
patch.object(V.graph, "get_dtype", self._fake_get_dtype(fake_out)),
|
||||||
), V.graph.set_current_device(layout.device), TritonTemplateKernel(
|
V.graph.set_current_device(layout.device),
|
||||||
kernel_name=kernel_name,
|
TritonTemplateKernel(
|
||||||
output_node=fake_out,
|
kernel_name=kernel_name,
|
||||||
workspace_arg=workspace_arg,
|
output_node=fake_out,
|
||||||
use_jit=False,
|
workspace_arg=workspace_arg,
|
||||||
**kernel_options,
|
use_jit=False,
|
||||||
) as kernel:
|
**kernel_options,
|
||||||
|
) as kernel,
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
template = kernel.render(self.template, kwargs)
|
template = kernel.render(self.template, kwargs)
|
||||||
with kernel.set_subgraph_body("<STORE_OUTPUT>"):
|
with kernel.set_subgraph_body("<STORE_OUTPUT>"):
|
||||||
|
|
@ -1442,9 +1446,9 @@ class ExternKernelCaller(ChoiceCaller):
|
||||||
|
|
||||||
def output_node(self):
|
def output_node(self):
|
||||||
if self.choice.use_fallback_kernel:
|
if self.choice.use_fallback_kernel:
|
||||||
assert (
|
assert self.choice.op_overload is not None, (
|
||||||
self.choice.op_overload is not None
|
"Please provide an op_overload to use ir.FallbackKernel"
|
||||||
), "Please provide an op_overload to use ir.FallbackKernel"
|
)
|
||||||
inner = ir.FallbackKernel.create(
|
inner = ir.FallbackKernel.create(
|
||||||
self.choice.op_overload, *self.input_nodes, **self.kwargs
|
self.choice.op_overload, *self.input_nodes, **self.kwargs
|
||||||
)
|
)
|
||||||
|
|
@ -1979,7 +1983,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||||
input_gen_fns = {}
|
input_gen_fns = {}
|
||||||
|
|
||||||
def get_inputs(
|
def get_inputs(
|
||||||
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]]
|
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]],
|
||||||
) -> AutotuneArgs:
|
) -> AutotuneArgs:
|
||||||
# de-duplicate args
|
# de-duplicate args
|
||||||
unique_example_inputs = {
|
unique_example_inputs = {
|
||||||
|
|
@ -2099,7 +2103,7 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||||
return timings
|
return timings
|
||||||
|
|
||||||
def benchmark_in_sub_process(
|
def benchmark_in_sub_process(
|
||||||
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]]
|
choices: Union[list[ExternKernelCaller], list[TritonTemplateCaller]],
|
||||||
):
|
):
|
||||||
from . import autotune_process
|
from . import autotune_process
|
||||||
|
|
||||||
|
|
@ -2139,7 +2143,8 @@ class AlgorithmSelectorCache(PersistentCache):
|
||||||
map(
|
map(
|
||||||
str,
|
str,
|
||||||
V.graph.sizevars.size_hints(
|
V.graph.sizevars.size_hints(
|
||||||
n.get_size(), fallback=config.unbacked_symint_fallback # type: ignore[arg-type]
|
n.get_size(),
|
||||||
|
fallback=config.unbacked_symint_fallback, # type: ignore[arg-type]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
@ -2313,15 +2318,15 @@ def autotune_select_algorithm(*args, **kwargs):
|
||||||
_ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache()
|
_ALGORITHM_SELECTOR_CACHE = AlgorithmSelectorCache()
|
||||||
|
|
||||||
if "return_multi_template" not in kwargs:
|
if "return_multi_template" not in kwargs:
|
||||||
kwargs[
|
kwargs["return_multi_template"] = (
|
||||||
"return_multi_template"
|
torch._inductor.config.benchmark_epilogue_fusion
|
||||||
] = torch._inductor.config.benchmark_epilogue_fusion
|
)
|
||||||
|
|
||||||
return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
|
return _ALGORITHM_SELECTOR_CACHE(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def add_feedback_saver(
|
def add_feedback_saver(
|
||||||
fn: Callable[[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None]
|
fn: Callable[[dict[ChoiceCaller, float], str, list[Any], list[ChoiceCaller]], None],
|
||||||
):
|
):
|
||||||
global _ALGORITHM_SELECTOR_CACHE
|
global _ALGORITHM_SELECTOR_CACHE
|
||||||
if _ALGORITHM_SELECTOR_CACHE is None:
|
if _ALGORITHM_SELECTOR_CACHE is None:
|
||||||
|
|
|
||||||
|
|
@ -905,9 +905,9 @@ class SimplifyIndexing(V.WrapperHandler): # type: ignore[name-defined]
|
||||||
def __init__(self, inner, var_ranges: VarRanges) -> None:
|
def __init__(self, inner, var_ranges: VarRanges) -> None:
|
||||||
super().__init__(inner)
|
super().__init__(inner)
|
||||||
self.name = "SimplifyIndexing"
|
self.name = "SimplifyIndexing"
|
||||||
self._simplify: Callable[
|
self._simplify: Callable[[Expr], Expr] = (
|
||||||
[Expr], Expr
|
lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges)
|
||||||
] = lambda index: V.graph.sizevars.simplify_with_ranges(index, var_ranges)
|
)
|
||||||
|
|
||||||
def load(self, name: str, index: sympy.Expr):
|
def load(self, name: str, index: sympy.Expr):
|
||||||
return self._inner.load(name, self._simplify(index))
|
return self._inner.load(name, self._simplify(index))
|
||||||
|
|
|
||||||
|
|
@ -283,9 +283,9 @@ def ceildiv(
|
||||||
# TODO: There is a bug in a call to this function, to repro:
|
# TODO: There is a bug in a call to this function, to repro:
|
||||||
# python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy
|
# python benchmarks/dynamo/huggingface.py --inductor -d cuda --accuracy
|
||||||
# --amp --only YituTechConvBert --dynamic-shapes
|
# --amp --only YituTechConvBert --dynamic-shapes
|
||||||
assert isinstance(numer, int) and isinstance(
|
assert isinstance(numer, int) and isinstance(denom, int), (
|
||||||
denom, int
|
f"{numer}: {type(numer)}, {denom}: {type(denom)}"
|
||||||
), f"{numer}: {type(numer)}, {denom}: {type(denom)}"
|
)
|
||||||
return runtime_ceildiv(numer, denom)
|
return runtime_ceildiv(numer, denom)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -325,7 +325,7 @@ def _type_of(key: Optional[torch.dtype]) -> str:
|
||||||
|
|
||||||
|
|
||||||
def convert_shape_to_inductor(
|
def convert_shape_to_inductor(
|
||||||
lst: Iterable[Union[int, torch.SymInt]]
|
lst: Iterable[Union[int, torch.SymInt]],
|
||||||
) -> list[sympy.Expr]:
|
) -> list[sympy.Expr]:
|
||||||
"""
|
"""
|
||||||
Gets the shape and stride of a tensor. For non-symbolic tensors, this is
|
Gets the shape and stride of a tensor. For non-symbolic tensors, this is
|
||||||
|
|
@ -502,11 +502,9 @@ RV = TypeVar("RV", covariant=True)
|
||||||
|
|
||||||
class CachedMethod(Protocol, Generic[P, RV]):
|
class CachedMethod(Protocol, Generic[P, RV]):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def clear_cache(cache: Any) -> None:
|
def clear_cache(cache: Any) -> None: ...
|
||||||
...
|
|
||||||
|
|
||||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV:
|
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> RV: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
# See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature
|
# See https://github.com/python/mypy/issues/13222#issuecomment-1193073470 to understand the type signature
|
||||||
|
|
@ -1359,9 +1357,9 @@ def _rocm_native_device_arch_name(device: str) -> str:
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def try_import_ck_lib() -> (
|
def try_import_ck_lib() -> tuple[
|
||||||
tuple[Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any]]
|
Optional[str], Callable[[], list[Any]], Callable[[], list[Any]], type[Any]
|
||||||
):
|
]:
|
||||||
try:
|
try:
|
||||||
import ck4inductor # type: ignore[import]
|
import ck4inductor # type: ignore[import]
|
||||||
from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import]
|
from ck4inductor.universal_gemm.gen_instances import ( # type: ignore[import]
|
||||||
|
|
@ -1610,9 +1608,12 @@ def get_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> list[str]:
|
||||||
|
|
||||||
return DummyModule()
|
return DummyModule()
|
||||||
|
|
||||||
with mock.patch.object(
|
with (
|
||||||
GraphLowering, "compile_to_module", patched_compile_to_module
|
mock.patch.object(
|
||||||
), mock.patch.object(GraphLowering, "save_output_code", save_output_code):
|
GraphLowering, "compile_to_module", patched_compile_to_module
|
||||||
|
),
|
||||||
|
mock.patch.object(GraphLowering, "save_output_code", save_output_code),
|
||||||
|
):
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
# Note the return here is None
|
# Note the return here is None
|
||||||
_ = fn(*args, **kwargs)
|
_ = fn(*args, **kwargs)
|
||||||
|
|
@ -1623,18 +1624,18 @@ def get_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> list[str]:
|
||||||
def get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> str:
|
def get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> str:
|
||||||
source_codes = get_code(fn, *args, **kwargs)
|
source_codes = get_code(fn, *args, **kwargs)
|
||||||
# Can have two outputs if backwards was eagerly compiled
|
# Can have two outputs if backwards was eagerly compiled
|
||||||
assert (
|
assert 1 <= len(source_codes) <= 2, (
|
||||||
1 <= len(source_codes) <= 2
|
f"expected one or two code outputs got {len(source_codes)}"
|
||||||
), f"expected one or two code outputs got {len(source_codes)}"
|
)
|
||||||
return source_codes[0]
|
return source_codes[0]
|
||||||
|
|
||||||
|
|
||||||
def run_and_get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> str:
|
def run_and_get_triton_code(fn: Callable[..., Any], *args: Any, **kwargs: Any) -> str:
|
||||||
_, source_codes = run_and_get_code(fn, *args, **kwargs)
|
_, source_codes = run_and_get_code(fn, *args, **kwargs)
|
||||||
# Can have two outputs if backwards was eagerly compiled
|
# Can have two outputs if backwards was eagerly compiled
|
||||||
assert (
|
assert 1 <= len(source_codes) <= 2, (
|
||||||
1 <= len(source_codes) <= 2
|
f"expected one or two code outputs got {len(source_codes)}"
|
||||||
), f"expected one or two code outputs got {len(source_codes)}"
|
)
|
||||||
return source_codes[0]
|
return source_codes[0]
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1760,9 +1761,9 @@ def is_cpu_device(inputs: Sequence[torch.Tensor]) -> bool:
|
||||||
|
|
||||||
|
|
||||||
def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype:
|
def get_sympy_Expr_dtype(val: sympy.Expr) -> torch.dtype:
|
||||||
assert isinstance(
|
assert isinstance(val, sympy.Expr), (
|
||||||
val, sympy.Expr
|
"only support sympy.Expr as input to get_sympy_Expr_dtype"
|
||||||
), "only support sympy.Expr as input to get_sympy_Expr_dtype"
|
)
|
||||||
if val.is_integer: # type: ignore[attr-defined]
|
if val.is_integer: # type: ignore[attr-defined]
|
||||||
return torch.int64
|
return torch.int64
|
||||||
else:
|
else:
|
||||||
|
|
@ -1932,7 +1933,7 @@ def is_multi_outputs_template(input_buf: Optional[Union[Buffer, Operation]]) ->
|
||||||
|
|
||||||
|
|
||||||
def is_output_of_multi_outputs_template(
|
def is_output_of_multi_outputs_template(
|
||||||
input_buf: Optional[Union[Buffer, Operation]]
|
input_buf: Optional[Union[Buffer, Operation]],
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if input buffer is a output of multi-outputs template buffer
|
Check if input buffer is a output of multi-outputs template buffer
|
||||||
|
|
@ -2633,7 +2634,8 @@ def set_kernel_post_grad_provenance_tracing(
|
||||||
if node not in (EnableReduction, DisableReduction):
|
if node not in (EnableReduction, DisableReduction):
|
||||||
if node.node is not None:
|
if node.node is not None:
|
||||||
V.debug._inductor_triton_kernel_to_post_grad_node_info[kernel_name] = [
|
V.debug._inductor_triton_kernel_to_post_grad_node_info[kernel_name] = [
|
||||||
origin.name for origin in node.node.origins # type: ignore[attr-defined]
|
origin.name
|
||||||
|
for origin in node.node.origins # type: ignore[attr-defined]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -314,9 +314,9 @@ class _V:
|
||||||
KernelFormatterHandler = KernelFormatterHandler
|
KernelFormatterHandler = KernelFormatterHandler
|
||||||
WrapperHandler = WrapperHandler
|
WrapperHandler = WrapperHandler
|
||||||
|
|
||||||
set_ops_handler: Callable[
|
set_ops_handler: Callable[[OpsHandler[Any]], AbstractContextManager[None]] = (
|
||||||
[OpsHandler[Any]], AbstractContextManager[None]
|
_ops._set_handler
|
||||||
] = _ops._set_handler
|
)
|
||||||
get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler
|
get_ops_handler: Callable[[], OpsHandler[Any]] = _ops._get_handler
|
||||||
set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler
|
set_graph_handler: Callable[[GraphLowering], Any] = _graph._set_handler
|
||||||
set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler
|
set_real_inputs: Callable[[Any], Any] = _real_inputs._set_handler
|
||||||
|
|
|
||||||
|
|
@ -14,8 +14,7 @@ from .runtime.runtime_utils import create_bandwidth_info_str, get_num_bytes
|
||||||
|
|
||||||
|
|
||||||
class BenchmarkCallableType(Protocol):
|
class BenchmarkCallableType(Protocol):
|
||||||
def __call__(self, times: int, repeat: int) -> float:
|
def __call__(self, times: int, repeat: int) -> float: ...
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
_kernel_category_choices = [
|
_kernel_category_choices = [
|
||||||
|
|
@ -138,9 +137,9 @@ def benchmark_all_kernels(
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
ms = benchmarker.benchmark_gpu(lambda: kernel_mod.call(args), rep=40)
|
ms = benchmarker.benchmark_gpu(lambda: kernel_mod.call(args), rep=40)
|
||||||
assert (
|
assert len(triton_kernel.launchers) == 1, (
|
||||||
len(triton_kernel.launchers) == 1
|
"Autotuner should have selected the best config"
|
||||||
), "Autotuner should have selected the best config"
|
)
|
||||||
launcher = triton_kernel.launchers[0]
|
launcher = triton_kernel.launchers[0]
|
||||||
print(
|
print(
|
||||||
get_info_str(
|
get_info_str(
|
||||||
|
|
@ -256,9 +255,9 @@ def parse_profile_event_list(
|
||||||
"triton_unknown",
|
"triton_unknown",
|
||||||
"unknown",
|
"unknown",
|
||||||
]
|
]
|
||||||
assert OrderedSet(all_events.keys()).issubset(
|
assert OrderedSet(all_events.keys()).issubset(OrderedSet(category_list)), (
|
||||||
OrderedSet(category_list)
|
f"{list(all_events.keys())}"
|
||||||
), f"{list(all_events.keys())}"
|
)
|
||||||
|
|
||||||
per_category_wall_time = {}
|
per_category_wall_time = {}
|
||||||
total_device_ms = 0.0
|
total_device_ms = 0.0
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user