pytorch/torch/backends/__init__.py
Nikhil Gupta 41b38f755c Revert "Reverting the PR adding Kleidiai-based int4 kernels (#145392)" (#145505)
https://github.com/pytorch/pytorch/pull/134124 was reverted by https://github.com/pytorch/pytorch/pull/145392 due to KleidiAI clone issue.

1. This reverts commit 0940eb6d44 (https://github.com/pytorch/pytorch/pull/145392 )and Fixes KleidiAI mirror issue.
2. KleidiAI is now cloned from github mirror instead of arm gitlab

Change-Id: I7d6eee7214cd117d3057d615936fcc3ee6052fa2

Fixes https://github.com/pytorch/pytorch/issues/145273

Pull Request resolved: https://github.com/pytorch/pytorch/pull/145505
Approved by: https://github.com/malfet
2025-01-23 18:50:59 +00:00

74 lines
1.7 KiB
Python

# mypy: allow-untyped-defs
import types
from contextlib import contextmanager
# The idea for this parameter is that we forbid bare assignment
# to torch.backends.<cudnn|mkldnn>.enabled and friends when running our
# test suite, where it's very easy to forget to undo the change
# later.
__allow_nonbracketed_mutation_flag = True
def disable_global_flags():
global __allow_nonbracketed_mutation_flag
__allow_nonbracketed_mutation_flag = False
def flags_frozen():
return not __allow_nonbracketed_mutation_flag
@contextmanager
def __allow_nonbracketed_mutation():
global __allow_nonbracketed_mutation_flag
old = __allow_nonbracketed_mutation_flag
__allow_nonbracketed_mutation_flag = True
try:
yield
finally:
__allow_nonbracketed_mutation_flag = old
class ContextProp:
def __init__(self, getter, setter):
self.getter = getter
self.setter = setter
def __get__(self, obj, objtype):
return self.getter()
def __set__(self, obj, val):
if not flags_frozen():
self.setter(val)
else:
raise RuntimeError(
f"not allowed to set {obj.__name__} flags "
"after disable_global_flags; please use flags() context manager instead"
)
class PropModule(types.ModuleType):
def __init__(self, m, name):
super().__init__(name)
self.m = m
def __getattr__(self, attr):
return self.m.__getattribute__(attr)
from torch.backends import (
cpu as cpu,
cuda as cuda,
cudnn as cudnn,
cusparselt as cusparselt,
kleidiai as kleidiai,
mha as mha,
mkl as mkl,
mkldnn as mkldnn,
mps as mps,
nnpack as nnpack,
openmp as openmp,
quantized as quantized,
)