[nativert] Downcast triton double arguments to floats (#166620)

This diff tries to fix a limitation in Sigmoid + Triton interaction, where float arguments are not correctly passed. NativeRT passes float arguments as double, while triton kernels were reading as a float, resulting in wrong values.

---

## Limitations in (de)seriazliation

In triton, float arguments to a kernel are encoded as "fp32" ([code](https://github.com/triton-lang/triton-cpu/blob/main-merged/python/triton/runtime/jit.py#L310-L326)):
```
        elif isinstance(arg, float):
            return ("fp32", None)
```
But it seems like that torch export serde uses double ([code](d2eff5d454/torch/_export/serde/export_schema.thrift (L149))) because Thrift only has the double type:
```
union Argument {
  10: bool as_none;
  20: TensorArgument as_tensor;
  30: list<TensorArgument> as_tensors;
  50: i64 as_int;
  70: list<i64> as_ints;
  80: double as_float;   ===> actually double
...
```
`TritonKernel` constructor loads attributes from a node, where `Constant` represents the variant type. And it only has `double` ([code](d2eff5d454/torch/nativert/graph/Graph.h (L86))):
```
using Constant = std::variant<
    None,
    int64_t,
    std::vector<int64_t>,
    double,    ===> triton float is loaded as double
```

So, NativeRT passes float arguments (originally in Triton) as double to triton kernels. But, all of the triton backends (nvidia, amd and cpu) are reading them as float because the signature still says `fp32`.

D84423898 was the current workaround: wrapping float arguments with tensors.

## The Fix

Fixing the thrift definition isn't viable because Thrift only supports double type. It's also possible to fix on the triton side: it can downcast from double to float. But I needed to fix all backends.

Instead, I think this diff would be the most effective way: when building `TritonKernel`, have downcasted float values, right after loading double arguments.

Test Plan:
```
buck test fbcode//mode/opt-amd-gpu fbcode//caffe2/test:test_export --
```

Differential Revision: D85747160

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166620
Approved by: https://github.com/XueningXu
This commit is contained in:
Minjang Kim 2025-10-31 03:52:20 +00:00 committed by PyTorch MergeBot
parent 267d0197bf
commit 85b035ca9c
3 changed files with 50 additions and 16 deletions

View File

@ -600,6 +600,8 @@ def forward(self, x):
in_ptr1,
out_ptr,
n_elements,
fval,
ival,
BLOCK_SIZE: "tl.constexpr",
):
pid = tl.program_id(axis=0)
@ -608,7 +610,7 @@ def forward(self, x):
mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y
output = x + y + fval + ival
tl.store(out_ptr + offsets, output, mask=mask)
def custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
@ -618,7 +620,9 @@ def forward(self, x):
def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16)
wrap_triton(add_kernel)[grid](
x, y, output, n_elements, 3.14, 42, BLOCK_SIZE=16
)
return output
@ -633,7 +637,9 @@ def forward(self, x):
def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16, num_warps=8)
wrap_triton(add_kernel)[grid](
x, y, output, n_elements, 3.14, 42, BLOCK_SIZE=16, num_warps=8
)
return output
@ -661,35 +667,44 @@ def forward(self, x):
self.assertIsNotNone(triton_node)
args = []
kwargs = []
kwargs = {}
for arg in triton_node.inputs:
if arg.kind == ArgumentKind.POSITIONAL:
args.append(arg.arg)
elif arg.kind == ArgumentKind.KEYWORD:
kwargs.append(arg.arg)
kwargs[arg.name] = arg.arg
self.assertEqual(len(args), 4)
self.assertEqual(len(kwargs), 5)
self.assertEqual(len(args), 6)
# Always: name, grid, output_indices and num_warps are
# Triton version dependent: num_cpu_threads, shared_memory_bytes
self.assertTrue(len(kwargs) >= 4)
for i in range(3):
self.assertIsNotNone(args[i].as_tensor)
self.assertEqual(args[3].as_int, 3)
kernel_name = kwargs[0].as_string
self.assertAlmostEqual(args[4].as_float, 3.14, places=2)
self.assertEqual(args[5].as_int, 42)
kernel_name = kwargs["name"].as_string
symbol_name = kernel_name.rpartition("_")[0]
self.assertEqual(symbol_name, "add_kernel") # symbol name
self.assertEqual(kwargs[1].as_ints, [1, 1, 1]) # grid
self.assertEqual(kwargs[2].as_ints, [2]) # output indices
self.assertEqual(symbol_name, "add_kernel")
self.assertEqual(kwargs["grid"].as_ints, [1, 1, 1])
self.assertEqual(kwargs["output_indices"].as_ints, [2])
self.assertEqual(
kwargs[3].as_int, 8 if isinstance(m, MyModelAutotune) else 4
) # num warps
self.assertEqual(kwargs[4].as_int, 0) # shared mem bytes
kwargs["num_warps"].as_int, 8 if isinstance(m, MyModelAutotune) else 4
)
if "num_cpu_threads" in kwargs:
self.assertEqual(kwargs["num_cpu_threads"].as_int, 0)
if "shared_memory_bytes" in kwargs:
self.assertEqual(kwargs["shared_memory_bytes"].as_int, 0)
self.assertEqual(len(triton_node.outputs), 1)
self.assertIsNotNone(triton_node.outputs[0].as_tensors)
self.assertEqual(
len(triton_node.outputs[0].as_tensors), len(kwargs[2].as_ints)
len(triton_node.outputs[0].as_tensors),
len(kwargs["output_indices"].as_ints),
)
self.assertEqual(triton_node.outputs[0].as_tensors[0].name, "getitem")

View File

@ -39,13 +39,30 @@ TritonKernel::TritonKernel(
std::string kernel_name{};
std::string symbol_name{};
bool found_grid = false;
// To prevent vector reallocation and dangling pointers
size_t num_double_attrs = 0;
for (const auto& attr : node_->attributes()) {
if (attr.name.empty() && std::holds_alternative<double>(attr.value)) {
++num_double_attrs;
}
}
float_attrs_.reserve(num_double_attrs);
for (const auto& attr : node_->attributes()) {
if (attr.name.empty()) {
attr_ptrs_.emplace_back(std::visit(
[](auto&& arg) -> void* {
[this](auto&& arg) -> void* {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, None>) {
return nullptr;
} else if constexpr (std::is_same_v<T, double>) {
// Triton always uses fp32 for floats. See create_specialize_impl
// in jit.py. However, due to the Thrift schema, floats are
// serialized as doubles here. But, Triton kernels read them as
// floats. So, we need to downcast double to float here.
float_attrs_.push_back(static_cast<float>(arg));
return static_cast<void*>(&float_attrs_.back());
}
return static_cast<void*>(const_cast<T*>(&arg));
},

View File

@ -24,6 +24,8 @@ class TritonKernel : public OpKernel {
// unnamed node attributes will be passed as arguments to the kernel
std::vector<void*> attr_ptrs_;
// Storage for float attributes that were serialized as doubles
std::vector<float> float_attrs_;
std::vector<int64_t> output_indices_;
LaunchParams launch_params_;
};