Delete Lark (#123689)

Now that we are using MLIR bindings inside triton, lets delete Lark parser.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123689
Approved by: https://github.com/jansel
This commit is contained in:
Oguz Ulgen 2024-04-09 17:33:33 -07:00 committed by PyTorch MergeBot
parent 8d9af8b91c
commit a631461eef
6 changed files with 3 additions and 234 deletions

View File

@ -52,11 +52,6 @@ junitparser==2.1.1
#Pinned versions: 2.1.1
#test that import:
lark==0.12.0
#Description: parser
#Pinned versions: 0.12.0
#test that import:
librosa>=0.6.2 ; python_version < "3.11"
#Description: A python package for music and audio analysis
#Pinned versions: >=0.6.2

View File

@ -18,4 +18,3 @@ fsspec
setuptools ; python_version >= "3.12"
packaging
optree>=0.9.1
lark

View File

@ -7,8 +7,6 @@ import torch
import torch._dynamo.testing
import torch._inductor.test_case
from torch._dynamo import config
from torch._dynamo.testing import make_test_cls_with_patches
from torch._higher_order_ops.triton_kernel_wrap import (
generate_ttir,
@ -1225,7 +1223,6 @@ def forward(self, x_1, output_1):
def make_mutation_test(fn):
@requires_cuda
@requires_lark
@skipIfRocm
def test_fn(self):
from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors
@ -1776,7 +1773,7 @@ class MutationTests(torch._inductor.test_case.TestCase):
)
if HAS_CUDA and HAS_LARK:
if HAS_CUDA:
t = torch.randn(4)
tt = torch.randn(4, 1)
tests = [
@ -1921,15 +1918,6 @@ if HAS_CUDA and HAS_LARK:
common_utils.instantiate_parametrized_tests(KernelTests)
no_opt_test_class = make_test_cls_with_patches(
KernelTests,
"NoOptimization",
"_no_optimizations",
(config, "optimize_user_defined_triton_kernels", False),
)
globals()[no_opt_test_class.__name__] = no_opt_test_class
no_opt_test_class.__module__ = __name__
if __name__ == "__main__":
from torch._inductor.test_case import run_tests

View File

@ -393,9 +393,6 @@ capture_autograd_function = True
# enable/disable dynamo tracing for `torch.func` transforms
capture_func_transforms = True
# enable/disable user-defined triton kernel optimizations
optimize_user_defined_triton_kernels = True
# If to log Dynamo compilation metrics into log files (for OSS) and Scuba tables (for fbcode).
log_compilation_metrics = True

View File

@ -2,7 +2,6 @@ import dataclasses
import inspect
import logging
import threading
import warnings
from collections import defaultdict
from typing import Any, Dict, List, Optional, Union
@ -363,193 +362,6 @@ def ttir_to_functions(ttir_module) -> Dict[str, Dict[Intermediate, List[Op]]]:
return functions
def parse_ttir(ttir, kwargs):
"""
Given a Triton emitted TTIR text, this function lexes and parses the
code using a minimal grammar defined inside. During the lexing/parsing,
we drop any constant value and type information as they are not
necessary to us.
Being able to choose what we need makes this not a general purpose TTIR
parser which further makes parsing much simpler.
"""
# TODO(oulgen):
# - Support closures (e.g. "tt.reduce")
try:
import lark # type: ignore[import-not-found]
from lark import Lark, Transformer, v_args
except ModuleNotFoundError:
warnings.warn(
"Using slow path for user-defined Triton kernels. `pip install lark` to fix this."
)
raise
# Ops looks like one of the following forms:
#
# %14 = tt.addptr %13, %4 : tensor<4x!tt.ptr<f32, 1>>, tensor<4xi32>
# tt.store %14, %12, %5 {cache = 1 : i32, evict = 1 : i32} : tensor<4xf32>
# %15 = "tt.atomic_rmw"(%14, %12, %5) <{atomic_rmw_op = 5 : i32, scope = 1 : i32, sem = 4 : i32}> : (tensor<4x!tt.ptr<f32, 1>>, tensor<4xf32>, tensor<4xi1>) -> tensor<4xf32> # noqa: B950
grammar = """
start: (module_block | loc_line)+
loc_line: "#loc" /.+/ NEWLINE
module_block: "module" "{" func_block+ "}" LOC
func_block: "tt.func" ("public"|"private") FN_NAME "(" /.+/ NEWLINE stmt* "}" LOC -> process_func
?stmt: op | if | for | while | condition_stmt | label_stmt | cf_stmt
if: [assign_lhs "="] "scf.if" args rest stmt* "}" "else" "{" stmt* "}" LOC -> process_if
for: [assign_lhs "="] "scf.for" args rest stmt* "}" divisibility_annot? LOC -> process_for
while: [assign_lhs "="] "scf.while" args rest stmt* "}" "do" "{" stmt* "}" LOC -> process_while
condition_stmt: "scf.condition" "(" arg ")" args rest
label_stmt: LABEL ":" "// pred:" LABEL
| LABEL "(" /.+/ NEWLINE
cf_stmt: "cf" "." NAME /.+/ NEWLINE
op: OP_NAME LOC
| [assign_lhs "="] OP_NAME [FN_NAME] args rest? -> process_op
?rest: (":" | "{" | "\\"" | "->" | "<" | "=") /.+/ NEWLINE
divisibility_annot: "{" "tt.divisibility_arg1" /[^}]+/ "}"
args: | "(" ")" | "("? arg ("," arg)* ")"?
?arg: INTERMEDIATE
| INTERMEDIATE_CONSTANT
| CONSTANT
| PARAM
| "[" args "]"
| arg_with_index
?arg_with_index: arg "#" DIGIT+
?assign_lhs: (INTERMEDIATE | INTERMEDIATE_CONSTANT) [":" DIGIT+]
PARAM.5: "%arg" DIGIT+
INTERMEDIATE.4: "%" DIGIT+
INTERMEDIATE_CONSTANT.3: "%" NAME
CONSTANT: FLOAT | DIGIT+ | NAME ("<" DIGIT+ ">")?
LABEL: "^bb" DIGIT+
NAME: (LETTER | DIGIT | "_")+
NON_CF_NAME: /(?!(cf))/ NAME
FN_NAME: "@" (NAME | ESCAPED_STRING)
OP_NAME: "\\""? NON_CF_NAME ("." NAME)+ "\\""?
LOC.5: "loc(#loc" DIGIT* ")"
%import common.LETTER
%import common.DIGIT
%import common.WS
%import common.NEWLINE
%import common.ESCAPED_STRING
%import common.FLOAT
%ignore WS
"""
next_fake_intermediate = 0
def convert(token):
if isinstance(token, lark.tree.Tree):
if token.data == "args":
res = []
for a in token.children:
c = convert(a)
if isinstance(c, list):
res.extend(c)
else:
res.append(c)
return res
elif token.data in {"assign_lhs", "arg_with_index"}:
# Drop length/index qualifier
return convert(token.children[0])
else:
raise AssertionError(f"Tree node with {token.data}")
if token is None or (
isinstance(token, lark.lexer.Token)
and token.type in ("CONSTANT", "INTERMEDIATE_CONSTANT")
):
nonlocal next_fake_intermediate
next_fake_intermediate -= 1
return Intermediate(next_fake_intermediate)
assert isinstance(token, lark.lexer.Token)
if token.type == "INTERMEDIATE":
return Intermediate(int(token.value[len("%") :]))
if token.type == "PARAM":
return Param(int(token.value[len("%arg") :]))
raise AssertionError(f"{type(token.type)} => {token.value} invalid")
# In alternative representation, function names are quoted.
# It should be possible to move this into the grammar alltogether.
def convert_name(token):
if token is None:
return None
s = token.value
if len(s) > 2 and s[0] == '"' and s[-1] == '"':
return s[1:-1]
return s
functions: Dict[str, Dict[Intermediate, List[Op]]] = {}
def extend_dict_list(d1, d2):
for key, values in d2.items():
d1[key].extend(values)
@v_args(inline=True)
class TransformOps(Transformer):
def process_op(self, ret, op_name, fn_name, args, *rest):
return Op(
convert_name(op_name),
convert_name(fn_name),
convert(args),
convert(ret),
)
def process_func(self, name, _args, *stmts):
ops: Dict[Intermediate, List[Op]] = defaultdict(list)
for e in stmts:
if isinstance(e, Op):
ops[e.ret].append(e)
elif isinstance(e, dict):
extend_dict_list(ops, e)
functions[name.value] = ops
def _process_scf(self, ret, stmts):
ret = convert(ret)
ops: Dict[Intermediate, List[Op]] = defaultdict(list)
for e in stmts:
if isinstance(e, Op):
if e.name == "scf.yield":
ops[ret].append(Op(e.name, None, e.args, ret))
else:
ops[e.ret].append(e)
elif isinstance(e, dict):
extend_dict_list(ops, e)
return ops
def process_if(self, ret, _args, _rest, *stmts):
return self._process_scf(ret, stmts)
def process_for(self, ret, _args, _rest, *stmts):
return self._process_scf(ret, stmts)
def process_while(self, ret, _args, _rest, *stmts):
return self._process_scf(ret, stmts)
parser = Lark(
grammar, parser="lalr", maybe_placeholders=True, transformer=TransformOps()
)
parser.parse(ttir)
return functions
class MemoizeWithCycleCheck:
def __init__(self, fn):
self.fn = fn
@ -637,20 +449,10 @@ def identify_mutated_tensors(kernel, kwargs):
ttir_module = None
functions = None
try:
from torch._dynamo import config
if not config.optimize_user_defined_triton_kernels:
raise ValueError("optimize_user_defined_triton_kernels is False")
ttir_module, ordered_tensor_names = generate_ttir(kernel, kwargs)
# extract functions from TTIR
if hasattr(ttir_module, "walk"):
# use MLIR bindings exposed by Triton code
functions = ttir_to_functions(ttir_module)
else:
# parse string representation of Triton IR
functions = parse_ttir(str(ttir_module), kwargs)
# extract functions from TTIR using MLIR bindings exposed by Triton code
functions = ttir_to_functions(ttir_module)
assert functions is not None
kernel_name = next(iter(functions.keys()))

View File

@ -5,18 +5,6 @@ import unittest
from torch.testing._internal.inductor_utils import HAS_CUDA
def has_lark():
try:
import lark # noqa: F401
return True
except ModuleNotFoundError:
return False
HAS_LARK = has_lark()
requires_lark = unittest.skipUnless(HAS_LARK, "requires lark")
requires_cuda = unittest.skipUnless(HAS_CUDA, "requires cuda")
if HAS_CUDA: