diff --git a/aten/src/ATen/autocast_mode.cpp b/aten/src/ATen/autocast_mode.cpp index b15fb9910af..e3424cc4cb8 100644 --- a/aten/src/ATen/autocast_mode.cpp +++ b/aten/src/ATen/autocast_mode.cpp @@ -2,7 +2,6 @@ #include #include -#include #include namespace at::autocast { @@ -37,29 +36,10 @@ namespace { using weakref_type = c10::weak_intrusive_ptr; using val_type = std::tuple; -// We maintain separate caches for gradient-enabled and gradient-disabled modes. -// This ensures that tensors cached in torch.no_grad() (with requires_grad=False) -// are not incorrectly reused in gradient-enabled contexts. -// This fixes issue #158232 while maintaining optimal performance for both modes. -static ska::flat_hash_map& get_cached_casts_grad_enabled() { - static ska::flat_hash_map cached_casts_grad_enabled; - return cached_casts_grad_enabled; +ska::flat_hash_map& get_cached_casts() { + static ska::flat_hash_map cached_casts; + return cached_casts; } - -static ska::flat_hash_map& get_cached_casts_grad_disabled() { - static ska::flat_hash_map cached_casts_grad_disabled; - return cached_casts_grad_disabled; -} - -// Helper function to get the appropriate cache based on current gradient mode. -// This allows us to cache tensors separately for grad-enabled and grad-disabled contexts, -// preventing incorrect cache hits when gradient mode changes. -static ska::flat_hash_map& get_cached_casts() { - return at::GradMode::is_enabled() ? - get_cached_casts_grad_enabled() : - get_cached_casts_grad_disabled(); -} - std::mutex cached_casts_mutex; @@ -106,9 +86,7 @@ thread_local bool cache_enabled = true; void clear_cache() { const std::lock_guard lock(cached_casts_mutex); - // Clear both caches to ensure consistent behavior regardless of current gradient mode - get_cached_casts_grad_enabled().clear(); - get_cached_casts_grad_disabled().clear(); + get_cached_casts().clear(); } int increment_nesting() { @@ -143,11 +121,6 @@ Tensor cached_cast(at::ScalarType to_type, const Tensor& arg, DeviceType device_ if (is_eligible(arg, device_type) && (arg.scalar_type() != to_type)) { // Heuristic: Do what Apex does, and cache lower_precision_fp casts of fp32 model weights (leaves). // See cached_casts declaration above for detailed strategy. - // - // We maintain separate caches for gradient-enabled and gradient-disabled modes - // (see get_cached_casts() above). This ensures correctness when mixing torch.no_grad() - // with torch.autocast(), while maintaining optimal performance for both training and inference. - // This fixes issue #158232 without any performance regression. bool can_try_cache = (to_type == get_lower_precision_fp_from_device_type(device_type) && arg.scalar_type() == at::kFloat && arg.requires_grad() && arg.is_leaf() && !arg.is_view() && cache_enabled && diff --git a/test/test_autocast.py b/test/test_autocast.py index d1c5f525b8d..19e05dd0a9d 100644 --- a/test/test_autocast.py +++ b/test/test_autocast.py @@ -384,143 +384,6 @@ class TestTorchAutocast(TestCase): with self.assertRaisesRegex(expected_exception=ValueError, expected_regex=msg): torch.autocast(device_type=dev) - @skipIfTorchDynamo() - def test_autocast_nograd_caching_issue_158232(self): - """ - Regression test for issue #158232: autocast + no_grad incompatibility - - When torch.no_grad() is nested inside torch.autocast(), the autocast cache - must not cache tensors created in the no_grad context, because they lack - gradient tracking. If cached, subsequent operations in gradient-enabled mode - would incorrectly use the no-gradient cached version. - - Before fix: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn - After fix: Should work correctly - """ - model = torch.nn.Linear(2, 2) - inp = torch.randn(8, 2) - - with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True): - # First forward pass in no_grad context (e.g., shape inference) - with torch.no_grad(): - out1 = model(inp) - self.assertFalse( - out1.requires_grad, "Output in no_grad should not require grad" - ) - - # Second forward pass with gradients enabled (e.g., training) - out2 = model(inp) - self.assertTrue( - out2.requires_grad, - "Output should require gradients after exiting no_grad", - ) - self.assertIsNotNone( - out2.grad_fn, "Output should have grad_fn after exiting no_grad" - ) - - # Backward pass should work - loss = out2.mean() - loss.backward() - - # Verify gradients were computed - self.assertIsNotNone(model.weight.grad) - self.assertIsNotNone(model.bias.grad) - - @skipIfTorchDynamo() - def test_autocast_inference_mode_interaction(self): - """ - Test that autocast works correctly with torch.inference_mode() - - InferenceMode is a stricter version of no_grad that provides additional - performance optimizations. Verify it doesn't break with autocast. - """ - model = torch.nn.Linear(2, 2) - inp = torch.randn(8, 2) - - # Test 1: inference_mode inside autocast - with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True): - with torch.inference_mode(): - out1 = model(inp) - self.assertFalse(out1.requires_grad) - self.assertEqual(out1.dtype, torch.bfloat16) - - # After exiting inference_mode, gradients should work - out2 = model(inp) - self.assertTrue(out2.requires_grad) - out2.mean().backward() - - # Test 2: autocast inside inference_mode - with torch.inference_mode(): - with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True): - out = model(inp) - self.assertFalse(out.requires_grad) - self.assertEqual(out.dtype, torch.bfloat16) - - def test_autocast_caching_still_works_with_gradients(self): - """ - Verify that autocast caching still functions correctly when gradients ARE enabled. - - This test ensures the fix for #158232 didn't break normal caching behavior. - We can't directly observe cache hits, but we verify that repeated operations - with gradients enabled work correctly. - """ - model = torch.nn.Linear(2, 2) - inp = torch.randn(8, 2) - - with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True): - # Multiple forward passes with gradients enabled - out1 = model(inp) - out2 = model(inp) - out3 = model(inp) - - # All should have gradients - self.assertTrue(out1.requires_grad) - self.assertTrue(out2.requires_grad) - self.assertTrue(out3.requires_grad) - - # All should have grad_fn - self.assertIsNotNone(out1.grad_fn) - self.assertIsNotNone(out2.grad_fn) - self.assertIsNotNone(out3.grad_fn) - - # Backward should work on all - out1.mean().backward(retain_graph=True) - out2.mean().backward(retain_graph=True) - out3.mean().backward() - - @skipIfTorchDynamo() - def test_autocast_mixed_grad_contexts(self): - """ - Test complex nesting of gradient contexts within autocast. - - This ensures the gradient mode check works correctly across - multiple transitions between gradient-enabled and disabled states. - """ - model = torch.nn.Linear(2, 2) - inp = torch.randn(8, 2) - - with torch.autocast("cpu", dtype=torch.bfloat16, enabled=True): - # Pass 1: no_grad - with torch.no_grad(): - out1 = model(inp) - self.assertFalse(out1.requires_grad) - - # Pass 2: gradients enabled - out2 = model(inp) - self.assertTrue(out2.requires_grad) - - # Pass 3: no_grad again - with torch.no_grad(): - out3 = model(inp) - self.assertFalse(out3.requires_grad) - - # Pass 4: gradients enabled again - out4 = model(inp) - self.assertTrue(out4.requires_grad) - - # Backward on gradient-enabled outputs - (out2.mean() + out4.mean()).backward() - if __name__ == "__main__": run_tests()