pytorch/torch/_dynamo/variables/sdpa.py
William Wen 79aabaf626 [3.13, dynamo] codegen PUSH_NULL when callable is codegen'd (#129172)
Significant bytecode generation API change!

The new suggested convention to generating bytecode to call a function is now to wrap instructions that push a callable to the stack with `add_push_null`, then that callable is called with `create_call_function` with `push_null=False` (see diff for examples).

In Python 3.13, NULL is now expected to be pushed after the callable. In <=3.12, the NULL was pushed before the callable.  This change abstracts away the exact placement of the NULL, but the developer must be aware that a NULL may be needed when codegen'ing a callable.

This abstraction also reduces the need for the `push_null=True` option in `create_call_function`, which removes the need to rotate a NULL to the right place on the stack with a sequence of `SWAP` instructions.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129172
Approved by: https://github.com/jansel
2024-06-22 17:25:23 +00:00

87 lines
2.8 KiB
Python

# mypy: ignore-errors
from inspect import getattr_static
from ..bytecode_transformation import create_call_function
from ..exc import Unsupported
from .base import VariableTracker
class SDPAParamsVariable(VariableTracker):
"""Represents the c++ params struct for scaled dot product attention.
This is a read-only container."""
@staticmethod
def create(tx, value, source):
from torch.backends.cuda import SDPAParams
from ..source import AttrSource
from .builder import VariableBuilder
from .torch import TorchInGraphFunctionVariable
query_var = VariableBuilder(tx, AttrSource(source, "query"))(value.query)
key_var = VariableBuilder(tx, AttrSource(source, "key"))(value.key)
value_var = VariableBuilder(tx, AttrSource(source, "value"))(value.value)
attn_mask_var = VariableBuilder(tx, AttrSource(source, "attn_mask"))(
value.attn_mask
)
dropout_var = VariableBuilder(tx, AttrSource(source, "dropout"))(value.dropout)
is_causal_var = VariableBuilder(tx, AttrSource(source, "is_causal"))(
value.is_causal
)
param_vars = [
query_var,
key_var,
value_var,
attn_mask_var,
dropout_var,
is_causal_var,
]
return TorchInGraphFunctionVariable(SDPAParams).call_function(
tx, param_vars, {}
)
def __init__(self, proxy, param_vars, **kwargs):
self.proxy = proxy
self.param_vars = param_vars
super().__init__(**kwargs)
def reconstruct(self, codegen):
assert self.source is None
assert self.param_vars is not None
codegen.add_push_null(
lambda: codegen.load_import_from("torch._C", "_SDPAParams")
)
codegen.foreach(self.param_vars)
codegen.extend_output(create_call_function(len(self.param_vars), False))
def as_proxy(self):
return self.proxy
def var_getattr(self, tx, name: str) -> VariableTracker:
import torch._C
from ..source import AttrSource
from .builder import wrap_fx_proxy
from .misc import GetAttrVariable
try:
getattr_static(torch._C._SDPAParams, name)
except AttributeError:
# Using raise from is too verbose here
raise Unsupported(
f"Unsupported torch._C._SDPAParams attribute {name}"
) from None
proxy = GetAttrVariable.create_getattr_proxy(self.as_proxy(), name)
if self.source is not None:
return wrap_fx_proxy(
tx=tx, proxy=proxy, source=AttrSource(self.source, name)
)
else:
return wrap_fx_proxy(tx=tx, proxy=proxy)
@staticmethod
def is_sdpa_params(value):
from torch.backends.cuda import SDPAParams
return value is SDPAParams