[nativert] Downcast triton double arguments to floats (#166620)

This diff tries to fix a limitation in Sigmoid + Triton interaction, where float arguments are not correctly passed. NativeRT passes float arguments as double, while triton kernels were reading as a float, resulting in wrong values.

---

## Limitations in (de)seriazliation

In triton, float arguments to a kernel are encoded as "fp32" ([code](https://github.com/triton-lang/triton-cpu/blob/main-merged/python/triton/runtime/jit.py#L310-L326)):
```
        elif isinstance(arg, float):
            return ("fp32", None)
```
But it seems like that torch export serde uses double ([code](d2eff5d454/torch/_export/serde/export_schema.thrift (L149))) because Thrift only has the double type:
```
union Argument {
  10: bool as_none;
  20: TensorArgument as_tensor;
  30: list<TensorArgument> as_tensors;
  50: i64 as_int;
  70: list<i64> as_ints;
  80: double as_float;   ===> actually double
...
```
`TritonKernel` constructor loads attributes from a node, where `Constant` represents the variant type. And it only has `double` ([code](d2eff5d454/torch/nativert/graph/Graph.h (L86))):
```
using Constant = std::variant<
    None,
    int64_t,
    std::vector<int64_t>,
    double,    ===> triton float is loaded as double
```

So, NativeRT passes float arguments (originally in Triton) as double to triton kernels. But, all of the triton backends (nvidia, amd and cpu) are reading them as float because the signature still says `fp32`.

D84423898 was the current workaround: wrapping float arguments with tensors.

## The Fix

Fixing the thrift definition isn't viable because Thrift only supports double type. It's also possible to fix on the triton side: it can downcast from double to float. But I needed to fix all backends.

Instead, I think this diff would be the most effective way: when building `TritonKernel`, have downcasted float values, right after loading double arguments.

Test Plan:
```
buck test fbcode//mode/opt-amd-gpu fbcode//caffe2/test:test_export --
```

Differential Revision: D85747160

Pull Request resolved: https://github.com/pytorch/pytorch/pull/166620
Approved by: https://github.com/XueningXu
This commit is contained in:
Minjang Kim 2025-10-31 03:52:20 +00:00 committed by PyTorch MergeBot
parent 267d0197bf
commit 85b035ca9c
3 changed files with 50 additions and 16 deletions

View File

@ -600,6 +600,8 @@ def forward(self, x):
in_ptr1, in_ptr1,
out_ptr, out_ptr,
n_elements, n_elements,
fval,
ival,
BLOCK_SIZE: "tl.constexpr", BLOCK_SIZE: "tl.constexpr",
): ):
pid = tl.program_id(axis=0) pid = tl.program_id(axis=0)
@ -608,7 +610,7 @@ def forward(self, x):
mask = offsets < n_elements mask = offsets < n_elements
x = tl.load(in_ptr0 + offsets, mask=mask) x = tl.load(in_ptr0 + offsets, mask=mask)
y = tl.load(in_ptr1 + offsets, mask=mask) y = tl.load(in_ptr1 + offsets, mask=mask)
output = x + y output = x + y + fval + ival
tl.store(out_ptr + offsets, output, mask=mask) tl.store(out_ptr + offsets, output, mask=mask)
def custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: def custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
@ -618,7 +620,9 @@ def forward(self, x):
def grid(meta): def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16) wrap_triton(add_kernel)[grid](
x, y, output, n_elements, 3.14, 42, BLOCK_SIZE=16
)
return output return output
@ -633,7 +637,9 @@ def forward(self, x):
def grid(meta): def grid(meta):
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16, num_warps=8) wrap_triton(add_kernel)[grid](
x, y, output, n_elements, 3.14, 42, BLOCK_SIZE=16, num_warps=8
)
return output return output
@ -661,35 +667,44 @@ def forward(self, x):
self.assertIsNotNone(triton_node) self.assertIsNotNone(triton_node)
args = [] args = []
kwargs = [] kwargs = {}
for arg in triton_node.inputs: for arg in triton_node.inputs:
if arg.kind == ArgumentKind.POSITIONAL: if arg.kind == ArgumentKind.POSITIONAL:
args.append(arg.arg) args.append(arg.arg)
elif arg.kind == ArgumentKind.KEYWORD: elif arg.kind == ArgumentKind.KEYWORD:
kwargs.append(arg.arg) kwargs[arg.name] = arg.arg
self.assertEqual(len(args), 4) self.assertEqual(len(args), 6)
self.assertEqual(len(kwargs), 5) # Always: name, grid, output_indices and num_warps are
# Triton version dependent: num_cpu_threads, shared_memory_bytes
self.assertTrue(len(kwargs) >= 4)
for i in range(3): for i in range(3):
self.assertIsNotNone(args[i].as_tensor) self.assertIsNotNone(args[i].as_tensor)
self.assertEqual(args[3].as_int, 3) self.assertEqual(args[3].as_int, 3)
kernel_name = kwargs[0].as_string self.assertAlmostEqual(args[4].as_float, 3.14, places=2)
self.assertEqual(args[5].as_int, 42)
kernel_name = kwargs["name"].as_string
symbol_name = kernel_name.rpartition("_")[0] symbol_name = kernel_name.rpartition("_")[0]
self.assertEqual(symbol_name, "add_kernel") # symbol name self.assertEqual(symbol_name, "add_kernel")
self.assertEqual(kwargs[1].as_ints, [1, 1, 1]) # grid self.assertEqual(kwargs["grid"].as_ints, [1, 1, 1])
self.assertEqual(kwargs[2].as_ints, [2]) # output indices self.assertEqual(kwargs["output_indices"].as_ints, [2])
self.assertEqual( self.assertEqual(
kwargs[3].as_int, 8 if isinstance(m, MyModelAutotune) else 4 kwargs["num_warps"].as_int, 8 if isinstance(m, MyModelAutotune) else 4
) # num warps )
self.assertEqual(kwargs[4].as_int, 0) # shared mem bytes
if "num_cpu_threads" in kwargs:
self.assertEqual(kwargs["num_cpu_threads"].as_int, 0)
if "shared_memory_bytes" in kwargs:
self.assertEqual(kwargs["shared_memory_bytes"].as_int, 0)
self.assertEqual(len(triton_node.outputs), 1) self.assertEqual(len(triton_node.outputs), 1)
self.assertIsNotNone(triton_node.outputs[0].as_tensors) self.assertIsNotNone(triton_node.outputs[0].as_tensors)
self.assertEqual( self.assertEqual(
len(triton_node.outputs[0].as_tensors), len(kwargs[2].as_ints) len(triton_node.outputs[0].as_tensors),
len(kwargs["output_indices"].as_ints),
) )
self.assertEqual(triton_node.outputs[0].as_tensors[0].name, "getitem") self.assertEqual(triton_node.outputs[0].as_tensors[0].name, "getitem")

View File

@ -39,13 +39,30 @@ TritonKernel::TritonKernel(
std::string kernel_name{}; std::string kernel_name{};
std::string symbol_name{}; std::string symbol_name{};
bool found_grid = false; bool found_grid = false;
// To prevent vector reallocation and dangling pointers
size_t num_double_attrs = 0;
for (const auto& attr : node_->attributes()) {
if (attr.name.empty() && std::holds_alternative<double>(attr.value)) {
++num_double_attrs;
}
}
float_attrs_.reserve(num_double_attrs);
for (const auto& attr : node_->attributes()) { for (const auto& attr : node_->attributes()) {
if (attr.name.empty()) { if (attr.name.empty()) {
attr_ptrs_.emplace_back(std::visit( attr_ptrs_.emplace_back(std::visit(
[](auto&& arg) -> void* { [this](auto&& arg) -> void* {
using T = std::decay_t<decltype(arg)>; using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, None>) { if constexpr (std::is_same_v<T, None>) {
return nullptr; return nullptr;
} else if constexpr (std::is_same_v<T, double>) {
// Triton always uses fp32 for floats. See create_specialize_impl
// in jit.py. However, due to the Thrift schema, floats are
// serialized as doubles here. But, Triton kernels read them as
// floats. So, we need to downcast double to float here.
float_attrs_.push_back(static_cast<float>(arg));
return static_cast<void*>(&float_attrs_.back());
} }
return static_cast<void*>(const_cast<T*>(&arg)); return static_cast<void*>(const_cast<T*>(&arg));
}, },

View File

@ -24,6 +24,8 @@ class TritonKernel : public OpKernel {
// unnamed node attributes will be passed as arguments to the kernel // unnamed node attributes will be passed as arguments to the kernel
std::vector<void*> attr_ptrs_; std::vector<void*> attr_ptrs_;
// Storage for float attributes that were serialized as doubles
std::vector<float> float_attrs_;
std::vector<int64_t> output_indices_; std::vector<int64_t> output_indices_;
LaunchParams launch_params_; LaunchParams launch_params_;
}; };