pytorch/torch/cuda/jiterator.py
zabboud 01478f1afa Fix pydocstyle errors listed in issue 112589 (#113227)
Fixes #112589

Fixed errors relating to pydocstyle in the following files. The remaining errors are related to docstrings at the module level and at methods within each module (see details below)

pydocstyle torch/cuda/_utils.py --count
before: 3
after: 0

pydocstyle torch/cuda/jiterator.py --count
before: 3
after: 1

**remaining errors:**
```
torch/cuda/jiterator.py:1 at module level:
        D100: Missing docstring in public module
```

pydocstyle torch/cuda/graphs.py --count
before: 25
after: 7

**remaining errors:**
```
torch/cuda/graphs.py:1 at module level:
        D100: Missing docstring in public module
torch/cuda/graphs.py:54 in public method `__new__`:
        D102: Missing docstring in public method
torch/cuda/graphs.py:108 in public method `debug_dump`:
        D205: 1 blank line required between summary line and description (found 0)
torch/cuda/graphs.py:108 in public method `debug_dump`:
        D400: First line should end with a period (not ':')
torch/cuda/graphs.py:150 in public method `__init__`:
        D107: Missing docstring in __init__
torch/cuda/graphs.py:172 in public method `__enter__`:
        D105: Missing docstring in magic method
torch/cuda/graphs.py:186 in public method `__exit__`:
        D105: Missing docstring in magic method
```

pydocstyle torch/cuda/_sanitizer.py --count
before: 35
after: 31

**remaining errors:**
```
torch/cuda/_sanitizer.py:43 in public class `AccessType`:
        D101: Missing docstring in public class
torch/cuda/_sanitizer.py:47 in public method `__str__`:
        D105: Missing docstring in magic method
torch/cuda/_sanitizer.py:84 in public method `__init__`:
        D107: Missing docstring in __init__
torch/cuda/_sanitizer.py:96 in public method `__str__`:
        D105: Missing docstring in magic method
torch/cuda/_sanitizer.py:139 in public method `__init__`:
        D107: Missing docstring in __init__
torch/cuda/_sanitizer.py:142 in public method `__str__`:
        D105: Missing docstring in magic method
torch/cuda/_sanitizer.py:218 in public class `StreamSynchronizations`:
        D101: Missing docstring in public class
torch/cuda/_sanitizer.py:219 in public method `__init__`:
        D107: Missing docstring in __init__
torch/cuda/_sanitizer.py:256 in public method `create_stream`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:268 in public method `create_event`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:272 in public method `delete_event`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:276 in public method `update_seq_num`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:280 in public method `record_state`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:291 in public method `stream_wait_for_event`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:298 in public method `all_streams_wait_for_event`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:307 in public method `all_streams_wait_for_stream`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:316 in public method `sync_all_streams`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:323 in public method `is_ordered_after`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:339 in public method `__init__`:
        D107: Missing docstring in __init__
torch/cuda/_sanitizer.py:460 in public function `zip_by_key`:
        D103: Missing docstring in public function
torch/cuda/_sanitizer.py:466 in public function `zip_arguments`:
        D103: Missing docstring in public function
torch/cuda/_sanitizer.py:478 in public class `ArgumentHandler`:
        D101: Missing docstring in public class
torch/cuda/_sanitizer.py:479 in public method `__init__`:
        D107: Missing docstring in __init__
torch/cuda/_sanitizer.py:505 in public method `parse_inputs`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:520 in public method `parse_outputs`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:527 in public class `CUDASanitizerDispatchMode`:
        D101: Missing docstring in public class
torch/cuda/_sanitizer.py:528 in public method `__init__`:
        D107: Missing docstring in __init__
torch/cuda/_sanitizer.py:562 in public method `__torch_dispatch__`:
        D105: Missing docstring in magic method
torch/cuda/_sanitizer.py:597 in public method `__init__`:
        D107: Missing docstring in __init__
torch/cuda/_sanitizer.py:601 in public method `enable`:
        D102: Missing docstring in public method
torch/cuda/_sanitizer.py:605 in public method `__del__`:
        D105: Missing docstring in magic method
```

pydocstyle torch/storage.py --count
before: 90
after: 37

**remaining errors:**
```
torch/storage.py:1 at module level:
        D100: Missing docstring in public module
torch/storage.py:310 in public class `UntypedStorage`:
        D101: Missing docstring in public class
torch/storage.py:311 in public method `__getitem__`:
        D105: Missing docstring in magic method
torch/storage.py:317 in public method `is_cuda`:
        D102: Missing docstring in public method
torch/storage.py:321 in public method `is_hpu`:
        D102: Missing docstring in public method
torch/storage.py:325 in public method `share_memory_`:
        D102: Missing docstring in public method
torch/storage.py:444 in public class `TypedStorage`:
        D101: Missing docstring in public class
torch/storage.py:453 in public method `fill_`:
        D102: Missing docstring in public method
torch/storage.py:458 in public method `__new__`:
        D102: Missing docstring in public method
torch/storage.py:530 in public method `__init__`:
        D107: Missing docstring in __init__
torch/storage.py:599 in public method `is_cuda`:
        D102: Missing docstring in public method
torch/storage.py:604 in public method `is_hpu`:
        D102: Missing docstring in public method
torch/storage.py:624 in public method `__len__`:
        D105: Missing docstring in magic method
torch/storage.py:653 in public method `__setitem__`:
        D105: Missing docstring in magic method
torch/storage.py:681 in public method `__getitem__`:
        D105: Missing docstring in magic method
torch/storage.py:715 in public method `copy_`:
        D102: Missing docstring in public method
torch/storage.py:723 in public method `nbytes`:
        D102: Missing docstring in public method
torch/storage.py:731 in public method `type`:
        D102: Missing docstring in public method
torch/storage.py:744 in public method `cuda`:
        D102: Missing docstring in public method
torch/storage.py:751 in public method `hpu`:
        D102: Missing docstring in public method
torch/storage.py:758 in public method `element_size`:
        D102: Missing docstring in public method
torch/storage.py:766 in public method `get_device`:
        D102: Missing docstring in public method
torch/storage.py:770 in public method `__str__`:
        D105: Missing docstring in magic method
torch/storage.py:781 in public method `__repr__`:
        D105: Missing docstring in magic method
torch/storage.py:785 in public method `__iter__`:
        D105: Missing docstring in magic method
torch/storage.py:789 in public method `__copy__`:
        D105: Missing docstring in magic method
torch/storage.py:793 in public method `__deepcopy__`:
        D105: Missing docstring in magic method
torch/storage.py:801 in public method `__sizeof__`:
        D105: Missing docstring in magic method
torch/storage.py:877 in public method `device`:
        D102: Missing docstring in public method
torch/storage.py:881 in public method `size`:
        D102: Missing docstring in public method
torch/storage.py:891 in public method `pickle_storage_type`:
        D102: Missing docstring in public method
torch/storage.py:902 in public method `__reduce__`:
        D105: Missing docstring in magic method
torch/storage.py:907 in public method `data_ptr`:
        D102: Missing docstring in public method
torch/storage.py:915 in public method `resize_`:
        D102: Missing docstring in public method
torch/storage.py:931 in public method `from_buffer`:
        D102: Missing docstring in public method
torch/storage.py:1032 in public method `from_file`:
        D402: First line should not be the function's "signature"
torch/storage.py:1075 in public method `is_shared`:
        D102: Missing docstring in public method

```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113227
Approved by: https://github.com/kit1980
2023-11-13 22:05:45 +00:00

186 lines
6.6 KiB
Python

import re
from typing import Callable, List
import torch
from torch import Tensor
__all__: List[str] = []
class _CodeParser:
def __init__(self, code_string: str):
optional_ws = r"\s*"
required_ws = r"\s+"
template_params = r"(?P<template_params>\<.+\>)"
return_type = r"(?P<return_type>\w+)"
function_name = r"(?P<function_name>\w+)"
function_params = r"(?P<function_params>\(.+\))"
function_body = r"(?P<function_body>\{.+\})"
pattern = (
optional_ws
+ "template"
+ optional_ws
+ template_params
+ optional_ws
+ return_type
+ required_ws
+ function_name
+ optional_ws
+ function_params
+ optional_ws
+ function_body
+ optional_ws
)
result = re.match(
pattern, code_string, re.DOTALL
) # DOTALL for matching multiline
if result is None:
raise Exception(
f"Couldn't parse code, please check correctness:\n {code_string}"
)
self.template_params = result["template_params"]
self.return_type = result["return_type"]
self.function_name = result["function_name"]
self.function_params = result["function_params"]
self.function_body = result["function_body"]
class _JittedFunction:
def __init__(
self, code_string: str, return_by_ref: bool, num_outputs: int, **kwargs
):
self.code_string = code_string
assert (
return_by_ref or num_outputs == 1
), "Return by value only works for single output. "
self.return_by_ref = return_by_ref
self.num_outputs = num_outputs
parsed_code = _CodeParser(code_string)
self.kernel_name = parsed_code.function_name
self.kwargs_dict = kwargs
self.is_cuda_available = torch.cuda.is_available()
def __call__(self, *tensors: Tensor, **kwargs):
# Jiterator follow torch.cuda's lazy initialization behavior
# Defer checking cuda's availability at the function invocation time
assert (
self.is_cuda_available
), "Jiterator is only supported on CUDA and ROCm GPUs, none are available."
assert len(tensors) <= 8, "jiterator only supports up to 8 tensor inputs."
expanded_kwargs = self.kwargs_dict.copy()
for key, value in kwargs.items():
if key in self.kwargs_dict:
expanded_kwargs[key] = value
else:
raise KeyError(f"{key} is not declared in function definition")
return torch._C._cuda_jiterator_compile_and_launch_kernel(
self.code_string,
self.kernel_name,
self.return_by_ref,
self.num_outputs,
tensors,
expanded_kwargs,
)
def _create_jit_fn(code_string: str, **kwargs) -> Callable:
"""
Create a jiterator-generated cuda kernel for an elementwise op.
The code string has to be a valid CUDA function that describes the computation for a single element. The code
string has to follow the c++ template pattern, as shown in the example below. This function will be inlined
into elementwise kernel template, and compiled on the fly. Compiled kernel will be cached in memory, as well as
local temp dir.
Jiterator-generated kernels accepts noncontiguous tensors, and supports broadcasting and type promotion.
Args:
code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return by value.
kwargs (Dict, optional): Keyword arguments for generated function
Example::
code_string = "template <typename T> T my_kernel(T x, T y, T alpha) { return -x + alpha * y; }"
jitted_fn = create_jit_fn(code_string, alpha=1.0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
# invoke jitted function like a regular python function
result = jitted_fn(a, b, alpha=3.14)
code_string also allows multiple function definitions, and the last function will be treated as the entry function.
Example::
code_string = "template <typename T> T util_fn(T x, T y) { return ::sin(x) + ::cos(y); }"
code_string += "template <typename T> T my_kernel(T x, T y, T val) { return ::min(val, util_fn(x, y)); }"
jitted_fn = create_jit_fn(code_string, val=0.0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
# invoke jitted function like a regular python function
result = jitted_fn(a, b) # using default val=0.0
Jiterator can be used together with python registration to override an operator's cuda kernel.
Following example is overriding gelu's cuda kernel with relu.
Example::
code_string = "template <typename T> T my_gelu(T a) { return a > 0 ? a : 0; }"
my_gelu = create_jit_fn(code_string)
my_lib = torch.library.Library("aten", "IMPL")
my_lib.impl('aten::gelu', my_gelu, "CUDA")
# torch.nn.GELU and torch.nn.function.gelu are now overridden
a = torch.rand(3, device='cuda')
torch.allclose(torch.nn.functional.gelu(a), torch.nn.functional.relu(a))
.. warning::
This API is in beta and may change in future releases.
.. warning::
This API only supports up to 8 inputs and 1 output
.. warning::
All input tensors must live in CUDA device
"""
return _JittedFunction(code_string, return_by_ref=False, num_outputs=1, **kwargs)
def _create_multi_output_jit_fn(
code_string: str, num_outputs: int, **kwargs
) -> Callable:
"""
Create a jiterator-generated cuda kernel for an elementwise op that supports returning one or more outputs.
Args:
code_string (str): CUDA code string to be compiled by jiterator. The entry functor must return value by reference.
num_outputs(int): number of outputs return by the kernel
kwargs (Dict, optional): Keyword arguments for generated function
Example::
code_string = "template <typename T> void my_kernel(T x, T y, T alpha, T& out) { out = -x + alpha * y; }"
jitted_fn = create_jit_fn(code_string, alpha=1.0)
a = torch.rand(3, device='cuda')
b = torch.rand(3, device='cuda')
# invoke jitted function like a regular python function
result = jitted_fn(a, b, alpha=3.14)
.. warning::
This API is in beta and may change in future releases.
.. warning::
This API only supports up to 8 inputs and 8 outputs
"""
return _JittedFunction(
code_string, return_by_ref=True, num_outputs=num_outputs, **kwargs
)