mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
kjt pytree registration (#161114)
Differential Revision: D80656182 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161114 Approved by: https://github.com/henryoier
This commit is contained in:
parent
49d30f9a23
commit
783985e9fe
|
|
@ -4,6 +4,7 @@
|
|||
#include <fmt/format.h>
|
||||
|
||||
#include <c10/util/Enumerate.h>
|
||||
#include <torch/custom_class.h>
|
||||
#include <torch/nativert/detail/ITree.h>
|
||||
|
||||
namespace torch::nativert::detail {
|
||||
|
|
@ -1147,4 +1148,200 @@ TEST(ITreeTest, ToAtenType) {
|
|||
c10::TypeKind::AnyType);
|
||||
}
|
||||
|
||||
TEST(ITreeTest, KeyedJaggedTensorUnflatten) {
|
||||
// Test KeyedJaggedTensor pytree node registration
|
||||
// KeyedJaggedTensor has 6 tensor fields: _values, _weights, _lengths,
|
||||
// _offsets, _stride_per_key_per_rank, _inverse_indices
|
||||
auto jsonSpec = R"(
|
||||
[
|
||||
1,
|
||||
{
|
||||
"type": "torchrec.sparse.jagged_tensor.KeyedJaggedTensor",
|
||||
"context": "[\"key1\", \"key2\"]",
|
||||
"children_spec": [
|
||||
{
|
||||
"type": null,
|
||||
"context": null,
|
||||
"children_spec": []
|
||||
},
|
||||
{
|
||||
"type": null,
|
||||
"context": null,
|
||||
"children_spec": []
|
||||
},
|
||||
{
|
||||
"type": null,
|
||||
"context": null,
|
||||
"children_spec": []
|
||||
},
|
||||
{
|
||||
"type": null,
|
||||
"context": null,
|
||||
"children_spec": []
|
||||
},
|
||||
{
|
||||
"type": null,
|
||||
"context": null,
|
||||
"children_spec": []
|
||||
},
|
||||
{
|
||||
"type": null,
|
||||
"context": null,
|
||||
"children_spec": []
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
)";
|
||||
|
||||
auto [graph, valuePtrs] = makeValues(6);
|
||||
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
|
||||
|
||||
// Create mock tensor values for the 6 fields
|
||||
std::vector<c10::IValue> flats = {
|
||||
c10::IValue(1), // _values
|
||||
c10::IValue(2), // _weights
|
||||
c10::IValue(3), // _lengths
|
||||
c10::IValue(4), // _offsets
|
||||
c10::IValue(5), // _stride_per_key_per_rank
|
||||
c10::IValue(6), // _inverse_indices tensor part
|
||||
};
|
||||
|
||||
// Test unflatten - this will create a generic tuple since we don't have
|
||||
// the actual KeyedJaggedTensor constructor available in tests
|
||||
auto itree = itreeUnflatten(flats, spec);
|
||||
EXPECT_TRUE(itree.isTuple());
|
||||
EXPECT_EQ(itree.toTupleRef().elements().size(), 6);
|
||||
|
||||
// Verify the values match what we put in
|
||||
for (size_t i = 0; i < 6; i++) {
|
||||
EXPECT_EQ(itree.toTupleRef().elements()[i], flats[i]);
|
||||
}
|
||||
|
||||
// Verify spec has correct number of children and structure
|
||||
EXPECT_EQ(spec.children().size(), 6);
|
||||
EXPECT_EQ(spec.numIValues(), 6);
|
||||
EXPECT_FALSE(spec.isIValue());
|
||||
EXPECT_EQ(
|
||||
spec.uniformName(), "torchrec.sparse.jagged_tensor.KeyedJaggedTensor");
|
||||
}
|
||||
|
||||
TEST(ITreeTest, KeyedJaggedTensorNodeRegistration) {
|
||||
// Test that KeyedJaggedTensor pytree node is properly registered
|
||||
|
||||
// Verify the KeyedJaggedTensor node is in the registry by attempting
|
||||
// to load a spec that references it
|
||||
auto jsonSpec = R"(
|
||||
[
|
||||
1,
|
||||
{
|
||||
"type": "torchrec.sparse.jagged_tensor.KeyedJaggedTensor",
|
||||
"context": "[\"key1\", \"key2\"]",
|
||||
"children_spec": [
|
||||
{
|
||||
"type": null,
|
||||
"context": null,
|
||||
"children_spec": []
|
||||
},
|
||||
{
|
||||
"type": null,
|
||||
"context": null,
|
||||
"children_spec": []
|
||||
},
|
||||
{
|
||||
"type": null,
|
||||
"context": null,
|
||||
"children_spec": []
|
||||
},
|
||||
{
|
||||
"type": null,
|
||||
"context": null,
|
||||
"children_spec": []
|
||||
},
|
||||
{
|
||||
"type": null,
|
||||
"context": null,
|
||||
"children_spec": []
|
||||
},
|
||||
{
|
||||
"type": null,
|
||||
"context": null,
|
||||
"children_spec": []
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
)";
|
||||
|
||||
auto [graph, valuePtrs] = makeValues(6);
|
||||
|
||||
// This should not throw - if KeyedJaggedTensor wasn't registered,
|
||||
// we'd get an exception about "Unknown pytree node type"
|
||||
EXPECT_NO_THROW({
|
||||
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
|
||||
|
||||
// Verify the spec loaded correctly
|
||||
EXPECT_FALSE(spec.isIValue());
|
||||
EXPECT_EQ(
|
||||
spec.uniformName(), "torchrec.sparse.jagged_tensor.KeyedJaggedTensor");
|
||||
EXPECT_EQ(spec.children().size(), 6);
|
||||
EXPECT_EQ(spec.numIValues(), 6);
|
||||
|
||||
// Verify context is parsed correctly
|
||||
EXPECT_FALSE(spec.context().is_null());
|
||||
EXPECT_TRUE(spec.context().is_array());
|
||||
EXPECT_EQ(spec.context().size(), 2);
|
||||
});
|
||||
}
|
||||
|
||||
TEST(ITreeTest, JaggedTensorNodeRegistration) {
|
||||
// Test that JaggedTensor pytree node is also properly registered
|
||||
|
||||
auto jsonSpec = R"(
|
||||
[
|
||||
1,
|
||||
{
|
||||
"type": "torchrec.sparse.jagged_tensor.JaggedTensor",
|
||||
"context": "null",
|
||||
"children_spec": [
|
||||
{
|
||||
"type": null,
|
||||
"context": null,
|
||||
"children_spec": []
|
||||
},
|
||||
{
|
||||
"type": null,
|
||||
"context": null,
|
||||
"children_spec": []
|
||||
},
|
||||
{
|
||||
"type": null,
|
||||
"context": null,
|
||||
"children_spec": []
|
||||
},
|
||||
{
|
||||
"type": null,
|
||||
"context": null,
|
||||
"children_spec": []
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
)";
|
||||
|
||||
auto [graph, valuePtrs] = makeValues(4);
|
||||
|
||||
// This should not throw - if JaggedTensor wasn't registered,
|
||||
// we'd get an exception about "Unknown pytree node type"
|
||||
EXPECT_NO_THROW({
|
||||
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
|
||||
|
||||
// Verify the spec loaded correctly
|
||||
EXPECT_FALSE(spec.isIValue());
|
||||
EXPECT_EQ(spec.uniformName(), "torchrec.sparse.jagged_tensor.JaggedTensor");
|
||||
EXPECT_EQ(spec.children().size(), 4);
|
||||
EXPECT_EQ(spec.numIValues(), 4);
|
||||
});
|
||||
}
|
||||
|
||||
} // namespace torch::nativert::detail
|
||||
|
|
|
|||
|
|
@ -172,6 +172,148 @@ class PytreeNodeRegistry {
|
|||
registerNode(
|
||||
"torch.fx.immutable_collections.immutable_dict",
|
||||
getNodeDef("builtins.dict"));
|
||||
// Register JaggedTensor pytree node
|
||||
registerNode(
|
||||
"torchrec.sparse.jagged_tensor.JaggedTensor",
|
||||
NodeDef{
|
||||
[](const c10::IValue& nested,
|
||||
const ITreeSpec& spec,
|
||||
std::vector<c10::IValue>& ivalues) {
|
||||
// JaggedTensor has 4 fields: _values, _weights, _lengths,
|
||||
// _offsets All fields are optional torch.Tensor except _values
|
||||
TORCH_CHECK(nested.isObject(), "Expected JaggedTensor object");
|
||||
const auto& obj = nested.toObjectRef();
|
||||
|
||||
// Extract the tensor fields in order: _values, _weights,
|
||||
// _lengths, _offsets
|
||||
TORCH_CHECK(
|
||||
spec.children().size() == 4,
|
||||
"JaggedTensor should have 4 children");
|
||||
|
||||
// Flatten each tensor field
|
||||
itreeFlatten(obj.getAttr("_values"), spec.children(0), ivalues);
|
||||
itreeFlatten(obj.getAttr("_weights"), spec.children(1), ivalues);
|
||||
itreeFlatten(obj.getAttr("_lengths"), spec.children(2), ivalues);
|
||||
itreeFlatten(obj.getAttr("_offsets"), spec.children(3), ivalues);
|
||||
},
|
||||
[](std::vector<c10::IValue> flats,
|
||||
const nlohmann::json& obj) -> c10::IValue {
|
||||
// Reconstruct JaggedTensor from flattened tensors
|
||||
// This is a simplified reconstruction - in practice would need
|
||||
// to call the actual JaggedTensor constructor
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(obj.is_null());
|
||||
TORCH_CHECK(
|
||||
flats.size() == 4, "JaggedTensor expects 4 tensor fields");
|
||||
|
||||
// Return a generic tuple for now - actual implementation would
|
||||
// need to construct the JaggedTensor custom class
|
||||
return c10::ivalue::Tuple::create(std::move(flats));
|
||||
},
|
||||
[](ITreeMapNoReturnFn fn,
|
||||
const c10::IValue& nested,
|
||||
const ITreeSpec& spec) {
|
||||
TORCH_CHECK(nested.isObject(), "Expected JaggedTensor object");
|
||||
const auto& obj = nested.toObjectRef();
|
||||
|
||||
TORCH_CHECK(
|
||||
spec.children().size() == 4,
|
||||
"JaggedTensor should have 4 children");
|
||||
|
||||
// Apply function to each tensor field
|
||||
ivalueApply(fn, obj.getAttr("_values"), spec.children(0));
|
||||
ivalueApply(fn, obj.getAttr("_weights"), spec.children(1));
|
||||
ivalueApply(fn, obj.getAttr("_lengths"), spec.children(2));
|
||||
ivalueApply(fn, obj.getAttr("_offsets"), spec.children(3));
|
||||
}});
|
||||
|
||||
// Register KeyedJaggedTensor pytree node
|
||||
registerNode(
|
||||
"torchrec.sparse.jagged_tensor.KeyedJaggedTensor",
|
||||
NodeDef{
|
||||
[](const c10::IValue& nested,
|
||||
const ITreeSpec& spec,
|
||||
std::vector<c10::IValue>& ivalues) {
|
||||
// KeyedJaggedTensor has 6 tensor fields plus keys context
|
||||
// Fields: _values, _weights, _lengths, _offsets,
|
||||
// _stride_per_key_per_rank, _inverse_indices tensor
|
||||
TORCH_CHECK(
|
||||
nested.isObject(), "Expected KeyedJaggedTensor object");
|
||||
const auto& obj = nested.toObjectRef();
|
||||
|
||||
// Extract the tensor fields in order
|
||||
TORCH_CHECK(
|
||||
spec.children().size() == 6,
|
||||
"KeyedJaggedTensor should have 6 children");
|
||||
|
||||
// Flatten each tensor field
|
||||
itreeFlatten(obj.getAttr("_values"), spec.children(0), ivalues);
|
||||
itreeFlatten(obj.getAttr("_weights"), spec.children(1), ivalues);
|
||||
itreeFlatten(obj.getAttr("_lengths"), spec.children(2), ivalues);
|
||||
itreeFlatten(obj.getAttr("_offsets"), spec.children(3), ivalues);
|
||||
itreeFlatten(
|
||||
obj.getAttr("_stride_per_key_per_rank"),
|
||||
spec.children(4),
|
||||
ivalues);
|
||||
// For _inverse_indices, we need to extract the tensor part
|
||||
// (second element of tuple)
|
||||
auto inverse_indices = obj.getAttr("_inverse_indices");
|
||||
if (!inverse_indices.isNone()) {
|
||||
auto tuple = inverse_indices.toTuple();
|
||||
itreeFlatten(tuple->elements()[1], spec.children(5), ivalues);
|
||||
} else {
|
||||
// Handle None case by adding a null tensor
|
||||
itreeFlatten(c10::IValue(), spec.children(5), ivalues);
|
||||
}
|
||||
},
|
||||
[](std::vector<c10::IValue> flats,
|
||||
const nlohmann::json& obj) -> c10::IValue {
|
||||
// Reconstruct KeyedJaggedTensor from flattened tensors and keys
|
||||
// context
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!obj.is_null());
|
||||
TORCH_CHECK(
|
||||
flats.size() == 6,
|
||||
"KeyedJaggedTensor expects 6 tensor fields");
|
||||
|
||||
// The context should contain the keys list
|
||||
// Return a generic tuple for now - actual implementation would
|
||||
// need to construct the KeyedJaggedTensor custom class
|
||||
return c10::ivalue::Tuple::create(std::move(flats));
|
||||
},
|
||||
[](ITreeMapNoReturnFn fn,
|
||||
const c10::IValue& nested,
|
||||
const ITreeSpec& spec) {
|
||||
TORCH_CHECK(
|
||||
nested.isObject(), "Expected KeyedJaggedTensor object");
|
||||
const auto& obj = nested.toObjectRef();
|
||||
|
||||
TORCH_CHECK(
|
||||
spec.children().size() == 6,
|
||||
"KeyedJaggedTensor should have 6 children");
|
||||
|
||||
// Apply function to each tensor field
|
||||
ivalueApply(fn, obj.getAttr("_values"), spec.children(0));
|
||||
ivalueApply(fn, obj.getAttr("_weights"), spec.children(1));
|
||||
ivalueApply(fn, obj.getAttr("_lengths"), spec.children(2));
|
||||
ivalueApply(fn, obj.getAttr("_offsets"), spec.children(3));
|
||||
ivalueApply(
|
||||
fn,
|
||||
obj.getAttr("_stride_per_key_per_rank"),
|
||||
spec.children(4));
|
||||
// For _inverse_indices, we need to apply to the tensor part
|
||||
// (second element of tuple)
|
||||
auto inverse_indices = obj.getAttr("_inverse_indices");
|
||||
if (!inverse_indices.isNone()) {
|
||||
auto tuple = inverse_indices.toTuple();
|
||||
ivalueApply(fn, tuple->elements()[1], spec.children(5));
|
||||
} else {
|
||||
// Handle None case
|
||||
ivalueApply(fn, c10::IValue(), spec.children(5));
|
||||
}
|
||||
},
|
||||
[](std::string_view context) {
|
||||
// Context contains the keys list as JSON
|
||||
return nlohmann::json::parse(context);
|
||||
}});
|
||||
}
|
||||
bool hasNodeDef(std::string_view typeName) const {
|
||||
return registry_.find(std::string{typeName}) != registry_.end();
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user