mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Delete Lark (#123689)
Now that we are using MLIR bindings inside triton, lets delete Lark parser. Pull Request resolved: https://github.com/pytorch/pytorch/pull/123689 Approved by: https://github.com/jansel
This commit is contained in:
parent
8d9af8b91c
commit
a631461eef
|
|
@ -52,11 +52,6 @@ junitparser==2.1.1
|
|||
#Pinned versions: 2.1.1
|
||||
#test that import:
|
||||
|
||||
lark==0.12.0
|
||||
#Description: parser
|
||||
#Pinned versions: 0.12.0
|
||||
#test that import:
|
||||
|
||||
librosa>=0.6.2 ; python_version < "3.11"
|
||||
#Description: A python package for music and audio analysis
|
||||
#Pinned versions: >=0.6.2
|
||||
|
|
|
|||
|
|
@ -18,4 +18,3 @@ fsspec
|
|||
setuptools ; python_version >= "3.12"
|
||||
packaging
|
||||
optree>=0.9.1
|
||||
lark
|
||||
|
|
|
|||
|
|
@ -7,8 +7,6 @@ import torch
|
|||
import torch._dynamo.testing
|
||||
|
||||
import torch._inductor.test_case
|
||||
from torch._dynamo import config
|
||||
from torch._dynamo.testing import make_test_cls_with_patches
|
||||
|
||||
from torch._higher_order_ops.triton_kernel_wrap import (
|
||||
generate_ttir,
|
||||
|
|
@ -1225,7 +1223,6 @@ def forward(self, x_1, output_1):
|
|||
|
||||
def make_mutation_test(fn):
|
||||
@requires_cuda
|
||||
@requires_lark
|
||||
@skipIfRocm
|
||||
def test_fn(self):
|
||||
from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors
|
||||
|
|
@ -1776,7 +1773,7 @@ class MutationTests(torch._inductor.test_case.TestCase):
|
|||
)
|
||||
|
||||
|
||||
if HAS_CUDA and HAS_LARK:
|
||||
if HAS_CUDA:
|
||||
t = torch.randn(4)
|
||||
tt = torch.randn(4, 1)
|
||||
tests = [
|
||||
|
|
@ -1921,15 +1918,6 @@ if HAS_CUDA and HAS_LARK:
|
|||
|
||||
common_utils.instantiate_parametrized_tests(KernelTests)
|
||||
|
||||
no_opt_test_class = make_test_cls_with_patches(
|
||||
KernelTests,
|
||||
"NoOptimization",
|
||||
"_no_optimizations",
|
||||
(config, "optimize_user_defined_triton_kernels", False),
|
||||
)
|
||||
|
||||
globals()[no_opt_test_class.__name__] = no_opt_test_class
|
||||
no_opt_test_class.__module__ = __name__
|
||||
|
||||
if __name__ == "__main__":
|
||||
from torch._inductor.test_case import run_tests
|
||||
|
|
|
|||
|
|
@ -393,9 +393,6 @@ capture_autograd_function = True
|
|||
# enable/disable dynamo tracing for `torch.func` transforms
|
||||
capture_func_transforms = True
|
||||
|
||||
# enable/disable user-defined triton kernel optimizations
|
||||
optimize_user_defined_triton_kernels = True
|
||||
|
||||
# If to log Dynamo compilation metrics into log files (for OSS) and Scuba tables (for fbcode).
|
||||
log_compilation_metrics = True
|
||||
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ import dataclasses
|
|||
import inspect
|
||||
import logging
|
||||
import threading
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
|
|
@ -363,193 +362,6 @@ def ttir_to_functions(ttir_module) -> Dict[str, Dict[Intermediate, List[Op]]]:
|
|||
return functions
|
||||
|
||||
|
||||
def parse_ttir(ttir, kwargs):
|
||||
"""
|
||||
Given a Triton emitted TTIR text, this function lexes and parses the
|
||||
code using a minimal grammar defined inside. During the lexing/parsing,
|
||||
we drop any constant value and type information as they are not
|
||||
necessary to us.
|
||||
Being able to choose what we need makes this not a general purpose TTIR
|
||||
parser which further makes parsing much simpler.
|
||||
"""
|
||||
# TODO(oulgen):
|
||||
# - Support closures (e.g. "tt.reduce")
|
||||
|
||||
try:
|
||||
import lark # type: ignore[import-not-found]
|
||||
from lark import Lark, Transformer, v_args
|
||||
except ModuleNotFoundError:
|
||||
warnings.warn(
|
||||
"Using slow path for user-defined Triton kernels. `pip install lark` to fix this."
|
||||
)
|
||||
raise
|
||||
|
||||
# Ops looks like one of the following forms:
|
||||
#
|
||||
# %14 = tt.addptr %13, %4 : tensor<4x!tt.ptr<f32, 1>>, tensor<4xi32>
|
||||
# tt.store %14, %12, %5 {cache = 1 : i32, evict = 1 : i32} : tensor<4xf32>
|
||||
# %15 = "tt.atomic_rmw"(%14, %12, %5) <{atomic_rmw_op = 5 : i32, scope = 1 : i32, sem = 4 : i32}> : (tensor<4x!tt.ptr<f32, 1>>, tensor<4xf32>, tensor<4xi1>) -> tensor<4xf32> # noqa: B950
|
||||
grammar = """
|
||||
start: (module_block | loc_line)+
|
||||
|
||||
loc_line: "#loc" /.+/ NEWLINE
|
||||
|
||||
module_block: "module" "{" func_block+ "}" LOC
|
||||
|
||||
func_block: "tt.func" ("public"|"private") FN_NAME "(" /.+/ NEWLINE stmt* "}" LOC -> process_func
|
||||
|
||||
?stmt: op | if | for | while | condition_stmt | label_stmt | cf_stmt
|
||||
|
||||
if: [assign_lhs "="] "scf.if" args rest stmt* "}" "else" "{" stmt* "}" LOC -> process_if
|
||||
for: [assign_lhs "="] "scf.for" args rest stmt* "}" divisibility_annot? LOC -> process_for
|
||||
while: [assign_lhs "="] "scf.while" args rest stmt* "}" "do" "{" stmt* "}" LOC -> process_while
|
||||
|
||||
condition_stmt: "scf.condition" "(" arg ")" args rest
|
||||
label_stmt: LABEL ":" "// pred:" LABEL
|
||||
| LABEL "(" /.+/ NEWLINE
|
||||
cf_stmt: "cf" "." NAME /.+/ NEWLINE
|
||||
|
||||
op: OP_NAME LOC
|
||||
| [assign_lhs "="] OP_NAME [FN_NAME] args rest? -> process_op
|
||||
|
||||
?rest: (":" | "{" | "\\"" | "->" | "<" | "=") /.+/ NEWLINE
|
||||
divisibility_annot: "{" "tt.divisibility_arg1" /[^}]+/ "}"
|
||||
|
||||
args: | "(" ")" | "("? arg ("," arg)* ")"?
|
||||
|
||||
?arg: INTERMEDIATE
|
||||
| INTERMEDIATE_CONSTANT
|
||||
| CONSTANT
|
||||
| PARAM
|
||||
| "[" args "]"
|
||||
| arg_with_index
|
||||
|
||||
?arg_with_index: arg "#" DIGIT+
|
||||
|
||||
?assign_lhs: (INTERMEDIATE | INTERMEDIATE_CONSTANT) [":" DIGIT+]
|
||||
|
||||
PARAM.5: "%arg" DIGIT+
|
||||
INTERMEDIATE.4: "%" DIGIT+
|
||||
INTERMEDIATE_CONSTANT.3: "%" NAME
|
||||
CONSTANT: FLOAT | DIGIT+ | NAME ("<" DIGIT+ ">")?
|
||||
LABEL: "^bb" DIGIT+
|
||||
|
||||
NAME: (LETTER | DIGIT | "_")+
|
||||
NON_CF_NAME: /(?!(cf))/ NAME
|
||||
FN_NAME: "@" (NAME | ESCAPED_STRING)
|
||||
OP_NAME: "\\""? NON_CF_NAME ("." NAME)+ "\\""?
|
||||
|
||||
LOC.5: "loc(#loc" DIGIT* ")"
|
||||
|
||||
%import common.LETTER
|
||||
%import common.DIGIT
|
||||
%import common.WS
|
||||
%import common.NEWLINE
|
||||
%import common.ESCAPED_STRING
|
||||
%import common.FLOAT
|
||||
%ignore WS
|
||||
"""
|
||||
|
||||
next_fake_intermediate = 0
|
||||
|
||||
def convert(token):
|
||||
if isinstance(token, lark.tree.Tree):
|
||||
if token.data == "args":
|
||||
res = []
|
||||
for a in token.children:
|
||||
c = convert(a)
|
||||
if isinstance(c, list):
|
||||
res.extend(c)
|
||||
else:
|
||||
res.append(c)
|
||||
return res
|
||||
elif token.data in {"assign_lhs", "arg_with_index"}:
|
||||
# Drop length/index qualifier
|
||||
return convert(token.children[0])
|
||||
else:
|
||||
raise AssertionError(f"Tree node with {token.data}")
|
||||
|
||||
if token is None or (
|
||||
isinstance(token, lark.lexer.Token)
|
||||
and token.type in ("CONSTANT", "INTERMEDIATE_CONSTANT")
|
||||
):
|
||||
nonlocal next_fake_intermediate
|
||||
next_fake_intermediate -= 1
|
||||
return Intermediate(next_fake_intermediate)
|
||||
|
||||
assert isinstance(token, lark.lexer.Token)
|
||||
|
||||
if token.type == "INTERMEDIATE":
|
||||
return Intermediate(int(token.value[len("%") :]))
|
||||
if token.type == "PARAM":
|
||||
return Param(int(token.value[len("%arg") :]))
|
||||
|
||||
raise AssertionError(f"{type(token.type)} => {token.value} invalid")
|
||||
|
||||
# In alternative representation, function names are quoted.
|
||||
# It should be possible to move this into the grammar alltogether.
|
||||
def convert_name(token):
|
||||
if token is None:
|
||||
return None
|
||||
s = token.value
|
||||
if len(s) > 2 and s[0] == '"' and s[-1] == '"':
|
||||
return s[1:-1]
|
||||
return s
|
||||
|
||||
functions: Dict[str, Dict[Intermediate, List[Op]]] = {}
|
||||
|
||||
def extend_dict_list(d1, d2):
|
||||
for key, values in d2.items():
|
||||
d1[key].extend(values)
|
||||
|
||||
@v_args(inline=True)
|
||||
class TransformOps(Transformer):
|
||||
def process_op(self, ret, op_name, fn_name, args, *rest):
|
||||
return Op(
|
||||
convert_name(op_name),
|
||||
convert_name(fn_name),
|
||||
convert(args),
|
||||
convert(ret),
|
||||
)
|
||||
|
||||
def process_func(self, name, _args, *stmts):
|
||||
ops: Dict[Intermediate, List[Op]] = defaultdict(list)
|
||||
for e in stmts:
|
||||
if isinstance(e, Op):
|
||||
ops[e.ret].append(e)
|
||||
elif isinstance(e, dict):
|
||||
extend_dict_list(ops, e)
|
||||
functions[name.value] = ops
|
||||
|
||||
def _process_scf(self, ret, stmts):
|
||||
ret = convert(ret)
|
||||
ops: Dict[Intermediate, List[Op]] = defaultdict(list)
|
||||
for e in stmts:
|
||||
if isinstance(e, Op):
|
||||
if e.name == "scf.yield":
|
||||
ops[ret].append(Op(e.name, None, e.args, ret))
|
||||
else:
|
||||
ops[e.ret].append(e)
|
||||
elif isinstance(e, dict):
|
||||
extend_dict_list(ops, e)
|
||||
return ops
|
||||
|
||||
def process_if(self, ret, _args, _rest, *stmts):
|
||||
return self._process_scf(ret, stmts)
|
||||
|
||||
def process_for(self, ret, _args, _rest, *stmts):
|
||||
return self._process_scf(ret, stmts)
|
||||
|
||||
def process_while(self, ret, _args, _rest, *stmts):
|
||||
return self._process_scf(ret, stmts)
|
||||
|
||||
parser = Lark(
|
||||
grammar, parser="lalr", maybe_placeholders=True, transformer=TransformOps()
|
||||
)
|
||||
parser.parse(ttir)
|
||||
return functions
|
||||
|
||||
|
||||
class MemoizeWithCycleCheck:
|
||||
def __init__(self, fn):
|
||||
self.fn = fn
|
||||
|
|
@ -637,20 +449,10 @@ def identify_mutated_tensors(kernel, kwargs):
|
|||
ttir_module = None
|
||||
functions = None
|
||||
try:
|
||||
from torch._dynamo import config
|
||||
|
||||
if not config.optimize_user_defined_triton_kernels:
|
||||
raise ValueError("optimize_user_defined_triton_kernels is False")
|
||||
|
||||
ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs)
|
||||
|
||||
# extract functions from TTIR
|
||||
if hasattr(ttir_module, "walk"):
|
||||
# use MLIR bindings exposed by Triton code
|
||||
functions = ttir_to_functions(ttir_module)
|
||||
else:
|
||||
# parse string representation of Triton IR
|
||||
functions = parse_ttir(str(ttir_module), kwargs)
|
||||
# extract functions from TTIR using MLIR bindings exposed by Triton code
|
||||
functions = ttir_to_functions(ttir_module)
|
||||
|
||||
assert functions is not None
|
||||
kernel_name = next(iter(functions.keys()))
|
||||
|
|
|
|||
|
|
@ -5,18 +5,6 @@ import unittest
|
|||
from torch.testing._internal.inductor_utils import HAS_CUDA
|
||||
|
||||
|
||||
def has_lark():
|
||||
try:
|
||||
import lark # noqa: F401
|
||||
|
||||
return True
|
||||
except ModuleNotFoundError:
|
||||
return False
|
||||
|
||||
|
||||
HAS_LARK = has_lark()
|
||||
|
||||
requires_lark = unittest.skipUnless(HAS_LARK, "requires lark")
|
||||
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
|
||||
|
||||
if HAS_CUDA:
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user