mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Redistribute test exercise extensively various sharding schemes and redistribution between them. These tests uncovered more edge cases that were not supported by the local tensor primarily different flavors of uneven sharding. In order to handle these cases this change implements missing functional collectives and adds support for uneven sharding case where sharding group (ranks) is larger than the size of the dimension being sharded. In the latter case the "missing" shards are represented by zero sized tensors so that the rest of the local tensor machinery can stay oblivious to this special case. Pull Request resolved: https://github.com/pytorch/pytorch/pull/166081 Approved by: https://github.com/ezyang
79 lines
1.7 KiB
Python
79 lines
1.7 KiB
Python
from typing import * # noqa: F403
|
|
|
|
|
|
# Python version of c10/core/ConstantSymNodeImpl.cpp
|
|
# This needs to exist because the Python version of nested int is not compatible
|
|
# with the C++ version of constant symnode.
|
|
class ConstantIntNode:
|
|
def __init__(self, val: int):
|
|
self.val = val
|
|
|
|
def is_constant(self) -> bool:
|
|
return True
|
|
|
|
def maybe_as_int(self) -> int:
|
|
return self.val
|
|
|
|
def is_int(self) -> bool:
|
|
return True
|
|
|
|
def is_float(self) -> bool:
|
|
return False
|
|
|
|
def is_bool(self) -> bool:
|
|
return False
|
|
|
|
def is_nested_int(self) -> bool:
|
|
return False
|
|
|
|
def clone(self) -> "ConstantIntNode":
|
|
return self
|
|
|
|
def _str(self) -> str:
|
|
return str(self.val)
|
|
|
|
def __str__(self) -> str:
|
|
return self._str()
|
|
|
|
def __repr__(self) -> str:
|
|
return self._str()
|
|
|
|
def _graph_repr(self) -> str:
|
|
return self._str()
|
|
|
|
def add(self, other: Any) -> Any:
|
|
return other.add(self)
|
|
|
|
def sub(self, other: Any) -> Any:
|
|
return other.neg().add(self.val)
|
|
|
|
def mul(self, other: Any) -> Any:
|
|
return other.mul(self)
|
|
|
|
def eq(self, other: Any) -> Any:
|
|
return other.eq(self)
|
|
|
|
def ne(self, other: Any) -> Any:
|
|
return other.ne(self)
|
|
|
|
def gt(self, other: Any) -> Any:
|
|
return other.lt(self)
|
|
|
|
def lt(self, other: Any) -> Any:
|
|
return other.gt(self)
|
|
|
|
def le(self, other: Any) -> Any:
|
|
return other.ge(self)
|
|
|
|
def ge(self, other: Any) -> Any:
|
|
return other.le(self)
|
|
|
|
def is_symbolic(self) -> bool:
|
|
return False
|
|
|
|
def constant_int(self) -> int:
|
|
return self.val
|
|
|
|
def guard_int(self, file: str, line: int) -> int:
|
|
return self.val
|