mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "Rename propagate_tensor_meta to make private again (#161744)"
This reverts commit 734ce8eba9.
Reverted https://github.com/pytorch/pytorch/pull/161744 on behalf of https://github.com/jeanschmidt due to seems to break internal tests, see D81657000 for more details ([comment](https://github.com/pytorch/pytorch/pull/161744#issuecomment-3258934519))
This commit is contained in:
parent
06da7c0730
commit
f3cebec39e
|
|
@ -196,27 +196,27 @@ class ShardingPropagator:
|
|||
return None
|
||||
|
||||
@lru_cache # noqa: B019
|
||||
def _propagate_tensor_meta_cached(
|
||||
self, op_schema: OpSchema
|
||||
) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
|
||||
"""
|
||||
Cached version of _propagate_tensor_meta_non_cached
|
||||
Use _propagate_tensor_meta instead to make compile-safe.
|
||||
"""
|
||||
return self._propagate_tensor_meta_non_cached(op_schema)
|
||||
|
||||
def _propagate_tensor_meta(
|
||||
self, op_schema: OpSchema
|
||||
) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
|
||||
"""
|
||||
Cached version of _propagate_tensor_meta_non_cached
|
||||
This is a private API. Use propagate_tensor_meta instead.
|
||||
"""
|
||||
return self._propagate_tensor_meta_non_cached(op_schema)
|
||||
|
||||
def propagate_tensor_meta(
|
||||
self, op_schema: OpSchema
|
||||
) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
|
||||
"""
|
||||
Propagate the tensor metadata, it could either return a TensorMeta
|
||||
or a list/tuple of TensorMetas. Uses the cached version if not
|
||||
actively tracing. Use this method if you need caching.
|
||||
or a list/tuple of TensorMetas. This is a public API that should be
|
||||
used if cache should be used.
|
||||
"""
|
||||
if _are_we_tracing():
|
||||
return self._propagate_tensor_meta_non_cached(op_schema)
|
||||
else:
|
||||
return self._propagate_tensor_meta_cached(op_schema)
|
||||
return self._propagate_tensor_meta(op_schema)
|
||||
|
||||
def _wrap_output_spec_tensor_meta(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -112,7 +112,7 @@ def _propagate_tensor_meta(
|
|||
kwargs: dict[str, object],
|
||||
) -> TensorMeta:
|
||||
op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
|
||||
tensor_meta = DTensor._op_dispatcher.sharding_propagator._propagate_tensor_meta(
|
||||
tensor_meta = DTensor._op_dispatcher.sharding_propagator.propagate_tensor_meta(
|
||||
op_info.schema
|
||||
)
|
||||
if isinstance(tensor_meta, TensorMeta):
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user