diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index f83bd906c23..73ac6eb0da7 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -678,88 +678,6 @@ class TestMultiProc(DynamoDistributedMultiProcTestCase): outputs = fsdp_m(inputs) self.assertTrue(same(correct_outputs, outputs)) - @config.patch(enable_compiler_collectives=True) - @skip_if_lt_x_gpu(1) - def test_fsdp_dynamism_on_int_attr(self): - global GUARDS_FILE - GUARDS_FILE = StringIO() - - with _dynamo_dist_per_rank_init(self.rank, self.world_size): - - class ToyModelWithIntAttr(nn.Module): - def __init__(self): - super().__init__() - self.attr = 2 - - def forward(self, x): - out = x + self.attr - - @comptime - def _(ctx): - ctx.print_guards(file=GUARDS_FILE) - - return out - - def get_model_with_int_attr(device): - m = ToyModelWithIntAttr().to(device) - inputs = torch.rand(10).to(device) - outputs = m(inputs) - return m, inputs, outputs - - m, inputs, correct_outputs = get_model_with_int_attr(f"cuda:{self.rank}") - fsdp_m = FSDP(m, use_orig_params=True) - compiled_fsdp_m = torch.compile( - fsdp_m, backend="eager", dynamic=True, fullgraph=True - ) - outputs = compiled_fsdp_m(inputs) - self.assertTrue(same(correct_outputs, outputs)) - - FileCheck().check( - """local_fsdp_module "L['fn']._modules['_fsdp_wrapped_module'].attr" EQUALS_MATCH""" - ).run(GUARDS_FILE.getvalue()) - - @config.patch(enable_compiler_collectives=True) - @config.patch(allow_unspec_int_on_fsdp_module=True) - @skip_if_lt_x_gpu(1) - def test_fsdp_dynamism_on_int_attr_unspec(self): - global GUARDS_FILE - GUARDS_FILE = StringIO() - - with _dynamo_dist_per_rank_init(self.rank, self.world_size): - - class ToyModelWithIntAttr(nn.Module): - def __init__(self): - super().__init__() - self.attr = 2 - - def forward(self, x): - out = x + self.attr - - @comptime - def _(ctx): - ctx.print_guards(file=GUARDS_FILE) - - return out - - def get_model_with_int_attr(device): - m = ToyModelWithIntAttr().to(device) - inputs = torch.rand(10).to(device) - outputs = m(inputs) - return m, inputs, outputs - - m, inputs, correct_outputs = get_model_with_int_attr(f"cuda:{self.rank}") - fsdp_m = FSDP(m, use_orig_params=True) - compiled_fsdp_m = torch.compile( - fsdp_m, backend="eager", dynamic=True, fullgraph=True - ) - outputs = compiled_fsdp_m(inputs) - self.assertTrue(same(correct_outputs, outputs)) - - # No presence of EQUALS_MATCH because the guard will be dynamic - FileCheck().check( - """local_fsdp_module "L['fn']._modules['_fsdp_wrapped_module'].attr" TYPE_MATCH""" - ).run(GUARDS_FILE.getvalue()) - @skip_if_lt_x_gpu(2) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_ddp_optimizer_cudagraph(self): diff --git a/torch/_dynamo/config.py b/torch/_dynamo/config.py index caf9521021a..b3708ff8493 100644 --- a/torch/_dynamo/config.py +++ b/torch/_dynamo/config.py @@ -284,13 +284,6 @@ force_unspec_int_unbacked_size_like_on_torchrec_kjt = False # Defaults to False for BC. allow_unspec_int_on_nn_module = False -# Mirrors `allow_unspec_int_on_nn_module`, but for FSDP: for <=2.8 versions, -# integer attributes on FSDP modules were treated as dynamic, while the same -# attributes on plain nn.Modules were static. We unified the behaviour by making -# FSDP ints static too. Set this flag to True to restore the legacy dynamic -# handling if needed. -allow_unspec_int_on_fsdp_module = False - # Specify how to optimize a compiled DDP module. The flag accepts a boolean # value or a string. There are 3 modes. # 1. "ddp_optimizer" (or True): with "ddp_optimizer", Dynamo will automatically diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py index 4f33258d12a..e11888b9dc6 100644 --- a/torch/_dynamo/utils.py +++ b/torch/_dynamo/utils.py @@ -2398,15 +2398,6 @@ def is_int_specialization_case(value, source): source.guard_source().is_specialized_nn_module() and not config.allow_unspec_int_on_nn_module ) - # integers coming from FSDP modules are considered static. This is - # purely empirical and perhaps we should have a better heuristic. - or ( - source.guard_source().is_fsdp_module() - and not ( - config.allow_unspec_int_on_nn_module - or config.allow_unspec_int_on_fsdp_module - ) - ) or ( source.guard_source().is_unspecialized_builtin_nn_module() and not config.allow_unspec_int_on_nn_module diff --git a/torch/_guards.py b/torch/_guards.py index 5619db212bf..28becfac586 100644 --- a/torch/_guards.py +++ b/torch/_guards.py @@ -155,6 +155,17 @@ class GuardSource(enum.Enum): return self in (GuardSource.GLOBAL_FSDP_MODULE, GuardSource.LOCAL_FSDP_MODULE) def is_specialized_nn_module(self) -> bool: + import torch._dynamo.config as config + + if config._unsafe_skip_fsdp_module_guards: + return ( + self + in ( + GuardSource.GLOBAL_SPECIALIZED_NN_MODULE, + GuardSource.LOCAL_SPECIALIZED_NN_MODULE, + ) + or self.is_fsdp_module() + ) return self in ( GuardSource.GLOBAL_SPECIALIZED_NN_MODULE, GuardSource.LOCAL_SPECIALIZED_NN_MODULE,