mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[dynamo][guards] 1/N Guard selectively for DTensor (#165824)"
This reverts commit ee7434be82.
Reverted https://github.com/pytorch/pytorch/pull/165824 on behalf of https://github.com/anijain2305 due to internal job failed ([comment](https://github.com/pytorch/pytorch/pull/165824#issuecomment-3462667536))
This commit is contained in:
parent
35f3572fa4
commit
d7040e6d75
|
|
@ -464,25 +464,6 @@ def forward(self, b_parametrizations_buffer_original0, x):
|
|||
run(g, 64, 8)
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_dtensor_requires_grad_recompile(self):
|
||||
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
@torch.compile(backend=cnt, fullgraph=True)
|
||||
def f(x):
|
||||
y = x * x
|
||||
return y.to_local()
|
||||
|
||||
full_x = torch.randn(8, 8, requires_grad=False)
|
||||
x = distribute_tensor(full_x, mesh, [Shard(0)])
|
||||
f(x)
|
||||
|
||||
full_x = torch.randn(8, 8, requires_grad=True)
|
||||
x = distribute_tensor(full_x, mesh, [Shard(0)])
|
||||
f(x)
|
||||
|
||||
self.assertEqual(cnt.frame_count, 2)
|
||||
|
||||
def test_dtensor_attribute_access_on_intermediate(self):
|
||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||
|
||||
|
|
|
|||
|
|
@ -2150,19 +2150,6 @@ class GuardBuilder(GuardBuilderBase):
|
|||
metadata_checker, get_verbose_code_parts(global_name, guard)
|
||||
)
|
||||
|
||||
def DTENSOR_SPEC_MATCH(self, guard: Guard) -> None:
|
||||
# Copied from DTensor __metadata_guard__
|
||||
# TODO - Consider moving this to C++ if stable
|
||||
value = deepcopy(self.get(guard.name))
|
||||
|
||||
def guard_fn(x: Any) -> bool:
|
||||
return x._check_equals(value, skip_shapes=True)
|
||||
|
||||
code = f"__dtensor_spec_{id(guard_fn)}"
|
||||
self.get_guard_manager(guard).add_lambda_guard(
|
||||
guard_fn, get_verbose_code_parts(code, guard)
|
||||
)
|
||||
|
||||
def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None:
|
||||
ref = self.arg_ref(guard)
|
||||
val = self.get(guard.name)
|
||||
|
|
|
|||
|
|
@ -2229,12 +2229,6 @@ class VariableBuilder:
|
|||
if isinstance(source, GradSource) and is_from_optimizer_source(source):
|
||||
guard_type = GuardBuilder.NOT_NONE_MATCH
|
||||
|
||||
is_dtensor = torch.distributed.is_available() and isinstance(
|
||||
value, torch.distributed.tensor.DTensor
|
||||
)
|
||||
if not is_dtensor:
|
||||
# We guard on the _local_tensor and the _spec, and therefore we dont
|
||||
# have to guard on the outer DTensor.
|
||||
self.install_guards(
|
||||
functools.partial(
|
||||
guard_type,
|
||||
|
|
@ -2249,45 +2243,6 @@ class VariableBuilder:
|
|||
# We install TYPE_MATCH guards for traceable wrapper subclass object,
|
||||
# and recursively install corresponding guard for each inner attribute.
|
||||
if is_traceable_wrapper_subclass(value):
|
||||
# Tensor subclass guards are very expensive because they are
|
||||
# implemented in Python. Since DTensor is PyTorch-maintained class,
|
||||
# we can skip a lot of these guards.
|
||||
if is_dtensor:
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
|
||||
# The inner tensor name is always _local_tensor. If its not, we
|
||||
# raise assertion to update the check accordingly.
|
||||
inner_tensor_name = value.__tensor_flatten__()[0][0]
|
||||
if inner_tensor_name != "_local_tensor":
|
||||
raise RuntimeError(
|
||||
"Expecting Dtensor inner tensor name to be _local_tensor"
|
||||
)
|
||||
|
||||
# Now selectively guard on the flattening context
|
||||
flattening_ctx = value.__tensor_flatten__()[1]
|
||||
# This is supposed to be (self._spec, self.requires_grad)
|
||||
if not (
|
||||
len(flattening_ctx) == 2
|
||||
and flattening_ctx[0] == value._spec
|
||||
and flattening_ctx[1] == value.requires_grad
|
||||
):
|
||||
# If not, raise an assertion to update to the new guards
|
||||
raise RuntimeError(
|
||||
"Expecting Dtensor flattening ctx to be _spec, requires_grad"
|
||||
)
|
||||
# Guard on the dtensor spec
|
||||
install_guard(
|
||||
AttrSource(self.source, "_spec").make_guard(
|
||||
GuardBuilder.DTENSOR_SPEC_MATCH
|
||||
)
|
||||
)
|
||||
# Move this to C++
|
||||
install_guard(
|
||||
AttrSource(self.source, "requires_grad").make_guard(
|
||||
GuardBuilder.EQUALS_MATCH
|
||||
)
|
||||
)
|
||||
else:
|
||||
self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH)
|
||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||
install_guard(
|
||||
|
|
|
|||
|
|
@ -671,8 +671,6 @@ class DTensor(torch.Tensor):
|
|||
def __metadata_guard__(
|
||||
cls, orig: tuple[DTensorSpec, bool], other: tuple[DTensorSpec, bool]
|
||||
) -> bool:
|
||||
# TODO - delete this - This is now unused after the PR -
|
||||
# https://github.com/pytorch/pytorch/pull/165824
|
||||
orig_spec, orig_requires_grad = orig
|
||||
other_spec, other_requires_grad = other
|
||||
return (
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user