pytorch/torch/jit/_state.py
Meghan Lele 75bf5f2b59 [JIT] Improve class type annotation inference (#45940)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/45940

**Summary**
In `try_ann_to_type`, if an annotation has an attribute named
`__torch_script_class__`, it is assumed to be a TorchScript class that
has already been scripted. However, if it is a class that extends
another class, this code path causes a crash because it looks up the
JIT type for the class by name in the compilation unit. This JIT type
obviously cannot exist because inheritance is not supported.

This commit fixes this by looking up the qualified name of a class
in torch.jit._state._script_class in order to ascertain whether it has
already been scripted (instead of looking for a `__torch_script_class__`
attribute on the class object.

**Test Plan**
This commit adds a unit test consisting of the code sample from the
issue that reported this problem.

**Fixes**
This commit fixes #45860.

Test Plan: Imported from OSS

Reviewed By: anjali411

Differential Revision: D24310027

Pulled By: SplitInfinity

fbshipit-source-id: 9f8225f3316fd50738d98e3544bf5562b16425b6
2020-10-14 23:28:47 -07:00

109 lines
3.2 KiB
Python

"""JIT-related state
This module stores various pieces of Python-global state relating to the JIT.
This is not intended to be imported directly; please the exposed
functionalities in `torch.jit`.
"""
import torch
import os
import weakref
class EnabledProxy:
"""Stores whether the JIT is enabled or not.
This is just a wrapper for a bool, so that we get reference semantics
"""
def __init__(self):
self.enabled = self.parse_env(
"PYTORCH_JIT", True, "> Using PyTorch JIT", "> PyTorch JIT DISABLED"
)
def parse_env(self, name, default, true_message, false_message):
value = os.environ.get(name)
if value is None:
return default
if value.lower() in {"1", "true", "yes"}:
return True
elif value.lower() in {"0", "false", "no"}:
return False
if value == "1v":
print(true_message)
return True
elif value == "0v":
print(false_message)
return False
raise ValueError("Unknown setting of {}. Try using 0 or 1.".format(name))
def __bool__(self):
return self.enabled
_enabled = EnabledProxy()
def disable():
_enabled.enabled = False
def enable():
_enabled.enabled = True
# The Python CompilationUnit. All functions and modules defined in Python will
# live in here. It's defined in Python because doing in cpp creates static
# destruction order issues.
_python_cu = torch._C.CompilationUnit()
# qualified_name => ScriptClass mapping
_script_classes = {}
def _add_script_class(cls, name):
global _script_classes
_script_classes[name] = cls
def _get_script_class(name):
global _script_classes
if name not in _script_classes:
return None
return _script_classes[name]
# Caching: we currently cache compilation of free functions and overloaded functions.
# To cache free functions we hold a weak ref to the function object and
# map to the compiled fn's qualified name.
# To cache overloaded functions we hold a weak ref to the function obj and
# map to all of its overloaded compiled fns.
# In the future we could consider caching more types of objects so that
# aliasing is preserved across separate compilations of the same object.
_jit_caching_layer: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
_jit_function_overload_caching: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
def _try_get_jit_cached_overloads(key):
qual_names = _jit_function_overload_caching.get(key, None)
if qual_names:
return [_python_cu.find_function(qual_name) for qual_name in qual_names]
else:
return None
def _set_jit_overload_cache(key, compiled_fns):
_jit_function_overload_caching[key] = [fn.qualified_name for fn in compiled_fns]
def _try_get_jit_cached_function(key):
if getattr(key, "__disable_jit_function_caching__", False) is True:
return None
qual_name = _jit_caching_layer.get(key, None)
if qual_name:
return _python_cu.find_function(qual_name)
else:
return None
def _set_jit_function_cache(key, value):
# only free functions currently supported
assert isinstance(value, torch.jit.ScriptFunction)
_jit_caching_layer[key] = value.qualified_name