pytorch/torch/torch_version.py
hauntsaninja e9c64168d9 Import packaging.version in torch_version, if available (#71902)
Summary:
Resolves https://github.com/pytorch/pytorch/issues/71280

We used to use `from pkg_resources import packaging`. To recap, this has
three potential problems:
1) `pkg_resources` is a really slow import
2) We have an undeclared runtime dependency on `setuptools`
3) We're relying on `pkg_resources`'s secret vendored copy of
   `packaging`. This is obviously not part of the public API of
   `pkg_resources`.

In https://github.com/pytorch/pytorch/issues/71345 this was made a lazy import, which is great! It means we don't
run into these problems as long as users don't use `torch.__version__`.

This change additionally helps further address problems 1 and 3, by
directly importing `packaging`, if present, and only falling back to the
vendored copy in `pkg_resources`.

Benchmark for speed difference in a virtual environment with a couple
hundred packages installed:
```
λ hyperfine -w 2 'python -c "from pkg_resources import packaging"' 'python -c "import packaging.version"'
Benchmark 1: python -c "from pkg_resources import packaging"
  Time (mean ± σ):     706.7 ms ±  77.1 ms    [User: 266.5 ms, System: 156.8 ms]
  Range (min … max):   627.9 ms … 853.2 ms    10 runs

Benchmark 2: python -c "import packaging.version"
  Time (mean ± σ):      53.8 ms ±   8.5 ms    [User: 34.8 ms, System: 14.4 ms]
  Range (min … max):    46.3 ms …  72.3 ms    53 runs
  'python -c "import packaging.version"' ran
   13.14 ± 2.52 times faster than 'python -c "from pkg_resources import packaging"'
```

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71902

Reviewed By: mikaylagawarecki

Differential Revision: D34343145

Pulled By: malfet

fbshipit-source-id: a6bd7ecf0cbb6b5c20ab18a22576aa2df9eb3324
(cherry picked from commit 0a249044c8)
2022-02-22 21:30:14 +00:00

88 lines
3.4 KiB
Python

from typing import Any, Iterable
from .version import __version__ as internal_version
class _LazyImport:
"""Wraps around classes lazy imported from packaging.version
Output of the function v in following snippets are identical:
from packaging.version import Version
def v():
return Version('1.2.3')
and
Version = _LazyImport('Version')
def v():
return Version('1.2.3')
The difference here is that in later example imports
do not happen until v is called
"""
def __init__(self, cls_name: str) -> None:
self._cls_name = cls_name
def get_cls(self):
try:
import packaging.version # type: ignore[import]
except ImportError:
# If packaging isn't installed, try and use the vendored copy
# in pkg_resources
from pkg_resources import packaging # type: ignore[attr-defined]
return getattr(packaging.version, self._cls_name)
def __call__(self, *args, **kwargs):
return self.get_cls()(*args, **kwargs)
def __instancecheck__(self, obj):
return isinstance(obj, self.get_cls())
Version = _LazyImport("Version")
InvalidVersion = _LazyImport("InvalidVersion")
class TorchVersion(str):
"""A string with magic powers to compare to both Version and iterables!
Prior to 1.10.0 torch.__version__ was stored as a str and so many did
comparisons against torch.__version__ as if it were a str. In order to not
break them we have TorchVersion which masquerades as a str while also
having the ability to compare against both packaging.version.Version as
well as tuples of values, eg. (1, 2, 1)
Examples:
Comparing a TorchVersion object to a Version object
TorchVersion('1.10.0a') > Version('1.10.0a')
Comparing a TorchVersion object to a Tuple object
TorchVersion('1.10.0a') > (1, 2) # 1.2
TorchVersion('1.10.0a') > (1, 2, 1) # 1.2.1
Comparing a TorchVersion object against a string
TorchVersion('1.10.0a') > '1.2'
TorchVersion('1.10.0a') > '1.2.1'
"""
# fully qualified type names here to appease mypy
def _convert_to_version(self, inp: Any) -> Any:
if isinstance(inp, Version.get_cls()):
return inp
elif isinstance(inp, str):
return Version(inp)
elif isinstance(inp, Iterable):
# Ideally this should work for most cases by attempting to group
# the version tuple, assuming the tuple looks (MAJOR, MINOR, ?PATCH)
# Examples:
# * (1) -> Version("1")
# * (1, 20) -> Version("1.20")
# * (1, 20, 1) -> Version("1.20.1")
return Version('.'.join((str(item) for item in inp)))
else:
raise InvalidVersion(inp)
def _cmp_wrapper(self, cmp: Any, method: str) -> bool:
try:
return getattr(Version(self), method)(self._convert_to_version(cmp))
except BaseException as e:
if not isinstance(e, InvalidVersion.get_cls()):
raise
# Fall back to regular string comparison if dealing with an invalid
# version like 'parrot'
return getattr(super(), method)(cmp)
for cmp_method in ["__gt__", "__lt__", "__eq__", "__ge__", "__le__"]:
setattr(TorchVersion, cmp_method, lambda x, y, method=cmp_method: x._cmp_wrapper(y, method))
__version__ = TorchVersion(internal_version)