pytorch/torch/fx/experimental/_constant_symnode.py
Dzmitry Huba 86f9f1d0ab Enable local tensor model for DTensor redistribute tests (#166081)
Redistribute test exercise extensively various sharding schemes and
redistribution between them. These tests uncovered more edge cases
that were not supported by the local tensor primarily different flavors
of uneven sharding. In order to handle these cases this change implements
missing functional collectives and adds support for uneven sharding
case where sharding group (ranks) is larger than the size of the dimension
being sharded. In the latter case the "missing" shards are represented
by zero sized tensors so that the rest of the local tensor machinery
can stay oblivious to this special case.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166081
Approved by: https://github.com/ezyang
2025-10-26 22:21:43 +00:00

79 lines
1.7 KiB
Python

from typing import * # noqa: F403
# Python version of c10/core/ConstantSymNodeImpl.cpp
# This needs to exist because the Python version of nested int is not compatible
# with the C++ version of constant symnode.
class ConstantIntNode:
def __init__(self, val: int):
self.val = val
def is_constant(self) -> bool:
return True
def maybe_as_int(self) -> int:
return self.val
def is_int(self) -> bool:
return True
def is_float(self) -> bool:
return False
def is_bool(self) -> bool:
return False
def is_nested_int(self) -> bool:
return False
def clone(self) -> "ConstantIntNode":
return self
def _str(self) -> str:
return str(self.val)
def __str__(self) -> str:
return self._str()
def __repr__(self) -> str:
return self._str()
def _graph_repr(self) -> str:
return self._str()
def add(self, other: Any) -> Any:
return other.add(self)
def sub(self, other: Any) -> Any:
return other.neg().add(self.val)
def mul(self, other: Any) -> Any:
return other.mul(self)
def eq(self, other: Any) -> Any:
return other.eq(self)
def ne(self, other: Any) -> Any:
return other.ne(self)
def gt(self, other: Any) -> Any:
return other.lt(self)
def lt(self, other: Any) -> Any:
return other.gt(self)
def le(self, other: Any) -> Any:
return other.ge(self)
def ge(self, other: Any) -> Any:
return other.le(self)
def is_symbolic(self) -> bool:
return False
def constant_int(self) -> int:
return self.val
def guard_int(self, file: str, line: int) -> int:
return self.val