mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Making Numpy depedency in Local Tensor optional to fix broken Torchao CI (#165938)
In recent change LocalTensor introduced dependency on Numpy and has broken Torchao CI. This dependency cna be made optional and required only when Local Tensor is used. Pull Request resolved: https://github.com/pytorch/pytorch/pull/165938 Approved by: https://github.com/atalman
This commit is contained in:
parent
4a6cf0a93e
commit
0b1c462979
|
|
@ -51,8 +51,15 @@ from collections.abc import Sequence
|
|||
from types import TracebackType
|
||||
from typing import Any, Callable, Generator, Optional, Union
|
||||
|
||||
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
HAS_NUMPY = True
|
||||
except ModuleNotFoundError:
|
||||
HAS_NUMPY = False
|
||||
np = None # type: ignore[assignment]
|
||||
|
||||
import torch
|
||||
from torch import Size, SymBool, SymInt, Tensor
|
||||
from torch._C import DispatchKey, DispatchKeySet, ScriptObject
|
||||
|
|
@ -524,8 +531,13 @@ class LocalTensor(torch.Tensor):
|
|||
with LocalTensorMode(local_tensor._ranks):
|
||||
return func(*args, **kwargs)
|
||||
|
||||
def numpy(self, *, force: bool = False) -> np.ndarray:
|
||||
def numpy(
|
||||
self, *, force: bool = False
|
||||
) -> np.ndarray: # pyrefly: ignore # missing-attribute
|
||||
if HAS_NUMPY:
|
||||
return self.reconcile().numpy(force=force)
|
||||
else:
|
||||
raise RuntimeError("Numpy is not available")
|
||||
|
||||
def __lt__(
|
||||
self, other: torch.Tensor | bool | complex | float | int
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user