mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/45940 **Summary** In `try_ann_to_type`, if an annotation has an attribute named `__torch_script_class__`, it is assumed to be a TorchScript class that has already been scripted. However, if it is a class that extends another class, this code path causes a crash because it looks up the JIT type for the class by name in the compilation unit. This JIT type obviously cannot exist because inheritance is not supported. This commit fixes this by looking up the qualified name of a class in torch.jit._state._script_class in order to ascertain whether it has already been scripted (instead of looking for a `__torch_script_class__` attribute on the class object. **Test Plan** This commit adds a unit test consisting of the code sample from the issue that reported this problem. **Fixes** This commit fixes #45860. Test Plan: Imported from OSS Reviewed By: anjali411 Differential Revision: D24310027 Pulled By: SplitInfinity fbshipit-source-id: 9f8225f3316fd50738d98e3544bf5562b16425b6
109 lines
3.2 KiB
Python
109 lines
3.2 KiB
Python
"""JIT-related state
|
|
|
|
This module stores various pieces of Python-global state relating to the JIT.
|
|
|
|
This is not intended to be imported directly; please the exposed
|
|
functionalities in `torch.jit`.
|
|
"""
|
|
import torch
|
|
import os
|
|
import weakref
|
|
|
|
class EnabledProxy:
|
|
"""Stores whether the JIT is enabled or not.
|
|
|
|
This is just a wrapper for a bool, so that we get reference semantics
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.enabled = self.parse_env(
|
|
"PYTORCH_JIT", True, "> Using PyTorch JIT", "> PyTorch JIT DISABLED"
|
|
)
|
|
|
|
def parse_env(self, name, default, true_message, false_message):
|
|
value = os.environ.get(name)
|
|
if value is None:
|
|
return default
|
|
if value.lower() in {"1", "true", "yes"}:
|
|
return True
|
|
elif value.lower() in {"0", "false", "no"}:
|
|
return False
|
|
if value == "1v":
|
|
print(true_message)
|
|
return True
|
|
elif value == "0v":
|
|
print(false_message)
|
|
return False
|
|
raise ValueError("Unknown setting of {}. Try using 0 or 1.".format(name))
|
|
|
|
def __bool__(self):
|
|
return self.enabled
|
|
|
|
|
|
_enabled = EnabledProxy()
|
|
|
|
|
|
def disable():
|
|
_enabled.enabled = False
|
|
|
|
|
|
def enable():
|
|
_enabled.enabled = True
|
|
|
|
|
|
# The Python CompilationUnit. All functions and modules defined in Python will
|
|
# live in here. It's defined in Python because doing in cpp creates static
|
|
# destruction order issues.
|
|
_python_cu = torch._C.CompilationUnit()
|
|
|
|
|
|
# qualified_name => ScriptClass mapping
|
|
_script_classes = {}
|
|
|
|
def _add_script_class(cls, name):
|
|
global _script_classes
|
|
_script_classes[name] = cls
|
|
|
|
|
|
def _get_script_class(name):
|
|
global _script_classes
|
|
if name not in _script_classes:
|
|
return None
|
|
return _script_classes[name]
|
|
|
|
|
|
# Caching: we currently cache compilation of free functions and overloaded functions.
|
|
# To cache free functions we hold a weak ref to the function object and
|
|
# map to the compiled fn's qualified name.
|
|
# To cache overloaded functions we hold a weak ref to the function obj and
|
|
# map to all of its overloaded compiled fns.
|
|
# In the future we could consider caching more types of objects so that
|
|
# aliasing is preserved across separate compilations of the same object.
|
|
|
|
_jit_caching_layer: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
|
|
_jit_function_overload_caching: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
|
|
|
|
def _try_get_jit_cached_overloads(key):
|
|
qual_names = _jit_function_overload_caching.get(key, None)
|
|
if qual_names:
|
|
return [_python_cu.find_function(qual_name) for qual_name in qual_names]
|
|
else:
|
|
return None
|
|
|
|
def _set_jit_overload_cache(key, compiled_fns):
|
|
_jit_function_overload_caching[key] = [fn.qualified_name for fn in compiled_fns]
|
|
|
|
def _try_get_jit_cached_function(key):
|
|
if getattr(key, "__disable_jit_function_caching__", False) is True:
|
|
return None
|
|
qual_name = _jit_caching_layer.get(key, None)
|
|
if qual_name:
|
|
return _python_cu.find_function(qual_name)
|
|
else:
|
|
return None
|
|
|
|
def _set_jit_function_cache(key, value):
|
|
# only free functions currently supported
|
|
assert isinstance(value, torch.jit.ScriptFunction)
|
|
_jit_caching_layer[key] = value.qualified_name
|