diff --git a/torch/_dynamo/eval_frame.py b/torch/_dynamo/eval_frame.py index 1188bfd74fc..6b500a87bc3 100644 --- a/torch/_dynamo/eval_frame.py +++ b/torch/_dynamo/eval_frame.py @@ -350,6 +350,16 @@ def get_compiler_fn(compiler_fn): def lookup_backend(compiler_fn): """Expand backend strings to functions""" if compiler_fn == "inductor": + if torch.cuda.is_available(): + if ( + torch.backends.cuda.matmul.allow_tf32 is False + and torch.cuda.get_device_capability() >= (8, 0) + ): + warnings.warn( + "TensorFloat32 tensor cores for float32 matrix multiplication available but not enabled." + "Consider setting `torch.set_float32_matmul_precision('high')`" + ) + compiler_fn = import_module(f"{config.inductor_import}.compile_fx").compile_fx elif isinstance(compiler_fn, str): from .optimizations import BACKENDS