mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
More robust check of whether a class is defined in torch (#64083)
Summary: This would prevent bugs for classes that 1) Is defined in a module that happens to start with `torch`, say `torchvision` 2) Is defined in torch but with an import alias like `import torch as th` Pull Request resolved: https://github.com/pytorch/pytorch/pull/64083 Reviewed By: soulitzer Differential Revision: D30598369 Pulled By: gmagogsfm fbshipit-source-id: 9d3a7135737b2339c9bd32195e4e69a9c07549d4
This commit is contained in:
parent
f2c47cf4db
commit
ad8eddbd80
|
|
@ -1,6 +1,10 @@
|
|||
|
||||
import torch
|
||||
|
||||
import inspect
|
||||
import typing
|
||||
import pathlib
|
||||
import sys
|
||||
from typing import Optional, Iterable, List, Dict
|
||||
from collections import defaultdict
|
||||
from types import CodeType
|
||||
|
|
@ -15,6 +19,18 @@ try:
|
|||
except ImportError:
|
||||
_IS_MONKEYTYPE_INSTALLED = False
|
||||
|
||||
# Checks whether a class is defind in `torch.*` modules
|
||||
def is_torch_native_class(cls):
|
||||
if not hasattr(cls, '__module__'):
|
||||
return False
|
||||
|
||||
parent_modules = cls.__module__.split('.')
|
||||
if not parent_modules:
|
||||
return False
|
||||
|
||||
root_module = sys.modules.get(parent_modules[0])
|
||||
return root_module is torch
|
||||
|
||||
def get_type(type):
|
||||
"""
|
||||
Helper function which converts the given type to a torchScript acceptable format.
|
||||
|
|
@ -28,7 +44,7 @@ def get_type(type):
|
|||
# typing.List is not accepted by TorchScript.
|
||||
type_to_string = str(type)
|
||||
return type_to_string.replace(type.__module__ + '.', '')
|
||||
elif type.__module__.startswith('torch'):
|
||||
elif is_torch_native_class(type):
|
||||
# If the type is a subtype of torch module, then TorchScript expects a fully qualified name
|
||||
# for the type which is obtained by combining the module name and type name.
|
||||
return type.__module__ + '.' + type.__name__
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user