kjt pytree registration (#161114)

Differential Revision: D80656182

Pull Request resolved: https://github.com/pytorch/pytorch/pull/161114
Approved by: https://github.com/henryoier
This commit is contained in:
Georgia Phillips 2025-09-13 03:57:40 +00:00 committed by PyTorch MergeBot
parent 49d30f9a23
commit 783985e9fe
2 changed files with 339 additions and 0 deletions

View File

@ -4,6 +4,7 @@
#include <fmt/format.h> #include <fmt/format.h>
#include <c10/util/Enumerate.h> #include <c10/util/Enumerate.h>
#include <torch/custom_class.h>
#include <torch/nativert/detail/ITree.h> #include <torch/nativert/detail/ITree.h>
namespace torch::nativert::detail { namespace torch::nativert::detail {
@ -1147,4 +1148,200 @@ TEST(ITreeTest, ToAtenType) {
c10::TypeKind::AnyType); c10::TypeKind::AnyType);
} }
TEST(ITreeTest, KeyedJaggedTensorUnflatten) {
// Test KeyedJaggedTensor pytree node registration
// KeyedJaggedTensor has 6 tensor fields: _values, _weights, _lengths,
// _offsets, _stride_per_key_per_rank, _inverse_indices
auto jsonSpec = R"(
[
1,
{
"type": "torchrec.sparse.jagged_tensor.KeyedJaggedTensor",
"context": "[\"key1\", \"key2\"]",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
}
]
)";
auto [graph, valuePtrs] = makeValues(6);
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
// Create mock tensor values for the 6 fields
std::vector<c10::IValue> flats = {
c10::IValue(1), // _values
c10::IValue(2), // _weights
c10::IValue(3), // _lengths
c10::IValue(4), // _offsets
c10::IValue(5), // _stride_per_key_per_rank
c10::IValue(6), // _inverse_indices tensor part
};
// Test unflatten - this will create a generic tuple since we don't have
// the actual KeyedJaggedTensor constructor available in tests
auto itree = itreeUnflatten(flats, spec);
EXPECT_TRUE(itree.isTuple());
EXPECT_EQ(itree.toTupleRef().elements().size(), 6);
// Verify the values match what we put in
for (size_t i = 0; i < 6; i++) {
EXPECT_EQ(itree.toTupleRef().elements()[i], flats[i]);
}
// Verify spec has correct number of children and structure
EXPECT_EQ(spec.children().size(), 6);
EXPECT_EQ(spec.numIValues(), 6);
EXPECT_FALSE(spec.isIValue());
EXPECT_EQ(
spec.uniformName(), "torchrec.sparse.jagged_tensor.KeyedJaggedTensor");
}
TEST(ITreeTest, KeyedJaggedTensorNodeRegistration) {
// Test that KeyedJaggedTensor pytree node is properly registered
// Verify the KeyedJaggedTensor node is in the registry by attempting
// to load a spec that references it
auto jsonSpec = R"(
[
1,
{
"type": "torchrec.sparse.jagged_tensor.KeyedJaggedTensor",
"context": "[\"key1\", \"key2\"]",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
}
]
)";
auto [graph, valuePtrs] = makeValues(6);
// This should not throw - if KeyedJaggedTensor wasn't registered,
// we'd get an exception about "Unknown pytree node type"
EXPECT_NO_THROW({
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
// Verify the spec loaded correctly
EXPECT_FALSE(spec.isIValue());
EXPECT_EQ(
spec.uniformName(), "torchrec.sparse.jagged_tensor.KeyedJaggedTensor");
EXPECT_EQ(spec.children().size(), 6);
EXPECT_EQ(spec.numIValues(), 6);
// Verify context is parsed correctly
EXPECT_FALSE(spec.context().is_null());
EXPECT_TRUE(spec.context().is_array());
EXPECT_EQ(spec.context().size(), 2);
});
}
TEST(ITreeTest, JaggedTensorNodeRegistration) {
// Test that JaggedTensor pytree node is also properly registered
auto jsonSpec = R"(
[
1,
{
"type": "torchrec.sparse.jagged_tensor.JaggedTensor",
"context": "null",
"children_spec": [
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
},
{
"type": null,
"context": null,
"children_spec": []
}
]
}
]
)";
auto [graph, valuePtrs] = makeValues(4);
// This should not throw - if JaggedTensor wasn't registered,
// we'd get an exception about "Unknown pytree node type"
EXPECT_NO_THROW({
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
// Verify the spec loaded correctly
EXPECT_FALSE(spec.isIValue());
EXPECT_EQ(spec.uniformName(), "torchrec.sparse.jagged_tensor.JaggedTensor");
EXPECT_EQ(spec.children().size(), 4);
EXPECT_EQ(spec.numIValues(), 4);
});
}
} // namespace torch::nativert::detail } // namespace torch::nativert::detail

View File

@ -172,6 +172,148 @@ class PytreeNodeRegistry {
registerNode( registerNode(
"torch.fx.immutable_collections.immutable_dict", "torch.fx.immutable_collections.immutable_dict",
getNodeDef("builtins.dict")); getNodeDef("builtins.dict"));
// Register JaggedTensor pytree node
registerNode(
"torchrec.sparse.jagged_tensor.JaggedTensor",
NodeDef{
[](const c10::IValue& nested,
const ITreeSpec& spec,
std::vector<c10::IValue>& ivalues) {
// JaggedTensor has 4 fields: _values, _weights, _lengths,
// _offsets All fields are optional torch.Tensor except _values
TORCH_CHECK(nested.isObject(), "Expected JaggedTensor object");
const auto& obj = nested.toObjectRef();
// Extract the tensor fields in order: _values, _weights,
// _lengths, _offsets
TORCH_CHECK(
spec.children().size() == 4,
"JaggedTensor should have 4 children");
// Flatten each tensor field
itreeFlatten(obj.getAttr("_values"), spec.children(0), ivalues);
itreeFlatten(obj.getAttr("_weights"), spec.children(1), ivalues);
itreeFlatten(obj.getAttr("_lengths"), spec.children(2), ivalues);
itreeFlatten(obj.getAttr("_offsets"), spec.children(3), ivalues);
},
[](std::vector<c10::IValue> flats,
const nlohmann::json& obj) -> c10::IValue {
// Reconstruct JaggedTensor from flattened tensors
// This is a simplified reconstruction - in practice would need
// to call the actual JaggedTensor constructor
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(obj.is_null());
TORCH_CHECK(
flats.size() == 4, "JaggedTensor expects 4 tensor fields");
// Return a generic tuple for now - actual implementation would
// need to construct the JaggedTensor custom class
return c10::ivalue::Tuple::create(std::move(flats));
},
[](ITreeMapNoReturnFn fn,
const c10::IValue& nested,
const ITreeSpec& spec) {
TORCH_CHECK(nested.isObject(), "Expected JaggedTensor object");
const auto& obj = nested.toObjectRef();
TORCH_CHECK(
spec.children().size() == 4,
"JaggedTensor should have 4 children");
// Apply function to each tensor field
ivalueApply(fn, obj.getAttr("_values"), spec.children(0));
ivalueApply(fn, obj.getAttr("_weights"), spec.children(1));
ivalueApply(fn, obj.getAttr("_lengths"), spec.children(2));
ivalueApply(fn, obj.getAttr("_offsets"), spec.children(3));
}});
// Register KeyedJaggedTensor pytree node
registerNode(
"torchrec.sparse.jagged_tensor.KeyedJaggedTensor",
NodeDef{
[](const c10::IValue& nested,
const ITreeSpec& spec,
std::vector<c10::IValue>& ivalues) {
// KeyedJaggedTensor has 6 tensor fields plus keys context
// Fields: _values, _weights, _lengths, _offsets,
// _stride_per_key_per_rank, _inverse_indices tensor
TORCH_CHECK(
nested.isObject(), "Expected KeyedJaggedTensor object");
const auto& obj = nested.toObjectRef();
// Extract the tensor fields in order
TORCH_CHECK(
spec.children().size() == 6,
"KeyedJaggedTensor should have 6 children");
// Flatten each tensor field
itreeFlatten(obj.getAttr("_values"), spec.children(0), ivalues);
itreeFlatten(obj.getAttr("_weights"), spec.children(1), ivalues);
itreeFlatten(obj.getAttr("_lengths"), spec.children(2), ivalues);
itreeFlatten(obj.getAttr("_offsets"), spec.children(3), ivalues);
itreeFlatten(
obj.getAttr("_stride_per_key_per_rank"),
spec.children(4),
ivalues);
// For _inverse_indices, we need to extract the tensor part
// (second element of tuple)
auto inverse_indices = obj.getAttr("_inverse_indices");
if (!inverse_indices.isNone()) {
auto tuple = inverse_indices.toTuple();
itreeFlatten(tuple->elements()[1], spec.children(5), ivalues);
} else {
// Handle None case by adding a null tensor
itreeFlatten(c10::IValue(), spec.children(5), ivalues);
}
},
[](std::vector<c10::IValue> flats,
const nlohmann::json& obj) -> c10::IValue {
// Reconstruct KeyedJaggedTensor from flattened tensors and keys
// context
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!obj.is_null());
TORCH_CHECK(
flats.size() == 6,
"KeyedJaggedTensor expects 6 tensor fields");
// The context should contain the keys list
// Return a generic tuple for now - actual implementation would
// need to construct the KeyedJaggedTensor custom class
return c10::ivalue::Tuple::create(std::move(flats));
},
[](ITreeMapNoReturnFn fn,
const c10::IValue& nested,
const ITreeSpec& spec) {
TORCH_CHECK(
nested.isObject(), "Expected KeyedJaggedTensor object");
const auto& obj = nested.toObjectRef();
TORCH_CHECK(
spec.children().size() == 6,
"KeyedJaggedTensor should have 6 children");
// Apply function to each tensor field
ivalueApply(fn, obj.getAttr("_values"), spec.children(0));
ivalueApply(fn, obj.getAttr("_weights"), spec.children(1));
ivalueApply(fn, obj.getAttr("_lengths"), spec.children(2));
ivalueApply(fn, obj.getAttr("_offsets"), spec.children(3));
ivalueApply(
fn,
obj.getAttr("_stride_per_key_per_rank"),
spec.children(4));
// For _inverse_indices, we need to apply to the tensor part
// (second element of tuple)
auto inverse_indices = obj.getAttr("_inverse_indices");
if (!inverse_indices.isNone()) {
auto tuple = inverse_indices.toTuple();
ivalueApply(fn, tuple->elements()[1], spec.children(5));
} else {
// Handle None case
ivalueApply(fn, c10::IValue(), spec.children(5));
}
},
[](std::string_view context) {
// Context contains the keys list as JSON
return nlohmann::json::parse(context);
}});
} }
bool hasNodeDef(std::string_view typeName) const { bool hasNodeDef(std::string_view typeName) const {
return registry_.find(std::string{typeName}) != registry_.end(); return registry_.find(std::string{typeName}) != registry_.end();