mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/48249 Introduced autograd related data models at tools.codegen.api.autograd. Migrated load_derivatives.py to produce the new data models from derivatives.yaml. It has clean mypy-strict result. Changed both gen_autograd_functions.py and gen_variable_type.py to consume the new data model. Added type annotations to gen_autograd_functions.py - it has clean mypy-strict result except for the .gen_autograd import (so haven't added it to the strict config in this PR). To limit the scope of the PR, gen_variable_type.py is not refactored, and the main structure of load_derivatives.py / gen_autograd_functions.py is kept. We only make necessary changes to make it work. Confirmed byte-for-byte compatible with the old codegen: ``` Run it before and after this PR: .jenkins/pytorch/codegen-test.sh <baseline_output_dir> .jenkins/pytorch/codegen-test.sh <test_output_dir> Then run diff to compare the generated files: diff -Naur <baseline_output_dir> <test_output_dir> ``` Test Plan: Imported from OSS Reviewed By: ezyang Differential Revision: D25086561 Pulled By: ljk53 fbshipit-source-id: 1f43ab0931d9814c24683b9a48ca497c5fc3d729
15 lines
525 B
Python
15 lines
525 B
Python
import re
|
|
from typing import Tuple, List
|
|
|
|
# Matches "foo" in "foo, bar" but not "foobar". Used to search for the
|
|
# occurrence of a parameter in the derivative formula
|
|
IDENT_REGEX = r'(^|\W){}($|\W)'
|
|
|
|
# TODO: Use a real parser here; this will get bamboozled
|
|
def split_name_params(schema: str) -> Tuple[str, List[str]]:
|
|
m = re.match(r'(\w+)(\.\w+)?\((.*)\)', schema)
|
|
if m is None:
|
|
raise RuntimeError(f'Unsupported function schema: {schema}')
|
|
name, _, params = m.groups()
|
|
return name, params.split(', ')
|