mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/54915 TorchScript and torch.package have different mangling schemes. To avoid them interfering with each other, we should undo the torch.package mangling before processing anything with TorchScript (since TS independently makes sure that no names collide). Test Plan: Imported from OSS Reviewed By: SplitInfinity Differential Revision: D27410472 Pulled By: suo fbshipit-source-id: d1cc013c532d9abb7fb9615122bc465ded4785bb
120 lines
3.6 KiB
Python
120 lines
3.6 KiB
Python
"""JIT-related state
|
|
|
|
This module stores various pieces of Python-global state relating to the JIT.
|
|
|
|
This is not intended to be imported directly; please the exposed
|
|
functionalities in `torch.jit`.
|
|
"""
|
|
import torch
|
|
import os
|
|
import weakref
|
|
|
|
class EnabledProxy:
|
|
"""Stores whether the JIT is enabled or not.
|
|
|
|
This is just a wrapper for a bool, so that we get reference semantics
|
|
"""
|
|
|
|
def __init__(self):
|
|
self.enabled = self.parse_env(
|
|
"PYTORCH_JIT", True, "> Using PyTorch JIT", "> PyTorch JIT DISABLED"
|
|
)
|
|
|
|
def parse_env(self, name, default, true_message, false_message):
|
|
value = os.environ.get(name)
|
|
if value is None:
|
|
return default
|
|
if value.lower() in {"1", "true", "yes"}:
|
|
return True
|
|
elif value.lower() in {"0", "false", "no"}:
|
|
return False
|
|
if value == "1v":
|
|
print(true_message)
|
|
return True
|
|
elif value == "0v":
|
|
print(false_message)
|
|
return False
|
|
raise ValueError("Unknown setting of {}. Try using 0 or 1.".format(name))
|
|
|
|
def __bool__(self):
|
|
return self.enabled
|
|
|
|
|
|
_enabled = EnabledProxy()
|
|
|
|
|
|
def disable():
|
|
_enabled.enabled = False
|
|
|
|
|
|
def enable():
|
|
_enabled.enabled = True
|
|
|
|
|
|
# The Python CompilationUnit. All functions and modules defined in Python will
|
|
# live in here. It's defined in Python because doing in cpp creates static
|
|
# destruction order issues.
|
|
_python_cu = torch._C.CompilationUnit()
|
|
|
|
|
|
# python class => ScriptClass mapping
|
|
_script_classes = {}
|
|
_name_to_pyclass = {}
|
|
|
|
|
|
def _add_script_class(python_class, script_class):
|
|
_script_classes[python_class] = script_class
|
|
_name_to_pyclass[script_class.qualified_name()] = python_class
|
|
|
|
|
|
def _get_script_class(python_class):
|
|
override = getattr(python_class, "_jit_override_qualname", None)
|
|
if override is not None:
|
|
python_class = _get_python_class(override)
|
|
return _script_classes.get(python_class, None)
|
|
|
|
|
|
def _get_python_class(qualified_name):
|
|
return _name_to_pyclass.get(qualified_name, None)
|
|
|
|
|
|
def _clear_class_state():
|
|
_script_classes.clear()
|
|
_name_to_pyclass.clear()
|
|
|
|
|
|
# Caching: we currently cache compilation of free functions and overloaded functions.
|
|
# To cache free functions we hold a weak ref to the function object and
|
|
# map to the compiled fn's qualified name.
|
|
# To cache overloaded functions we hold a weak ref to the function obj and
|
|
# map to all of its overloaded compiled fns.
|
|
# In the future we could consider caching more types of objects so that
|
|
# aliasing is preserved across separate compilations of the same object.
|
|
|
|
_jit_caching_layer: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
|
|
_jit_function_overload_caching: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
|
|
|
|
def _try_get_jit_cached_overloads(key):
|
|
qual_names = _jit_function_overload_caching.get(key, None)
|
|
if qual_names:
|
|
return [_python_cu.find_function(qual_name) for qual_name in qual_names]
|
|
else:
|
|
return None
|
|
|
|
def _set_jit_overload_cache(key, compiled_fns):
|
|
_jit_function_overload_caching[key] = [fn.qualified_name for fn in compiled_fns]
|
|
|
|
def _try_get_jit_cached_function(key):
|
|
if getattr(key, "__disable_jit_function_caching__", False) is True:
|
|
return None
|
|
qual_name = _jit_caching_layer.get(key, None)
|
|
if qual_name:
|
|
return _python_cu.find_function(qual_name)
|
|
else:
|
|
return None
|
|
|
|
def _set_jit_function_cache(key, value):
|
|
# only free functions currently supported
|
|
assert isinstance(value, torch.jit.ScriptFunction)
|
|
_jit_caching_layer[key] = value.qualified_name
|