mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Pull Request resolved: https://github.com/pytorch/pytorch/pull/155812 Approved by: https://github.com/Skylion007 ghstack dependencies: #155782, #155783
28 lines
662 B
Python
28 lines
662 B
Python
# mypy: allow-untyped-defs
|
|
from torch._C import (
|
|
_get_backcompat_broadcast_warn,
|
|
_get_backcompat_keepdim_warn,
|
|
_set_backcompat_broadcast_warn,
|
|
_set_backcompat_keepdim_warn,
|
|
)
|
|
|
|
|
|
class Warning:
|
|
def __init__(self, setter, getter):
|
|
self.setter = setter
|
|
self.getter = getter
|
|
|
|
def set_enabled(self, value):
|
|
self.setter(value)
|
|
|
|
def get_enabled(self):
|
|
return self.getter()
|
|
|
|
enabled = property(get_enabled, set_enabled)
|
|
|
|
|
|
broadcast_warning = Warning(
|
|
_set_backcompat_broadcast_warn, _get_backcompat_broadcast_warn
|
|
)
|
|
keepdim_warning = Warning(_set_backcompat_keepdim_warn, _get_backcompat_keepdim_warn)
|