mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Based on the [conversation](https://github.com/pytorch/pytorch/issues/121791), we plan to drop the "highest, high, medium" to represent fp32 internal computation data types . Instead, we will directly use the algorithm to represent it. ### Design Choice: Directly use algorithms name like "TF32", "BF16". #### Pros - The names are more informative. 'tf32' is more informative than a simple "high". - Easier to extend new algorithm like `tf32x3` #### Cons - "HIGHEST, HIGH, MEDIUM" indicated the relative precision between different algorithms. However, we can have more documents to discuss them. ### We provide a layered structure for backends/operators. ('f32' is short for 'fp32_precision')  ### We provide 3 fp32 compute precision can be set: - **"ieee"**: Not allowed to use any other internal computation data types . - **"tf32"**: Allowed to use tf32 as internal computation data types. - **"bf16"**: Allowed to use bf16 as internal computation data types. - **"none"**: Precision's are not set. Can be override by its father node. ### Overriding Precision Settings Child node can be override by its father node if it is set to default. For current default settings: ``` backend = generic, op = all, precision setting = none backend = cuda, op = all, precision setting = none backend = cuda, op = conv, precision setting = tf32 backend = cuda, op = rnn, precision setting = tf32 backend = cuda, op = matmul, precision setting = none backend = matmul, op = all, precision setting = none backend = matmul, op = conv, precision setting = none backend = matmul, op = rnn, precision setting = none backend = matmul, op = matmul, precision setting = none ``` - If the user set `torch.backends.mkldnn.fp32_precision="bf16"`, his child nodes `torch.backends.mkldnn.matmul.fp32_precision` / `torch.backends.mkldnn.conv.fp32_precision` / `torch.backends.mkldnn.rnn.fp32_precision` will also be override to "bf16". - If the user set `torch.backends.fp32_precision="bf16"`, `torch.backends.mkldnn.fp32_precision` and his child nodes will also we override to "bf16". ### Backward Compatible Since new API allow user to have more fine-grained control. There will be some conflict. For example, previous `torch.backends.cudnn.allow_tf32` are not enough to represent the status for `torch.backends.cudnn.rnn.fp32_precision="ieee"` and `torch.backends.cudnn.conv.fp32_precision="tf32"`. Therefore, our goal for backward compatible is - If the user only uses previous APIs, it will work as previous expectations. - If the user use **new** API to change the status to an **un-representable** status for old API, and try to access the status by **old** API. We will raise Runtime Error and point the document for user. ### Test Plan ``` python test/test_cuda.py -k test_fp32_precision_with_tf32 python test/test_cuda.py -k test_fp32_precision_with_float32_matmul_precision python test/test_cuda.py -k test_invalid_status_for_legacy_api python test/test_mkldnn.py -k test_mlkdnn_get_set python test/test_mkldnn.py -k test_generic_precision python test/test_mkldnn.py -k test_invalid python test/test_mkldnn.py -k test_default_use_parent ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/125888 Approved by: https://github.com/jgong5, https://github.com/albanD Co-authored-by: Jiang, Yanbing <yanbing.jiang@intel.com>
141 lines
3.5 KiB
Python
141 lines
3.5 KiB
Python
# mypy: allow-untyped-defs
|
|
import sys
|
|
import types
|
|
from contextlib import contextmanager
|
|
|
|
import torch
|
|
|
|
|
|
# The idea for this parameter is that we forbid bare assignment
|
|
# to torch.backends.<cudnn|mkldnn>.enabled and friends when running our
|
|
# test suite, where it's very easy to forget to undo the change
|
|
# later.
|
|
__allow_nonbracketed_mutation_flag = True
|
|
|
|
|
|
def disable_global_flags():
|
|
global __allow_nonbracketed_mutation_flag
|
|
__allow_nonbracketed_mutation_flag = False
|
|
|
|
|
|
def flags_frozen():
|
|
return not __allow_nonbracketed_mutation_flag
|
|
|
|
|
|
@contextmanager
|
|
def __allow_nonbracketed_mutation():
|
|
global __allow_nonbracketed_mutation_flag
|
|
old = __allow_nonbracketed_mutation_flag
|
|
__allow_nonbracketed_mutation_flag = True
|
|
try:
|
|
yield
|
|
finally:
|
|
__allow_nonbracketed_mutation_flag = old
|
|
|
|
|
|
class ContextProp:
|
|
def __init__(self, getter, setter):
|
|
self.getter = getter
|
|
self.setter = setter
|
|
|
|
def __get__(self, obj, objtype):
|
|
return self.getter()
|
|
|
|
def __set__(self, obj, val):
|
|
if not flags_frozen():
|
|
self.setter(val)
|
|
else:
|
|
raise RuntimeError(
|
|
f"not allowed to set {obj.__name__} flags "
|
|
"after disable_global_flags; please use flags() context manager instead"
|
|
)
|
|
|
|
|
|
class PropModule(types.ModuleType):
|
|
def __init__(self, m, name):
|
|
super().__init__(name)
|
|
self.m = m
|
|
|
|
def __getattr__(self, attr):
|
|
return self.m.__getattribute__(attr)
|
|
|
|
|
|
class _FP32Precision:
|
|
def __init__(self, backend, op):
|
|
self.backend = backend
|
|
self.op = op
|
|
|
|
def __setattr__(self, name, value):
|
|
if name == "fp32_precision":
|
|
torch._C._set_fp32_precision_setter(self.backend, self.op, value)
|
|
elif name in ("backend", "op"):
|
|
super().__setattr__(name, value)
|
|
else:
|
|
raise AttributeError("Unknown attribute " + name)
|
|
|
|
def __getattr__(self, name):
|
|
if name == "fp32_precision":
|
|
return torch._C._get_fp32_precision_getter(self.backend, self.op)
|
|
else:
|
|
raise AttributeError("Unknown attribute " + name)
|
|
|
|
|
|
def set_flags(_fp32_precision="none"):
|
|
orig_flags = (torch._C._get_fp32_precision_getter("generic", "all"),)
|
|
if _fp32_precision is not None:
|
|
torch._C._set_fp32_precision_setter("generic", "all", _fp32_precision)
|
|
return orig_flags
|
|
|
|
|
|
@contextmanager
|
|
def flags(fp32_precision="none"):
|
|
with __allow_nonbracketed_mutation():
|
|
orig_flags = set_flags(fp32_precision)
|
|
try:
|
|
yield
|
|
finally:
|
|
with __allow_nonbracketed_mutation():
|
|
set_flags(*orig_flags)
|
|
|
|
|
|
def _get_fp32_precision_getter(backend, op):
|
|
def inner():
|
|
return torch._C._get_fp32_precision_getter(backend, op)
|
|
|
|
return inner
|
|
|
|
|
|
def _set_fp32_precision_setter(backend, op):
|
|
def inner(precision):
|
|
return torch._C._set_fp32_precision_setter(backend, op, precision)
|
|
|
|
return inner
|
|
|
|
|
|
class GenericModule(PropModule):
|
|
def __init__(self, m, name):
|
|
super().__init__(m, name)
|
|
|
|
fp32_precision = ContextProp(
|
|
_get_fp32_precision_getter("generic", "all"),
|
|
_set_fp32_precision_setter("generic", "all"),
|
|
)
|
|
|
|
|
|
sys.modules[__name__] = GenericModule(sys.modules[__name__], __name__)
|
|
|
|
from torch.backends import (
|
|
cpu as cpu,
|
|
cuda as cuda,
|
|
cudnn as cudnn,
|
|
cusparselt as cusparselt,
|
|
kleidiai as kleidiai,
|
|
mha as mha,
|
|
mkl as mkl,
|
|
mkldnn as mkldnn,
|
|
mps as mps,
|
|
nnpack as nnpack,
|
|
openmp as openmp,
|
|
quantized as quantized,
|
|
)
|