mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[nativert] Downcast triton double arguments to floats (#166620)
This diff tries to fix a limitation in Sigmoid + Triton interaction, where float arguments are not correctly passed. NativeRT passes float arguments as double, while triton kernels were reading as a float, resulting in wrong values. --- ## Limitations in (de)seriazliation In triton, float arguments to a kernel are encoded as "fp32" ([code](https://github.com/triton-lang/triton-cpu/blob/main-merged/python/triton/runtime/jit.py#L310-L326)): ``` elif isinstance(arg, float): return ("fp32", None) ``` But it seems like that torch export serde uses double ([code](d2eff5d454/torch/_export/serde/export_schema.thrift (L149))) because Thrift only has the double type: ``` union Argument { 10: bool as_none; 20: TensorArgument as_tensor; 30: list<TensorArgument> as_tensors; 50: i64 as_int; 70: list<i64> as_ints; 80: double as_float; ===> actually double ... ``` `TritonKernel` constructor loads attributes from a node, where `Constant` represents the variant type. And it only has `double` ([code](d2eff5d454/torch/nativert/graph/Graph.h (L86))): ``` using Constant = std::variant< None, int64_t, std::vector<int64_t>, double, ===> triton float is loaded as double ``` So, NativeRT passes float arguments (originally in Triton) as double to triton kernels. But, all of the triton backends (nvidia, amd and cpu) are reading them as float because the signature still says `fp32`. D84423898 was the current workaround: wrapping float arguments with tensors. ## The Fix Fixing the thrift definition isn't viable because Thrift only supports double type. It's also possible to fix on the triton side: it can downcast from double to float. But I needed to fix all backends. Instead, I think this diff would be the most effective way: when building `TritonKernel`, have downcasted float values, right after loading double arguments. Test Plan: ``` buck test fbcode//mode/opt-amd-gpu fbcode//caffe2/test:test_export -- ``` Differential Revision: D85747160 Pull Request resolved: https://github.com/pytorch/pytorch/pull/166620 Approved by: https://github.com/XueningXu
This commit is contained in:
parent
267d0197bf
commit
85b035ca9c
|
|
@ -600,6 +600,8 @@ def forward(self, x):
|
|||
in_ptr1,
|
||||
out_ptr,
|
||||
n_elements,
|
||||
fval,
|
||||
ival,
|
||||
BLOCK_SIZE: "tl.constexpr",
|
||||
):
|
||||
pid = tl.program_id(axis=0)
|
||||
|
|
@ -608,7 +610,7 @@ def forward(self, x):
|
|||
mask = offsets < n_elements
|
||||
x = tl.load(in_ptr0 + offsets, mask=mask)
|
||||
y = tl.load(in_ptr1 + offsets, mask=mask)
|
||||
output = x + y
|
||||
output = x + y + fval + ival
|
||||
tl.store(out_ptr + offsets, output, mask=mask)
|
||||
|
||||
def custom_add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
|
|
@ -618,7 +620,9 @@ def forward(self, x):
|
|||
def grid(meta):
|
||||
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
|
||||
wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16)
|
||||
wrap_triton(add_kernel)[grid](
|
||||
x, y, output, n_elements, 3.14, 42, BLOCK_SIZE=16
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
|
@ -633,7 +637,9 @@ def forward(self, x):
|
|||
def grid(meta):
|
||||
return (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
|
||||
|
||||
wrap_triton(add_kernel)[grid](x, y, output, n_elements, 16, num_warps=8)
|
||||
wrap_triton(add_kernel)[grid](
|
||||
x, y, output, n_elements, 3.14, 42, BLOCK_SIZE=16, num_warps=8
|
||||
)
|
||||
|
||||
return output
|
||||
|
||||
|
|
@ -661,35 +667,44 @@ def forward(self, x):
|
|||
self.assertIsNotNone(triton_node)
|
||||
|
||||
args = []
|
||||
kwargs = []
|
||||
kwargs = {}
|
||||
|
||||
for arg in triton_node.inputs:
|
||||
if arg.kind == ArgumentKind.POSITIONAL:
|
||||
args.append(arg.arg)
|
||||
elif arg.kind == ArgumentKind.KEYWORD:
|
||||
kwargs.append(arg.arg)
|
||||
kwargs[arg.name] = arg.arg
|
||||
|
||||
self.assertEqual(len(args), 4)
|
||||
self.assertEqual(len(kwargs), 5)
|
||||
self.assertEqual(len(args), 6)
|
||||
# Always: name, grid, output_indices and num_warps are
|
||||
# Triton version dependent: num_cpu_threads, shared_memory_bytes
|
||||
self.assertTrue(len(kwargs) >= 4)
|
||||
|
||||
for i in range(3):
|
||||
self.assertIsNotNone(args[i].as_tensor)
|
||||
|
||||
self.assertEqual(args[3].as_int, 3)
|
||||
kernel_name = kwargs[0].as_string
|
||||
self.assertAlmostEqual(args[4].as_float, 3.14, places=2)
|
||||
self.assertEqual(args[5].as_int, 42)
|
||||
kernel_name = kwargs["name"].as_string
|
||||
symbol_name = kernel_name.rpartition("_")[0]
|
||||
self.assertEqual(symbol_name, "add_kernel") # symbol name
|
||||
self.assertEqual(kwargs[1].as_ints, [1, 1, 1]) # grid
|
||||
self.assertEqual(kwargs[2].as_ints, [2]) # output indices
|
||||
self.assertEqual(symbol_name, "add_kernel")
|
||||
self.assertEqual(kwargs["grid"].as_ints, [1, 1, 1])
|
||||
self.assertEqual(kwargs["output_indices"].as_ints, [2])
|
||||
self.assertEqual(
|
||||
kwargs[3].as_int, 8 if isinstance(m, MyModelAutotune) else 4
|
||||
) # num warps
|
||||
self.assertEqual(kwargs[4].as_int, 0) # shared mem bytes
|
||||
kwargs["num_warps"].as_int, 8 if isinstance(m, MyModelAutotune) else 4
|
||||
)
|
||||
|
||||
if "num_cpu_threads" in kwargs:
|
||||
self.assertEqual(kwargs["num_cpu_threads"].as_int, 0)
|
||||
if "shared_memory_bytes" in kwargs:
|
||||
self.assertEqual(kwargs["shared_memory_bytes"].as_int, 0)
|
||||
|
||||
self.assertEqual(len(triton_node.outputs), 1)
|
||||
self.assertIsNotNone(triton_node.outputs[0].as_tensors)
|
||||
self.assertEqual(
|
||||
len(triton_node.outputs[0].as_tensors), len(kwargs[2].as_ints)
|
||||
len(triton_node.outputs[0].as_tensors),
|
||||
len(kwargs["output_indices"].as_ints),
|
||||
)
|
||||
self.assertEqual(triton_node.outputs[0].as_tensors[0].name, "getitem")
|
||||
|
||||
|
|
|
|||
|
|
@ -39,13 +39,30 @@ TritonKernel::TritonKernel(
|
|||
std::string kernel_name{};
|
||||
std::string symbol_name{};
|
||||
bool found_grid = false;
|
||||
|
||||
// To prevent vector reallocation and dangling pointers
|
||||
size_t num_double_attrs = 0;
|
||||
for (const auto& attr : node_->attributes()) {
|
||||
if (attr.name.empty() && std::holds_alternative<double>(attr.value)) {
|
||||
++num_double_attrs;
|
||||
}
|
||||
}
|
||||
float_attrs_.reserve(num_double_attrs);
|
||||
|
||||
for (const auto& attr : node_->attributes()) {
|
||||
if (attr.name.empty()) {
|
||||
attr_ptrs_.emplace_back(std::visit(
|
||||
[](auto&& arg) -> void* {
|
||||
[this](auto&& arg) -> void* {
|
||||
using T = std::decay_t<decltype(arg)>;
|
||||
if constexpr (std::is_same_v<T, None>) {
|
||||
return nullptr;
|
||||
} else if constexpr (std::is_same_v<T, double>) {
|
||||
// Triton always uses fp32 for floats. See create_specialize_impl
|
||||
// in jit.py. However, due to the Thrift schema, floats are
|
||||
// serialized as doubles here. But, Triton kernels read them as
|
||||
// floats. So, we need to downcast double to float here.
|
||||
float_attrs_.push_back(static_cast<float>(arg));
|
||||
return static_cast<void*>(&float_attrs_.back());
|
||||
}
|
||||
return static_cast<void*>(const_cast<T*>(&arg));
|
||||
},
|
||||
|
|
|
|||
|
|
@ -24,6 +24,8 @@ class TritonKernel : public OpKernel {
|
|||
|
||||
// unnamed node attributes will be passed as arguments to the kernel
|
||||
std::vector<void*> attr_ptrs_;
|
||||
// Storage for float attributes that were serialized as doubles
|
||||
std::vector<float> float_attrs_;
|
||||
std::vector<int64_t> output_indices_;
|
||||
LaunchParams launch_params_;
|
||||
};
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user