A user wants to use the flop counter with meta devices. This previously caused problems for SDPA+NJT:
1. autocast check: `torch.is_autocast_enabled("meta")` fails because `meta` is not valid for autocasting. If we skip this, we run into the next error
2. math backend: conversion to NST requires getting concrete offsets in a list of python integers, which doesn't work on a meta tensor b2eb0e8c6a/torch/nested/_internal/sdpa.py (L809-L815)
3. (fixed in the previous PR, #134288) - if we force using flash attention backend for flop counting, `_flash_attention_forward` previously didn't support meta tensors.
In this PR, we check specifically for FlopCounterMode, and, if it's enabled and combined with meta tensors, (a) skip autocasting and (b) force it down the flash attention path. This isn't generally safe for tracing (e.g. if you actually care which kernels you are running), but in the absence of actual device information, we have to make some assumptions. By specifically checking for FlopCounterMode, this should reduce the chance of unintended side effects for other meta tensor users.
Note: fake tensor would solve a bunch of these issues, but it's not a viable solution right now for the user.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134289
Approved by: https://github.com/soulitzer
ghstack dependencies: #134288