mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Rewrite of ATen code generator (#42629)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/42629 How to approach reviewing this diff: - The new codegen itself lives in `tools/codegen`. Start with `gen.py`, then read `model.py` and them the `api/` folder. The comments at the top of the files describe what is going on. The CLI interface of the new codegen is similar to the old one, but (1) it is no longer necessary to explicitly specify cwrap inputs (and now we will error if you do so) and (2) the default settings for source and install dir are much better; to the extent that if you run the codegen from the root source directory as just `python -m tools.codegen.gen`, something reasonable will happen. - The old codegen is (nearly) entirely deleted; every Python file in `aten/src/ATen` was deleted except for `common_with_cwrap.py`, which now permanently finds its home in `tools/shared/cwrap_common.py` (previously cmake copied the file there), and `code_template.py`, which now lives in `tools/codegen/code_template.py`. We remove the copying logic for `common_with_cwrap.py`. - All of the inputs to the old codegen are deleted. - Build rules now have to be adjusted to not refer to files that no longer exist, and to abide by the (slightly modified) CLI. - LegacyTHFunctions files have been generated and checked in. We expect these to be deleted as these final functions get ported to ATen. The deletion process is straightforward; just delete the functions of the ones you are porting. There are 39 more functions left to port. Signed-off-by: Edward Z. Yang <ezyang@fb.com> Test Plan: Imported from OSS Reviewed By: bhosmer Differential Revision: D23183978 Pulled By: ezyang fbshipit-source-id: 6073ba432ad182c7284a97147b05f0574a02f763
This commit is contained in:
parent
576880febf
commit
6ea89166bd
|
|
@ -47,16 +47,11 @@ sudo apt-get -y install doxygen
|
|||
# Generate ATen files
|
||||
pushd "${pt_checkout}"
|
||||
pip install -r requirements.txt
|
||||
time python aten/src/ATen/gen.py \
|
||||
time python -m tools.codegen.gen \
|
||||
-s aten/src/ATen \
|
||||
-d build/aten/src/ATen \
|
||||
aten/src/ATen/Declarations.cwrap \
|
||||
aten/src/THCUNN/generic/THCUNN.h \
|
||||
aten/src/ATen/nn.yaml \
|
||||
aten/src/ATen/native/native_functions.yaml
|
||||
-d build/aten/src/ATen
|
||||
|
||||
# Copy some required files
|
||||
cp aten/src/ATen/common_with_cwrap.py tools/shared/cwrap_common.py
|
||||
cp torch/_utils_internal.py tools/shared
|
||||
|
||||
# Generate PyTorch files
|
||||
|
|
|
|||
8
.github/workflows/lint.yml
vendored
8
.github/workflows/lint.yml
vendored
|
|
@ -131,13 +131,9 @@ jobs:
|
|||
time python setup.py --cmake-only build
|
||||
|
||||
# Generate ATen files.
|
||||
time python aten/src/ATen/gen.py \
|
||||
time python -m tools.codegen.gen \
|
||||
-s aten/src/ATen \
|
||||
-d build/aten/src/ATen \
|
||||
aten/src/ATen/Declarations.cwrap \
|
||||
aten/src/THCUNN/generic/THCUNN.h \
|
||||
aten/src/ATen/nn.yaml \
|
||||
aten/src/ATen/native/native_functions.yaml
|
||||
-d build/aten/src/ATen
|
||||
|
||||
# Generate PyTorch files.
|
||||
time python tools/setup_helpers/generate_code.py \
|
||||
|
|
|
|||
3
.gitignore
vendored
3
.gitignore
vendored
|
|
@ -108,9 +108,6 @@ env
|
|||
# macOS dir files
|
||||
.DS_Store
|
||||
|
||||
# Symbolic files
|
||||
tools/shared/cwrap_common.py
|
||||
|
||||
# Ninja files
|
||||
.ninja_deps
|
||||
.ninja_log
|
||||
|
|
|
|||
|
|
@ -248,6 +248,8 @@ else
|
|||
export MAX_JOBS=`expr $(nproc) - 1`
|
||||
fi
|
||||
|
||||
pip install --user dataclasses
|
||||
|
||||
$PYTHON setup.py install --user
|
||||
|
||||
report_compile_cache_stats
|
||||
|
|
|
|||
|
|
@ -20,7 +20,7 @@ if [ ! -d "${WORKSPACE_DIR}/miniconda3" ]; then
|
|||
fi
|
||||
export PATH="${WORKSPACE_DIR}/miniconda3/bin:$PATH"
|
||||
source ${WORKSPACE_DIR}/miniconda3/bin/activate
|
||||
retry conda install -y mkl mkl-include numpy=1.18.5 pyyaml=5.3 setuptools=46.0.0 cmake cffi ninja typing_extensions
|
||||
retry conda install -y mkl mkl-include numpy=1.18.5 pyyaml=5.3 setuptools=46.0.0 cmake cffi ninja typing_extensions dataclasses
|
||||
|
||||
# The torch.hub tests make requests to GitHub.
|
||||
#
|
||||
|
|
|
|||
|
|
@ -21,8 +21,8 @@ call %INSTALLER_DIR%\install_sccache.bat
|
|||
call %INSTALLER_DIR%\install_miniconda3.bat
|
||||
|
||||
|
||||
:: Install ninja
|
||||
if "%REBUILD%"=="" ( pip install -q "ninja==1.9.0" )
|
||||
:: Install ninja and other deps
|
||||
if "%REBUILD%"=="" ( pip install -q "ninja==1.9.0" dataclasses )
|
||||
|
||||
git submodule sync --recursive
|
||||
git submodule update --init --recursive
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ call %CONDA_PARENT_DIR%\Miniconda3\Scripts\activate.bat %CONDA_PARENT_DIR%\Minic
|
|||
if NOT "%BUILD_ENVIRONMENT%"=="" (
|
||||
:: We have to pin Python version to 3.6.7, until mkl supports Python 3.7
|
||||
:: Numba is pinned to 0.44.0 to avoid https://github.com/numba/numba/issues/4352
|
||||
call conda install -y -q python=3.6.7 numpy mkl cffi pyyaml boto3 protobuf numba==0.44.0 scipy==1.5.0 typing_extensions
|
||||
call conda install -y -q python=3.6.7 numpy mkl cffi pyyaml boto3 protobuf numba==0.44.0 scipy==1.5.0 typing_extensions dataclasses
|
||||
if %errorlevel% neq 0 ( exit /b %errorlevel% )
|
||||
call conda install -y -q -c conda-forge cmake
|
||||
if %errorlevel% neq 0 ( exit /b %errorlevel% )
|
||||
|
|
|
|||
21
BUILD.bazel
21
BUILD.bazel
|
|
@ -106,17 +106,19 @@ cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
# TODO: refactor this into its own library (but how to make
|
||||
# a binary based off of a module in a library?)
|
||||
py_binary(
|
||||
name = "gen",
|
||||
srcs = ["aten/src/ATen/gen.py"],
|
||||
srcs = ["tools/setup_helpers/gen.py"],
|
||||
deps = [
|
||||
":tools_codegen"
|
||||
],
|
||||
)
|
||||
|
||||
genrule(
|
||||
name = "generated_cpp",
|
||||
srcs = [
|
||||
"aten/src/ATen/Declarations.cwrap",
|
||||
"aten/src/THCUNN/generic/THCUNN.h",
|
||||
"aten/src/ATen/nn.yaml",
|
||||
"aten/src/ATen/native/native_functions.yaml",
|
||||
] + glob(["aten/src/ATen/templates/**"]),
|
||||
outs = [
|
||||
|
|
@ -126,8 +128,6 @@ genrule(
|
|||
"aten/src/ATen/CPUType.cpp",
|
||||
"aten/src/ATen/Functions.h",
|
||||
"aten/src/ATen/Functions.cpp",
|
||||
"aten/src/ATen/LegacyTHFunctionsCPU.h",
|
||||
"aten/src/ATen/LegacyTHFunctionsCPU.cpp",
|
||||
"aten/src/ATen/NativeFunctions.h",
|
||||
"aten/src/ATen/MkldnnCPUType.h",
|
||||
"aten/src/ATen/MkldnnCPUType.cpp",
|
||||
|
|
@ -141,14 +141,13 @@ genrule(
|
|||
"aten/src/ATen/core/TensorMethods.cpp",
|
||||
"aten/src/ATen/core/ATenOpList.cpp",
|
||||
],
|
||||
cmd = "$(location :gen) --source-path aten/src/ATen --install_dir `dirname $(location aten/src/ATen/Declarations.yaml)` aten/src/ATen/Declarations.cwrap aten/src/THCUNN/generic/THCUNN.h aten/src/ATen/nn.yaml aten/src/ATen/native/native_functions.yaml",
|
||||
cmd = "$(location :gen) --source-path aten/src/ATen --install_dir `dirname $(location aten/src/ATen/Declarations.yaml)`",
|
||||
tools = [":gen"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "code_template",
|
||||
srcs = ["aten/src/ATen/code_template.py"],
|
||||
imports = ["aten"],
|
||||
name = "tools_codegen",
|
||||
srcs = glob(["tools/codegen/**/*.py"]),
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
|
@ -158,7 +157,7 @@ py_library(
|
|||
"tools/autograd/*.yaml",
|
||||
"tools/autograd/templates/*",
|
||||
]),
|
||||
deps = [":code_template"],
|
||||
deps = [":tools_codegen"],
|
||||
)
|
||||
|
||||
py_library(
|
||||
|
|
|
|||
|
|
@ -169,7 +169,7 @@ If you are building for NVIDIA's Jetson platforms (Jetson Nano, TX1, TX2, AGX Xa
|
|||
|
||||
Common
|
||||
```bash
|
||||
conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six requests
|
||||
conda install numpy ninja pyyaml mkl mkl-include setuptools cmake cffi typing_extensions future six requests dataclasses
|
||||
```
|
||||
|
||||
On Linux
|
||||
|
|
|
|||
|
|
@ -1,561 +0,0 @@
|
|||
[[
|
||||
name: _th_masked_fill_
|
||||
cuda_bool: True
|
||||
cuda_bfloat16: True
|
||||
cname: maskedFill
|
||||
variants: function
|
||||
backends:
|
||||
- CUDA
|
||||
return: self
|
||||
options:
|
||||
- arguments:
|
||||
- THTensor* self
|
||||
- THByteTensor* mask
|
||||
- real value
|
||||
]]
|
||||
[[
|
||||
name: _th_masked_fill_bool_
|
||||
cuda_bool: True
|
||||
cuda_bfloat16: True
|
||||
cname: maskedFillBool
|
||||
variants: function
|
||||
backends:
|
||||
- CUDA
|
||||
return: self
|
||||
options:
|
||||
- arguments:
|
||||
- THTensor* self
|
||||
- THBoolTensor* mask
|
||||
- real value
|
||||
]]
|
||||
[[
|
||||
name: _th_masked_scatter_
|
||||
cpu_bool: True
|
||||
cuda_bool: True
|
||||
cpu_bfloat16: True
|
||||
cuda_bfloat16: True
|
||||
cname: maskedCopy
|
||||
variants: function
|
||||
return: self
|
||||
arguments:
|
||||
- THTensor* self
|
||||
- THByteTensor* mask
|
||||
- THTensor* source
|
||||
]]
|
||||
[[
|
||||
name: _th_masked_scatter_bool_
|
||||
cpu_bool: True
|
||||
cuda_bool: True
|
||||
cpu_bfloat16: True
|
||||
cuda_bfloat16: True
|
||||
cname: maskedCopyBool
|
||||
variants: function
|
||||
return: self
|
||||
arguments:
|
||||
- THTensor* self
|
||||
- THBoolTensor* mask
|
||||
- THTensor* source
|
||||
]]
|
||||
[[
|
||||
name: _th_nonzero
|
||||
cname: nonzero
|
||||
cpu_half: True
|
||||
cpu_bool: True
|
||||
cuda_bool: True
|
||||
cpu_bfloat16: True
|
||||
cuda_bfloat16: True
|
||||
variants:
|
||||
- function
|
||||
return: argument 0
|
||||
arguments:
|
||||
- arg: THIndexTensor* result
|
||||
output: True
|
||||
- THTensor* self
|
||||
]]
|
||||
[[
|
||||
name: _th_index_copy_
|
||||
cname: indexCopy
|
||||
cpu_bool: True
|
||||
cuda_bool: True
|
||||
variants: function
|
||||
return: argument 0
|
||||
arguments:
|
||||
- THTensor* self
|
||||
- long dim
|
||||
- THIndexTensor* index
|
||||
- THTensor* source
|
||||
]]
|
||||
[[
|
||||
name: _th_take
|
||||
cpu_bool: True
|
||||
cuda_bool: True
|
||||
cname: take
|
||||
variants:
|
||||
- function
|
||||
return: argument 0
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
output: True
|
||||
- THTensor* self
|
||||
- THIndexTensor* index
|
||||
]]
|
||||
[[
|
||||
name: _th_put_
|
||||
cpu_bool: True
|
||||
cuda_bool: True
|
||||
cname: put
|
||||
variants: function
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
return: argument 0
|
||||
arguments:
|
||||
- THTensor* self
|
||||
- THIndexTensor* index
|
||||
- THTensor* source
|
||||
- bool accumulate
|
||||
]]
|
||||
[[
|
||||
name: _th_index_fill_
|
||||
cpu_bool: True
|
||||
cuda_bool: True
|
||||
cname: indexFill
|
||||
variants: function
|
||||
return: argument 0
|
||||
options:
|
||||
- arguments:
|
||||
- THTensor* self
|
||||
- long dim
|
||||
- THIndexTensor* index
|
||||
- real value
|
||||
]]
|
||||
[[
|
||||
name: _th_mode
|
||||
variants: function
|
||||
cname: mode
|
||||
return: argument 0,1
|
||||
arguments:
|
||||
- arg: THTensor* values
|
||||
output: True
|
||||
- arg: THIndexTensor* indices
|
||||
output: True
|
||||
- THTensor* self
|
||||
- long dim
|
||||
- bool keepdim
|
||||
]]
|
||||
[[
|
||||
name: _th_sort
|
||||
cname: sort
|
||||
cpu_half: True
|
||||
variants:
|
||||
- function
|
||||
return: argument 0,1
|
||||
arguments:
|
||||
- arg: THTensor* values
|
||||
output: True
|
||||
- arg: THIndexTensor* indices
|
||||
output: True
|
||||
- THTensor* self
|
||||
- long dim
|
||||
- bool descending
|
||||
]]
|
||||
[[
|
||||
name: _th_topk
|
||||
cname: topk
|
||||
cuda_bfloat16: True
|
||||
backends:
|
||||
- CUDA
|
||||
variants:
|
||||
- function
|
||||
return: argument 0,1
|
||||
arguments:
|
||||
- arg: THTensor* values
|
||||
output: True
|
||||
- arg: THIndexTensor* indices
|
||||
output: True
|
||||
- THTensor* self
|
||||
- long k
|
||||
- long dim
|
||||
- bool largest
|
||||
- bool sorted
|
||||
]]
|
||||
[[
|
||||
name: _th_var
|
||||
types:
|
||||
- floating_point
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants: function
|
||||
options:
|
||||
- cname: var_all
|
||||
return: accreal
|
||||
arguments:
|
||||
- THTensor* self
|
||||
- bool unbiased
|
||||
]]
|
||||
[[
|
||||
name: _th_std
|
||||
types:
|
||||
- floating_point
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants: function
|
||||
options:
|
||||
- cname: std_all
|
||||
return: accreal
|
||||
arguments:
|
||||
- THTensor* self
|
||||
- bool unbiased
|
||||
]]
|
||||
[[
|
||||
name: _th_renorm
|
||||
cname: renorm
|
||||
types:
|
||||
- floating_point
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants:
|
||||
- function
|
||||
return: argument 0
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
output: True
|
||||
- THTensor* self
|
||||
- real p
|
||||
- long dim
|
||||
- real maxnorm
|
||||
]]
|
||||
[[
|
||||
name: _th_renorm_
|
||||
types:
|
||||
- floating_point
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
cname: renorm
|
||||
variants: function
|
||||
return: self
|
||||
arguments:
|
||||
- THTensor* self
|
||||
- THTensor* self
|
||||
- real p
|
||||
- long dim
|
||||
- real maxnorm
|
||||
]]
|
||||
[[
|
||||
name: _th_histc
|
||||
cname: histc
|
||||
types:
|
||||
- Float
|
||||
- Double
|
||||
backends:
|
||||
- CPU
|
||||
variants:
|
||||
- function
|
||||
return: argument 0
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
output: True
|
||||
- THTensor* self
|
||||
- long bins
|
||||
- real min
|
||||
- real max
|
||||
]]
|
||||
[[
|
||||
name: _th_trace
|
||||
cname: trace
|
||||
variants:
|
||||
- function
|
||||
return: accreal
|
||||
arguments:
|
||||
- THTensor* self
|
||||
backends:
|
||||
- CPU
|
||||
]]
|
||||
[[
|
||||
name: _th_fmod
|
||||
return: argument 0
|
||||
variants:
|
||||
- function
|
||||
backends:
|
||||
- CUDA
|
||||
options:
|
||||
- cname: fmod
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
output: True
|
||||
- THTensor* self
|
||||
- real other
|
||||
- cname: cfmod
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
output: True
|
||||
- THTensor* self
|
||||
- THTensor* other
|
||||
]]
|
||||
[[
|
||||
name: _th_fmod_
|
||||
return: argument 0
|
||||
variants: function
|
||||
backends:
|
||||
- CUDA
|
||||
options:
|
||||
- cname: fmod
|
||||
arguments:
|
||||
- THTensor* self
|
||||
- THTensor* self
|
||||
- real other
|
||||
- cname: cfmod
|
||||
arguments:
|
||||
- THTensor* self
|
||||
- THTensor* self
|
||||
- THTensor* other
|
||||
]]
|
||||
[[
|
||||
name: _th_cross_kernel
|
||||
cname: crossKernel
|
||||
variants:
|
||||
- function
|
||||
backends:
|
||||
- CUDA
|
||||
return: argument 0
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
output: True
|
||||
- THTensor* self
|
||||
- THTensor* other
|
||||
- arg: int64_t dim
|
||||
]]
|
||||
[[
|
||||
name: _th_addr
|
||||
cname: addr
|
||||
cpu_bfloat16: True
|
||||
variants: function
|
||||
return: argument 0
|
||||
backends: [CPU]
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
output: True
|
||||
- THTensor* self
|
||||
- THTensor* vec1
|
||||
- THTensor* vec2
|
||||
- real beta
|
||||
- real alpha
|
||||
]]
|
||||
[[
|
||||
name: _th_addr_
|
||||
cpu_bfloat16: True
|
||||
cname: addr
|
||||
return: self
|
||||
variants: function
|
||||
backends: [CPU]
|
||||
arguments:
|
||||
- THTensor* self
|
||||
- THTensor* self
|
||||
- THTensor* vec1
|
||||
- THTensor* vec2
|
||||
- real beta
|
||||
- real alpha
|
||||
]]
|
||||
[[
|
||||
[[
|
||||
name: _th_bmm
|
||||
cuda_bfloat16: True
|
||||
cname: baddbmm
|
||||
variants:
|
||||
- function
|
||||
backends:
|
||||
- CUDA
|
||||
return: argument 0
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
output: True
|
||||
- argument 0
|
||||
- THTensor* self
|
||||
- THTensor* mat2
|
||||
- CONSTANT AS_REAL(0)
|
||||
- CONSTANT AS_REAL(1)
|
||||
]]
|
||||
[[
|
||||
name: _th_baddbmm
|
||||
cuda_bfloat16: True
|
||||
cname: baddbmm
|
||||
variants:
|
||||
- function
|
||||
backends:
|
||||
- CUDA
|
||||
return: argument 0
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
output: True
|
||||
- arg: THTensor* self
|
||||
- THTensor* batch1
|
||||
- THTensor* batch2
|
||||
- real beta
|
||||
- real alpha
|
||||
]]
|
||||
[[
|
||||
name: _th_gels
|
||||
cname: gels
|
||||
types:
|
||||
- Float
|
||||
- Double
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants:
|
||||
- function
|
||||
return: argument 0,1
|
||||
arguments:
|
||||
- arg: THTensor* res1
|
||||
output: True
|
||||
- arg: THTensor* res2
|
||||
output: True
|
||||
- THTensor* self
|
||||
- THTensor* A
|
||||
]]
|
||||
[[
|
||||
name: _th_eig
|
||||
cname: geev
|
||||
types:
|
||||
- Float
|
||||
- Double
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants:
|
||||
- function
|
||||
return: argument 0,1
|
||||
arguments:
|
||||
- arg: THTensor* res1
|
||||
output: True
|
||||
- arg: THTensor* res2
|
||||
output: True
|
||||
- THTensor* self
|
||||
- bool eigenvectors
|
||||
]]
|
||||
[[
|
||||
name: _th_potri
|
||||
cname: potri
|
||||
types:
|
||||
- Float
|
||||
- Double
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants:
|
||||
- function
|
||||
return: argument 0
|
||||
arguments:
|
||||
- arg: THTensor* output
|
||||
output: True
|
||||
- THTensor* self
|
||||
- bool upper
|
||||
]]
|
||||
[[
|
||||
name: _th_geqrf
|
||||
cname: geqrf
|
||||
types:
|
||||
- Float
|
||||
- Double
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants:
|
||||
- function
|
||||
return: argument 0,1
|
||||
arguments:
|
||||
- arg: THTensor* res1
|
||||
output: True
|
||||
- arg: THTensor* res2
|
||||
output: True
|
||||
- THTensor* self
|
||||
]]
|
||||
[[
|
||||
name: _th_orgqr
|
||||
cname: orgqr
|
||||
types:
|
||||
- Float
|
||||
- Double
|
||||
backends:
|
||||
- CPU
|
||||
variants:
|
||||
- function
|
||||
return: argument 0
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
output: True
|
||||
- THTensor* self
|
||||
- THTensor* input2
|
||||
]]
|
||||
[[
|
||||
name: _th_ormqr
|
||||
cname: ormqr
|
||||
types:
|
||||
- Float
|
||||
- Double
|
||||
backends:
|
||||
- CPU
|
||||
variants:
|
||||
- function
|
||||
return: argument 0
|
||||
arguments:
|
||||
- arg: THTensor* result
|
||||
output: True
|
||||
- THTensor* self
|
||||
- THTensor* input2
|
||||
- THTensor* input3
|
||||
- bool left
|
||||
- bool transpose
|
||||
]]
|
||||
[[
|
||||
name: _th_multinomial_alias_setup
|
||||
cname: multinomialAliasSetup
|
||||
variants:
|
||||
- function
|
||||
types:
|
||||
- floating_point
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
return: argument 1,2
|
||||
arguments:
|
||||
- arg: THTensor* probs
|
||||
- arg: THIndexTensor* J
|
||||
output: True
|
||||
- arg: THTensor* q
|
||||
output: True
|
||||
]]
|
||||
[[
|
||||
name: _th_multinomial_alias_draw
|
||||
cname: multinomialAliasDraw
|
||||
types:
|
||||
- floating_point
|
||||
backends:
|
||||
- CPU
|
||||
- CUDA
|
||||
variants:
|
||||
- function
|
||||
return: argument 0
|
||||
arguments:
|
||||
- arg: THIndexTensor* result
|
||||
output: True
|
||||
- THTensor* q
|
||||
- THIndexTensor* J
|
||||
- long num_samples
|
||||
- c10::optional<Generator> generator
|
||||
]]
|
||||
[[
|
||||
name: _th_copy_ignoring_overlaps_
|
||||
cname: copyIgnoringOverlaps
|
||||
return: self
|
||||
variants: function
|
||||
backends:
|
||||
- CUDA
|
||||
arguments:
|
||||
- THTensor* self
|
||||
- THTensor* src
|
||||
]]
|
||||
1712
aten/src/ATen/LegacyTHFunctionsCPU.cpp
Normal file
1712
aten/src/ATen/LegacyTHFunctionsCPU.cpp
Normal file
File diff suppressed because it is too large
Load Diff
67
aten/src/ATen/LegacyTHFunctionsCPU.h
Normal file
67
aten/src/ATen/LegacyTHFunctionsCPU.h
Normal file
|
|
@ -0,0 +1,67 @@
|
|||
#pragma once
|
||||
|
||||
// @generated by aten/src/ATen/gen.py from LegacyTHFunctions.h
|
||||
|
||||
#include <ATen/Context.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
|
||||
namespace c10 {
|
||||
class Scalar;
|
||||
}
|
||||
namespace at {
|
||||
struct Generator;
|
||||
class Tensor;
|
||||
struct Type;
|
||||
} // namespace at
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace legacy {
|
||||
namespace cpu {
|
||||
|
||||
Tensor & _th_masked_scatter_(Tensor & self, const Tensor & mask, const Tensor & source);
|
||||
Tensor & _th_masked_scatter_bool_(Tensor & self, const Tensor & mask, const Tensor & source);
|
||||
Tensor & _th_nonzero_out(Tensor & result, const Tensor & self);
|
||||
Tensor _th_nonzero(const Tensor & self);
|
||||
Tensor & _th_index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source);
|
||||
Tensor & _th_take_out(Tensor & result, const Tensor & self, const Tensor & index);
|
||||
Tensor _th_take(const Tensor & self, const Tensor & index);
|
||||
Tensor & _th_put_(Tensor & self, const Tensor & index, const Tensor & source, bool accumulate);
|
||||
Tensor & _th_index_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar value);
|
||||
std::tuple<Tensor &,Tensor &> _th_mode_out(Tensor & values, Tensor & indices, const Tensor & self, int64_t dim, bool keepdim);
|
||||
std::tuple<Tensor,Tensor> _th_mode(const Tensor & self, int64_t dim, bool keepdim);
|
||||
std::tuple<Tensor &,Tensor &> _th_sort_out(Tensor & values, Tensor & indices, const Tensor & self, int64_t dim, bool descending);
|
||||
std::tuple<Tensor,Tensor> _th_sort(const Tensor & self, int64_t dim, bool descending);
|
||||
Tensor _th_var(const Tensor & self, bool unbiased);
|
||||
Tensor _th_std(const Tensor & self, bool unbiased);
|
||||
Tensor & _th_renorm_out(Tensor & result, const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm);
|
||||
Tensor _th_renorm(const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm);
|
||||
Tensor & _th_renorm_(Tensor & self, Scalar p, int64_t dim, Scalar maxnorm);
|
||||
Tensor & _th_histc_out(Tensor & result, const Tensor & self, int64_t bins, Scalar min, Scalar max);
|
||||
Tensor _th_histc(const Tensor & self, int64_t bins, Scalar min, Scalar max);
|
||||
Tensor _th_trace(const Tensor & self);
|
||||
Tensor & _th_addr_out(Tensor & result, const Tensor & self, const Tensor & vec1, const Tensor & vec2, Scalar beta, Scalar alpha);
|
||||
Tensor _th_addr(const Tensor & self, const Tensor & vec1, const Tensor & vec2, Scalar beta, Scalar alpha);
|
||||
Tensor & _th_addr_(Tensor & self, const Tensor & vec1, const Tensor & vec2, Scalar beta, Scalar alpha);
|
||||
std::tuple<Tensor &,Tensor &> _th_gels_out(Tensor & res1, Tensor & res2, const Tensor & self, const Tensor & A);
|
||||
std::tuple<Tensor,Tensor> _th_gels(const Tensor & self, const Tensor & A);
|
||||
std::tuple<Tensor &,Tensor &> _th_eig_out(Tensor & res1, Tensor & res2, const Tensor & self, bool eigenvectors);
|
||||
std::tuple<Tensor,Tensor> _th_eig(const Tensor & self, bool eigenvectors);
|
||||
Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper);
|
||||
Tensor _th_potri(const Tensor & self, bool upper);
|
||||
std::tuple<Tensor &,Tensor &> _th_geqrf_out(Tensor & res1, Tensor & res2, const Tensor & self);
|
||||
std::tuple<Tensor,Tensor> _th_geqrf(const Tensor & self);
|
||||
Tensor & _th_orgqr_out(Tensor & result, const Tensor & self, const Tensor & input2);
|
||||
Tensor _th_orgqr(const Tensor & self, const Tensor & input2);
|
||||
Tensor & _th_ormqr_out(Tensor & result, const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose);
|
||||
Tensor _th_ormqr(const Tensor & self, const Tensor & input2, const Tensor & input3, bool left, bool transpose);
|
||||
std::tuple<Tensor &,Tensor &> _th_multinomial_alias_setup_out(Tensor & J, Tensor & q, const Tensor & probs);
|
||||
std::tuple<Tensor,Tensor> _th_multinomial_alias_setup(const Tensor & probs);
|
||||
Tensor & _th_multinomial_alias_draw_out(Tensor & result, const Tensor & q, const Tensor & J, int64_t num_samples, c10::optional<Generator> generator);
|
||||
Tensor _th_multinomial_alias_draw(const Tensor & q, const Tensor & J, int64_t num_samples, c10::optional<Generator> generator);
|
||||
|
||||
} // namespace th
|
||||
} // namespace legacy
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
111
aten/src/ATen/LegacyTHFunctionsCUDA.h
Normal file
111
aten/src/ATen/LegacyTHFunctionsCUDA.h
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
#pragma once
|
||||
|
||||
// @generated by aten/src/ATen/gen.py from LegacyTHFunctions.h
|
||||
|
||||
#include <ATen/Context.h>
|
||||
#include <c10/core/ScalarType.h>
|
||||
#include <c10/core/TensorOptions.h>
|
||||
|
||||
namespace c10 {
|
||||
class Scalar;
|
||||
}
|
||||
namespace at {
|
||||
struct Generator;
|
||||
class Tensor;
|
||||
struct Type;
|
||||
} // namespace at
|
||||
|
||||
namespace at {
|
||||
namespace native {
|
||||
namespace legacy {
|
||||
namespace cuda {
|
||||
|
||||
Tensor & _th_masked_fill_(Tensor & self, const Tensor & mask, Scalar value);
|
||||
Tensor & _th_masked_fill_bool_(Tensor & self, const Tensor & mask, Scalar value);
|
||||
Tensor & _th_masked_scatter_(Tensor & self, const Tensor & mask, const Tensor & source);
|
||||
Tensor & _th_masked_scatter_bool_(Tensor & self, const Tensor & mask, const Tensor & source);
|
||||
Tensor & _th_nonzero_out(Tensor & result, const Tensor & self);
|
||||
Tensor _th_nonzero(const Tensor & self);
|
||||
Tensor & _th_index_copy_(Tensor & self, int64_t dim, const Tensor & index, const Tensor & source);
|
||||
Tensor & _th_take_out(Tensor & result, const Tensor & self, const Tensor & index);
|
||||
Tensor _th_take(const Tensor & self, const Tensor & index);
|
||||
Tensor & _th_put_(Tensor & self, const Tensor & index, const Tensor & source, bool accumulate);
|
||||
Tensor & _th_index_fill_(Tensor & self, int64_t dim, const Tensor & index, Scalar value);
|
||||
std::tuple<Tensor &,Tensor &> _th_mode_out(Tensor & values, Tensor & indices, const Tensor & self, int64_t dim, bool keepdim);
|
||||
std::tuple<Tensor,Tensor> _th_mode(const Tensor & self, int64_t dim, bool keepdim);
|
||||
std::tuple<Tensor &,Tensor &> _th_sort_out(Tensor & values, Tensor & indices, const Tensor & self, int64_t dim, bool descending);
|
||||
std::tuple<Tensor,Tensor> _th_sort(const Tensor & self, int64_t dim, bool descending);
|
||||
std::tuple<Tensor &,Tensor &> _th_topk_out(Tensor & values, Tensor & indices, const Tensor & self, int64_t k, int64_t dim, bool largest, bool sorted);
|
||||
std::tuple<Tensor,Tensor> _th_topk(const Tensor & self, int64_t k, int64_t dim, bool largest, bool sorted);
|
||||
Tensor _th_var(const Tensor & self, bool unbiased);
|
||||
Tensor _th_std(const Tensor & self, bool unbiased);
|
||||
Tensor & _th_renorm_out(Tensor & result, const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm);
|
||||
Tensor _th_renorm(const Tensor & self, Scalar p, int64_t dim, Scalar maxnorm);
|
||||
Tensor & _th_renorm_(Tensor & self, Scalar p, int64_t dim, Scalar maxnorm);
|
||||
Tensor & _th_fmod_out(Tensor & result, const Tensor & self, Scalar other);
|
||||
Tensor _th_fmod(const Tensor & self, Scalar other);
|
||||
Tensor & _th_fmod_out(Tensor & result, const Tensor & self, const Tensor & other);
|
||||
Tensor _th_fmod(const Tensor & self, const Tensor & other);
|
||||
Tensor & _th_fmod_(Tensor & self, Scalar other);
|
||||
Tensor & _th_fmod_(Tensor & self, const Tensor & other);
|
||||
Tensor & _th_cross_kernel_out(Tensor & result, const Tensor & self, const Tensor & other, int64_t dim);
|
||||
Tensor _th_cross_kernel(const Tensor & self, const Tensor & other, int64_t dim);
|
||||
Tensor & _th_bmm_out(Tensor & result, const Tensor & self, const Tensor & mat2);
|
||||
Tensor _th_bmm(const Tensor & self, const Tensor & mat2);
|
||||
Tensor & _th_baddbmm_out(Tensor & result, const Tensor & self, const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha);
|
||||
Tensor _th_baddbmm(const Tensor & self, const Tensor & batch1, const Tensor & batch2, Scalar beta, Scalar alpha);
|
||||
std::tuple<Tensor &,Tensor &> _th_gels_out(Tensor & res1, Tensor & res2, const Tensor & self, const Tensor & A);
|
||||
std::tuple<Tensor,Tensor> _th_gels(const Tensor & self, const Tensor & A);
|
||||
std::tuple<Tensor &,Tensor &> _th_eig_out(Tensor & res1, Tensor & res2, const Tensor & self, bool eigenvectors);
|
||||
std::tuple<Tensor,Tensor> _th_eig(const Tensor & self, bool eigenvectors);
|
||||
Tensor & _th_potri_out(Tensor & output, const Tensor & self, bool upper);
|
||||
Tensor _th_potri(const Tensor & self, bool upper);
|
||||
std::tuple<Tensor &,Tensor &> _th_geqrf_out(Tensor & res1, Tensor & res2, const Tensor & self);
|
||||
std::tuple<Tensor,Tensor> _th_geqrf(const Tensor & self);
|
||||
std::tuple<Tensor &,Tensor &> _th_multinomial_alias_setup_out(Tensor & J, Tensor & q, const Tensor & probs);
|
||||
std::tuple<Tensor,Tensor> _th_multinomial_alias_setup(const Tensor & probs);
|
||||
Tensor & _th_multinomial_alias_draw_out(Tensor & result, const Tensor & q, const Tensor & J, int64_t num_samples, c10::optional<Generator> generator);
|
||||
Tensor _th_multinomial_alias_draw(const Tensor & q, const Tensor & J, int64_t num_samples, c10::optional<Generator> generator);
|
||||
Tensor & _th_copy_ignoring_overlaps_(Tensor & self, const Tensor & src);
|
||||
Tensor & _thnn_multi_margin_loss_forward_out(Tensor & output, const Tensor & self, const Tensor & target, Scalar p, Scalar margin, const Tensor & weight, int64_t reduction);
|
||||
Tensor _thnn_multi_margin_loss_forward(const Tensor & self, const Tensor & target, Scalar p, Scalar margin, const Tensor & weight, int64_t reduction);
|
||||
Tensor & _thnn_multi_margin_loss_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, const Tensor & target, Scalar p, Scalar margin, const Tensor & weight, int64_t reduction);
|
||||
Tensor _thnn_multi_margin_loss_backward(const Tensor & grad_output, const Tensor & self, const Tensor & target, Scalar p, Scalar margin, const Tensor & weight, int64_t reduction);
|
||||
std::tuple<Tensor &,Tensor &> _thnn_multilabel_margin_loss_forward_out(Tensor & output, Tensor & is_target, const Tensor & self, const Tensor & target, int64_t reduction);
|
||||
std::tuple<Tensor,Tensor> _thnn_multilabel_margin_loss_forward(const Tensor & self, const Tensor & target, int64_t reduction);
|
||||
Tensor & _thnn_multilabel_margin_loss_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, const Tensor & target, int64_t reduction, const Tensor & is_target);
|
||||
Tensor _thnn_multilabel_margin_loss_backward(const Tensor & grad_output, const Tensor & self, const Tensor & target, int64_t reduction, const Tensor & is_target);
|
||||
std::tuple<Tensor &,Tensor &> _thnn_nll_loss_forward_out(Tensor & output, Tensor & total_weight, const Tensor & self, const Tensor & target, const Tensor & weight, int64_t reduction, int64_t ignore_index);
|
||||
std::tuple<Tensor,Tensor> _thnn_nll_loss_forward(const Tensor & self, const Tensor & target, const Tensor & weight, int64_t reduction, int64_t ignore_index);
|
||||
Tensor & _thnn_nll_loss_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, const Tensor & target, const Tensor & weight, int64_t reduction, int64_t ignore_index, const Tensor & total_weight);
|
||||
Tensor _thnn_nll_loss_backward(const Tensor & grad_output, const Tensor & self, const Tensor & target, const Tensor & weight, int64_t reduction, int64_t ignore_index, const Tensor & total_weight);
|
||||
std::tuple<Tensor &,Tensor &> _thnn_nll_loss2d_forward_out(Tensor & output, Tensor & total_weight, const Tensor & self, const Tensor & target, const Tensor & weight, int64_t reduction, int64_t ignore_index);
|
||||
std::tuple<Tensor,Tensor> _thnn_nll_loss2d_forward(const Tensor & self, const Tensor & target, const Tensor & weight, int64_t reduction, int64_t ignore_index);
|
||||
Tensor & _thnn_nll_loss2d_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, const Tensor & target, const Tensor & weight, int64_t reduction, int64_t ignore_index, const Tensor & total_weight);
|
||||
Tensor _thnn_nll_loss2d_backward(const Tensor & grad_output, const Tensor & self, const Tensor & target, const Tensor & weight, int64_t reduction, int64_t ignore_index, const Tensor & total_weight);
|
||||
Tensor & _thnn_glu_forward_out(Tensor & output, const Tensor & self, int64_t dim);
|
||||
Tensor _thnn_glu_forward(const Tensor & self, int64_t dim);
|
||||
Tensor & _thnn_glu_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, int64_t dim);
|
||||
Tensor _thnn_glu_backward(const Tensor & grad_output, const Tensor & self, int64_t dim);
|
||||
std::tuple<Tensor &,Tensor &> _thnn_log_sigmoid_forward_out(Tensor & output, Tensor & buffer, const Tensor & self);
|
||||
std::tuple<Tensor,Tensor> _thnn_log_sigmoid_forward(const Tensor & self);
|
||||
Tensor & _thnn_log_sigmoid_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, const Tensor & buffer);
|
||||
Tensor _thnn_log_sigmoid_backward(const Tensor & grad_output, const Tensor & self, const Tensor & buffer);
|
||||
Tensor & _thnn_rrelu_with_noise_forward_out(Tensor & output, const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training, c10::optional<at::Generator> generator);
|
||||
Tensor _thnn_rrelu_with_noise_forward(const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training, c10::optional<at::Generator> generator);
|
||||
Tensor & _thnn_rrelu_with_noise_backward_out(Tensor & grad_input, const Tensor & grad_output, const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training);
|
||||
Tensor _thnn_rrelu_with_noise_backward(const Tensor & grad_output, const Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training);
|
||||
Tensor & _thnn_rrelu_with_noise_forward_(Tensor & self, const Tensor & noise, Scalar lower, Scalar upper, bool training, c10::optional<at::Generator> generator);
|
||||
std::tuple<Tensor &,Tensor &,Tensor &> _thnn_conv2d_forward_out(Tensor & output, Tensor & columns, Tensor & ones, const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const Tensor & bias, IntArrayRef stride, IntArrayRef padding);
|
||||
std::tuple<Tensor,Tensor,Tensor> _thnn_conv2d_forward(const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const Tensor & bias, IntArrayRef stride, IntArrayRef padding);
|
||||
std::tuple<Tensor &,Tensor &,Tensor &> _thnn_conv2d_backward_out(Tensor & grad_input, Tensor & grad_weight, Tensor & grad_bias, const Tensor & grad_output, const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, const Tensor & columns, const Tensor & ones);
|
||||
std::tuple<Tensor,Tensor,Tensor> _thnn_conv2d_backward(const Tensor & grad_output, const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, const Tensor & columns, const Tensor & ones, std::array<bool,3> output_mask);
|
||||
Tensor & _thnn_conv_depthwise2d_forward_out(Tensor & output, const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const Tensor & bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation);
|
||||
Tensor _thnn_conv_depthwise2d_forward(const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, const Tensor & bias, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation);
|
||||
std::tuple<Tensor &,Tensor &> _thnn_conv_depthwise2d_backward_out(Tensor & grad_input, Tensor & grad_weight, const Tensor & grad_output, const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation);
|
||||
std::tuple<Tensor,Tensor> _thnn_conv_depthwise2d_backward(const Tensor & grad_output, const Tensor & self, const Tensor & weight, IntArrayRef kernel_size, IntArrayRef stride, IntArrayRef padding, IntArrayRef dilation, std::array<bool,2> output_mask);
|
||||
|
||||
} // namespace th
|
||||
} // namespace legacy
|
||||
} // namespace native
|
||||
} // namespace at
|
||||
4176
aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp
Normal file
4176
aten/src/ATen/cuda/LegacyTHFunctionsCUDA.cpp
Normal file
File diff suppressed because it is too large
Load Diff
|
|
@ -1,38 +0,0 @@
|
|||
import yaml
|
||||
import copy
|
||||
|
||||
try:
|
||||
# use faster C loader if available
|
||||
from yaml import CLoader as Loader
|
||||
except ImportError:
|
||||
from yaml import Loader
|
||||
|
||||
# follows similar logic to cwrap, ignores !inc, and just looks for [[]]
|
||||
|
||||
|
||||
def parse(filename):
|
||||
with open(filename, 'r') as file:
|
||||
declaration_lines = []
|
||||
declarations = []
|
||||
in_declaration = False
|
||||
for line in file.readlines():
|
||||
line = line.rstrip()
|
||||
if line == '[[':
|
||||
declaration_lines = []
|
||||
in_declaration = True
|
||||
elif line == ']]':
|
||||
in_declaration = False
|
||||
declaration = yaml.load('\n'.join(declaration_lines), Loader=Loader)
|
||||
declarations.append(declaration)
|
||||
elif in_declaration:
|
||||
declaration_lines.append(line)
|
||||
declarations = [process_declaration(declaration) for declaration in declarations]
|
||||
return declarations
|
||||
|
||||
def process_declaration(declaration):
|
||||
declaration = copy.deepcopy(declaration)
|
||||
if "arguments" in declaration:
|
||||
declaration["schema_order_arguments"] = copy.deepcopy(declaration["arguments"])
|
||||
if "options" in declaration:
|
||||
declaration["options"] = [process_declaration(option) for option in declaration["options"]]
|
||||
return declaration
|
||||
File diff suppressed because it is too large
Load Diff
|
|
@ -1,545 +0,0 @@
|
|||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
import yaml
|
||||
from collections import defaultdict
|
||||
from collections import OrderedDict
|
||||
|
||||
import sys
|
||||
from os import path
|
||||
sys.path.append(path.dirname(path.abspath(__file__)))
|
||||
|
||||
import cwrap_parser
|
||||
import nn_parse
|
||||
import native_parse
|
||||
import preprocess_declarations
|
||||
import function_wrapper
|
||||
import gen_backend_select_register
|
||||
|
||||
from code_template import CodeTemplate
|
||||
|
||||
|
||||
# This file is the top-level entry point for code generation in ATen.
|
||||
# It takes an arbitrary number of arguments specifying metadata files to
|
||||
# process (.cwrap, .yaml and .h) and outputs a number generated header
|
||||
# and cpp files in ATen/ (see invocations of 'write' for each file that
|
||||
# is written.) It is invoked from cmake; look for the 'cwrap_files'
|
||||
# variable for an up-to-date list of files which are passed.
|
||||
|
||||
parser = argparse.ArgumentParser(description='Generate ATen source files')
|
||||
parser.add_argument('files', help='cwrap files', nargs='+')
|
||||
|
||||
parser.add_argument(
|
||||
'-s',
|
||||
'--source-path',
|
||||
help='path to source directory for ATen',
|
||||
default='.')
|
||||
parser.add_argument(
|
||||
'-o',
|
||||
'--output-dependencies',
|
||||
help='output a list of dependencies into the given file and exit')
|
||||
parser.add_argument(
|
||||
'-d', '--install_dir', help='output directory', default='ATen')
|
||||
parser.add_argument(
|
||||
'--rocm',
|
||||
action='store_true',
|
||||
help='reinterpret CUDA as ROCm/HIP and adjust filepaths accordingly')
|
||||
parser.add_argument(
|
||||
'--vulkan',
|
||||
action='store_true',
|
||||
help='Generate Vulkan backend functions')
|
||||
parser.add_argument(
|
||||
'--op_registration_whitelist',
|
||||
nargs='*',
|
||||
help='filter op registrations by the whitelist (if set); '
|
||||
'each item is `namespace`::`operator name` without overload name; '
|
||||
'e.g.: aten::empty aten::conv2d ...')
|
||||
parser.add_argument(
|
||||
'--backend_whitelist',
|
||||
nargs='*',
|
||||
help='filter dispatch backend by the whitelist (if set), '
|
||||
'e.g.: CPU CUDA QuantizedCPU ...')
|
||||
parser.add_argument(
|
||||
'--per_op_registration',
|
||||
action='store_true',
|
||||
help='group function registrations by op name and write to separate files; '
|
||||
'must also set --op_registration_whitelist param')
|
||||
parser.add_argument(
|
||||
'--force_schema_registration',
|
||||
action='store_true',
|
||||
help='force it to generate schema-only registrations for all ops, including'
|
||||
'those that are not listed on --op_registration_whitelist')
|
||||
options = parser.parse_args()
|
||||
|
||||
# NB: It is mandatory to NOT use os.path.join here, as the install directory
|
||||
# will eventually be ingested by cmake, which does not respect Windows style
|
||||
# path slashes. If you switch this to use os.path.join, you'll get an error
|
||||
# like:
|
||||
#
|
||||
# Syntax error in cmake code when parsing string
|
||||
#
|
||||
# C:/Jenkins/workspace/pytorch-builds/pytorch-win-ws2016-cuda9-cudnn7-py3-build/build/aten/src/ATen\core/TensorMethods.h
|
||||
#
|
||||
# Invalid character escape '\c'.
|
||||
core_install_dir = options.install_dir + '/core' if options.install_dir is not None else None
|
||||
if options.install_dir is not None and not os.path.exists(options.install_dir):
|
||||
os.makedirs(options.install_dir)
|
||||
if core_install_dir is not None and not os.path.exists(core_install_dir):
|
||||
os.makedirs(core_install_dir)
|
||||
|
||||
|
||||
class FileManager(object):
|
||||
def __init__(self, install_dir=None):
|
||||
self.install_dir = install_dir if install_dir else options.install_dir
|
||||
self.filenames = set()
|
||||
self.outputs_written = False
|
||||
self.undeclared_files = []
|
||||
|
||||
def will_write(self, filename):
|
||||
filename = '{}/{}'.format(self.install_dir, filename)
|
||||
if self.outputs_written:
|
||||
raise Exception("'will_write' can only be called before " +
|
||||
"the call to write_outputs, refactor so outputs are registered " +
|
||||
"before running the generators")
|
||||
self.filenames.add(filename)
|
||||
|
||||
def _write_if_changed(self, filename, contents):
|
||||
try:
|
||||
with open(filename, 'r') as f:
|
||||
old_contents = f.read()
|
||||
except IOError:
|
||||
old_contents = None
|
||||
if contents != old_contents:
|
||||
with open(filename, 'w') as f:
|
||||
f.write(contents)
|
||||
|
||||
def write_outputs(self, filename):
|
||||
"""Write a file containing the list of all outputs which are
|
||||
generated by this script."""
|
||||
self._write_if_changed(
|
||||
filename,
|
||||
''.join(name + ";" for name in sorted(self.filenames)))
|
||||
self.outputs_written = True
|
||||
|
||||
def write(self, filename, s, env=None):
|
||||
filename = '{}/{}'.format(self.install_dir, filename)
|
||||
if isinstance(s, CodeTemplate):
|
||||
assert env is not None
|
||||
comment = "@" + "generated by aten/src/ATen/gen.py"
|
||||
if s.filename:
|
||||
comment += " from {}".format(os.path.basename(s.filename))
|
||||
env['generated_comment'] = comment
|
||||
s = s.substitute(env)
|
||||
self._write_if_changed(filename, s)
|
||||
if filename not in self.filenames:
|
||||
self.undeclared_files.append(filename)
|
||||
else:
|
||||
self.filenames.remove(filename)
|
||||
|
||||
def check_all_files_written(self):
|
||||
if len(self.undeclared_files) > 0:
|
||||
raise Exception(
|
||||
"trying to write files {} which are not ".format(self.undeclared_files) +
|
||||
"in the list of outputs this script produces. " +
|
||||
"use will_write to add them.")
|
||||
if len(self.filenames) > 0:
|
||||
raise Exception("Outputs declared with 'will_write' were " +
|
||||
"never written: {}".format(self.filenames))
|
||||
|
||||
|
||||
TEMPLATE_PATH = options.source_path + "/templates"
|
||||
TYPE_DERIVED_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDerived.cpp")
|
||||
SPARSE_TYPE_DERIVED_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/SparseTypeDerived.cpp")
|
||||
TYPE_DERIVED_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDerived.h")
|
||||
TYPE_DEFAULT_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.h")
|
||||
TYPE_DEFAULT_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/TypeDefault.cpp")
|
||||
OPS_ALREADY_MOVED_TO_C10_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/ATenOpList.cpp")
|
||||
BACKEND_SELECT_REGISTER_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/BackendSelectRegister.cpp")
|
||||
SCHEMA_REGISTER_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/SchemaRegister.cpp")
|
||||
TENSOR_H = CodeTemplate.from_file(TEMPLATE_PATH + "/TensorBody.h")
|
||||
TENSOR_METHODS_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/TensorMethods.cpp")
|
||||
|
||||
FUNCTIONS_H = CodeTemplate.from_file(TEMPLATE_PATH + "/Functions.h")
|
||||
FUNCTIONS_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/Functions.cpp")
|
||||
|
||||
LEGACY_TH_FUNCTIONS_H = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHFunctions.h")
|
||||
LEGACY_TH_FUNCTIONS_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/LegacyTHFunctions.cpp")
|
||||
|
||||
NATIVE_FUNCTIONS_H = CodeTemplate.from_file(TEMPLATE_PATH + "/NativeFunctions.h")
|
||||
|
||||
PER_OP_REGISTRATION_CPP = CodeTemplate.from_file(TEMPLATE_PATH + "/PerOpRegistration.cpp")
|
||||
|
||||
core_file_manager = FileManager(core_install_dir)
|
||||
file_manager = FileManager()
|
||||
cuda_file_manager = FileManager()
|
||||
|
||||
def backend_to_devicetype(backend):
|
||||
if backend == 'QuantizedCPU':
|
||||
return 'CPU'
|
||||
elif backend == 'QuantizedCUDA':
|
||||
return 'CUDA'
|
||||
return backend
|
||||
|
||||
backends = ['CPU', 'CUDA']
|
||||
densities = ['Dense', 'Sparse', 'Mkldnn'] # TODO: layout instead of densities?
|
||||
|
||||
quantized_backends = ['QuantizedCPU', 'QuantizedCUDA']
|
||||
|
||||
# scalar_name, c_type, accreal, is_floating_type
|
||||
quantized_scalar_types = [
|
||||
('QInt8', 'qint8', 'QInt8AccrealNotDefined', 'QInt8IsFloatingTypeNotDefined'),
|
||||
('QUInt8', 'quint8', 'QUInt8AccrealNotDefined', 'QUInt8IsFloatingTypeNotDefined'),
|
||||
('QInt32', 'qint32', 'QInt32AccrealNotDefined', 'Qint32IsFloatingTypeNotDefined'),
|
||||
]
|
||||
|
||||
# whitelist used to filter op registrations for custom build
|
||||
if options.op_registration_whitelist is not None:
|
||||
op_registration_whitelist = set(options.op_registration_whitelist)
|
||||
else:
|
||||
op_registration_whitelist = None
|
||||
|
||||
# shared environment for non-derived base classes TensorBody.h Storage.h
|
||||
top_env = {
|
||||
'cpu_type_headers': [],
|
||||
'cuda_type_headers': [],
|
||||
'function_registrations': [],
|
||||
'aten_ops': [],
|
||||
'type_method_declarations': [],
|
||||
'type_method_definitions': [],
|
||||
'tensor_method_declarations': [],
|
||||
'tensor_method_definitions': [],
|
||||
'function_declarations': [],
|
||||
'function_definitions': [],
|
||||
'type_ids': [],
|
||||
'native_function_declarations': [],
|
||||
}
|
||||
|
||||
|
||||
def is_whitelisted_backend(backend):
|
||||
return options.backend_whitelist is None or backend in options.backend_whitelist
|
||||
|
||||
def is_cuda_backend(backend):
|
||||
return backend in ("QuantizedCUDA", "CUDA")
|
||||
|
||||
def dict_representer(dumper, data):
|
||||
return dumper.represent_dict(data.items())
|
||||
|
||||
|
||||
def postprocess_output_declarations(output_declarations):
|
||||
# ensure each return has a name associated with it
|
||||
for decl in output_declarations:
|
||||
has_named_ret = False
|
||||
for n, ret in enumerate(decl.returns):
|
||||
if 'name' not in ret:
|
||||
assert not has_named_ret
|
||||
if decl.inplace:
|
||||
ret['name'] = 'self'
|
||||
elif len(decl.returns) == 1:
|
||||
ret['name'] = 'out'
|
||||
else:
|
||||
ret['name'] = 'out' + str(n)
|
||||
else:
|
||||
has_named_ret = True
|
||||
|
||||
def remove_key_if_none(dictionary, key):
|
||||
if key in dictionary.keys() and dictionary[key] is None:
|
||||
del dictionary[key]
|
||||
return dictionary
|
||||
|
||||
return [remove_key_if_none(decl._asdict(), 'buffers')
|
||||
for decl in output_declarations]
|
||||
|
||||
|
||||
def format_yaml(data):
|
||||
if options.output_dependencies:
|
||||
# yaml formatting is slow so don't do it if we will ditch it.
|
||||
return ""
|
||||
noalias_dumper = yaml.dumper.SafeDumper
|
||||
noalias_dumper.ignore_aliases = lambda self, data: True
|
||||
# Support serializing OrderedDict
|
||||
noalias_dumper.add_representer(OrderedDict, dict_representer)
|
||||
# Some yaml parsers (e.g. Haskell's) don't understand line breaks.
|
||||
# width=float('Inf') turns off optional line breaks and improves
|
||||
# the portability of the outputted yaml.
|
||||
return yaml.dump(data, default_flow_style=False, Dumper=noalias_dumper, width=float('Inf'))
|
||||
|
||||
|
||||
def add_op_registrations(per_type_registrations, per_op_registrations, schema_registrations, op_registrations):
|
||||
for op_registration in op_registrations:
|
||||
opname = op_registration.operator_name
|
||||
registration = op_registration.registration_code
|
||||
|
||||
# collect schema registration for all ops (whitelisted or not)
|
||||
if schema_registrations is not None:
|
||||
schema_registrations.append(op_registration.schema_registration_code)
|
||||
|
||||
# apply whitelist
|
||||
if op_registration_whitelist is not None and opname not in op_registration_whitelist:
|
||||
continue
|
||||
if options.per_op_registration:
|
||||
# per op registration
|
||||
per_op_registrations[opname].append(registration)
|
||||
else:
|
||||
# per type registration
|
||||
per_type_registrations.append(registration)
|
||||
|
||||
|
||||
def generate_storage_type_and_tensor(backend, density, declarations, per_op_registrations, schema_registrations):
|
||||
env = {}
|
||||
density_tag = density if density != 'Dense' else ''
|
||||
env['Density'] = density
|
||||
env['Type'] = "{}{}Type".format(density_tag, backend)
|
||||
env['DeviceType'] = backend_to_devicetype(backend)
|
||||
env['Backend'] = density_tag + backend
|
||||
if not is_whitelisted_backend(env['Backend']):
|
||||
return
|
||||
env['storage_tensor_headers'] = []
|
||||
if density != 'Sparse':
|
||||
env['storage_tensor_headers'] = ['#include <c10/core/TensorImpl.h>']
|
||||
|
||||
# used for generating switch logic for external functions
|
||||
tag = density_tag + backend
|
||||
env['TypeID'] = 'TypeID::' + tag
|
||||
top_env['type_ids'].append(tag + ',')
|
||||
|
||||
env['legacy_th_headers'] = []
|
||||
if is_cuda_backend(backend):
|
||||
env['extra_cuda_headers'] = []
|
||||
env['extra_cuda_headers'].append('#include <ATen/DeviceGuard.h>')
|
||||
if options.rocm:
|
||||
env['th_headers'] = [
|
||||
'#include <THH/THH.h>',
|
||||
'#include <THH/THHTensor.hpp>',
|
||||
'#include <THHUNN/THHUNN.h>',
|
||||
'#undef THNN_',
|
||||
'#undef THCIndexTensor_',
|
||||
]
|
||||
env['extra_cuda_headers'].append('#include <ATen/hip/ATenHIPGeneral.h>')
|
||||
env['extra_cuda_headers'].append('#include <ATen/hip/HIPDevice.h>')
|
||||
env['extra_cuda_headers'].append('#include <ATen/hip/HIPContext.h>')
|
||||
else:
|
||||
env['th_headers'] = [
|
||||
'#include <THC/THC.h>',
|
||||
'#include <THC/THCTensor.hpp>',
|
||||
'#include <THCUNN/THCUNN.h>',
|
||||
'#undef THNN_',
|
||||
'#undef THCIndexTensor_',
|
||||
]
|
||||
env['extra_cuda_headers'].append('#include <ATen/cuda/ATenCUDAGeneral.h>')
|
||||
env['extra_cuda_headers'].append('#include <ATen/cuda/CUDADevice.h>')
|
||||
env['extra_cuda_headers'].append('#include <ATen/cuda/CUDAContext.h>')
|
||||
env['state'] = ['globalContext().getTHCState()']
|
||||
env['isCUDA'] = 'true'
|
||||
env['storage_device'] = 'return storage->device;'
|
||||
env['Generator'] = 'CUDAGeneratorImpl'
|
||||
env['allocator'] = 'at::cuda::getCUDADeviceAllocator()'
|
||||
else:
|
||||
env['th_headers'] = [
|
||||
'#include <TH/TH.h>',
|
||||
'#include <TH/THTensor.hpp>',
|
||||
]
|
||||
env['extra_cuda_headers'] = []
|
||||
env['state'] = []
|
||||
env['isCUDA'] = 'false'
|
||||
env['storage_device'] = 'throw std::runtime_error("CPU storage has no device");'
|
||||
env['Generator'] = 'CPUGeneratorImpl'
|
||||
env['allocator'] = 'getCPUAllocator()'
|
||||
|
||||
declarations, definitions, op_registrations, th_declarations, th_definitions = function_wrapper.create_derived(
|
||||
env, declarations)
|
||||
env['type_derived_method_declarations'] = declarations
|
||||
env['type_derived_method_definitions'] = definitions
|
||||
env['legacy_th_declarations'] = th_declarations
|
||||
env['legacy_th_definitions'] = th_definitions
|
||||
env['function_registrations'] = []
|
||||
add_op_registrations(env['function_registrations'], per_op_registrations, schema_registrations, op_registrations)
|
||||
|
||||
fm = file_manager
|
||||
if env['DeviceType'] == 'CUDA':
|
||||
fm = cuda_file_manager
|
||||
|
||||
if env['Backend'] == 'CPU' or env['Backend'] == 'CUDA':
|
||||
env['namespace'] = env['Backend'].lower()
|
||||
env['legacy_th_headers'].append('#include <ATen/LegacyTHFunctions' + env['Backend'] + ".h>")
|
||||
fm.write('LegacyTHFunctions' + env['Backend'] + ".h", LEGACY_TH_FUNCTIONS_H, env)
|
||||
fm.write('LegacyTHFunctions' + env['Backend'] + ".cpp", LEGACY_TH_FUNCTIONS_CPP, env)
|
||||
|
||||
if density != 'Sparse':
|
||||
fm.write(env['Type'] + ".cpp", TYPE_DERIVED_CPP, env)
|
||||
else:
|
||||
fm.write(env['Type'] + ".cpp", SPARSE_TYPE_DERIVED_CPP, env)
|
||||
fm.write(env['Type'] + ".h", TYPE_DERIVED_H, env)
|
||||
|
||||
if env['DeviceType'] == 'CPU' or env['DeviceType'] == 'Vulkan':
|
||||
top_env['cpu_type_headers'].append(
|
||||
'#include <ATen/{}.h>'.format(env['Type']))
|
||||
else:
|
||||
assert env['DeviceType'] == 'CUDA'
|
||||
top_env['cuda_type_headers'].append(
|
||||
'#include <ATen/{}.h>'.format(env['Type']))
|
||||
|
||||
|
||||
# yields (backend, density) tuples
|
||||
def iterate_types():
|
||||
for backend in backends:
|
||||
for density in densities:
|
||||
if density == 'Mkldnn' and backend != 'CPU':
|
||||
continue
|
||||
else:
|
||||
yield (backend, density)
|
||||
for backend in quantized_backends:
|
||||
yield (backend, 'Dense')
|
||||
if options.vulkan:
|
||||
yield('Vulkan', 'Dense')
|
||||
|
||||
|
||||
def gen_per_op_registration_filename(opname):
|
||||
return 'pt_op_register_{}.cpp'.format(opname.replace(':', '-'))
|
||||
|
||||
|
||||
###################
|
||||
# declare what files will be output _before_ we do any work
|
||||
# so that the script runs quickly when we are just querying the
|
||||
# outputs
|
||||
def declare_outputs():
|
||||
core_files = ['TensorBody.h', 'TensorMethods.cpp', 'ATenOpList.cpp']
|
||||
for f in core_files:
|
||||
core_file_manager.will_write(f)
|
||||
files = ['Declarations.yaml', 'TypeDefault.cpp', 'TypeDefault.h',
|
||||
'Functions.h', 'Functions.cpp', 'NativeFunctions.h', 'BackendSelectRegister.cpp']
|
||||
for f in files:
|
||||
file_manager.will_write(f)
|
||||
for backend, density in iterate_types():
|
||||
full_backend = backend if density == "Dense" else density + backend
|
||||
if not is_whitelisted_backend(full_backend):
|
||||
continue
|
||||
fm = file_manager
|
||||
if is_cuda_backend(backend):
|
||||
fm = cuda_file_manager
|
||||
for kind in ["Type"]:
|
||||
if kind != 'Type' and density == "Sparse":
|
||||
# No Storage or Tensor for sparse
|
||||
continue
|
||||
fm.will_write("{}{}.h".format(full_backend, kind))
|
||||
fm.will_write("{}{}.cpp".format(full_backend, kind))
|
||||
if backend == 'CPU' or backend == 'CUDA':
|
||||
fm.will_write("LegacyTHFunctions{}.h".format(backend))
|
||||
fm.will_write("LegacyTHFunctions{}.cpp".format(backend))
|
||||
|
||||
if options.per_op_registration:
|
||||
if op_registration_whitelist is None:
|
||||
raise Exception("Must set --op_registration_whitelist for per-op registration.")
|
||||
for whitelisted_op in op_registration_whitelist:
|
||||
fname = gen_per_op_registration_filename(whitelisted_op)
|
||||
file_manager.will_write(fname)
|
||||
|
||||
if options.force_schema_registration:
|
||||
file_manager.will_write('SchemaRegister.cpp')
|
||||
|
||||
|
||||
def filter_by_extension(files, *extensions):
|
||||
filtered_files = []
|
||||
for file in files:
|
||||
for extension in extensions:
|
||||
if file.endswith(extension):
|
||||
filtered_files.append(file)
|
||||
return filtered_files
|
||||
|
||||
|
||||
def generate_per_op_registration(per_op_registrations):
|
||||
if not options.per_op_registration:
|
||||
return
|
||||
|
||||
# Ensure all whitelisted operators have a corresponding registration file.
|
||||
# Generate an empty placeholder file for nonexistent operators, which might
|
||||
# be registered manually instead of via codegen.
|
||||
# This can simplify the custom BUCK build which consumes the output of this
|
||||
# script, since it can uniformly create per-op build targets and dependencies
|
||||
# without having to know the subtle difference about op registration.
|
||||
# Manually registered operators might call codegen registered operators thus
|
||||
# we cannot simply ignore them when calculating transitive dependencies for
|
||||
# custom build.
|
||||
for whitelisted_op in op_registration_whitelist:
|
||||
if whitelisted_op not in per_op_registrations:
|
||||
per_op_registrations[whitelisted_op] = []
|
||||
|
||||
for opname, function_registrations in per_op_registrations.items():
|
||||
fname = gen_per_op_registration_filename(opname)
|
||||
file_manager.write(fname, PER_OP_REGISTRATION_CPP, {
|
||||
'extra_headers': top_env['cpu_type_headers'] + top_env['cuda_type_headers'],
|
||||
'function_registrations': function_registrations,
|
||||
})
|
||||
|
||||
|
||||
def generate_schema_registration(schema_registrations):
|
||||
if not options.force_schema_registration:
|
||||
return
|
||||
file_manager.write('SchemaRegister.cpp', SCHEMA_REGISTER_CPP, {
|
||||
'schema_registrations': sorted(set(schema_registrations)),
|
||||
})
|
||||
|
||||
|
||||
def generate_outputs():
|
||||
cwrap_files = filter_by_extension(options.files, '.cwrap')
|
||||
nn_files = filter_by_extension(options.files, 'nn.yaml', '.h')
|
||||
native_files = filter_by_extension(options.files, 'native_functions.yaml')
|
||||
|
||||
declarations = [d
|
||||
for file in cwrap_files
|
||||
for d in cwrap_parser.parse(file)]
|
||||
|
||||
declarations += nn_parse.run(nn_files)
|
||||
declarations += native_parse.run(native_files)
|
||||
declarations = preprocess_declarations.run(declarations)
|
||||
|
||||
per_op_registrations = defaultdict(list) if options.per_op_registration else None
|
||||
schema_registrations = [] if options.force_schema_registration else None
|
||||
|
||||
# note: this will fill in top_env['type/tensor_method_declarations/definitions']
|
||||
# and modify the declarations to include any information that will all_backends
|
||||
# be used by function_wrapper.create_derived
|
||||
output_declarations, op_registrations = function_wrapper.create_generic(
|
||||
top_env, declarations)
|
||||
output_declarations = postprocess_output_declarations(output_declarations)
|
||||
file_manager.write("Declarations.yaml", format_yaml(output_declarations))
|
||||
|
||||
gen_backend_select_register.register_backend_select_methods(declarations, BACKEND_SELECT_REGISTER_CPP, file_manager)
|
||||
|
||||
add_op_registrations(
|
||||
top_env['function_registrations'], per_op_registrations, schema_registrations, op_registrations)
|
||||
|
||||
for backend, density in iterate_types():
|
||||
generate_storage_type_and_tensor(
|
||||
backend, density, declarations, per_op_registrations, schema_registrations)
|
||||
|
||||
core_files = {
|
||||
'TensorBody.h': TENSOR_H,
|
||||
'TensorMethods.cpp': TENSOR_METHODS_CPP,
|
||||
'ATenOpList.cpp': OPS_ALREADY_MOVED_TO_C10_CPP,
|
||||
}
|
||||
|
||||
for core_file, core_template_file in core_files.items():
|
||||
core_file_manager.write(core_file, core_template_file, top_env)
|
||||
|
||||
file_manager.write('TypeDefault.h', TYPE_DEFAULT_H, top_env)
|
||||
file_manager.write('TypeDefault.cpp', TYPE_DEFAULT_CPP, top_env)
|
||||
|
||||
file_manager.write('Functions.h', FUNCTIONS_H, top_env)
|
||||
file_manager.write('Functions.cpp', FUNCTIONS_CPP, top_env)
|
||||
|
||||
file_manager.write('NativeFunctions.h', NATIVE_FUNCTIONS_H, top_env)
|
||||
|
||||
generate_per_op_registration(per_op_registrations)
|
||||
generate_schema_registration(schema_registrations)
|
||||
|
||||
file_manager.check_all_files_written()
|
||||
cuda_file_manager.check_all_files_written()
|
||||
|
||||
declare_outputs()
|
||||
if options.output_dependencies is not None:
|
||||
file_manager.write_outputs(options.output_dependencies)
|
||||
core_file_manager.write_outputs(options.output_dependencies + "-core")
|
||||
cuda_file_manager.write_outputs(options.output_dependencies + "-cuda")
|
||||
else:
|
||||
generate_outputs()
|
||||
|
|
@ -1,111 +0,0 @@
|
|||
# This script generates BackendSelectRegister.cpp which is being used for dispatching purposes.
|
||||
#
|
||||
# TLDR: most operators take one or more Tensors as arguments, and dispatch keys extracted from
|
||||
# these Tensors determine which kernel (operator implementation) the dispatcher actually invokes.
|
||||
# E.g., calling add() on two CUDA Tensors will dispatch to the CUDA implementation of add(),
|
||||
# and so on.
|
||||
#
|
||||
# But factory functions don't take Tensors, so we need to get dispatch keys from other arguments.
|
||||
# Rather than teaching the dispatcher how to extract dispatch keys from types besides Tensor, we
|
||||
# register an extra kernel for each factory op, under the `BackendSelect` dispatch key. This key
|
||||
# has higher precedence than dispatch keys for actual backends, so a BackendSelect kernel will
|
||||
# front-run other kernels registered for the same op.
|
||||
#
|
||||
# It's the responsibility of the BackendSelect factory kernels to extract the "real" dispatch
|
||||
# key from non-Tensor arguments, and redispatch using this key. Here, we generate implementations
|
||||
# that obtain the key from the TensorOptions argument that's passed to all Tensor factory ops.
|
||||
#
|
||||
# BackendSelectRegister.cpp will contain both the BackendSelect kernels and registrations for
|
||||
# all factory functions that have 'backend_select' flag in its native_functions.yaml definition.
|
||||
|
||||
from code_template import CodeTemplate
|
||||
from function_wrapper import gen_dispatch_key_init
|
||||
|
||||
GENERATED_COMMENT = CodeTemplate(
|
||||
"@" + "generated from ${filename}")
|
||||
|
||||
# See NOTE[UnboxedOnly] in function_wrapper.py
|
||||
UNBOXEDONLY_FUNCTION_REGISTRATION = CodeTemplate("""\
|
||||
m.impl_UNBOXED("aten::${op_name_with_overload_name}", ${function_name});
|
||||
""")
|
||||
|
||||
FUNCTION_REGISTRATION = CodeTemplate("""\
|
||||
m.impl("aten::${op_name_with_overload_name}",
|
||||
c10::impl::hacky_wrapper_for_legacy_signatures<${schema_order_cpp_signature}>(
|
||||
TORCH_FN(${function_name})));
|
||||
""")
|
||||
|
||||
FUNCTION_DEFINITION = CodeTemplate("""\
|
||||
// ${schema_string}
|
||||
Tensor ${function_name}(${method_formals}) {
|
||||
static auto op = c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("aten::${name}", "${overload_name}")
|
||||
.typed<${function_cpp_signature}>();
|
||||
${dispatch_key_init}
|
||||
return op.callWithDispatchKey(_dk, ${function_actuals});
|
||||
}
|
||||
""")
|
||||
|
||||
|
||||
def needs_backend_select(declaration_option):
|
||||
# We register an op under the BackendSelect dispatch key
|
||||
# if a TensorOptions argument has been gathered from its declared args
|
||||
# We skip all the 'new_*' and '*_like' ops as they are special cased and avoid dispatching.
|
||||
# See TypeDefault.cpp
|
||||
if declaration_option['name'].endswith('_like') or declaration_option['name'].startswith('new_'):
|
||||
return False
|
||||
|
||||
return any(a.get('dynamic_type') == 'TensorOptions' for a in declaration_option['arguments'])
|
||||
|
||||
def register_backend_select_methods(declarations, template_path, file_manager):
|
||||
backend_select_method_definitions = []
|
||||
backend_select_function_registrations = []
|
||||
|
||||
for decl in declarations:
|
||||
for option in decl["options"]:
|
||||
if needs_backend_select(option):
|
||||
name = option['name']
|
||||
op_name_with_overload_name = option['name']
|
||||
if option.get('overload_name', '') != '':
|
||||
name = "{0}_{1}".format(name, option['overload_name'])
|
||||
op_name_with_overload_name = "{0}.{1}".format(op_name_with_overload_name, option['overload_name'])
|
||||
|
||||
if option['use_c10_dispatcher'] == 'full':
|
||||
func_reg = FUNCTION_REGISTRATION.substitute(schema_string=option['schema_string'],
|
||||
op_name_with_overload_name=op_name_with_overload_name,
|
||||
function_name=name,
|
||||
schema_order_cpp_signature=option['schema_order_cpp_signature'])
|
||||
else:
|
||||
assert option['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper'
|
||||
func_reg = UNBOXEDONLY_FUNCTION_REGISTRATION.substitute(schema_string=option['schema_string'],
|
||||
op_name_with_overload_name=op_name_with_overload_name,
|
||||
function_name=name)
|
||||
|
||||
dispatch_key_init = gen_dispatch_key_init('_dk', option['formals_list'])
|
||||
|
||||
# See NOTE[UnboxedOnly] in function_wrapper.py
|
||||
if option['use_c10_dispatcher'] == 'full':
|
||||
function_cpp_signature = option['schema_order_cpp_signature']
|
||||
function_actuals = option['schema_order_actuals']
|
||||
else:
|
||||
assert option['use_c10_dispatcher'] == 'with_codegenerated_unboxing_wrapper'
|
||||
function_cpp_signature = option['cpp_signature']
|
||||
function_actuals = option['actuals']
|
||||
method_def = FUNCTION_DEFINITION.substitute(function_name=name,
|
||||
schema_string=option['schema_string'],
|
||||
method_formals=option['formals_with_defaults'],
|
||||
name=option['name'],
|
||||
overload_name=option['overload_name'],
|
||||
dispatch_key_init=dispatch_key_init,
|
||||
function_cpp_signature=function_cpp_signature,
|
||||
function_actuals=function_actuals)
|
||||
|
||||
backend_select_function_registrations.append(func_reg)
|
||||
backend_select_method_definitions.append(method_def)
|
||||
|
||||
env = {}
|
||||
env['backend_select_method_definitions'] = backend_select_method_definitions
|
||||
env['backend_select_function_registrations'] = backend_select_function_registrations
|
||||
|
||||
env['generated_comment'] = GENERATED_COMMENT.substitute(filename=template_path)
|
||||
file_manager.write('BackendSelectRegister.cpp', template_path, env)
|
||||
|
|
@ -3166,7 +3166,7 @@
|
|||
CPU: roll_cpu
|
||||
CUDA: roll_cuda
|
||||
|
||||
# default int[] value [0,1] should not add space after comma, since native_parse.py uses ', ' to split args
|
||||
# default int[] value [0,1] should not add space after comma, since codegen parser uses ', ' to split args
|
||||
|
||||
- func: rot90(Tensor self, int k=1, int[] dims=[0,1]) -> Tensor
|
||||
use_c10_dispatcher: full
|
||||
|
|
@ -5773,7 +5773,7 @@
|
|||
CPU: foreach_tensor_add_scalar_kernel_slow
|
||||
CUDA: foreach_tensor_add_scalar_kernel_cuda
|
||||
|
||||
- func: _foreach_add_.Scalar(Tensor[](a!) self, Scalar scalar) -> ()
|
||||
- func: _foreach_add_.Scalar(Tensor(a!)[] self, Scalar scalar) -> ()
|
||||
device_guard: False
|
||||
variants: function
|
||||
dispatch:
|
||||
|
|
|
|||
|
|
@ -1,482 +0,0 @@
|
|||
from __future__ import print_function
|
||||
import re
|
||||
import yaml
|
||||
import pprint
|
||||
import sys
|
||||
import copy
|
||||
|
||||
try:
|
||||
# use faster C loader if available
|
||||
from yaml import CLoader as Loader
|
||||
except ImportError:
|
||||
from yaml import Loader
|
||||
|
||||
# [temp translations]
|
||||
# We're currently incrementally moving from the custom func schema to the
|
||||
# JIT signature schema incrementally. This will reduce overall complexity
|
||||
# and increase compliance between these components. So for now we do simple
|
||||
# type translations to continue to emit the legacy func schema for further
|
||||
# processing by downstream tools. This will helps us avoid having to prematurely
|
||||
# change all downstream tools to detect these new types.
|
||||
def type_argument_translations(arg):
|
||||
type_and_name = [a.strip() for a in arg.rsplit(' ', 1)]
|
||||
name = ''
|
||||
if len(type_and_name) > 1:
|
||||
name = type_and_name[1]
|
||||
t = type_and_name[0]
|
||||
name = name.split('=')
|
||||
default = None
|
||||
nullable = False
|
||||
size = None # Only applies to int[\d+] and Tensor[\d+] arguments
|
||||
if len(name) > 1:
|
||||
default = name[1]
|
||||
name = name[0]
|
||||
|
||||
match = re.match(r'(Tensor.*)\((.+)\)(.*)', t)
|
||||
annotation = None
|
||||
if match:
|
||||
t = match.group(1) + match.group(3)
|
||||
annotation = match.group(2)
|
||||
|
||||
# XXX: is_nullable flag can only annotate entire type as optional type,
|
||||
# need to special case Generator? logic to make ? only available in jit
|
||||
# TODO: deprecate is_nullable global flag, and parse the type
|
||||
# to support annotating complicated types with optional annotation
|
||||
nullable = '?' in t
|
||||
|
||||
# This enables "Generator? x = None and translates to legacy
|
||||
# "Generator x = nullptr". See [temp translations].
|
||||
if t == 'Generator?' and default == 'None':
|
||||
t = 'Generator'
|
||||
default = 'c10::nullopt'
|
||||
# Enables Tensor[] by translating to legacy TensorList.
|
||||
elif t == 'Tensor[]' or t == 'Tensor?[]':
|
||||
t = 'TensorList'
|
||||
# Enables int[] by translating to legacy IntArrayRef.
|
||||
elif t == 'int[]':
|
||||
t = 'IntArrayRef'
|
||||
elif t == 'int[]?':
|
||||
t = 'IntArrayRef?'
|
||||
# Enables int by translating to legacy int64_t.
|
||||
elif t == 'int':
|
||||
t = 'int64_t'
|
||||
elif t == 'int?':
|
||||
t = 'int64_t?'
|
||||
elif t == 'int64_t':
|
||||
raise RuntimeError("Please use int and not int64_t. "
|
||||
"See [temp translations] for details.")
|
||||
elif t == 'int64_t?':
|
||||
raise RuntimeError("Please use int? and not int64_t?. "
|
||||
"See [temp translations] for details.")
|
||||
# Enables Dimname[] by translating to legacy DimnameList.
|
||||
elif t == 'Dimname[]':
|
||||
t = 'DimnameList'
|
||||
elif t == 'Dimname[]?':
|
||||
t = 'DimnameList?'
|
||||
# Enables float by translating to legacy double.
|
||||
elif t == 'float':
|
||||
t = 'double'
|
||||
elif t == 'float?':
|
||||
t = 'double?'
|
||||
elif t == 'float[]':
|
||||
t = 'ArrayRef<double>'
|
||||
elif t == 'float[]?':
|
||||
t = 'ArrayRef<double>?'
|
||||
# Enables str by translating to legacy std::string.
|
||||
elif t == 'str':
|
||||
t = 'std::string'
|
||||
elif t == 'double':
|
||||
raise RuntimeError("Please use float and not double. "
|
||||
"See [temp translations] for details.")
|
||||
# Enables int[x] by translating to legacy IntArrayRef[x]. See [temp translations]
|
||||
elif re.match(r'int\[(\d+)\]\?', t):
|
||||
match = re.match(r'int\[(\d+)\]\?', t)
|
||||
t = 'IntArrayRef'
|
||||
size = int(match.group(1))
|
||||
elif re.match(r'int\[(\d+)\]', t):
|
||||
match = re.match(r'int\[(\d+)\]', t)
|
||||
t = 'IntArrayRef'
|
||||
size = int(match.group(1))
|
||||
# Enables bool[x] by translating to legacy std::array<bool,x>. See [temp translations]
|
||||
elif re.match(r'bool\[(\d+)\]', t):
|
||||
match = re.match(r'bool\[(\d+)\]', t)
|
||||
t = 'std::array<bool,{}>'.format(match.group(1))
|
||||
elif re.match(r'std::array', t):
|
||||
raise RuntimeError("Please use array notation, e.g. bool[3] and not std::array."
|
||||
"See [temp translations] for details.")
|
||||
# Enables Dimname[x] by translating to DimnameList[x]. See [temp translations]
|
||||
elif re.match(r'Dimname\[(\d+)\]', t):
|
||||
match = re.match(r'Dimname\[(\d+)\]', t)
|
||||
t = 'DimnameList'
|
||||
size = int(match.group(1))
|
||||
|
||||
if not default:
|
||||
pass
|
||||
# This enables Tensor? x=None and translates to legacy
|
||||
# "Tensor? x={}". See [temp translations].
|
||||
elif t.startswith('Tensor?') and default == 'None':
|
||||
default = "{}"
|
||||
elif default == 'True':
|
||||
default = True
|
||||
elif default == 'False':
|
||||
default = False
|
||||
elif default == 'true':
|
||||
raise RuntimeError("Please use True and not true. "
|
||||
"See [temp translations] for details.")
|
||||
elif default == 'false':
|
||||
raise RuntimeError("Please use False and not false. "
|
||||
"See [temp translations] for details.")
|
||||
# Enables default argument [] by translating to legacy {}.
|
||||
# See [temp translations]
|
||||
elif default == '[]':
|
||||
default = '{}'
|
||||
# Enables lists by translating to legacy {.*}.
|
||||
# See [temp translations]
|
||||
elif re.match(r'\[.*\]', default):
|
||||
default = "{" + default[1:-1] + "}"
|
||||
elif default == 'None':
|
||||
default = 'c10::nullopt'
|
||||
# The JIT signature schema uses Mean, but in particular C++ needs
|
||||
# the legacy at::Reduction::Mean. So we'll continue emiting that until
|
||||
# we change this at either a JIT schema or C++ level.
|
||||
elif default == 'Mean':
|
||||
default = 'at::Reduction::Mean'
|
||||
elif default == 'contiguous_format':
|
||||
default = 'MemoryFormat::Contiguous'
|
||||
elif default == 'per_tensor_affine':
|
||||
default = 'QScheme::PER_TENSOR_AFFINE'
|
||||
else:
|
||||
try:
|
||||
default = int(default)
|
||||
except ValueError:
|
||||
try:
|
||||
default = float(default)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
return t, name, default, nullable, size, annotation
|
||||
|
||||
|
||||
def parse_arguments(args):
|
||||
arguments = []
|
||||
kwarg_only = False
|
||||
|
||||
if len(args.strip()) == 0:
|
||||
return arguments
|
||||
|
||||
# TODO: Use a real parser here; this will get bamboozled
|
||||
# by signatures that contain things like std::array<bool, 2> (note the space)
|
||||
for arg_idx, arg in enumerate(args.split(', ')):
|
||||
type_and_name = [a.strip() for a in arg.rsplit(' ', 1)]
|
||||
if type_and_name == ['*']:
|
||||
assert not kwarg_only
|
||||
kwarg_only = True
|
||||
continue
|
||||
|
||||
t, name, default, nullable, size, annotation = type_argument_translations(arg)
|
||||
|
||||
argument_dict = {'type': t.rstrip('?'), 'name': name, 'is_nullable': nullable, 'annotation': annotation}
|
||||
if size:
|
||||
argument_dict['size'] = size
|
||||
if default is not None:
|
||||
argument_dict['default'] = default
|
||||
if kwarg_only:
|
||||
argument_dict['kwarg_only'] = True
|
||||
arguments.append(argument_dict)
|
||||
|
||||
return arguments
|
||||
|
||||
def process_arguments(arguments, func_variants, declaration, func_return):
|
||||
is_out_fn = False
|
||||
arguments_out = []
|
||||
arguments_other = []
|
||||
for argument in arguments:
|
||||
if argument['type'] == "Tensor" and \
|
||||
argument['annotation'] and \
|
||||
re.match(r'^(.*!)$', argument['annotation']) and \
|
||||
argument.get('kwarg_only'):
|
||||
argument['output'] = True
|
||||
argument['kwarg_only'] = False
|
||||
arguments_out.append(argument)
|
||||
is_out_fn = True
|
||||
else:
|
||||
arguments_other.append(argument)
|
||||
|
||||
arguments = arguments_out + arguments_other
|
||||
|
||||
name = declaration['name']
|
||||
if is_out_fn:
|
||||
declaration['name'] += "_out"
|
||||
|
||||
# Reverse splat of TensorOptions
|
||||
# As we move towards the JIT function schema for native_functions.yaml we need to support
|
||||
# the expanded version of TensorOptions. For now we discover whether there are three
|
||||
# types and names of keyword arguments: "ScalarType dtype", "Layout layout" and "Device device"
|
||||
# Each, if set, must have default arguments set to long or float, strided and "cpu" respectively.
|
||||
# They must appear in this order and in this order only in order for us to be able to process them.
|
||||
# In the future we will get rid of this specific processing as downstream consumers start relying
|
||||
# less on the content of Declarations.yaml. If you want to support more than this you'll
|
||||
# potentially have to extend the JIT.
|
||||
|
||||
supported_topt_arguments = [
|
||||
[
|
||||
{'name': 'dtype', 'type': 'ScalarType', 'is_nullable': False, 'annotation': None},
|
||||
{'name': 'layout', 'type': 'Layout', 'is_nullable': False, 'annotation': None},
|
||||
{'name': 'device', 'type': 'Device', 'is_nullable': False, 'annotation': None},
|
||||
{'name': 'pin_memory', 'type': 'bool', 'is_nullable': False, 'annotation': None, 'default': False},
|
||||
]
|
||||
]
|
||||
supported_topt_arguments.append(copy.deepcopy(supported_topt_arguments[0]))
|
||||
for arg in supported_topt_arguments[1]:
|
||||
arg.update({'kwarg_only': True})
|
||||
supported_topt_arguments.append(copy.deepcopy(supported_topt_arguments[1]))
|
||||
for arg in supported_topt_arguments[2]:
|
||||
arg.update({'default': 'c10::nullopt', 'is_nullable': True})
|
||||
# add explicit support for what is needed for tril_indices / triu_indices
|
||||
supported_topt_arguments.append(
|
||||
[
|
||||
{'name': 'dtype', 'type': 'ScalarType', 'annotation': None, 'kwarg_only': True,
|
||||
'default': 'long', 'is_nullable': True},
|
||||
{'name': 'layout', 'type': 'Layout', 'annotation': None, 'kwarg_only': True,
|
||||
'default': 'c10::nullopt', 'is_nullable': True},
|
||||
{'name': 'device', 'type': 'Device', 'annotation': None, 'kwarg_only': True,
|
||||
'default': 'c10::nullopt', 'is_nullable': True},
|
||||
{'name': 'pin_memory', 'type': 'bool', 'annotation': None, 'kwarg_only': True,
|
||||
'default': 'c10::nullopt', 'is_nullable': True},
|
||||
]
|
||||
)
|
||||
supported_topt_arguments.append(
|
||||
[
|
||||
{'name': 'dtype', 'type': 'ScalarType', 'annotation': None, 'kwarg_only': True,
|
||||
'default': 'c10::nullopt', 'is_nullable': True},
|
||||
{'name': 'layout', 'type': 'Layout', 'annotation': None, 'kwarg_only': True,
|
||||
'default': 'c10::nullopt', 'is_nullable': True},
|
||||
{'name': 'device', 'type': 'Device', 'annotation': None, 'kwarg_only': True,
|
||||
'default': 'c10::nullopt', 'is_nullable': True},
|
||||
{'name': 'pin_memory', 'type': 'bool', 'annotation': None, 'kwarg_only': True,
|
||||
'default': False, 'is_nullable': True},
|
||||
]
|
||||
)
|
||||
|
||||
corresponding_topts = [
|
||||
{'type': 'TensorOptions', 'name': 'options', 'is_nullable': False, 'annotation': None},
|
||||
]
|
||||
corresponding_topts.append(corresponding_topts[0].copy())
|
||||
corresponding_topts[1]['kwarg_only'] = True
|
||||
corresponding_topts.append(corresponding_topts[1].copy())
|
||||
corresponding_topts[2]['default'] = '{}'
|
||||
corresponding_topts.append(
|
||||
{'type': 'TensorOptions', 'name': 'options', 'is_nullable': False, 'annotation': None,
|
||||
'kwarg_only': True, 'default': 'at::kLong'})
|
||||
corresponding_topts.append(
|
||||
{'type': 'TensorOptions', 'name': 'options', 'is_nullable': False, 'annotation': None,
|
||||
'kwarg_only': True})
|
||||
|
||||
def check_topt_representation(topt_representation):
|
||||
for idx, supported_topt in enumerate(supported_topt_arguments):
|
||||
matches = all(topt_representation[i] == topt for i, topt in enumerate(supported_topt))
|
||||
if matches:
|
||||
return corresponding_topts[idx]
|
||||
return None
|
||||
|
||||
def is_tensor_option(argument):
|
||||
return argument['name'] in ['dtype', 'layout', 'device', 'pin_memory']
|
||||
|
||||
new_arguments = []
|
||||
idx = 0
|
||||
while idx < len(arguments):
|
||||
argument = arguments[idx]
|
||||
number_of_arguments = len(supported_topt_arguments[0])
|
||||
if is_tensor_option(argument) and len(arguments) - idx >= number_of_arguments:
|
||||
topt_representation = []
|
||||
for i in range(number_of_arguments):
|
||||
argument = arguments[idx]
|
||||
if not is_tensor_option(argument):
|
||||
break
|
||||
topt_representation.append(argument)
|
||||
idx += 1
|
||||
if len(topt_representation) == number_of_arguments:
|
||||
merged_argument = check_topt_representation(topt_representation)
|
||||
assert merged_argument, \
|
||||
"Unsupported combination of TensorOptions {}, the only currently supported combinations are {}"\
|
||||
.format(str(topt_representation), str(supported_topt_arguments))
|
||||
new_arguments.append(merged_argument)
|
||||
else:
|
||||
new_arguments += topt_representation
|
||||
else:
|
||||
new_arguments.append(argument)
|
||||
idx += 1
|
||||
|
||||
arguments = new_arguments
|
||||
|
||||
# Sanity checks
|
||||
|
||||
# TODO: convention is that the ith-argument correspond to the i-th return, but it would
|
||||
# be better if we just named everything and matched by name.
|
||||
for arg_idx, argument in enumerate(arguments_out):
|
||||
assert argument['annotation'] == func_return[arg_idx]['annotation'], \
|
||||
"For func {} writeable keyword Tensor arguments need to have a matching return Tensor. Further, " \
|
||||
"the ith-argument needs to correspond to the i-th return.".format(name)
|
||||
|
||||
assert len(arguments_out) <= len(func_return), "func {} must return at least as many Tensors " \
|
||||
"as can be passed as output.".format(name)
|
||||
|
||||
if name.endswith('_out'):
|
||||
raise RuntimeError("Native function {} may not be suffixed with _out as we transition to a unified schema. "
|
||||
"Otherwise you will cause confusion amongst consumers of native functions.".format(name))
|
||||
|
||||
if is_out_fn and func_variants not in [[], 'function', ['function']]:
|
||||
raise RuntimeError("Native functions with output MUST be declared with only the function variant; "
|
||||
"e.g., variants: function; otherwise you will tickle a Python argument binding bug "
|
||||
"(which usually manifests itself as the result variable being undefined.) "
|
||||
"The culprit was: {}".format(name))
|
||||
if not is_out_fn:
|
||||
assert len(arguments_out) == 0, "func {} is not marked as output yet contains output " \
|
||||
"keyword arguments".format(name)
|
||||
|
||||
# TODO: Explicit checking for void is a hack and should disappear after a more
|
||||
# functionally complete implementation of Tensor aliases.
|
||||
if declaration['inplace'] and len(func_return) > 0:
|
||||
found_self = False
|
||||
for arg_idx, argument in enumerate(arguments):
|
||||
if argument['name'] == "self":
|
||||
assert argument['annotation'] and argument['annotation'].endswith("!"), \
|
||||
"Inplace function \"{}\" needs to annotate Tensor argument named self " \
|
||||
"as mutable.".format(name)
|
||||
found_self = True
|
||||
assert argument['annotation'] == func_return[arg_idx]['annotation'], \
|
||||
"Inplace function annotations of function {} need to match between " \
|
||||
"input and correponding output.".format(name)
|
||||
assert argument['name'] == func_return[arg_idx]['name'] or \
|
||||
argument['name'] == func_return[arg_idx]['name'] + "_return"
|
||||
assert argument['type'] == func_return[arg_idx]['type']
|
||||
assert found_self, "Inplace function \"{}\" needs Tensor argument named self.".format(name)
|
||||
|
||||
return arguments
|
||||
|
||||
|
||||
def parse_return_arguments(return_decl, inplace, func_decl):
|
||||
arguments = []
|
||||
if return_decl == '()':
|
||||
return arguments
|
||||
|
||||
# TODO: Use a real parser here; this will get bamboozled
|
||||
# by signatures that contain things like std::array<bool, 2> (note the space)
|
||||
if return_decl[0] == '(' and return_decl[-1] == ')':
|
||||
return_decl = return_decl[1:-1]
|
||||
|
||||
multiple_args = len(return_decl.split(', ')) > 1
|
||||
for arg_idx, arg in enumerate(return_decl.split(', ')):
|
||||
t, name, default, nullable, size, annotation = type_argument_translations(arg)
|
||||
# name of arguments and name of return sometimes have collision
|
||||
# in this case, we rename the return name to <name>_return.
|
||||
return_name = name
|
||||
if name in func_decl['func'].split('->')[0]:
|
||||
return_name = name + "_return"
|
||||
argument_dict = {'type': t, 'name': return_name, 'annotation': annotation}
|
||||
if name:
|
||||
# See Note [field_name versus name]
|
||||
argument_dict['field_name'] = name
|
||||
else:
|
||||
if t == "Tensor" and inplace:
|
||||
assert annotation and annotation.endswith("!"), \
|
||||
"Return Tensor of function \"{}\" flagged as inplace needs to be " \
|
||||
"annotated as mutable".format(func_decl['func'])
|
||||
argument_dict['name'] = 'self'
|
||||
elif t == "TensorList" and inplace:
|
||||
assert annotation and annotation.endswith("!"), \
|
||||
"Return TensorList of function \"{}\" flagged as inplace needs to be " \
|
||||
"annotated as mutable".format(func_decl['func'])
|
||||
argument_dict['name'] = 'self'
|
||||
else:
|
||||
argument_dict['name'] = 'result' if not multiple_args else 'result' + str(arg_idx)
|
||||
argument_dict['output'] = True
|
||||
arguments.append(argument_dict)
|
||||
return arguments
|
||||
|
||||
|
||||
def parse_dispatch(name, dispatch):
|
||||
"""
|
||||
Parse a dictionary like {"CPU, CUDA": "blah"}
|
||||
into {"CPU": "blah", "CUDA": "blah"}
|
||||
"""
|
||||
if not isinstance(dispatch, dict):
|
||||
return dispatch
|
||||
r = {}
|
||||
for old_k, v in dispatch.items():
|
||||
ks = old_k.split(',')
|
||||
for k in ks:
|
||||
k = k.strip()
|
||||
assert k not in r, "{}, {}".format(name, k)
|
||||
r[k] = v
|
||||
return r
|
||||
|
||||
|
||||
def parse_native_yaml(path):
|
||||
with open(path, 'r') as f:
|
||||
return yaml.load(f, Loader=Loader)
|
||||
|
||||
|
||||
def propagate_field_names(output_arguments, return_arguments):
|
||||
if output_arguments:
|
||||
for i, r in enumerate(return_arguments):
|
||||
if 'field_name' in r:
|
||||
output_arguments[i]['field_name'] = r['field_name']
|
||||
|
||||
|
||||
def run(paths):
|
||||
declarations = []
|
||||
for path in paths:
|
||||
for func in parse_native_yaml(path):
|
||||
declaration = {'mode': 'native'}
|
||||
try:
|
||||
declaration['schema_string'] = "aten::" + func['func']
|
||||
if '->' in func['func']:
|
||||
func_decl, return_decl = [x.strip() for x in func['func'].split('->')]
|
||||
else:
|
||||
raise Exception('Expected return declaration')
|
||||
fn_name, arguments = func_decl.split('(', 1)
|
||||
if '.' in fn_name:
|
||||
fn_name, overload_name = fn_name.split('.', 1)
|
||||
else:
|
||||
overload_name = ''
|
||||
assert arguments[-1] == ")", "Expecting closing ) for {}".format(func['func'])
|
||||
arguments = arguments[:-1] # Expect closing )
|
||||
declaration['name'] = func.get('name', fn_name)
|
||||
declaration['operator_name'] = func.get('name', fn_name)
|
||||
declaration['overload_name'] = func.get('overload_name', overload_name)
|
||||
declaration['inplace'] = re.search('(^__i|[^_]_$)', fn_name) is not None
|
||||
return_arguments = parse_return_arguments(return_decl, declaration['inplace'], func)
|
||||
schema_order_arguments = parse_arguments(arguments)
|
||||
arguments = process_arguments(schema_order_arguments, func.get('variants', []), declaration, return_arguments)
|
||||
output_arguments = [x for x in arguments if x.get('output')]
|
||||
propagate_field_names(output_arguments, return_arguments)
|
||||
declaration['return'] = return_arguments if len(output_arguments) == 0 else output_arguments
|
||||
declaration['variants'] = func.get('variants', ['function'])
|
||||
declaration['matches_jit_signature'] = func.get('matches_jit_signature', True)
|
||||
declaration['cpu_half'] = func.get('cpu_half', False)
|
||||
declaration['cpu_bfloat16'] = func.get('cpu_bfloat16', False)
|
||||
declaration['cuda_bfloat16'] = func.get('cuda_bfloat16', False)
|
||||
declaration['cpu_bool'] = func.get('cpu_bool', False)
|
||||
declaration['cuda_bool'] = func.get('cuda_bool', False)
|
||||
declaration['deprecated'] = func.get('deprecated', False)
|
||||
declaration['device_guard'] = func.get('device_guard', True)
|
||||
declaration['use_c10_dispatcher'] = func.get('use_c10_dispatcher', 'with_codegenerated_unboxing_wrapper')
|
||||
assert declaration['use_c10_dispatcher'] in ['with_codegenerated_unboxing_wrapper', 'full']
|
||||
declaration['manual_kernel_registration'] = func.get('manual_kernel_registration', False)
|
||||
declaration['category_override'] = func.get('category_override', '')
|
||||
declaration['arguments'] = func.get('arguments', arguments)
|
||||
declaration['schema_order_arguments'] = func.get('schema_order_arguments', schema_order_arguments)
|
||||
declaration['type_method_definition_dispatch'] = \
|
||||
parse_dispatch(fn_name, func.get('dispatch', declaration['name']))
|
||||
declaration['python_module'] = func.get('python_module', '')
|
||||
declarations.append(declaration)
|
||||
except Exception as e:
|
||||
msg = '''Exception raised in processing function:
|
||||
{func}
|
||||
Generated partial declaration:
|
||||
{decl}'''.format(func=pprint.pformat(func), decl=pprint.pformat(declaration))
|
||||
print(msg, file=sys.stderr)
|
||||
raise e
|
||||
|
||||
return declarations
|
||||
|
|
@ -1,62 +0,0 @@
|
|||
# Loss functions
|
||||
|
||||
- name: _thnn_multi_margin_loss(Tensor self, LongTensor target, Scalar p, Scalar margin, Tensor? weight, int64_t reduction)
|
||||
cname: MultiMarginCriterion
|
||||
|
||||
- name: _thnn_multilabel_margin_loss(Tensor self, LongTensor target, int64_t reduction=at::Reduction::Mean)
|
||||
cname: MultiLabelMarginCriterion
|
||||
buffers: [is_target]
|
||||
CUDA:
|
||||
forward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']
|
||||
backward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']
|
||||
|
||||
- name: _thnn_nll_loss(Tensor self, LongTensor target, Tensor? weight, int64_t reduction, int64_t ignore_index)
|
||||
cname: ClassNLLCriterion
|
||||
buffers: [total_weight]
|
||||
CPU:
|
||||
forward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']
|
||||
backward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']
|
||||
CUDA:
|
||||
forward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']
|
||||
backward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']
|
||||
|
||||
- name: _thnn_nll_loss2d(Tensor self, LongTensor target, Tensor? weight, int64_t reduction, int64_t ignore_index)
|
||||
cname: SpatialClassNLLCriterion
|
||||
buffers: [total_weight]
|
||||
CUDA:
|
||||
forward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']
|
||||
backward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']
|
||||
|
||||
# Activation functions
|
||||
|
||||
- name: _thnn_glu(Tensor self, int64_t dim)
|
||||
cname: GatedLinear
|
||||
|
||||
- name: _thnn_log_sigmoid(Tensor self)
|
||||
cname: LogSigmoid
|
||||
buffers: [buffer]
|
||||
|
||||
# NOTE: we treat noise as an input (it's really a buffer) because the codegen
|
||||
# can't handle in-place functions that have buffers
|
||||
- name: _thnn_rrelu_with_noise(Tensor self, Tensor noise, Scalar lower, Scalar upper, bool training, Generator? generator=None)
|
||||
cname: RReLU
|
||||
has_inplace: True
|
||||
|
||||
# Convolutions
|
||||
|
||||
- name: _thnn_conv2d(Tensor self, Tensor weight, IntArrayRef[2] kernel_size, Tensor? bias, IntArrayRef[2] stride, IntArrayRef[2] padding)
|
||||
cname: SpatialConvolutionMM
|
||||
buffers: [columns, ones]
|
||||
CPU:
|
||||
forward_scalar_types: ['Float', 'Double', 'Long', 'BFloat16']
|
||||
backward_scalar_types: ['Float', 'Double', 'BFloat16']
|
||||
CUDA:
|
||||
forward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']
|
||||
backward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']
|
||||
|
||||
- name: _thnn_conv_depthwise2d(Tensor self, Tensor weight, IntArrayRef[2] kernel_size, Tensor? bias, IntArrayRef[2] stride, IntArrayRef[2] padding, IntArrayRef[2] dilation)
|
||||
cname: SpatialDepthwiseConvolution
|
||||
buffers: []
|
||||
CUDA:
|
||||
forward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']
|
||||
backward_scalar_types: ['Float', 'Double', 'Half', 'BFloat16']
|
||||
|
|
@ -1,388 +0,0 @@
|
|||
import copy
|
||||
import re
|
||||
import common_with_cwrap
|
||||
import yaml
|
||||
from collections import OrderedDict, defaultdict
|
||||
|
||||
try:
|
||||
# use faster C loader if available
|
||||
from yaml import CLoader as Loader
|
||||
except ImportError:
|
||||
from yaml import Loader
|
||||
|
||||
|
||||
# matches `name`, `params` in `name(params)`
|
||||
NAME_PARAM_REGEX = r'(\w+)\((.*)\)'
|
||||
|
||||
|
||||
def argument_to_declaration(param, func=None):
|
||||
arg = {}
|
||||
arg['type'], name = param.split(' ')
|
||||
if (arg['type'].endswith('?')):
|
||||
arg['is_nullable'] = True
|
||||
arg['type'] = arg['type'].rstrip('?')
|
||||
if arg['type'] == 'Tensor':
|
||||
arg['type'] = 'THTensor*'
|
||||
elif arg['type'] == 'LongTensor':
|
||||
arg['type'] = 'THIndexTensor*'
|
||||
elif arg['type'] == 'Scalar':
|
||||
arg['type'] = 'accreal'
|
||||
elif arg['type'] == 'Generator':
|
||||
arg['type'] = 'c10::optional<at::Generator>'
|
||||
|
||||
match = re.match(r'IntArrayRef\[(\d+)\]', arg['type'])
|
||||
if match:
|
||||
arg['type'] = 'IntArrayRef'
|
||||
arg['size'] = int(match.group(1))
|
||||
|
||||
if '=' in name:
|
||||
name, default = name.split('=')
|
||||
arg['optional'] = True
|
||||
arg['default'] = default
|
||||
arg['name'] = name
|
||||
|
||||
return arg
|
||||
|
||||
|
||||
def output_arguments(thnn_function):
|
||||
cname = thnn_function.name
|
||||
output_args = []
|
||||
|
||||
# function_wrapper expects everything in a declaration to be in
|
||||
# the base type (i.e. THTensor*), but if we pull a THCUNN only
|
||||
# implementation, it will have THCTensor* as the arg type. So we
|
||||
# strip the THC here before returning
|
||||
def map_to_th_type(t):
|
||||
if t.startswith('THC'):
|
||||
t = t.replace('THC', 'TH')
|
||||
return t
|
||||
|
||||
def is_output_arg(arg_name, func_name):
|
||||
if arg_name == 'output' and 'updateOutput' in cname:
|
||||
return True
|
||||
if name in {'gradInput', 'gradWeight', 'gradBias', 'gradGrid'}:
|
||||
return True
|
||||
if arg_name == 'indices' and 'updateOutput' in cname and 'Unpool' not in cname:
|
||||
# indices is an output argument in pooling and an input in unpooling
|
||||
return True
|
||||
return False
|
||||
|
||||
for arg in thnn_function.arguments:
|
||||
name = arg.name
|
||||
if is_output_arg(name, cname):
|
||||
desc = {
|
||||
'type': map_to_th_type(arg.type),
|
||||
'name': camel_to_snake(name),
|
||||
'output': True,
|
||||
}
|
||||
if name.startswith('grad_'):
|
||||
desc['is_nullable'] = True
|
||||
output_args.append(desc)
|
||||
return output_args
|
||||
|
||||
|
||||
def get_return(args):
|
||||
indices = [str(idx) for idx, arg in enumerate(args) if arg.get('output')]
|
||||
return 'argument {}'.format(','.join(indices))
|
||||
|
||||
|
||||
ARGUMENT_MAPPINGS = {
|
||||
'k': 'kernel_size',
|
||||
'd': 'stride',
|
||||
'pad': 'padding',
|
||||
'p': 'padding',
|
||||
'o': 'output_size',
|
||||
'osize': 'output_size',
|
||||
'output': 'output_size', # as a prefix e.g. outputW
|
||||
'isize': 'input_size',
|
||||
'dilation': 'dilation',
|
||||
'adj': 'output_padding',
|
||||
'a': 'output_padding',
|
||||
}
|
||||
|
||||
DIMENSION_OFFSET = {
|
||||
'width': -1,
|
||||
'height': -2,
|
||||
'B': 0,
|
||||
'C': 1,
|
||||
'W': -1,
|
||||
'H': -2,
|
||||
'T': -3,
|
||||
'left': 0,
|
||||
'right': 1,
|
||||
'top': 2,
|
||||
'bottom': 3,
|
||||
'front': 4,
|
||||
'back': 5,
|
||||
}
|
||||
|
||||
SUBSTITUTIONS = {
|
||||
'input': 'self',
|
||||
'weights': 'weight',
|
||||
'train': 'training',
|
||||
'val': 'value',
|
||||
'lambda': 'lambd',
|
||||
'negval': 'negative_slope',
|
||||
}
|
||||
|
||||
|
||||
def camel_to_snake(name):
|
||||
# from https://stackoverflow.com/questions/1175208/elegant-python-function-to-convert-camelcase-to-snake-case
|
||||
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
|
||||
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()
|
||||
|
||||
|
||||
def get_thnn_args(thnn_function, params, inplace):
|
||||
params_by_name = {p['name']: p for p in params}
|
||||
|
||||
def arg_expr(prefix, suffix):
|
||||
# e.g kW, kH
|
||||
name = ARGUMENT_MAPPINGS[prefix]
|
||||
if name not in params_by_name:
|
||||
raise RuntimeError('missing arg "{}" in {}'.format(name, thnn_function.name))
|
||||
param = params_by_name[name]
|
||||
if param['type'] == 'IntArrayRef' and 'size' in param:
|
||||
name = name + '_'
|
||||
# NB: We calculate the dimension based on the name of
|
||||
# the argument, not its positional order. This means
|
||||
# that we may reorder arguments to get them in
|
||||
# the right place; e.g., if a THNN implementation
|
||||
# has arguments in the order kernelW, kernelH, we
|
||||
# will generate a caller that is kernel[1], kernel[0]
|
||||
# to order them in the correct way.
|
||||
index = DIMENSION_OFFSET[suffix]
|
||||
if index < 0:
|
||||
index += param['size']
|
||||
expr = '{}[{}]'.format(name, index)
|
||||
return {'type': 'EXPRESSION', 'name': expr}
|
||||
|
||||
thnn_args = []
|
||||
for arg in thnn_function.arguments:
|
||||
name = arg.name
|
||||
if name == 'state':
|
||||
continue
|
||||
if inplace and name == 'output':
|
||||
name = 'self'
|
||||
aten_name = camel_to_snake(SUBSTITUTIONS.get(name, name))
|
||||
parts = aten_name.split('_')
|
||||
if aten_name in params_by_name:
|
||||
param = params_by_name[aten_name]
|
||||
if arg.is_optional:
|
||||
param['is_nullable'] = True
|
||||
thnn_args.append(copy.deepcopy(param))
|
||||
elif len(parts) == 2 and parts[0] in ARGUMENT_MAPPINGS and parts[1] in DIMENSION_OFFSET:
|
||||
# e.g. pad_left
|
||||
thnn_args.append(arg_expr(parts[0], parts[1]))
|
||||
elif name[-1] in DIMENSION_OFFSET and name[:-1] in ARGUMENT_MAPPINGS:
|
||||
# e.g kW, kH
|
||||
thnn_args.append(arg_expr(name[:-1], name[-1]))
|
||||
elif name == 'owidth' or name == 'oheight':
|
||||
thnn_args.append(arg_expr(name[0], name[1:]))
|
||||
elif name == 'scale':
|
||||
thnn_args.append({'type': 'EXPRESSION', 'name': '1'})
|
||||
elif name == 'inplace':
|
||||
thnn_args.append({'type': 'EXPRESSION', 'name': str(inplace).lower()})
|
||||
else:
|
||||
raise RuntimeError("{}: can't find binding for '{}'"
|
||||
.format(thnn_function.name, name))
|
||||
return thnn_args
|
||||
|
||||
|
||||
def remove_unused_args(args, thnn_args):
|
||||
"""Returns the subset of args whose name appears in thnn_args"""
|
||||
def clean_name(name):
|
||||
name = name[:name.index('[')] if '[' in name else name
|
||||
if name.endswith('_'):
|
||||
name = name[:-1]
|
||||
return name
|
||||
uses = set([clean_name(arg['name']) for arg in thnn_args])
|
||||
uses.add('output_mask')
|
||||
args = [arg for arg in args if arg['name'] in uses]
|
||||
for arg in args:
|
||||
if 'default' in arg:
|
||||
del arg['default']
|
||||
return args
|
||||
|
||||
|
||||
def unique_args(argslist):
|
||||
result = []
|
||||
seen = set()
|
||||
for args in argslist:
|
||||
for arg in args:
|
||||
if arg['name'] in seen:
|
||||
continue
|
||||
seen.add(arg['name'])
|
||||
result.append(arg)
|
||||
return result
|
||||
|
||||
|
||||
def function_info(name, arguments, cimpls, buffers, backends, inplace, backend_types):
|
||||
"""
|
||||
cimpls contains information use to call into THNN:
|
||||
cname: THNN function name
|
||||
arguments: arguments to functional call
|
||||
condition: [optional] guard around call
|
||||
"""
|
||||
return {
|
||||
'mode': 'NN',
|
||||
'name': name,
|
||||
'cpu_bfloat16': True if backend_types is not None and 'CPU' in backend_types and
|
||||
'BFloat16' in backend_types['CPU'] else False,
|
||||
'cuda_bfloat16': True if backend_types is not None and 'CUDA' in backend_types and
|
||||
'BFloat16' in backend_types['CUDA'] else False,
|
||||
'backend_types': backend_types,
|
||||
'arguments': arguments,
|
||||
'schema_order_arguments': copy.deepcopy(arguments),
|
||||
'return': 'argument 0' if inplace else get_return(arguments),
|
||||
'buffers': buffers,
|
||||
'backends': backends,
|
||||
'cimpls': cimpls,
|
||||
'variants': ['function'],
|
||||
}
|
||||
|
||||
def base_declaration(func, thnn_function, backends, backend_types, inplace=False):
|
||||
"""Creates the NN function without any buffers in it's signature"""
|
||||
name, params = re.match(NAME_PARAM_REGEX, func['name']).groups()
|
||||
if inplace:
|
||||
name += '_'
|
||||
params = params.split(', ')
|
||||
arguments = [argument_to_declaration(a, func) for a in params]
|
||||
if not inplace:
|
||||
arguments += output_arguments(thnn_function)
|
||||
buffers = [argument_to_declaration('Tensor ' + buf)
|
||||
for buf in func.get('buffers', [])]
|
||||
|
||||
return function_info(name, arguments, None, buffers, backends, inplace, backend_types)
|
||||
|
||||
def forward_declaration(base, thnn_function, backend_types, inplace=False):
|
||||
name = '{}_forward'.format(base['name'])
|
||||
if inplace:
|
||||
name += '_'
|
||||
|
||||
arguments = [copy.deepcopy(arg) for arg in base['arguments']
|
||||
if not arg.get('output')]
|
||||
|
||||
arguments += output_arguments(thnn_function)
|
||||
for buffer in base['buffers']:
|
||||
buffer = copy.deepcopy(buffer)
|
||||
buffer['output'] = True
|
||||
arguments.append(buffer)
|
||||
|
||||
thnn_args = get_thnn_args(thnn_function, arguments, inplace)
|
||||
arguments = remove_unused_args(arguments, thnn_args)
|
||||
cimpl = {'cname': thnn_function.name, 'arguments': thnn_args}
|
||||
|
||||
return function_info(name, arguments, [cimpl], [], base['backends'], inplace, backend_types)
|
||||
|
||||
def backward_declaration(base, thnn_functions, backend_types):
|
||||
name = '{}_backward'.format(base['name'])
|
||||
|
||||
arguments = []
|
||||
arguments.append({'type': 'THTensor*', 'name': 'grad_output'})
|
||||
arguments += [copy.deepcopy(arg) for arg in base['arguments']
|
||||
if arg['name'] != 'inplace']
|
||||
arguments += base['buffers']
|
||||
|
||||
# outputs from the forward may be inputs to the backwards
|
||||
for arg in arguments:
|
||||
if 'output' in arg:
|
||||
del arg['output']
|
||||
|
||||
arguments += unique_args([output_arguments(f) for f in thnn_functions])
|
||||
|
||||
def initialize_output_arg(arg):
|
||||
# the mask array<bool, N> specifies which return values to compute
|
||||
arg['mask'] = True
|
||||
arg['is_nullable'] = True
|
||||
|
||||
is_batch_norm_backward = '_backward' in thnn_functions[0].name
|
||||
grad_params = []
|
||||
if len(thnn_functions) > 1 or is_batch_norm_backward:
|
||||
for arg in arguments:
|
||||
if arg.get('output', False):
|
||||
initialize_output_arg(arg)
|
||||
if 'Tensor' in arg['type'] and arg['name'].startswith('grad_') and \
|
||||
'input' not in arg['name'] and 'output' not in arg['name']:
|
||||
grad_params.append(arg['name'])
|
||||
|
||||
thnn_args = [get_thnn_args(f, arguments, False) for f in thnn_functions]
|
||||
arguments = remove_unused_args(arguments, unique_args(thnn_args))
|
||||
cimpls = []
|
||||
|
||||
def get_condition(func):
|
||||
# only call into the THNN functions if the output args are not null
|
||||
if '_updateGradInput' in func.name:
|
||||
return 'grad_input_'
|
||||
if '_accGradParameters' in func.name:
|
||||
return ' || '.join(p + '_' for p in grad_params)
|
||||
return None
|
||||
|
||||
for func, args in zip(thnn_functions, thnn_args):
|
||||
cimpl = {'cname': func.name, 'arguments': args}
|
||||
if len(thnn_functions) > 1:
|
||||
cimpl['condition'] = get_condition(func)
|
||||
cimpls.append(cimpl)
|
||||
|
||||
output_args = [arg for arg in arguments if arg.get('output', False)]
|
||||
|
||||
return function_info(name, arguments, cimpls, [], base['backends'], False, backend_types)
|
||||
|
||||
|
||||
def parse_nn_yaml(filename):
|
||||
with open(filename, 'r') as f:
|
||||
return yaml.load(f, Loader=Loader)
|
||||
|
||||
|
||||
include_only = '(updateOutput|updateGradInput|accGradParameters|backward)$'
|
||||
exclude = 'LookupTable'
|
||||
|
||||
|
||||
def run(paths):
|
||||
function_backends = defaultdict(list)
|
||||
header_functions = OrderedDict()
|
||||
|
||||
headers = [p for p in paths if p.endswith('.h')]
|
||||
yamls = [p for p in paths if p.endswith('.yaml')]
|
||||
|
||||
for path in headers:
|
||||
backend = 'CUDA' if re.search('THCU', path) else 'CPU'
|
||||
for func in common_with_cwrap.parse_header(path):
|
||||
if re.search(include_only, func.name) is None or re.search(exclude, func.name) is not None:
|
||||
continue
|
||||
function_backends[func.name].append(backend)
|
||||
if func.name not in header_functions:
|
||||
header_functions[func.name] = func
|
||||
|
||||
bwd_suffixes = ['_updateGradInput', '_accGradParameters', '_backward']
|
||||
|
||||
declarations = []
|
||||
for path in yamls:
|
||||
for func in parse_nn_yaml(path):
|
||||
cname = func['cname']
|
||||
backends = function_backends[cname + '_updateOutput']
|
||||
|
||||
fwd_function = header_functions[cname + '_updateOutput']
|
||||
bwd_functions = []
|
||||
for suffix in bwd_suffixes:
|
||||
if cname + suffix in header_functions:
|
||||
bwd_functions.append(header_functions[cname + suffix])
|
||||
|
||||
default_scalar_types = ['Float', 'Double', 'Half'] # Half will be stripped for CPU backend
|
||||
forward_backend_types = {}
|
||||
backward_backend_types = {}
|
||||
for backend in backends:
|
||||
backend_props = func.get(backend, {})
|
||||
forward_backend_types[backend] = backend_props.get('forward_scalar_types', default_scalar_types)
|
||||
backward_backend_types[backend] = backend_props.get('backward_scalar_types', default_scalar_types)
|
||||
|
||||
base = base_declaration(func, fwd_function, backends, None)
|
||||
declarations.append(forward_declaration(base, fwd_function, forward_backend_types))
|
||||
if bwd_functions:
|
||||
declarations.append(backward_declaration(base, bwd_functions, backward_backend_types))
|
||||
|
||||
|
||||
if func.get('has_inplace', False):
|
||||
declarations.append(base_declaration(func, fwd_function, backends, forward_backend_types, True))
|
||||
declarations.append(forward_declaration(base, fwd_function, forward_backend_types, True))
|
||||
|
||||
return declarations
|
||||
|
|
@ -1,213 +0,0 @@
|
|||
import re
|
||||
from copy import deepcopy
|
||||
from function_wrapper import TYPE_FORMAL_GENERIC
|
||||
import common_with_cwrap
|
||||
|
||||
type_map = {
|
||||
'floating_point': [
|
||||
'Float',
|
||||
'Double',
|
||||
'Half',
|
||||
'BFloat16',
|
||||
],
|
||||
'integral': [
|
||||
'Byte',
|
||||
'Char',
|
||||
'Short',
|
||||
'Int',
|
||||
'Long',
|
||||
'Bool',
|
||||
],
|
||||
'quantized': [
|
||||
'QInt8',
|
||||
'QUInt8',
|
||||
'QInt32',
|
||||
]
|
||||
}
|
||||
|
||||
all_types = type_map['floating_point'] + type_map['integral'] + type_map['quantized']
|
||||
type_map['all'] = all_types
|
||||
|
||||
all_backends = ['CPU', 'CUDA', 'SparseCPU', 'SparseCUDA', 'MkldnnCPU', 'QuantizedCPU', 'QuantizedCUDA', 'Vulkan']
|
||||
default_backends = ['CPU', 'CUDA']
|
||||
|
||||
|
||||
def process_types_and_backends(option):
|
||||
# if specific pairs were not listed, then enumerate them
|
||||
# based on the backend and type attributes
|
||||
# if backend or type is not defined, it is assumed to be all of them
|
||||
if 'backend_types' not in option:
|
||||
backends = option.get('backends', default_backends)
|
||||
if isinstance(option.get('type_method_definition_dispatch'), dict):
|
||||
backends = option.get('type_method_definition_dispatch').keys()
|
||||
backends = set(backends)
|
||||
|
||||
backend_types = {}
|
||||
for backend in backends:
|
||||
if backend in ('QuantizedCPU', 'QuantizedCUDA'):
|
||||
backend_types[backend] = type_map['quantized']
|
||||
else:
|
||||
backend_types[backend] = option.get('types', all_types)
|
||||
else:
|
||||
backend_types = option['backend_types']
|
||||
|
||||
# expand type alias (integral, floating_point, all)
|
||||
def expand(types):
|
||||
ret = []
|
||||
for t in types:
|
||||
if t in type_map:
|
||||
ret.extend(type_map[t])
|
||||
else:
|
||||
assert(t in all_types)
|
||||
ret.append(t)
|
||||
return ret
|
||||
|
||||
for backend in backend_types.keys():
|
||||
assert backend in all_backends, "{} {}".format(backend, option['name'])
|
||||
backend_types[backend] = set(expand(backend_types[backend]))
|
||||
|
||||
# special case remove Half for cpu unless it is explicitly enabled
|
||||
if not option.get('cpu_half', False):
|
||||
if 'CPU' in backend_types:
|
||||
backend_types['CPU'].discard('Half')
|
||||
|
||||
# special case remove BFloat16 for cpu and cuda unless it is explicitly enabled
|
||||
if not option.get('cpu_bfloat16', False):
|
||||
if 'CPU' in backend_types:
|
||||
backend_types['CPU'].discard('BFloat16')
|
||||
|
||||
if not option.get('cuda_bfloat16', False):
|
||||
if 'CUDA' in backend_types:
|
||||
backend_types['CUDA'].discard('BFloat16')
|
||||
|
||||
# special cases remove bool for cpu and cuda unless it is explicitly enabled
|
||||
if not option.get('cpu_bool', False):
|
||||
if 'CPU' in backend_types:
|
||||
backend_types['CPU'].discard('Bool')
|
||||
|
||||
if not option.get('cuda_bool', False):
|
||||
if 'CUDA' in backend_types:
|
||||
backend_types['CUDA'].discard('Bool')
|
||||
|
||||
# sort the result for easy reading
|
||||
for backend in backend_types.keys():
|
||||
backend_types[backend] = sorted(backend_types[backend])
|
||||
option['backend_types'] = backend_types
|
||||
|
||||
|
||||
def exclude(declaration):
|
||||
return 'only_register' in declaration or declaration.get('name') == 'ndimension'
|
||||
|
||||
|
||||
def add_variants(option):
|
||||
option.setdefault('variants', ['method'])
|
||||
|
||||
# if we have 'output' arguments, generate a variant where
|
||||
# we mark oututs as allocate = True, and where the method variant
|
||||
# is disabled...
|
||||
|
||||
|
||||
def handle_outputs_taken_as_arguments(options):
|
||||
new_options = []
|
||||
|
||||
def is_nullable(arg):
|
||||
return (arg['type'] in {'THIntegerTensor*', 'THTensor*'} and
|
||||
arg.get('default', '') in {None, 'NULL', 'nullptr'})
|
||||
|
||||
def should_generate_out_variant(option):
|
||||
if 'function' in option['variants'] and option['mode'] != 'native':
|
||||
# don't generate _out variants for in-place functions
|
||||
return re.search('(^__i|[^_]_$)', option['api_name']) is None
|
||||
return False
|
||||
|
||||
for option in options:
|
||||
for arg in option['arguments']:
|
||||
# mark arguments which can be null
|
||||
if is_nullable(arg):
|
||||
arg['is_nullable'] = True
|
||||
|
||||
if any('output' in arg for arg in option['arguments']):
|
||||
allocate_option = deepcopy(option)
|
||||
# the allocating option needs to be marked
|
||||
for arg in allocate_option['arguments']:
|
||||
if 'output' in arg:
|
||||
arg['allocate'] = True
|
||||
|
||||
# the original option, which takes arguments for the results,
|
||||
# is no longer a method, and has _out added to indicte it takes
|
||||
# output arguments
|
||||
if should_generate_out_variant(option):
|
||||
if 'method' in option['variants']:
|
||||
option['variants'].remove('method')
|
||||
option['api_name'] += '_out'
|
||||
new_options.append(option)
|
||||
|
||||
new_options.append(allocate_option)
|
||||
else:
|
||||
new_options.append(option)
|
||||
return new_options
|
||||
|
||||
|
||||
def sanitize_return(option):
|
||||
ret = option['return']
|
||||
m = re.match(r'argument (\d+(,\d+)*)', ret)
|
||||
if m is not None:
|
||||
arguments = [int(x) for x in m.group(1).split(',')]
|
||||
option['return'] = {'kind': 'arguments', 'arguments': arguments}
|
||||
elif ret == 'self':
|
||||
option['return'] = {'kind': 'arguments', 'arguments': []}
|
||||
for i, x in enumerate(option['arguments']):
|
||||
if x['name'] == 'self':
|
||||
option['return']['arguments'].append(i)
|
||||
break
|
||||
else:
|
||||
option['return'] = {'kind': 'type', 'type': option['return']}
|
||||
|
||||
|
||||
def set_mode(option):
|
||||
option['mode'] = option.get('mode', 'TH')
|
||||
|
||||
|
||||
def is_extended_method(option):
|
||||
if 'method' in option['variants']:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
def run(declarations):
|
||||
declarations = [d for d in declarations if not exclude(d)]
|
||||
non_extended_methods = set()
|
||||
for declaration in declarations:
|
||||
common_with_cwrap.set_declaration_defaults(declaration)
|
||||
declaration['options'] = [deepcopy(o) for o in declaration['options']]
|
||||
declaration['options'] = common_with_cwrap.filter_unique_options(
|
||||
declaration['options'],
|
||||
allow_kwarg=False,
|
||||
type_to_signature=TYPE_FORMAL_GENERIC,
|
||||
remove_self=True)
|
||||
|
||||
common_with_cwrap.sort_by_number_of_args(declaration)
|
||||
|
||||
for option in declaration['options']:
|
||||
set_mode(option)
|
||||
if option['mode'] != 'native':
|
||||
sanitize_return(option)
|
||||
process_types_and_backends(option)
|
||||
add_variants(option)
|
||||
if not is_extended_method(option):
|
||||
non_extended_methods.add(option['api_name'])
|
||||
declaration['options'] = handle_outputs_taken_as_arguments(
|
||||
declaration['options'])
|
||||
# We (very unfortunately) have overloaded virtual methods. Because
|
||||
# of C++'s rules, we cannot move one overload without doing some
|
||||
# extra work to make sure that overload in a superclass and an
|
||||
# overload in a subclass resolve together. I've chosen to resolve
|
||||
# this problem simply by moving ALL overloads of a method which
|
||||
# occurs in Tensor to Type. This is why we have to first compute
|
||||
# which methods *names* go on type, and then move ALL overloads
|
||||
# of this name to Type.
|
||||
for declaration in declarations:
|
||||
for option in declaration['options']:
|
||||
option['extended_method'] = option['api_name'] not in non_extended_methods
|
||||
return declarations
|
||||
|
|
@ -304,10 +304,6 @@ if(NOT INTERN_BUILD_MOBILE OR NOT BUILD_CAFFE2_MOBILE)
|
|||
# Generate files
|
||||
set(TOOLS_PATH "${TORCH_ROOT}/tools")
|
||||
|
||||
configure_file("${TORCH_ROOT}/aten/src/ATen/common_with_cwrap.py"
|
||||
"${TOOLS_PATH}/shared/cwrap_common.py"
|
||||
COPYONLY)
|
||||
|
||||
configure_file("${TORCH_SRC_DIR}/_utils_internal.py"
|
||||
"${TOOLS_PATH}/shared/_utils_internal.py"
|
||||
COPYONLY)
|
||||
|
|
|
|||
|
|
@ -36,10 +36,10 @@ if args.aten_root:
|
|||
if not os.path.exists(args.aten_root):
|
||||
raise ValueError('aten_root ({}) does not exist'.format(
|
||||
args.aten_root))
|
||||
sys.path.append(os.path.join(args.aten_root, 'src', 'ATen'))
|
||||
from code_template import CodeTemplate as CT
|
||||
sys.path.append(os.path.join(args.aten_root, '..')) # TODO: fix this
|
||||
from tools.codegen.code_template import CodeTemplate as CT
|
||||
else:
|
||||
from src.ATen.code_template import CodeTemplate as CT # type: ignore[import,no-redef]
|
||||
from tools.codegen.code_template import CodeTemplate as CT # type: ignore[import,no-redef]
|
||||
|
||||
OP_TEMPLATE = CT.from_file(
|
||||
os.path.join(args.template_dir, 'aten_op_template.h'))
|
||||
|
|
|
|||
|
|
@ -144,13 +144,7 @@ if(INTERN_BUILD_ATEN_OPS)
|
|||
endforeach()
|
||||
list(APPEND ATen_CPU_SRCS ${cpu_kernel_cpp})
|
||||
|
||||
set(cwrap_files
|
||||
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/Declarations.cwrap
|
||||
${CMAKE_CURRENT_LIST_DIR}/../aten/src/THCUNN/generic/THCUNN.h
|
||||
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/nn.yaml
|
||||
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/native_functions.yaml)
|
||||
|
||||
file(GLOB all_python "${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/*.py")
|
||||
file(GLOB all_python "${CMAKE_CURRENT_LIST_DIR}/../tools/codegen/*.py")
|
||||
|
||||
set(GEN_ROCM_FLAG)
|
||||
if(USE_ROCM)
|
||||
|
|
@ -189,11 +183,10 @@ if(INTERN_BUILD_ATEN_OPS)
|
|||
endif()
|
||||
|
||||
set(GEN_COMMAND
|
||||
"${PYTHON_EXECUTABLE}" ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/gen.py
|
||||
"${PYTHON_EXECUTABLE}" -m tools.codegen.gen
|
||||
--source-path ${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen
|
||||
--install_dir ${CMAKE_BINARY_DIR}/aten/src/ATen
|
||||
${GEN_ROCM_FLAG}
|
||||
${cwrap_files}
|
||||
${CUSTOM_BUILD_FLAGS}
|
||||
${GEN_VULKAN_FLAGS}
|
||||
)
|
||||
|
|
@ -202,6 +195,7 @@ if(INTERN_BUILD_ATEN_OPS)
|
|||
COMMAND ${GEN_COMMAND}
|
||||
--output-dependencies ${CMAKE_BINARY_DIR}/aten/src/ATen/generated_cpp.txt
|
||||
RESULT_VARIABLE RETURN_VALUE
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/..
|
||||
)
|
||||
if(NOT RETURN_VALUE EQUAL 0)
|
||||
message(STATUS ${generated_cpp})
|
||||
|
|
@ -219,7 +213,10 @@ if(INTERN_BUILD_ATEN_OPS)
|
|||
|
||||
add_custom_command(OUTPUT ${generated_cpp} ${cuda_generated_cpp} ${core_generated_cpp}
|
||||
COMMAND ${GEN_COMMAND}
|
||||
DEPENDS ${all_python} ${all_templates} ${cwrap_files})
|
||||
DEPENDS ${all_python} ${all_templates}
|
||||
${CMAKE_CURRENT_LIST_DIR}/../aten/src/ATen/native/native_functions.yaml
|
||||
WORKING_DIRECTORY ${CMAKE_CURRENT_LIST_DIR}/..
|
||||
)
|
||||
|
||||
# Generated headers used from a CUDA (.cu) file are
|
||||
# not tracked correctly in CMake. We make the libATen.so depend explicitly
|
||||
|
|
|
|||
|
|
@ -14,16 +14,9 @@ command -v doxygen >/dev/null 2>&1 || { echo >&2 "doxygen is not supported. Abor
|
|||
|
||||
pushd "$(dirname "$0")/../../.."
|
||||
|
||||
cp aten/src/ATen/common_with_cwrap.py tools/shared/cwrap_common.py
|
||||
cp torch/_utils_internal.py tools/shared
|
||||
|
||||
python aten/src/ATen/gen.py \
|
||||
-s aten/src/ATen \
|
||||
-d build/aten/src/ATen \
|
||||
aten/src/ATen/Declarations.cwrap \
|
||||
aten/src/THCUNN/generic/THCUNN.h \
|
||||
aten/src/ATen/nn.yaml \
|
||||
aten/src/ATen/native/native_functions.yaml
|
||||
python -m tools.codegen.gen
|
||||
|
||||
python tools/setup_helpers/generate_code.py \
|
||||
--declarations-path build/aten/src/ATen/Declarations.yaml \
|
||||
|
|
|
|||
|
|
@ -29,5 +29,4 @@ warn_return_any = True
|
|||
implicit_reexport = False
|
||||
strict_equality = True
|
||||
|
||||
files =
|
||||
aten/src/ATen/code_template.py
|
||||
files = tools/codegen/gen.py
|
||||
|
|
|
|||
1
mypy.ini
1
mypy.ini
|
|
@ -17,7 +17,6 @@ check_untyped_defs = True
|
|||
files =
|
||||
torch,
|
||||
caffe2,
|
||||
aten/src/ATen/function_wrapper.py,
|
||||
test/test_complex.py,
|
||||
test/test_futures.py,
|
||||
test/test_torch.py,
|
||||
|
|
|
|||
|
|
@ -5,3 +5,4 @@ requests
|
|||
setuptools
|
||||
six
|
||||
typing_extensions
|
||||
dataclasses
|
||||
|
|
|
|||
6
setup.py
6
setup.py
|
|
@ -351,8 +351,8 @@ def build_deps():
|
|||
|
||||
# Use copies instead of symbolic files.
|
||||
# Windows has very poor support for them.
|
||||
sym_files = ['tools/shared/cwrap_common.py', 'tools/shared/_utils_internal.py']
|
||||
orig_files = ['aten/src/ATen/common_with_cwrap.py', 'torch/_utils_internal.py']
|
||||
sym_files = ['tools/shared/_utils_internal.py']
|
||||
orig_files = ['torch/_utils_internal.py']
|
||||
for sym_file, orig_file in zip(sym_files, orig_files):
|
||||
same = False
|
||||
if os.path.exists(sym_file):
|
||||
|
|
@ -368,7 +368,7 @@ def build_deps():
|
|||
################################################################################
|
||||
|
||||
# the list of runtime dependencies required by this built package
|
||||
install_requires = ['future', 'typing_extensions']
|
||||
install_requires = ['future', 'typing_extensions', 'dataclasses']
|
||||
|
||||
missing_pydep = '''
|
||||
Missing build dependency: Unable to `import {importname}`.
|
||||
|
|
|
|||
|
|
@ -25,6 +25,8 @@ from torch._C import parse_schema
|
|||
# 1: date until which the allowlist entry is valid
|
||||
# 2: (optional) function argument regex
|
||||
# ]
|
||||
#
|
||||
# NB: function name DOES NOT include overload name!
|
||||
allow_list = [
|
||||
("c10_experimental", datetime.date(2222, 1, 1)),
|
||||
# We export some functions and classes for test_jit.py directly from libtorch.so,
|
||||
|
|
@ -69,9 +71,11 @@ allow_list = [
|
|||
("aten::gcd", datetime.date(2020, 7, 30)),
|
||||
("aten::unflatten", datetime.date(2020, 8, 14)),
|
||||
("aten::linalg_outer", datetime.date(2020, 8, 30)),
|
||||
# WARNING: overload name here doesn't do anything
|
||||
("aten::linalg_outer.out", datetime.date(2020, 8, 30)),
|
||||
("aten::_compute_linear_combination", datetime.date(2020, 9, 1)),
|
||||
("__getstate__", datetime.date(2020, 9, 1), "Conv[23]dPackedParams"),
|
||||
("aten::_foreach_add_", datetime.date(2020, 10, 1)),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -215,7 +215,7 @@ class TestTypeHints(TestCase):
|
|||
finally:
|
||||
os.chdir(cwd)
|
||||
if result != 0:
|
||||
self.fail("mypy failed: {}".format(stdout))
|
||||
self.fail("mypy failed: {} {}".format(stdout, stderr))
|
||||
|
||||
@unittest.skipIf(not HAVE_MYPY, "need mypy")
|
||||
def test_run_mypy_strict(self):
|
||||
|
|
@ -237,7 +237,7 @@ class TestTypeHints(TestCase):
|
|||
finally:
|
||||
os.chdir(cwd)
|
||||
if result != 0:
|
||||
self.fail("mypy failed: {}".format(stdout))
|
||||
self.fail("mypy failed: {} {}".format(stdout, stderr))
|
||||
|
||||
if __name__ == '__main__':
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -154,7 +154,6 @@ def load_aten_declarations(path):
|
|||
if has_tensoroptions_argument(declaration):
|
||||
declaration['schema_order_args'] = [process_schema_order_arg(arg) for arg in declaration['schema_order_args']]
|
||||
declaration['api_name'] = declaration['name']
|
||||
# NB: keep this in sync with common_with_cwrap.py
|
||||
if declaration.get('overload_name'):
|
||||
declaration['type_wrapper_name'] = "{}_{}".format(
|
||||
declaration['name'], declaration['overload_name'])
|
||||
|
|
|
|||
|
|
@ -35,11 +35,7 @@ import re
|
|||
from .gen_variable_type import should_trace
|
||||
from .utils import write, is_tensor_method
|
||||
|
||||
try:
|
||||
from src.ATen.code_template import CodeTemplate
|
||||
except ImportError:
|
||||
from tools.shared.module_loader import import_module
|
||||
CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate
|
||||
from tools.codegen.code_template import CodeTemplate
|
||||
|
||||
#
|
||||
# declarations blocklist
|
||||
|
|
|
|||
|
|
@ -216,7 +216,15 @@ ${return_type} ${type_wrapper_name}(${formals}) {
|
|||
}
|
||||
""")
|
||||
|
||||
# See NOTE[UnboxedOnly] in function_wrapper.py
|
||||
# NOTE[UnboxedOnly] Many of our codegen templates currently exist twice, once
|
||||
# in an _UNBOXEDONLY_ variant and once without _UNBOXEDONLY_. This is because
|
||||
# ops that are `use_c10_dispatcher: full` need different c++ code than ops
|
||||
# that aren't `use_c10_dispatcher: full` yet. The _UNBOXEDONLY_ variants
|
||||
# are for ops that aren't `use_c10_dispatcher: full` yet and those code templates
|
||||
# can be deleted once all ops are `use_c10_dispatcher: full`.
|
||||
# If you update one of the templates, you likely also have to update the other.
|
||||
|
||||
# See NOTE[UnboxedOnly]
|
||||
UNBOXEDONLY_WRAPPER_REGISTRATION = CodeTemplate("""\
|
||||
m.impl_UNBOXED("${unqual_operator_name_with_overload}", &${class_type}::${type_wrapper_name});
|
||||
""")
|
||||
|
|
@ -366,7 +374,7 @@ ${return_type} ${api_name}(${declaration_formals}); // {"schema": "${schema_stri
|
|||
|
||||
# TraceType templates
|
||||
# TODO: change `redispatch` to `NoTracerDispatchMode` + regular `call`.
|
||||
# See NOTE[UnboxedOnly] in function_wrapper.py
|
||||
# See NOTE[UnboxedOnly]
|
||||
UNBOXED_TRACE_DISPATCH = CodeTemplate("""\
|
||||
static auto op = c10::Dispatcher::singleton()
|
||||
.findSchemaOrThrow("aten::${operator_name}", "${overload_name}")
|
||||
|
|
|
|||
|
|
@ -9,11 +9,7 @@ __all__ = [
|
|||
'split_name_params', 'write',
|
||||
]
|
||||
|
||||
try:
|
||||
from src.ATen.code_template import CodeTemplate
|
||||
except ImportError:
|
||||
from tools.shared.module_loader import import_module
|
||||
CodeTemplate = import_module('code_template', 'aten/src/ATen/code_template.py').CodeTemplate
|
||||
from tools.codegen.code_template import CodeTemplate
|
||||
|
||||
# You should use these lines, rather than doing it manually.
|
||||
# Especially if you see this error!
|
||||
|
|
|
|||
0
tools/codegen/__init__.py
Normal file
0
tools/codegen/__init__.py
Normal file
0
tools/codegen/api/__init__.py
Normal file
0
tools/codegen/api/__init__.py
Normal file
241
tools/codegen/api/cpp.py
Normal file
241
tools/codegen/api/cpp.py
Normal file
|
|
@ -0,0 +1,241 @@
|
|||
from tools.codegen.model import *
|
||||
from tools.codegen.api.types import TensorOptionsArguments, CppArgument, ThisArgument
|
||||
import tools.codegen.local as local
|
||||
from typing import Optional, Sequence, Union, Callable, List
|
||||
|
||||
# This file describes the translation of JIT schema to the public C++
|
||||
# API, which is what people use when they call functions like at::add.
|
||||
#
|
||||
# Prominent characteristics of the C++ API:
|
||||
#
|
||||
# - dtype, layout, device and pin_memory are collected into
|
||||
# a single C++ type TensorOptions (the legacy dispatcher API
|
||||
# also has this, but tensor options is really most relevant
|
||||
# for the C++ API; it makes calling kwarg factory functions
|
||||
# pleasant)
|
||||
#
|
||||
# - for 'use_c10_dispatcher: full' functions, optional tensors are
|
||||
# represented explicitly using c10::optional
|
||||
#
|
||||
# - defaulting lives here (in fact, the dispatcher is completely
|
||||
# oblivious of defaults!)
|
||||
#
|
||||
# BTW: policy on name collisions: we try not to have types with
|
||||
# collisions, but functions are fair game to collide
|
||||
|
||||
def name(func: FunctionSchema) -> str:
|
||||
name = str(func.name.name)
|
||||
if func.is_out_fn():
|
||||
name += '_out'
|
||||
return name
|
||||
|
||||
# Translation of "value types" in JIT schema to C++ API type. Value
|
||||
# types look the same no matter if they are argument types are return
|
||||
# types. Returns None if the type in question is not a value type.
|
||||
def valuetype_type(t: Type) -> Optional[str]:
|
||||
if isinstance(t, BaseType):
|
||||
if t.name == BaseTy.Tensor:
|
||||
return None
|
||||
elif t.name == BaseTy.int:
|
||||
return 'int64_t'
|
||||
elif t.name == BaseTy.float:
|
||||
return 'double'
|
||||
elif t.name == BaseTy.str:
|
||||
return 'std::string'
|
||||
elif t.name in [BaseTy.bool, BaseTy.QScheme, BaseTy.Scalar,
|
||||
BaseTy.ScalarType, BaseTy.Generator, BaseTy.Storage,
|
||||
BaseTy.Layout, BaseTy.Device, BaseTy.MemoryFormat,
|
||||
BaseTy.Dimname, BaseTy.ConstQuantizerPtr]:
|
||||
# These C++ names line up with their schema names
|
||||
return t.name.name
|
||||
else:
|
||||
raise AssertionError(f"unsupported type: {t}")
|
||||
elif isinstance(t, OptionalType):
|
||||
elem = valuetype_type(t.elem)
|
||||
if elem is None:
|
||||
return None
|
||||
return f"c10::optional<{elem}>"
|
||||
elif isinstance(t, ListType):
|
||||
if str(t.elem) == 'bool':
|
||||
assert t.size is not None
|
||||
return f"std::array<bool,{t.size}>"
|
||||
else:
|
||||
return None
|
||||
else:
|
||||
raise AssertionError(f"unrecognized type {repr(t)}")
|
||||
|
||||
# Translation of types occuring in JIT arguments to a C++ argument type.
|
||||
def argumenttype_type(t: Type, *, mutable: bool) -> str:
|
||||
# If it's a value type, do the value type translation
|
||||
r = valuetype_type(t)
|
||||
if r is not None:
|
||||
return r
|
||||
|
||||
if str(t) == 'Tensor' and mutable and local.hack_const_mutable_self():
|
||||
return 'const Tensor &'
|
||||
|
||||
if isinstance(t, BaseType):
|
||||
if t.name == BaseTy.Tensor:
|
||||
if mutable:
|
||||
return 'Tensor &'
|
||||
else:
|
||||
return 'const Tensor &'
|
||||
else:
|
||||
raise AssertionError(f"base type should have been value type {t}")
|
||||
elif isinstance(t, OptionalType):
|
||||
if str(t.elem) == 'Tensor':
|
||||
if mutable:
|
||||
return 'Tensor &' # TODO: fix this discrepancy
|
||||
else:
|
||||
if local.use_c10_dispatcher() is UseC10Dispatcher.full:
|
||||
return 'const c10::optional<Tensor>&'
|
||||
else:
|
||||
return 'const Tensor &'
|
||||
elem = argumenttype_type(t.elem, mutable=mutable)
|
||||
return f"c10::optional<{elem}>"
|
||||
elif isinstance(t, ListType):
|
||||
# TODO: remove these special cases, ArrayRef fallthrough works fine
|
||||
if str(t.elem) == 'int':
|
||||
return "IntArrayRef"
|
||||
elif str(t.elem) == 'Tensor':
|
||||
return "TensorList"
|
||||
elif str(t.elem) == 'Dimname':
|
||||
return "DimnameList"
|
||||
# TODO: do something reasonable about lists of optional tensors
|
||||
elif not local.use_c10_dispatcher() is UseC10Dispatcher.full and str(t.elem) == 'Tensor?':
|
||||
return "TensorList"
|
||||
elem = argumenttype_type(t.elem, mutable=mutable)
|
||||
# TODO: explicitly qualify namespace here
|
||||
return f"ArrayRef<{elem}>"
|
||||
else:
|
||||
raise AssertionError(f"unrecognized type {repr(t)}")
|
||||
|
||||
# Translate a JIT argument into its C++ type
|
||||
def argument_type(a: Argument) -> str:
|
||||
return argumenttype_type(a.type, mutable=a.is_write)
|
||||
|
||||
# Translation of a (non-multi) return type from JIT to C++
|
||||
def returntype_type(t: Type, *, mutable: bool) -> str:
|
||||
r = valuetype_type(t)
|
||||
if r is not None:
|
||||
return r
|
||||
|
||||
if isinstance(t, BaseType):
|
||||
if t.name == BaseTy.Tensor:
|
||||
if mutable:
|
||||
return 'Tensor &'
|
||||
else:
|
||||
return 'Tensor'
|
||||
elif isinstance(t, ListType):
|
||||
elem = returntype_type(t.elem, mutable=mutable)
|
||||
assert t.size is None, f"fixed size list returns not supported: {t}"
|
||||
return f"std::vector<{elem}>"
|
||||
|
||||
raise AssertionError(f"unrecognized return type {t}")
|
||||
|
||||
# Translation of a single return to its C++ type
|
||||
def return_type(r: Return) -> str:
|
||||
return returntype_type(r.type, mutable=r.is_write)
|
||||
|
||||
# Translation of a full (possibly multi) return from JIT to its C++ type
|
||||
def returns_type(rs: Sequence[Return]) -> str:
|
||||
if len(rs) == 0:
|
||||
return 'void'
|
||||
elif len(rs) == 1:
|
||||
return return_type(rs[0])
|
||||
else:
|
||||
args = ','.join(map(return_type, rs))
|
||||
return f'std::tuple<{args}>'
|
||||
|
||||
JIT_TO_CPP_DEFAULT = {
|
||||
'False': 'false',
|
||||
'True': 'true',
|
||||
'None': 'c10::nullopt', # UGH this one is type directed
|
||||
'Mean': 'at::Reduction::Mean',
|
||||
'[]': '{}',
|
||||
'[0,1]': '{0,1}', # TODO: stop special casing
|
||||
'contiguous_format': 'MemoryFormat::Contiguous',
|
||||
}
|
||||
|
||||
# Convert a JIT default into C++ expression representing the default
|
||||
def default_expr(d: str, t: Type) -> str:
|
||||
if d == 'None' and str(t) == 'Tensor?':
|
||||
return '{}'
|
||||
return JIT_TO_CPP_DEFAULT.get(d, d)
|
||||
|
||||
# Convert an argument into its C++ API form
|
||||
def argument(a: Union[Argument, TensorOptionsArguments, ThisArgument]) -> CppArgument:
|
||||
if isinstance(a, Argument):
|
||||
return CppArgument(
|
||||
type=argument_type(a),
|
||||
name=a.name,
|
||||
default=default_expr(a.default, a.type) if a.default is not None else None,
|
||||
argument=a,
|
||||
)
|
||||
elif isinstance(a, ThisArgument):
|
||||
return CppArgument(
|
||||
type=argument_type(a.argument),
|
||||
name="const_cast<Tensor&>(*this)", # this is an abuse but it's convenient
|
||||
default=None,
|
||||
argument=a,
|
||||
)
|
||||
elif isinstance(a, TensorOptionsArguments):
|
||||
default = None
|
||||
if all(x.default == "None" for x in a.all()):
|
||||
default = '{}'
|
||||
elif a.dtype.default == "long":
|
||||
default = 'at::kLong' # TODO: this is wrong
|
||||
return CppArgument(
|
||||
type='const TensorOptions &',
|
||||
name='options',
|
||||
default=default,
|
||||
argument=a,
|
||||
)
|
||||
else:
|
||||
assert_never(a)
|
||||
|
||||
def group_arguments(
|
||||
func: FunctionSchema, *, method: bool = False
|
||||
) -> Sequence[Union[Argument, TensorOptionsArguments, ThisArgument]]:
|
||||
args: List[Union[Argument, ThisArgument, TensorOptionsArguments]] = []
|
||||
args.extend(func.out_arguments)
|
||||
|
||||
if method:
|
||||
args.extend(ThisArgument(a) if a.name == "self" else a for a in func.arguments)
|
||||
else:
|
||||
args.extend(func.arguments)
|
||||
|
||||
# group up arguments for tensor options
|
||||
|
||||
def pred(name: str, ty: Type) -> Callable[[Argument], bool]:
|
||||
return lambda a: a.name == name and a.type in [ty, OptionalType(ty)]
|
||||
predicates = [ # order matters
|
||||
pred('dtype', Type.parse('ScalarType')),
|
||||
pred('layout', Type.parse('Layout')),
|
||||
pred('device', Type.parse('Device')),
|
||||
pred('pin_memory', Type.parse('bool')),
|
||||
]
|
||||
|
||||
i = 0
|
||||
while i < len(func.kwarg_only_arguments):
|
||||
# If there is enough space...
|
||||
if i <= len(func.kwarg_only_arguments) - len(predicates):
|
||||
# And the next len(predicates) arguments look like TensorOptions arguments
|
||||
if all(p(a) for p, a in zip(predicates, func.kwarg_only_arguments[i : i + len(predicates)])):
|
||||
# Group them together as one argument
|
||||
args.append(TensorOptionsArguments(
|
||||
dtype=func.kwarg_only_arguments[i],
|
||||
layout=func.kwarg_only_arguments[i + 1],
|
||||
device=func.kwarg_only_arguments[i + 2],
|
||||
pin_memory=func.kwarg_only_arguments[i + 3],
|
||||
))
|
||||
i += len(predicates)
|
||||
continue
|
||||
args.append(func.kwarg_only_arguments[i])
|
||||
i += 1
|
||||
|
||||
return args
|
||||
|
||||
# Convert arguments to C++ API form
|
||||
def arguments(func: FunctionSchema, *, method: bool = False) -> Sequence[CppArgument]:
|
||||
return list(map(argument, group_arguments(func, method=method)))
|
||||
109
tools/codegen/api/dispatcher.py
Normal file
109
tools/codegen/api/dispatcher.py
Normal file
|
|
@ -0,0 +1,109 @@
|
|||
from tools.codegen.model import *
|
||||
|
||||
from tools.codegen.api.types import CppArgument, DispatcherExpr, TensorOptionsArguments, \
|
||||
DispatcherArgument, ThisArgument, LegacyDispatcherArgument
|
||||
import tools.codegen.api.cpp as cpp
|
||||
import tools.codegen.api.legacy_dispatcher as legacy_dispatcher
|
||||
import tools.codegen.local as local
|
||||
|
||||
import itertools
|
||||
from typing import Sequence, Optional
|
||||
|
||||
# This file describes the translation of JIT schema to the dispatcher
|
||||
# API, the *unboxed* calling convention by which invocations through
|
||||
# the dispatcher are made. Historically, the dispatcher API matched
|
||||
# the C++ API, but with the establishment of the boxed API, we've
|
||||
# made changes to the dispatcher API to so that the unboxed API
|
||||
# better aligns with the boxed API. The dispatcher API hooks heavily
|
||||
# into our template based boxing/unboxing machinery, so changes
|
||||
# to this convention will usually need template updates too.
|
||||
#
|
||||
# Prominent characteristics of the dispatcher API:
|
||||
#
|
||||
# - 'use_c10_dispatcher: full' controls whether or not we actually
|
||||
# use the modern calling convention or not. When use_c10_dispatcher
|
||||
# is not enabled, we don't use the template machinery.
|
||||
#
|
||||
# - dtype, layout, device and pin_memory are represented as separate
|
||||
# arguments.
|
||||
#
|
||||
|
||||
def argumenttype_type(t: Type, *, mutable: bool) -> str:
|
||||
if local.use_c10_dispatcher() is UseC10Dispatcher.full:
|
||||
# This is a faux amis. If it makes sense in the future to add
|
||||
# more special cases here, or invert things so cpp.argument_type
|
||||
# calls this, or just completely inline the function, please do
|
||||
# it.
|
||||
return cpp.argumenttype_type(t, mutable=mutable)
|
||||
else:
|
||||
# This is real sharing. If you're modifying this path, ask
|
||||
# yourself why you are changing the legacy dispatcher protocol
|
||||
# here and not in legacy_dispatcher.
|
||||
return legacy_dispatcher.argumenttype_type(t, mutable=mutable)
|
||||
|
||||
def argument_type(a: Argument) -> str:
|
||||
return argumenttype_type(a.type, mutable=a.is_write)
|
||||
|
||||
def returns_type(rs: Sequence[Return]) -> str:
|
||||
# At present, there is no difference. But there could be!
|
||||
return cpp.returns_type(rs)
|
||||
|
||||
def argument(a: Argument) -> DispatcherArgument:
|
||||
if local.use_c10_dispatcher() is UseC10Dispatcher.full:
|
||||
return DispatcherArgument(
|
||||
type=argument_type(a),
|
||||
name=a.name,
|
||||
argument=a,
|
||||
)
|
||||
else:
|
||||
la = legacy_dispatcher.argument(a)
|
||||
return DispatcherArgument(
|
||||
type=la.type,
|
||||
name=la.name,
|
||||
argument=la.argument,
|
||||
)
|
||||
|
||||
def arguments(func: FunctionSchema) -> Sequence[DispatcherArgument]:
|
||||
if local.use_c10_dispatcher() is UseC10Dispatcher.full:
|
||||
return list(map(argument, itertools.chain(func.out_arguments, func.arguments, func.kwarg_only_arguments)))
|
||||
else:
|
||||
return [
|
||||
DispatcherArgument(type=la.type, name=la.name, argument=la.argument)
|
||||
for la in legacy_dispatcher.arguments(func)
|
||||
]
|
||||
|
||||
# Given a set of CppArguments in scope, return a sequence of dispatcher
|
||||
# expressions that translate the cpp API into dispatcher API
|
||||
def cppargument_exprs(a: CppArgument, *, tensor_options: Optional[CppArgument]) -> Sequence[DispatcherExpr]:
|
||||
if isinstance(a.argument, TensorOptionsArguments):
|
||||
if local.use_c10_dispatcher() is UseC10Dispatcher.full:
|
||||
ta = a.argument
|
||||
return [
|
||||
DispatcherExpr(type=argument_type(ta.dtype), expr=f'optTypeMetaToScalarType({a.name}.dtype_opt())'),
|
||||
DispatcherExpr(type=argument_type(ta.layout), expr=f'{a.name}.layout_opt()'),
|
||||
DispatcherExpr(type=argument_type(ta.device), expr=f'{a.name}.device_opt()'),
|
||||
DispatcherExpr(type=argument_type(ta.pin_memory), expr=f'{a.name}.pinned_memory_opt()'), # weird discrep
|
||||
]
|
||||
else:
|
||||
return [DispatcherExpr(type='const TensorOptions &', expr=a.name)]
|
||||
elif isinstance(a.argument, Argument):
|
||||
if a.name == 'memory_format' and tensor_options is not None and local.use_c10_dispatcher() is UseC10Dispatcher.full:
|
||||
return [DispatcherExpr(
|
||||
type=argument_type(a.argument),
|
||||
expr=f'c10::impl::check_tensor_options_and_extract_memory_format({tensor_options.name}, {a.name})')
|
||||
]
|
||||
else:
|
||||
return [DispatcherExpr(type=argument_type(a.argument), expr=a.name)]
|
||||
elif isinstance(a.argument, ThisArgument):
|
||||
return [DispatcherExpr(type=argument_type(a.argument.argument), expr=a.name)]
|
||||
else:
|
||||
assert_never(a.argument)
|
||||
|
||||
def cpparguments_exprs(args: Sequence[CppArgument]) -> Sequence[DispatcherExpr]:
|
||||
tensor_options = next((a for a in args if isinstance(a.argument, TensorOptionsArguments)), None)
|
||||
return [r for a in args for r in cppargument_exprs(a, tensor_options=tensor_options)]
|
||||
|
||||
# I don't think this is entirely sound, but it should be reasonably
|
||||
# close
|
||||
def legacydispatcherarguments_exprs(args: Sequence[LegacyDispatcherArgument]) -> Sequence[DispatcherExpr]:
|
||||
return cpparguments_exprs([CppArgument(type=a.type, name=a.name, default=None, argument=a.argument) for a in args])
|
||||
74
tools/codegen/api/legacy_dispatcher.py
Normal file
74
tools/codegen/api/legacy_dispatcher.py
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
from tools.codegen.model import *
|
||||
|
||||
from tools.codegen.api.types import TensorOptionsArguments, LegacyDispatcherArgument, ThisArgument
|
||||
import tools.codegen.api.cpp as cpp
|
||||
|
||||
from typing import Union, Sequence
|
||||
|
||||
# This file describes the translation of JIT schema to the legacy
|
||||
# dispatcher API. This looks a lot like the C++ API (which
|
||||
# makes historical sense, because historically the dispatcher API
|
||||
# and the C++ API exactly matched), but over time we have
|
||||
# evolved the C++ API without actually changing our native::
|
||||
# kernels. To be deleted eventually. Dispatcher calls use
|
||||
# this when you are not use_c10_dispatcher: full.
|
||||
|
||||
def name(func: FunctionSchema) -> str:
|
||||
name = str(func.name.name)
|
||||
# TODO: delete this!
|
||||
if func.is_out_fn():
|
||||
name += '_out'
|
||||
if func.name.overload_name:
|
||||
name += f'_{func.name.overload_name}'
|
||||
return name
|
||||
|
||||
def argumenttype_type(t: Type, *, mutable: bool) -> str:
|
||||
if str(t) == 'Tensor?':
|
||||
if mutable:
|
||||
return 'Tensor &'
|
||||
else:
|
||||
return 'const Tensor &'
|
||||
elif str(t) == 'Tensor?[]':
|
||||
return 'TensorList'
|
||||
return cpp.argumenttype_type(t, mutable=mutable)
|
||||
|
||||
def returns_type(rs: Sequence[Return]) -> str:
|
||||
return cpp.returns_type(rs)
|
||||
|
||||
def argument_type(a: Argument) -> str:
|
||||
return argumenttype_type(a.type, mutable=a.is_write)
|
||||
|
||||
def argument(a: Union[Argument, ThisArgument, TensorOptionsArguments]) -> LegacyDispatcherArgument:
|
||||
if isinstance(a, Argument):
|
||||
return LegacyDispatcherArgument(
|
||||
type=argument_type(a),
|
||||
name=a.name,
|
||||
default=cpp.default_expr(a.default, a.type) if a.default is not None else None,
|
||||
argument=a,
|
||||
)
|
||||
elif isinstance(a, ThisArgument):
|
||||
# Erase ThisArgument from the distinction
|
||||
return LegacyDispatcherArgument(
|
||||
type=argument_type(a.argument),
|
||||
name=a.argument.name,
|
||||
default=None,
|
||||
argument=a.argument,
|
||||
)
|
||||
elif isinstance(a, TensorOptionsArguments):
|
||||
# TODO: expunge this logic entirely
|
||||
default = None
|
||||
if all(x.default == "None" for x in a.all()):
|
||||
default = '{}'
|
||||
elif a.dtype.default == "long":
|
||||
default = 'at::kLong' # TODO: this is wrong
|
||||
return LegacyDispatcherArgument(
|
||||
type='const TensorOptions &',
|
||||
name='options',
|
||||
default=default,
|
||||
argument=a,
|
||||
)
|
||||
else:
|
||||
assert_never(a)
|
||||
|
||||
def arguments(func: FunctionSchema) -> Sequence[LegacyDispatcherArgument]:
|
||||
return list(map(argument, cpp.group_arguments(func)))
|
||||
95
tools/codegen/api/types.py
Normal file
95
tools/codegen/api/types.py
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
from tools.codegen.model import *
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Union, Sequence
|
||||
|
||||
# Represents the implicit *this argument for method calls in C++ API
|
||||
@dataclass(frozen=True)
|
||||
class ThisArgument:
|
||||
argument: Argument
|
||||
|
||||
# Bundle of arguments that represent a TensorOptions in the C++ API.
|
||||
@dataclass(frozen=True)
|
||||
class TensorOptionsArguments:
|
||||
dtype: Argument
|
||||
layout: Argument
|
||||
device: Argument
|
||||
pin_memory: Argument
|
||||
|
||||
def all(self) -> Sequence[Argument]:
|
||||
return [self.dtype, self.layout, self.device, self.pin_memory]
|
||||
|
||||
# Describe a argument (e.g., the x in "f(int x)") in the C++ API
|
||||
@dataclass(frozen=True)
|
||||
class CppArgument:
|
||||
# C++ type, e.g., int
|
||||
type: str
|
||||
# C++ name, e.g., x
|
||||
name: str
|
||||
# Only used by the header, but we work it out in all cases anyway
|
||||
default: Optional[str]
|
||||
# The JIT argument(s) this formal was derived from. May
|
||||
# correspond to multiple arguments if this is TensorOptions!
|
||||
# May also correspond to the implicit *this argument!
|
||||
argument: Union[Argument, TensorOptionsArguments, ThisArgument]
|
||||
|
||||
# Default string representation prints the most elaborated form
|
||||
# of the formal
|
||||
def __str__(self) -> str:
|
||||
mb_default = ""
|
||||
if self.default is not None:
|
||||
mb_default = f"={self.default}"
|
||||
return f"{self.type} {self.name}{mb_default}"
|
||||
|
||||
# However, you might also find the version with no default useful
|
||||
def str_no_default(self) -> str:
|
||||
return f"{self.type} {self.name}"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CppExpr:
|
||||
type: str
|
||||
expr: str
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DispatcherExpr:
|
||||
type: str
|
||||
expr: str
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LegacyDispatcherExpr:
|
||||
type: str
|
||||
expr: str
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DispatcherArgument:
|
||||
type: str
|
||||
name: str
|
||||
# dispatcher NEVER has defaults
|
||||
argument: Union[Argument, TensorOptionsArguments]
|
||||
# TensorOptionsArguments can occur when not using full c10 dispatch
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"{self.type} {self.name}"
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class LegacyDispatcherArgument:
|
||||
type: str
|
||||
name: str
|
||||
# Legacy dispatcher arguments have defaults for some reasons (e.g.,
|
||||
# the function prototypes in CPUType.h are defaulted). There isn't
|
||||
# really any good reason to do this, as these functions are only
|
||||
# ever called from a context where all defaulted arguments are
|
||||
# guaranteed to be given explicitly.
|
||||
# TODO: Remove this
|
||||
default: Optional[str]
|
||||
argument: Union[Argument, TensorOptionsArguments]
|
||||
|
||||
# Convention here is swapped because arguably legacy
|
||||
# dispatcher shouldn't have defaults...
|
||||
def __str__(self) -> str:
|
||||
return f"{self.type} {self.name}"
|
||||
|
||||
def str_with_default(self) -> str:
|
||||
mb_default = ""
|
||||
if self.default is not None:
|
||||
mb_default = f"={self.default}"
|
||||
return f"{self.type} {self.name}{mb_default}"
|
||||
1111
tools/codegen/gen.py
Normal file
1111
tools/codegen/gen.py
Normal file
File diff suppressed because it is too large
Load Diff
49
tools/codegen/local.py
Normal file
49
tools/codegen/local.py
Normal file
|
|
@ -0,0 +1,49 @@
|
|||
import threading
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Iterator
|
||||
|
||||
from tools.codegen.model import UseC10Dispatcher
|
||||
|
||||
# Simple dynamic scoping implementation. The name "parametrize" comes
|
||||
# from Racket.
|
||||
#
|
||||
# WARNING WARNING: LOOKING TO EDIT THIS FILE? Think carefully about
|
||||
# why you need to add a toggle to the global behavior of code
|
||||
# generation. The parameters here should really only be used
|
||||
# for "temporary" situations, where we need to temporarily change
|
||||
# the codegen in some cases because we cannot conveniently update
|
||||
# all call sites, and are slated to be eliminated once all call
|
||||
# sites are eliminated. If you don't have a plan for how to get there,
|
||||
# DON'T add a new entry here.
|
||||
|
||||
class Locals(threading.local):
|
||||
use_c10_dispatcher: Optional[UseC10Dispatcher] = None
|
||||
hack_const_mutable_self: bool = False
|
||||
_locals = Locals()
|
||||
|
||||
# The use_c10_dispatcher field in native_functions.yaml is used to
|
||||
# control codegen behavior, so that we can handle cases where
|
||||
# Dispatcher templating logic can't handle. In the terminal
|
||||
# state, use_c10_dispatcher should always be UseC10Dispatcher.full
|
||||
# and this flag can be eliminated.
|
||||
def use_c10_dispatcher() -> UseC10Dispatcher:
|
||||
assert _locals.use_c10_dispatcher is not None, \
|
||||
"need to initialize local.use_c10_dispatcher with local.parametrize"
|
||||
return _locals.use_c10_dispatcher
|
||||
|
||||
# This is used to maintain compat, see Note [Byte-for-byte compatibility]
|
||||
# It can be removed when we drop compat.
|
||||
def hack_const_mutable_self() -> bool:
|
||||
return _locals.hack_const_mutable_self
|
||||
|
||||
@contextmanager
|
||||
def parametrize(*, use_c10_dispatcher: UseC10Dispatcher, hack_const_mutable_self: bool) -> Iterator[None]:
|
||||
old_use_c10_dispatcher = _locals.use_c10_dispatcher
|
||||
old_hack_const_mutable_self = _locals.hack_const_mutable_self
|
||||
try:
|
||||
_locals.use_c10_dispatcher = use_c10_dispatcher
|
||||
_locals.hack_const_mutable_self = hack_const_mutable_self
|
||||
yield
|
||||
finally:
|
||||
_locals.use_c10_dispatcher = old_use_c10_dispatcher
|
||||
_locals.hack_const_mutable_self = old_hack_const_mutable_self
|
||||
766
tools/codegen/model.py
Normal file
766
tools/codegen/model.py
Normal file
|
|
@ -0,0 +1,766 @@
|
|||
import re
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Sequence, Dict, Optional, Iterator, Tuple, Set, NoReturn
|
||||
from enum import Enum
|
||||
import itertools
|
||||
|
||||
# A little trick from https://github.com/python/mypy/issues/6366
|
||||
# for getting mypy to do exhaustiveness checking
|
||||
# TODO: put this somewhere else, maybe
|
||||
def assert_never(x: NoReturn) -> NoReturn:
|
||||
raise AssertionError("Unhandled type: {}".format(type(x).__name__))
|
||||
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
#
|
||||
# DATA MODEL
|
||||
#
|
||||
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ #
|
||||
#
|
||||
# Some general principles for our data model.
|
||||
#
|
||||
# - Stop using C++ data types as the internal data representation
|
||||
# format. Instead, the internal data structures are centered
|
||||
# around JIT schema representation. This avoid a big problem
|
||||
# with the old codegen where we read in all the types from
|
||||
# native_functions.yaml and then immediately had to retranslate
|
||||
# them into C++ types.
|
||||
#
|
||||
# - More semantic data representation. Instead of representing
|
||||
# everything as dicts and strings, we define dataclasses for
|
||||
# every interesting entity the code generation has to deal with.
|
||||
# These dataclasses have strong semantic invariants: for example,
|
||||
# we generally require them to roundtrip losslessly into the
|
||||
# form they were parsed from. These structures are immutable
|
||||
# and you're expected to populate information once during
|
||||
# construction.
|
||||
|
||||
# Represent a source location; used for better error reporting
|
||||
@dataclass(frozen=True)
|
||||
class Location:
|
||||
file: str
|
||||
line: int
|
||||
|
||||
def __str__(self) -> str:
|
||||
return "{}:{}".format(self.file, self.line)
|
||||
|
||||
# Valid values of the 'variants' field in native_functions.yaml
|
||||
Variant = Enum('Variant', ('function', 'method'))
|
||||
|
||||
UseC10Dispatcher = Enum('UseC10Dispatcher', (
|
||||
'full',
|
||||
'with_codegenerated_unboxing_wrapper'
|
||||
))
|
||||
|
||||
# The basic input to the code generation is native_functions.yaml.
|
||||
# The name "native", BTW, comes from the distinction between native
|
||||
# functions and legacy TH functions. The legacy TH functions are gone,
|
||||
# but the "native" descriptor has stuck.
|
||||
#
|
||||
# NativeFunction models a single entry in native_functions.yaml. Its
|
||||
# fields roughly correspond to what you would see in the YAML itself,
|
||||
# but after canonicalization and parsing has occurred.
|
||||
#
|
||||
# You can see some of the overall design patterns for how we setup
|
||||
# dataclasses in this class, but we will defer a complete discussion
|
||||
# of this at FunctionSchema.
|
||||
@dataclass(frozen=True)
|
||||
class NativeFunction:
|
||||
# The function schema of the operator in question. This schema
|
||||
# has been parsed; see FunctionSchema for more about its structure.
|
||||
# (This type is quoted as we are forward referencing a type
|
||||
# defined later in the file. I opted for this ordering of the
|
||||
# classes for expository clarity.)
|
||||
func: 'FunctionSchema'
|
||||
|
||||
# Corresponds to the 'use_c10_dispatcher' field. The default
|
||||
# is 'with_codegenerated_unboxing_wrapper'
|
||||
use_c10_dispatcher: UseC10Dispatcher
|
||||
|
||||
# Whether or not to omit automatic generation of a DeviceGuard
|
||||
device_guard: bool
|
||||
|
||||
# What python module to put the function in
|
||||
python_module: Optional[str]
|
||||
|
||||
# TODO: figure out what this does
|
||||
category_override: Optional[str]
|
||||
|
||||
# If no variants are specified in native_functions.yaml, this is
|
||||
# assumed to be {'function'}.
|
||||
variants: Set[Variant]
|
||||
|
||||
# Whether or not we should skip generating registrations for
|
||||
# this kernel. This is a bit of a double-edged sword, as manual
|
||||
# registrations don't participate in codegen-based selective build!
|
||||
manual_kernel_registration: bool
|
||||
|
||||
# Distinguish between a missing dispatch dict (historically, this
|
||||
# means to register a catch-all kernel) and a present but empty
|
||||
# dispatch dict (this means register nothing; arguably, this should
|
||||
# subsume manual_kernel_registration).
|
||||
#
|
||||
# TODO: str key could be replaced with more explicit enum
|
||||
dispatch: Optional[Dict[str, str]]
|
||||
|
||||
# The location in the YAML file were this native function entry was
|
||||
# defined. This is for conveniently reporting error messages!
|
||||
loc: 'Location'
|
||||
|
||||
# NB: The benefit of defining a dataclass is that we automatically get
|
||||
# a constructor defined for all the fields we specify. No need
|
||||
# to explicitly write it out.
|
||||
|
||||
@staticmethod
|
||||
def from_yaml(ei: Dict[str, object], loc: 'Location') -> 'NativeFunction':
|
||||
"""
|
||||
Parse a NativeFunction from a dictionary as directly parsed
|
||||
from native_functions.yaml
|
||||
"""
|
||||
e = ei.copy()
|
||||
|
||||
funcs = e.pop('func')
|
||||
assert isinstance(funcs, str), f'not a str: {funcs}'
|
||||
func = FunctionSchema.parse(funcs)
|
||||
|
||||
use_c10_dispatcher_s = e.pop('use_c10_dispatcher', None)
|
||||
if use_c10_dispatcher_s is None:
|
||||
use_c10_dispatcher = UseC10Dispatcher.with_codegenerated_unboxing_wrapper
|
||||
elif use_c10_dispatcher_s == 'full':
|
||||
use_c10_dispatcher = UseC10Dispatcher.full
|
||||
else:
|
||||
raise AssertionError(
|
||||
f'use_c10_dispatcher must be unset or set to full, got {use_c10_dispatcher}')
|
||||
|
||||
variants_s = e.pop('variants', 'function')
|
||||
assert isinstance(variants_s, str)
|
||||
variants: Set[Variant] = set()
|
||||
for v in variants_s.split(', '):
|
||||
if v == 'function':
|
||||
variants.add(Variant.function)
|
||||
elif v == 'method':
|
||||
variants.add(Variant.method)
|
||||
else:
|
||||
raise AssertionError(f'illegal variant {v}')
|
||||
|
||||
manual_kernel_registration = e.pop('manual_kernel_registration', False)
|
||||
assert isinstance(manual_kernel_registration, bool), f'not a bool: {manual_kernel_registration}'
|
||||
|
||||
device_guard = e.pop('device_guard', True)
|
||||
assert isinstance(device_guard, bool), f'not a bool: {device_guard}'
|
||||
|
||||
python_module = e.pop('python_module', None)
|
||||
assert python_module is None or isinstance(python_module, str), f'not a str: {python_module}'
|
||||
|
||||
category_override = e.pop('category_override', None)
|
||||
assert category_override is None or isinstance(category_override, str), f'not a str: {category_override}'
|
||||
|
||||
raw_dispatch = e.pop('dispatch', None)
|
||||
assert raw_dispatch is None or isinstance(raw_dispatch, dict), e
|
||||
dispatch: Optional[Dict[str, str]] = None
|
||||
if raw_dispatch is not None:
|
||||
dispatch = {}
|
||||
for ks, v in raw_dispatch.items():
|
||||
if ks == '__line__':
|
||||
continue # not worth tracking line numbers for dispatch entries
|
||||
assert isinstance(ks, str), e
|
||||
assert isinstance(v, str), e
|
||||
for k in ks.split(","):
|
||||
dispatch[k.strip()] = v
|
||||
|
||||
e.pop('__line__')
|
||||
assert not e, f"leftover entries: {e}"
|
||||
|
||||
return NativeFunction(
|
||||
func=func,
|
||||
use_c10_dispatcher=use_c10_dispatcher,
|
||||
variants=variants,
|
||||
manual_kernel_registration=manual_kernel_registration,
|
||||
python_module=python_module,
|
||||
category_override=category_override,
|
||||
dispatch=dispatch,
|
||||
device_guard=device_guard,
|
||||
loc=loc,
|
||||
)
|
||||
|
||||
# __post_init__ functions in dataclasses can be used to do extra
|
||||
# validation after construction.
|
||||
#
|
||||
# Notice that we don't do any type validation here. In fact, we
|
||||
# rely exclusively on mypy to check if you've done types correctly!
|
||||
# Validation is for nontrivial invariants that cannot be (conveniently)
|
||||
# encoded in the type system.
|
||||
def __post_init__(self) -> None:
|
||||
if self.func.out_arguments:
|
||||
assert self.variants == {Variant.function}, "Native functions with out arguments MUST " \
|
||||
"be declared with only function variant; e.g., variants: function; " \
|
||||
"otherwise you will tickle a Python argument binding bug " \
|
||||
"(which usually manifests itself as the result variable being undefined.)"
|
||||
|
||||
# The function schema is undoubtedly the most important data structure
|
||||
# in all of the codegen, as it defines the type signature for operators,
|
||||
# and most of the code generation we do is type directed (e.g., look at
|
||||
# the types, decide what to do. Think about how we code generate
|
||||
# C++ function stubs!)
|
||||
#
|
||||
# We will also see in this class the general structure for how we model
|
||||
# data in this code generation. A few notable properties to point out
|
||||
# ahead of time:
|
||||
#
|
||||
# - These dataclasses are a *lossless* representation of the strings
|
||||
# they are parsed from. In fact, we assert that given the
|
||||
# information stored in the dataclass, we can exactly reconstruct
|
||||
# the string we parsed from (and assert this inside the parse
|
||||
# definition). There are a few reasons for this:
|
||||
#
|
||||
# - If you find that it is difficult to reconstruct the string
|
||||
# given a dataclass, that is a clue that you are data
|
||||
# representation is wrong.
|
||||
#
|
||||
# - It helps ensure that all relevant information is present
|
||||
# in the dataclass, so that downstream users aren't tempted
|
||||
# to reparse the original string to get some information
|
||||
# that was omitted.
|
||||
#
|
||||
# - It forces you to represent the data in-memory in the same way
|
||||
# it is recorded textually, which makes the dataclasses easier
|
||||
# to understand for someone who is familiar with the
|
||||
# textual format. (As a tradeoff, it means you have to model
|
||||
# the syntax, even when it is inconvenient. But maybe that means
|
||||
# the syntax is bad!) If you don't understand the internal
|
||||
# representation, go look at the printing code to see how
|
||||
# it maps onto the surface syntax!
|
||||
#
|
||||
# - It makes it easy to test the parsing code, as parsing code
|
||||
# that is inconsistent with the string code will fail early
|
||||
# and loudly. (As a tradeoff, it makes the parsing code a bit
|
||||
# brittle (in particular, with trivial whitespace changes you
|
||||
# are likely to trigger an assert error).
|
||||
#
|
||||
# In general, try to make the __str__ code as simple as possible
|
||||
# (even at the cost of more complex parsing logic.) Additionally,
|
||||
# try to minimize redundancy in data representation. (Precomputed
|
||||
# fields are OK though: they are defined as a simple function on
|
||||
# the canonical representation in question.)
|
||||
#
|
||||
# - These dataclasses are all frozen; once constructed their
|
||||
# values never change. This makes it easy to tell where any
|
||||
# given data came from: just look to the constructor. As a
|
||||
# tradeoff, you can't easily "decorate" a schema with extra
|
||||
# information from a post-facto analysis. We impose this
|
||||
# restriction to make these structures more understandable.
|
||||
#
|
||||
@dataclass(frozen=True)
|
||||
class FunctionSchema:
|
||||
# The name of the operator this function schema describes.
|
||||
name: 'OperatorName'
|
||||
|
||||
# NB: Sequence here is intentional, to make it read only
|
||||
arguments: Sequence['Argument']
|
||||
kwarg_only_arguments: Sequence['Argument'] # but not including out args
|
||||
# Unlike in the previous codegen, we have factored out 'out' arguments
|
||||
# in the canonical representation, removing them from kwarg
|
||||
# arguments. This choice is justified by numerous downstream
|
||||
# transformations which treat out arguments specially; additionally,
|
||||
# you can see that canonicity is not violated!
|
||||
out_arguments: Sequence['Argument'] # these are also kwarg-only
|
||||
|
||||
# TODO: Need to handle collisions with argument names at some point
|
||||
returns: Sequence['Return']
|
||||
|
||||
def schema_order_arguments(self) -> Iterator['Argument']:
|
||||
return itertools.chain(self.arguments, self.kwarg_only_arguments, self.out_arguments)
|
||||
|
||||
@staticmethod
|
||||
def parse(func: str) -> 'FunctionSchema':
|
||||
# We should probably get a proper parser here
|
||||
assert ' -> ' in func, "function schema missing return type (spaces are mandatory)"
|
||||
func_decl, return_decl = [x.strip() for x in func.split(' -> ')]
|
||||
ops, args = func_decl.split('(', 1)
|
||||
assert args[-1] == ")", "Expecting closing )"
|
||||
args = args[:-1]
|
||||
name = OperatorName.parse(ops)
|
||||
arguments, kwarg_only_arguments, out_arguments = parse_arguments(args)
|
||||
returns = parse_returns(return_decl)
|
||||
r = FunctionSchema(
|
||||
name=name,
|
||||
arguments=arguments,
|
||||
kwarg_only_arguments=kwarg_only_arguments,
|
||||
out_arguments=out_arguments,
|
||||
returns=returns
|
||||
)
|
||||
assert str(r) == func, f'{str(r)} != {func}'
|
||||
return r
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
for arg, ret in zip(self.out_arguments, self.returns):
|
||||
assert arg.annotation == ret.annotation, \
|
||||
"Out arguments must have matching return Tensor; furthermore, " \
|
||||
"the ith-argument needs to correspond to the ith return"
|
||||
if self.out_arguments:
|
||||
assert len(self.out_arguments) == len(self.returns), \
|
||||
"Must return as many arguments as there are out arguments"
|
||||
if self.name.name.inplace:
|
||||
# TODO: fixme
|
||||
if str(self.name) not in [
|
||||
'_amp_non_finite_check_and_unscale_',
|
||||
'_foreach_add_.Scalar']:
|
||||
assert len(self.returns) == 1
|
||||
|
||||
def is_out_fn(self) -> bool:
|
||||
# Note [is_out_fn]
|
||||
#
|
||||
# out functions are the variants which take an explicit out= argument
|
||||
# to populate into. We need to know if a schema corresponds to an
|
||||
# out function for several reasons:
|
||||
#
|
||||
# - They codegen differently in C++ API
|
||||
# - codegen to at::add_out rather than at::add
|
||||
# - out argument is moved to front of C++ argument list
|
||||
#
|
||||
# out functions are DEFINED to be any function with a keyword-only
|
||||
# argument that is mutable. In principle, this could lead to a
|
||||
# false positive if you define a function that mutates a
|
||||
# kwarg only argument, but this isn't the "true" output of this
|
||||
# function. A more robust definition that would work in this
|
||||
# case would also look at:
|
||||
#
|
||||
# - The output types. Out functions take in the arguments
|
||||
# they mutate and then return them again; this is sort
|
||||
# of "definitionally" what makes something an out function.
|
||||
# Historically, we DO check this for consistency.
|
||||
# - Correspondence with pure variant. An out function
|
||||
# should have a signature equivalent to its pure variant,
|
||||
# but just with extra kwargs for the output elements. This
|
||||
# is difficult to actually check for and historically
|
||||
# we only do this check in tools/
|
||||
return bool(self.out_arguments)
|
||||
|
||||
def __str__(self) -> str:
|
||||
all_arguments: List[str] = []
|
||||
all_arguments.extend(map(str, self.arguments))
|
||||
if self.kwarg_only_arguments or self.out_arguments:
|
||||
all_arguments.append('*')
|
||||
all_arguments.extend(map(str, self.kwarg_only_arguments))
|
||||
all_arguments.extend(map(str, self.out_arguments))
|
||||
all_arguments_str = ', '.join(all_arguments)
|
||||
if len(self.returns) == 1:
|
||||
returns = str(self.returns[0]) # omit parentheses
|
||||
else:
|
||||
returns = '(' + ', '.join(map(str, self.returns)) + ')'
|
||||
return f'{self.name}({all_arguments_str}) -> {returns}'
|
||||
|
||||
# Here is the rest of the data model, described more briefly.
|
||||
|
||||
# Simplified version for what actually shows up in built-ins.
|
||||
# Look at alias_info.h for expanded syntax. If you need the structure,
|
||||
# you also need to make this structure recursive so it can be lined
|
||||
# up with the type components too. For primitives this isn't really
|
||||
# necessary
|
||||
@dataclass(frozen=True)
|
||||
class Annotation:
|
||||
# Typically only has one element. Not actually a set so
|
||||
# we can conveniently assume it is canonically ordered
|
||||
alias_set: Sequence[str]
|
||||
is_write: bool
|
||||
|
||||
@staticmethod
|
||||
def parse(ann: str) -> 'Annotation':
|
||||
m = re.match(r'^([a-z])(!?)$', ann)
|
||||
assert m is not None, f'unrecognized alias annotation {ann}'
|
||||
alias_set = [m.group(1)]
|
||||
is_write = m.group(2) == '!'
|
||||
r = Annotation(alias_set=alias_set, is_write=is_write)
|
||||
assert str(r) == ann, f'{r} != {ann}'
|
||||
return r
|
||||
|
||||
def __str__(self) -> str:
|
||||
alias_set = '|'.join(self.alias_set)
|
||||
is_write = '!' if self.is_write else ''
|
||||
return f'{alias_set}{is_write}'
|
||||
|
||||
# The base class for the type system. This is also loosely modeled
|
||||
# off of jit_type.h, but we've simplified the hierarchy to focus
|
||||
# in on the aspects of the type system that matter for code generation
|
||||
# (for example, there's no SingleElementType subclass anymore).
|
||||
# You never actually construct a Type; usually it's going to be one
|
||||
# of the subclasses. If Python had ADTs this would be one!
|
||||
@dataclass(frozen=True)
|
||||
class Type:
|
||||
@staticmethod
|
||||
def parse(t: str) -> 'Type':
|
||||
r = Type._parse(t)
|
||||
assert str(r) == t, f'{r} != {t}'
|
||||
return r
|
||||
|
||||
@staticmethod
|
||||
def _parse(t: str) -> 'Type':
|
||||
m = re.match(r'^(.+)\?$', t)
|
||||
if m is not None:
|
||||
return OptionalType(Type.parse(m.group(1)))
|
||||
m = re.match(r'^(.+)\[([0-9]+)?\]$', t)
|
||||
if m is not None:
|
||||
size = int(m.group(2)) if m.group(2) is not None else None
|
||||
return ListType(elem=Type.parse(m.group(1)), size=size)
|
||||
try:
|
||||
return BaseType(BaseTy[t])
|
||||
except KeyError:
|
||||
raise RuntimeError(f"unrecognized type {t}")
|
||||
|
||||
def __str__(self) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
# WARNING: These concepts are not very well-defined. For example,
|
||||
# is "int?" nullable? How about "int?[]". They are defined
|
||||
# so we can conveniently generate legacy Declarations.yaml but
|
||||
# really we should probably just remove these at some point
|
||||
|
||||
def is_tensor_like(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_nullable(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
def is_list_like(self) -> Optional['ListType']:
|
||||
raise NotImplementedError
|
||||
|
||||
# Base types are simple, atomic types with no further structure
|
||||
BaseTy = Enum('BaseTy', (
|
||||
'Generator',
|
||||
'ScalarType',
|
||||
'Tensor',
|
||||
'int',
|
||||
'Dimname',
|
||||
'float',
|
||||
'str',
|
||||
'bool',
|
||||
'Layout',
|
||||
'Device',
|
||||
'Scalar',
|
||||
'MemoryFormat',
|
||||
'QScheme',
|
||||
'Storage',
|
||||
'ConstQuantizerPtr', # TODO: rename
|
||||
))
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class BaseType(Type):
|
||||
name: BaseTy
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'{self.name.name}'
|
||||
|
||||
def is_tensor_like(self) -> bool:
|
||||
return self.name == BaseTy.Tensor
|
||||
|
||||
def is_nullable(self) -> bool:
|
||||
return False
|
||||
|
||||
def is_list_like(self) -> Optional['ListType']:
|
||||
return None
|
||||
|
||||
# Optional types may be specified, or may also be validly given None
|
||||
@dataclass(frozen=True)
|
||||
class OptionalType(Type):
|
||||
elem: Type
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f'{self.elem}?'
|
||||
|
||||
def is_tensor_like(self) -> bool:
|
||||
return self.elem.is_tensor_like()
|
||||
|
||||
def is_nullable(self) -> bool:
|
||||
return True
|
||||
|
||||
def is_list_like(self) -> Optional['ListType']:
|
||||
return self.elem.is_list_like()
|
||||
|
||||
# List types specify that we may have multiples of an element. We
|
||||
# also support explicit sizes on list types, but these have
|
||||
# some nontrivial semantics! (However, for C++ API purposes, explicit
|
||||
# sizes are mostly erased from the type system.)
|
||||
#
|
||||
# DANGER WILL ROBINSON: C++ elaboration depends on elem type; e.g.,
|
||||
# int[] elaborates differently than bool[3]!
|
||||
@dataclass(frozen=True)
|
||||
class ListType(Type):
|
||||
elem: Type
|
||||
size: Optional[int]
|
||||
|
||||
def __str__(self) -> str:
|
||||
size = f'{self.size}' if self.size else ''
|
||||
return f'{self.elem}[{size}]'
|
||||
|
||||
def is_tensor_like(self) -> bool:
|
||||
return self.elem.is_tensor_like()
|
||||
|
||||
def is_nullable(self) -> bool:
|
||||
return self.elem.is_nullable()
|
||||
|
||||
def is_list_like(self) -> Optional['ListType']:
|
||||
return self
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Argument:
|
||||
# NB: I didn't put kwarg_only as a boolean field here, unlike
|
||||
# c10::Argument, so that printing works correctly
|
||||
|
||||
name: str
|
||||
type: Type
|
||||
default: Optional[str]
|
||||
|
||||
# The semantics of the annotation field are a little strange.
|
||||
#
|
||||
# Alias annotations parametrize Tensors (since Tensors are the only things
|
||||
# that can alias.) This motivates why I write Tensor(a!)? (and not, for
|
||||
# example, Tensor?(a!)), because the (a!) describes aliasing on the tensor,
|
||||
# which may be optional (i.e., the alias annotation should bind first to
|
||||
# Tensor, before the optional postfix annotation).
|
||||
#
|
||||
# However, despite being a property of Tensor, we (and c10::Argument)
|
||||
# store the annotation at the top level of the Argument, rather than
|
||||
# inside the embedded Tensor type. In the C++ version of this
|
||||
# class, we then go through great lengths to mimic the type
|
||||
# structure in the annotation structure so we can correlate
|
||||
# annotations with types.
|
||||
#
|
||||
# Now, it turns out, in all applications in code generation, the
|
||||
# structure of annotated types is very simple. So we just hard
|
||||
# code it here. But if we ever do get anything more complex, this
|
||||
# model will have to change!
|
||||
annotation: Optional[Annotation]
|
||||
|
||||
@staticmethod
|
||||
def parse(arg: str) -> 'Argument':
|
||||
name: str
|
||||
default: Optional[str]
|
||||
type_and_annot, name_and_default = arg.rsplit(' ', 1)
|
||||
if '=' in name_and_default:
|
||||
name, default = name_and_default.split('=')
|
||||
else:
|
||||
name = name_and_default
|
||||
default = None
|
||||
# TODO: deduplicate annotation matching with Return
|
||||
match = re.match(r'Tensor\((.+)\)(.*)', type_and_annot)
|
||||
annotation: Optional[Annotation]
|
||||
if match:
|
||||
# If you update this, make sure the __str__ still works too
|
||||
assert match.group(2) in ['', '?', '[]'], 'unrecognized alias analysis form with Tensor'
|
||||
type_s = 'Tensor' + match.group(2)
|
||||
annotation = Annotation.parse(match.group(1))
|
||||
else:
|
||||
type_s = type_and_annot
|
||||
annotation = None
|
||||
type = Type.parse(type_s)
|
||||
r = Argument(
|
||||
name=name,
|
||||
type=type,
|
||||
default=default,
|
||||
annotation=annotation,
|
||||
)
|
||||
assert str(r) == arg, f'{str(r)} != {arg}'
|
||||
return r
|
||||
|
||||
@property
|
||||
def is_write(self) -> bool:
|
||||
return self.annotation is not None and self.annotation.is_write
|
||||
|
||||
def __str__(self) -> str:
|
||||
type = f'{self.type}'
|
||||
if self.annotation:
|
||||
assert type in ['Tensor', 'Tensor?', 'Tensor[]']
|
||||
type = type.replace('Tensor', f'Tensor({self.annotation})')
|
||||
if self.name is None:
|
||||
return type
|
||||
else:
|
||||
mb_default = ''
|
||||
if self.default:
|
||||
mb_default = f'={self.default}'
|
||||
return f"{type} {self.name}{mb_default}"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Return:
|
||||
name: Optional[str]
|
||||
type: Type
|
||||
annotation: Optional[Annotation]
|
||||
|
||||
@staticmethod
|
||||
def parse(arg: str) -> 'Return':
|
||||
name: Optional[str]
|
||||
if ' ' in arg:
|
||||
type_and_annot, name = arg.rsplit(' ', 1)
|
||||
else:
|
||||
type_and_annot = arg
|
||||
name = None
|
||||
match = re.match(r'Tensor\((.+)\)(.*)', type_and_annot)
|
||||
annotation: Optional[Annotation]
|
||||
if match:
|
||||
# If you update this, make sure the __str__ still works too
|
||||
assert match.group(2) in ['', '?', '[]'], 'unrecognized alias analysis form with Tensor'
|
||||
type_s = 'Tensor' + match.group(2)
|
||||
annotation = Annotation.parse(match.group(1))
|
||||
else:
|
||||
type_s = type_and_annot
|
||||
annotation = None
|
||||
type = Type.parse(type_s)
|
||||
r = Return(
|
||||
name=name,
|
||||
type=type,
|
||||
annotation=annotation,
|
||||
)
|
||||
assert str(r) == arg, f'{str(r)} != {arg}'
|
||||
return r
|
||||
|
||||
@property
|
||||
def is_write(self) -> bool:
|
||||
return self.annotation is not None and self.annotation.is_write
|
||||
|
||||
def __str__(self) -> str:
|
||||
type = f'{self.type}'
|
||||
if self.annotation:
|
||||
assert type in ['Tensor', 'Tensor?', 'Tensor[]']
|
||||
type = type.replace('Tensor', f'Tensor({self.annotation})')
|
||||
if self.name is None:
|
||||
return type
|
||||
else:
|
||||
return f"{type} {self.name}"
|
||||
|
||||
|
||||
# Names that validly are __iXXX__ indicating inplace operations.
|
||||
# Taken from https://www.python.org/dev/peps/pep-0203/#new-methods
|
||||
# NB: PyTorch hasn't actually implemented all of these
|
||||
AUGMENTED_ASSIGNMENT_NAMES = ['add', 'sub', 'mul', 'div', 'mod', 'pow', 'lshift', 'rshift', 'and', 'xor', 'or']
|
||||
|
||||
# A BaseOperatorName is what we think of the operator name, without
|
||||
# the overload name. Unusually, we don't represent this as just a
|
||||
# string; instead, we directly represent a few important semantic
|
||||
# bits of information we derive from the string: namely whether
|
||||
# or not it's inplace (add_) and whether or not it's a double-underscore
|
||||
# method (__add__)
|
||||
@dataclass(frozen=True)
|
||||
class BaseOperatorName:
|
||||
base: str
|
||||
inplace: bool
|
||||
dunder_method: bool
|
||||
|
||||
@staticmethod
|
||||
def parse(op: str) -> 'BaseOperatorName':
|
||||
assert op != ''
|
||||
assert not op.endswith('_out'), \
|
||||
"_out suffix is reserved and not permitted for operator names; " \
|
||||
"did you mean to specify an out overload name instead?"
|
||||
m = re.match(r'^__([^_]+)__$', op)
|
||||
if m is not None:
|
||||
dunder_method = True
|
||||
base = m.group(1)
|
||||
if any(base == f'i{n}' for n in AUGMENTED_ASSIGNMENT_NAMES):
|
||||
inplace = True
|
||||
base = base[1:]
|
||||
else:
|
||||
inplace = False
|
||||
# temporary, this is not intrinsically true but
|
||||
# has been historically true for dunder methods
|
||||
# we support (but, if we ever got, say, __int__, this would
|
||||
# be wrong!)
|
||||
assert base[0] != 'i'
|
||||
else:
|
||||
dunder_method = False
|
||||
base = op
|
||||
if base[-1] == '_':
|
||||
inplace = True
|
||||
base = base[:-1]
|
||||
else:
|
||||
inplace = False
|
||||
r = BaseOperatorName(base=base, inplace=inplace, dunder_method=dunder_method)
|
||||
assert str(r) == op, f'{str(r)} != {op}'
|
||||
return r
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.dunder_method:
|
||||
i = 'i' if self.inplace else ''
|
||||
return f'__{i}{self.base}__'
|
||||
else:
|
||||
i = '_' if self.inplace else ''
|
||||
return f'{self.base}{i}'
|
||||
|
||||
# Operator name is the base operator name along with the (typically not
|
||||
# user visible) overload string.
|
||||
@dataclass(frozen=True)
|
||||
class OperatorName:
|
||||
name: BaseOperatorName
|
||||
overload_name: str
|
||||
|
||||
@staticmethod
|
||||
def parse(op_name: str) -> 'OperatorName':
|
||||
if '.' in op_name:
|
||||
name, overload_name = op_name.split('.', 1)
|
||||
else:
|
||||
name = op_name
|
||||
overload_name = ''
|
||||
r = OperatorName(
|
||||
name=BaseOperatorName.parse(name),
|
||||
overload_name=overload_name
|
||||
)
|
||||
assert str(r) == op_name, f'{str(r)} != {op_name}'
|
||||
return r
|
||||
|
||||
def __str__(self) -> str:
|
||||
if self.overload_name:
|
||||
return f"{self.name}.{self.overload_name}"
|
||||
else:
|
||||
return f"{self.name}"
|
||||
|
||||
# Helper functions for parsing argument lists (both inputs and returns)
|
||||
|
||||
def parse_returns(return_decl: str) -> Sequence[Return]:
|
||||
"""
|
||||
Input: '()'
|
||||
Output: []
|
||||
"""
|
||||
if return_decl == '()':
|
||||
return []
|
||||
if return_decl[0] == '(' and return_decl[-1] == ')':
|
||||
return_decl = return_decl[1:-1]
|
||||
returns = []
|
||||
for arg in return_decl.split(', '):
|
||||
returns.append(Return.parse(arg))
|
||||
return returns
|
||||
|
||||
def parse_arguments(args: str) -> Tuple[Sequence[Argument], Sequence[Argument], Sequence[Argument]]:
|
||||
"""
|
||||
Input: 'int x, int y, int z'
|
||||
Output: positional args, kwarg only args
|
||||
"""
|
||||
arguments: List[Argument] = []
|
||||
kwarg_only_arguments: List[Argument] = []
|
||||
out_arguments: List[Argument] = []
|
||||
arguments_acc = arguments
|
||||
|
||||
# TODO: Use a real parser here; this will get bamboozled
|
||||
# by signatures that contain things like std::array<bool, 2> (note the space)
|
||||
for arg in args.split(', '):
|
||||
if not arg:
|
||||
continue
|
||||
if arg == '*':
|
||||
assert arguments_acc is arguments, "invalid syntax: kwarg-only specifier * can only occur once"
|
||||
arguments_acc = kwarg_only_arguments
|
||||
continue
|
||||
parg = Argument.parse(arg)
|
||||
# Currently, we rely directly on the invariant that there are NO
|
||||
# kwarg-only mutating arguments. If you want to relax this,
|
||||
# we will need a more semantic way of matching that takes
|
||||
# into account return arguments. In that case, you will have
|
||||
# to manage out_arguments computation a level up, in
|
||||
# FunctionSchema. See Note [is_out_fn]
|
||||
if parg.annotation is not None and parg.annotation.is_write:
|
||||
if arguments_acc is arguments:
|
||||
pass # do nothing
|
||||
elif arguments_acc is kwarg_only_arguments:
|
||||
arguments_acc = out_arguments
|
||||
else:
|
||||
assert arguments_acc is not out_arguments
|
||||
arguments_acc.append(parg)
|
||||
|
||||
return arguments, kwarg_only_arguments, out_arguments
|
||||
11
tools/setup_helpers/gen.py
Normal file
11
tools/setup_helpers/gen.py
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
# Little stub file to get BUILD.bazel to play along
|
||||
|
||||
import os.path
|
||||
import sys
|
||||
|
||||
root = os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
sys.path.insert(0, root)
|
||||
|
||||
import tools.codegen.gen
|
||||
|
||||
tools.codegen.gen.main()
|
||||
Loading…
Reference in New Issue
Block a user