pytorch/torch/_dynamo/variables/sdpa.py
PyTorch MergeBot 7d39401fa0 Revert "[BE][Typing][Dynamo] Type misc files in torch/_dynamo/variables/ (#166569)"
This reverts commit f1e4c42b6e.

Reverted https://github.com/pytorch/pytorch/pull/166569 on behalf of https://github.com/pytorch-auto-revert due to Reverted automatically by pytorch's autorevert, to avoid this behaviour add the tag autorevert: disable ([comment](https://github.com/pytorch/pytorch/pull/166569#issuecomment-3471180280))
2025-10-31 03:31:01 +00:00

87 lines
2.5 KiB
Python

# mypy: ignore-errors
from inspect import getattr_static
from typing import TYPE_CHECKING
from ..bytecode_transformation import create_call_function
from ..exc import Unsupported
from ..source import AttrSource
from .base import VariableTracker
if TYPE_CHECKING:
from torch._dynamo.codegen import PyCodegen
from torch._dynamo.symbolic_convert import InstructionTranslator
PARAM_NAMES = [
"query",
"key",
"value",
"attn_mask",
"dropout",
"is_causal",
"enable_gqa",
]
class SDPAParamsVariable(VariableTracker):
"""Represents the c++ params struct for scaled dot product attention.
This is a read-only container."""
@staticmethod
def create(tx: "InstructionTranslator", value, source):
from torch.backends.cuda import SDPAParams
from .torch import TorchInGraphFunctionVariable
params = [
VariableTracker.build(tx, getattr(value, p), AttrSource(source, p))
for p in PARAM_NAMES
]
return TorchInGraphFunctionVariable(SDPAParams).call_function(tx, params, {})
def __init__(self, proxy, param_vars, **kwargs) -> None:
self.proxy = proxy
self.param_vars = param_vars
super().__init__(**kwargs)
def reconstruct(self, codegen: "PyCodegen"):
assert self.source is None
assert self.param_vars is not None
codegen.add_push_null(
lambda: codegen.load_import_from("torch._C", "_SDPAParams")
)
codegen.foreach(self.param_vars)
codegen.extend_output(create_call_function(len(self.param_vars), False))
def as_proxy(self):
return self.proxy
def var_getattr(self, tx: "InstructionTranslator", name: str) -> VariableTracker:
import torch._C
from .builder import wrap_fx_proxy
from .misc import GetAttrVariable
try:
getattr_static(torch._C._SDPAParams, name)
except AttributeError:
# Using raise from is too verbose here
raise Unsupported(
f"Unsupported torch._C._SDPAParams attribute {name}"
) from None
proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name)
if self.source is not None:
return wrap_fx_proxy(
tx=tx, proxy=proxy, source=AttrSource(self.source, name)
)
else:
return wrap_fx_proxy(tx=tx, proxy=proxy)
@staticmethod
def is_sdpa_params(value):
from torch.backends.cuda import SDPAParams
return value is SDPAParams