mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[dynamo][guards] 1/N Guard selectively for DTensor (#165824)"
This reverts commit ee7434be82.
Reverted https://github.com/pytorch/pytorch/pull/165824 on behalf of https://github.com/anijain2305 due to internal job failed ([comment](https://github.com/pytorch/pytorch/pull/165824#issuecomment-3462667536))
This commit is contained in:
parent
35f3572fa4
commit
d7040e6d75
|
|
@ -464,25 +464,6 @@ def forward(self, b_parametrizations_buffer_original0, x):
|
||||||
run(g, 64, 8)
|
run(g, 64, 8)
|
||||||
self.assertEqual(cnt.frame_count, 2)
|
self.assertEqual(cnt.frame_count, 2)
|
||||||
|
|
||||||
def test_dtensor_requires_grad_recompile(self):
|
|
||||||
cnt = torch._dynamo.testing.CompileCounterWithBackend("aot_eager")
|
|
||||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
|
||||||
|
|
||||||
@torch.compile(backend=cnt, fullgraph=True)
|
|
||||||
def f(x):
|
|
||||||
y = x * x
|
|
||||||
return y.to_local()
|
|
||||||
|
|
||||||
full_x = torch.randn(8, 8, requires_grad=False)
|
|
||||||
x = distribute_tensor(full_x, mesh, [Shard(0)])
|
|
||||||
f(x)
|
|
||||||
|
|
||||||
full_x = torch.randn(8, 8, requires_grad=True)
|
|
||||||
x = distribute_tensor(full_x, mesh, [Shard(0)])
|
|
||||||
f(x)
|
|
||||||
|
|
||||||
self.assertEqual(cnt.frame_count, 2)
|
|
||||||
|
|
||||||
def test_dtensor_attribute_access_on_intermediate(self):
|
def test_dtensor_attribute_access_on_intermediate(self):
|
||||||
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
mesh = DeviceMesh(self.device_type, torch.arange(self.world_size))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2150,19 +2150,6 @@ class GuardBuilder(GuardBuilderBase):
|
||||||
metadata_checker, get_verbose_code_parts(global_name, guard)
|
metadata_checker, get_verbose_code_parts(global_name, guard)
|
||||||
)
|
)
|
||||||
|
|
||||||
def DTENSOR_SPEC_MATCH(self, guard: Guard) -> None:
|
|
||||||
# Copied from DTensor __metadata_guard__
|
|
||||||
# TODO - Consider moving this to C++ if stable
|
|
||||||
value = deepcopy(self.get(guard.name))
|
|
||||||
|
|
||||||
def guard_fn(x: Any) -> bool:
|
|
||||||
return x._check_equals(value, skip_shapes=True)
|
|
||||||
|
|
||||||
code = f"__dtensor_spec_{id(guard_fn)}"
|
|
||||||
self.get_guard_manager(guard).add_lambda_guard(
|
|
||||||
guard_fn, get_verbose_code_parts(code, guard)
|
|
||||||
)
|
|
||||||
|
|
||||||
def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None:
|
def EQUALS_MATCH(self, guard: Guard, recompile_hint: Optional[str] = None) -> None:
|
||||||
ref = self.arg_ref(guard)
|
ref = self.arg_ref(guard)
|
||||||
val = self.get(guard.name)
|
val = self.get(guard.name)
|
||||||
|
|
|
||||||
|
|
@ -2229,12 +2229,6 @@ class VariableBuilder:
|
||||||
if isinstance(source, GradSource) and is_from_optimizer_source(source):
|
if isinstance(source, GradSource) and is_from_optimizer_source(source):
|
||||||
guard_type = GuardBuilder.NOT_NONE_MATCH
|
guard_type = GuardBuilder.NOT_NONE_MATCH
|
||||||
|
|
||||||
is_dtensor = torch.distributed.is_available() and isinstance(
|
|
||||||
value, torch.distributed.tensor.DTensor
|
|
||||||
)
|
|
||||||
if not is_dtensor:
|
|
||||||
# We guard on the _local_tensor and the _spec, and therefore we dont
|
|
||||||
# have to guard on the outer DTensor.
|
|
||||||
self.install_guards(
|
self.install_guards(
|
||||||
functools.partial(
|
functools.partial(
|
||||||
guard_type,
|
guard_type,
|
||||||
|
|
@ -2249,45 +2243,6 @@ class VariableBuilder:
|
||||||
# We install TYPE_MATCH guards for traceable wrapper subclass object,
|
# We install TYPE_MATCH guards for traceable wrapper subclass object,
|
||||||
# and recursively install corresponding guard for each inner attribute.
|
# and recursively install corresponding guard for each inner attribute.
|
||||||
if is_traceable_wrapper_subclass(value):
|
if is_traceable_wrapper_subclass(value):
|
||||||
# Tensor subclass guards are very expensive because they are
|
|
||||||
# implemented in Python. Since DTensor is PyTorch-maintained class,
|
|
||||||
# we can skip a lot of these guards.
|
|
||||||
if is_dtensor:
|
|
||||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
|
||||||
|
|
||||||
# The inner tensor name is always _local_tensor. If its not, we
|
|
||||||
# raise assertion to update the check accordingly.
|
|
||||||
inner_tensor_name = value.__tensor_flatten__()[0][0]
|
|
||||||
if inner_tensor_name != "_local_tensor":
|
|
||||||
raise RuntimeError(
|
|
||||||
"Expecting Dtensor inner tensor name to be _local_tensor"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Now selectively guard on the flattening context
|
|
||||||
flattening_ctx = value.__tensor_flatten__()[1]
|
|
||||||
# This is supposed to be (self._spec, self.requires_grad)
|
|
||||||
if not (
|
|
||||||
len(flattening_ctx) == 2
|
|
||||||
and flattening_ctx[0] == value._spec
|
|
||||||
and flattening_ctx[1] == value.requires_grad
|
|
||||||
):
|
|
||||||
# If not, raise an assertion to update to the new guards
|
|
||||||
raise RuntimeError(
|
|
||||||
"Expecting Dtensor flattening ctx to be _spec, requires_grad"
|
|
||||||
)
|
|
||||||
# Guard on the dtensor spec
|
|
||||||
install_guard(
|
|
||||||
AttrSource(self.source, "_spec").make_guard(
|
|
||||||
GuardBuilder.DTENSOR_SPEC_MATCH
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# Move this to C++
|
|
||||||
install_guard(
|
|
||||||
AttrSource(self.source, "requires_grad").make_guard(
|
|
||||||
GuardBuilder.EQUALS_MATCH
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH)
|
self.install_guards(GuardBuilder.TENSOR_SUBCLASS_METADATA_MATCH)
|
||||||
self.install_guards(GuardBuilder.TYPE_MATCH)
|
self.install_guards(GuardBuilder.TYPE_MATCH)
|
||||||
install_guard(
|
install_guard(
|
||||||
|
|
|
||||||
|
|
@ -671,8 +671,6 @@ class DTensor(torch.Tensor):
|
||||||
def __metadata_guard__(
|
def __metadata_guard__(
|
||||||
cls, orig: tuple[DTensorSpec, bool], other: tuple[DTensorSpec, bool]
|
cls, orig: tuple[DTensorSpec, bool], other: tuple[DTensorSpec, bool]
|
||||||
) -> bool:
|
) -> bool:
|
||||||
# TODO - delete this - This is now unused after the PR -
|
|
||||||
# https://github.com/pytorch/pytorch/pull/165824
|
|
||||||
orig_spec, orig_requires_grad = orig
|
orig_spec, orig_requires_grad = orig
|
||||||
other_spec, other_requires_grad = other
|
other_spec, other_requires_grad = other
|
||||||
return (
|
return (
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user