# Owner(s): ["module: inductor"] import logging import numpy as np import torch import torch._inductor import torch._inductor.fx_passes.group_batch_fusion from torch._dynamo.utils import counters from torch._inductor.test_case import run_tests, TestCase from torch.testing._internal.common_utils import IS_LINUX from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_gpu log = logging.getLogger(__name__) class TargetCPModule(torch.nn.Module): def __init__(self): super().__init__() def forward(self, x1, x2): relued = torch.relu(x1) tanhed = torch.tanh(relued) tensor = torch.matmul( tanhed, x2, ) return tensor class FeedforwardNN(torch.nn.Module): def __init__(self): super().__init__() self.fc1 = torch.nn.Linear(1, 64) self.fc2 = torch.nn.Linear(64, 64) self.fc3 = torch.nn.Linear(64, 64) self.fc4 = torch.nn.Linear(64, 1) def forward(self, x): x = torch.relu(self.fc1(x)) tanh_x = torch.tanh(x) x = torch.relu(self.fc2(x)) x = torch.relu(self.fc3(tanh_x)) x = self.fc4(x) return x class LayernormNN(torch.nn.Module): def __init__(self): super().__init__() def forward(self, input, normalized_shape, weight, bias): x = torch.nn.functional.layer_norm( input=input, normalized_shape=normalized_shape, weight=weight, bias=bias, eps=1e-5, ) return x class TestQuantization(TestCase): def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3): if len(set(ref_dict.keys())) != len(set(res_dict.keys())): return False for key1 in ref_dict.keys(): key2 = "_orig_mod." + key1 assert key2 in res_dict, f"{key1} does not exist in traced module" # if both of them are None, continue if ( not isinstance(ref_dict[key1], torch.Tensor) and not isinstance(res_dict[key2], torch.Tensor) and ref_dict[key1] is None and res_dict[key2] is None ): log.info( "None found with key1 and value 1: %s, %s, key2 and value2 %s, %s", key1, ref_dict[key1], key2, res_dict[key2], ) continue elif not torch.allclose( ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol, equal_nan=True ): log.info( "gradient mismatch for eager and compiled modules, with eager: %s and compiled: %s", ref_dict[key1], res_dict[key2], ) return False return True def compare_pred(self, module, traced, input, rtol=1e-3, atol=1e-3): ref = module(*input) res = traced(*input) self.assertEqual(ref, res, rtol=rtol, atol=atol) def compare_parameters(self, module, traced, rtol=1e-3, atol=1e-3): ref_params = dict(module.named_parameters()) res_params = dict(traced.named_parameters()) self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol, atol)) def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3): ref_grad = {key: param.grad for key, param in module.named_parameters()} res_grad = {key: param.grad for key, param in traced.named_parameters()} self.assertTrue( self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol) ) @requires_gpu() @torch._inductor.config.patch( pre_grad_fusion_options={}, post_grad_fusion_options={ "activation_quantization_aten_pass": { "quant_type": "torch.float8_e5m2", "use_scaling": True, "size_in_mb": 0.0, "exclude_primals": True, "allowed_dtypes": "torch.bfloat16;torch.float32", }, }, ) def test_activation_quantization_aten_with_scaling(self): counters.clear() module = TargetCPModule().to(GPU_TYPE) input = [ torch.rand( (16, 10), requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16 ), torch.rand( (10, 16), requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16 ), ] traced = torch.compile(module) ref = module(*input) res = traced(*input) self.compare_pred(module, traced, input) ref.sum().backward() res.sum().backward() self.compare_parameters(module, traced) self.compare_gradients(module, traced) self.assertEqual( counters["inductor"]["activation_quantization_fwd_aten_pass"], 1 ) self.assertEqual( counters["inductor"]["activation_quantization_bwd_aten_pass"], 1 ) self.assertTrue(torch.allclose(ref, res)) counters.clear() module = FeedforwardNN().to(GPU_TYPE) X = np.linspace(-10, 10, 100).reshape(-1, 1).astype(np.float32) input = [ torch.from_numpy(X).to(GPU_TYPE), ] traced = torch.compile(module) ref = module(*input) res = traced(*input) self.compare_pred(module, traced, input) ref.sum().backward() res.sum().backward() self.compare_parameters(module, traced) self.compare_gradients(module, traced) self.assertEqual( counters["inductor"]["activation_quantization_fwd_aten_pass"], 1 ) self.assertEqual( counters["inductor"]["activation_quantization_bwd_aten_pass"], 1 ) self.assertTrue(torch.allclose(ref, res)) counters.clear() @requires_gpu() @torch._inductor.config.patch( pre_grad_fusion_options={}, post_grad_fusion_options={ "activation_quantization_aten_pass": { "quant_type": "torch.float8_e5m2", "use_scaling": False, "size_in_mb": 0.0, "exclude_primals": True, "allowed_dtypes": "torch.bfloat16;torch.float32", }, }, ) def test_activation_quantization_aten_without_scaling(self): counters.clear() module = LayernormNN().to(GPU_TYPE) normalized_shape = [256] input = [ torch.randn( (1, 3, 256), requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16 ), normalized_shape, torch.randn( *normalized_shape, requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16, ), torch.randn( *normalized_shape, requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16, ), ] traced = torch.compile(module) ref = module(*input) res = traced(*input) self.compare_pred(module, traced, input) ref.sum().backward() res.sum().backward() self.compare_parameters(module, traced) self.compare_gradients(module, traced) self.assertEqual( counters["inductor"]["activation_quantization_fwd_aten_pass"], 1 ) self.assertEqual( counters["inductor"]["activation_quantization_bwd_aten_pass"], 1 ) self.assertTrue(torch.allclose(ref, res)) counters.clear() if __name__ == "__main__": if IS_LINUX and HAS_GPU: run_tests()