Revert "Rename propagate_tensor_meta to make private again (#161744)"

This reverts commit 734ce8eba9.

Reverted https://github.com/pytorch/pytorch/pull/161744 on behalf of https://github.com/jeanschmidt due to seems to break internal tests, see D81657000 for more details ([comment](https://github.com/pytorch/pytorch/pull/161744#issuecomment-3258934519))
This commit is contained in:
PyTorch MergeBot 2025-09-05 16:20:29 +00:00
parent 06da7c0730
commit f3cebec39e
2 changed files with 13 additions and 13 deletions

View File

@ -196,27 +196,27 @@ class ShardingPropagator:
return None
@lru_cache # noqa: B019
def _propagate_tensor_meta_cached(
self, op_schema: OpSchema
) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
"""
Cached version of _propagate_tensor_meta_non_cached
Use _propagate_tensor_meta instead to make compile-safe.
"""
return self._propagate_tensor_meta_non_cached(op_schema)
def _propagate_tensor_meta(
self, op_schema: OpSchema
) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
"""
Cached version of _propagate_tensor_meta_non_cached
This is a private API. Use propagate_tensor_meta instead.
"""
return self._propagate_tensor_meta_non_cached(op_schema)
def propagate_tensor_meta(
self, op_schema: OpSchema
) -> Union[None, TensorMeta, Sequence[Optional[TensorMeta]]]:
"""
Propagate the tensor metadata, it could either return a TensorMeta
or a list/tuple of TensorMetas. Uses the cached version if not
actively tracing. Use this method if you need caching.
or a list/tuple of TensorMetas. This is a public API that should be
used if cache should be used.
"""
if _are_we_tracing():
return self._propagate_tensor_meta_non_cached(op_schema)
else:
return self._propagate_tensor_meta_cached(op_schema)
return self._propagate_tensor_meta(op_schema)
def _wrap_output_spec_tensor_meta(
self,

View File

@ -112,7 +112,7 @@ def _propagate_tensor_meta(
kwargs: dict[str, object],
) -> TensorMeta:
op_info = DTensor._op_dispatcher.unwrap_to_op_info(op_call, args, kwargs)
tensor_meta = DTensor._op_dispatcher.sharding_propagator._propagate_tensor_meta(
tensor_meta = DTensor._op_dispatcher.sharding_propagator.propagate_tensor_meta(
op_info.schema
)
if isinstance(tensor_meta, TensorMeta):