mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
kjt pytree registration (#161114)
Differential Revision: D80656182 Pull Request resolved: https://github.com/pytorch/pytorch/pull/161114 Approved by: https://github.com/henryoier
This commit is contained in:
parent
49d30f9a23
commit
783985e9fe
|
|
@ -4,6 +4,7 @@
|
||||||
#include <fmt/format.h>
|
#include <fmt/format.h>
|
||||||
|
|
||||||
#include <c10/util/Enumerate.h>
|
#include <c10/util/Enumerate.h>
|
||||||
|
#include <torch/custom_class.h>
|
||||||
#include <torch/nativert/detail/ITree.h>
|
#include <torch/nativert/detail/ITree.h>
|
||||||
|
|
||||||
namespace torch::nativert::detail {
|
namespace torch::nativert::detail {
|
||||||
|
|
@ -1147,4 +1148,200 @@ TEST(ITreeTest, ToAtenType) {
|
||||||
c10::TypeKind::AnyType);
|
c10::TypeKind::AnyType);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
TEST(ITreeTest, KeyedJaggedTensorUnflatten) {
|
||||||
|
// Test KeyedJaggedTensor pytree node registration
|
||||||
|
// KeyedJaggedTensor has 6 tensor fields: _values, _weights, _lengths,
|
||||||
|
// _offsets, _stride_per_key_per_rank, _inverse_indices
|
||||||
|
auto jsonSpec = R"(
|
||||||
|
[
|
||||||
|
1,
|
||||||
|
{
|
||||||
|
"type": "torchrec.sparse.jagged_tensor.KeyedJaggedTensor",
|
||||||
|
"context": "[\"key1\", \"key2\"]",
|
||||||
|
"children_spec": [
|
||||||
|
{
|
||||||
|
"type": null,
|
||||||
|
"context": null,
|
||||||
|
"children_spec": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": null,
|
||||||
|
"context": null,
|
||||||
|
"children_spec": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": null,
|
||||||
|
"context": null,
|
||||||
|
"children_spec": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": null,
|
||||||
|
"context": null,
|
||||||
|
"children_spec": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": null,
|
||||||
|
"context": null,
|
||||||
|
"children_spec": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": null,
|
||||||
|
"context": null,
|
||||||
|
"children_spec": []
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)";
|
||||||
|
|
||||||
|
auto [graph, valuePtrs] = makeValues(6);
|
||||||
|
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
|
||||||
|
|
||||||
|
// Create mock tensor values for the 6 fields
|
||||||
|
std::vector<c10::IValue> flats = {
|
||||||
|
c10::IValue(1), // _values
|
||||||
|
c10::IValue(2), // _weights
|
||||||
|
c10::IValue(3), // _lengths
|
||||||
|
c10::IValue(4), // _offsets
|
||||||
|
c10::IValue(5), // _stride_per_key_per_rank
|
||||||
|
c10::IValue(6), // _inverse_indices tensor part
|
||||||
|
};
|
||||||
|
|
||||||
|
// Test unflatten - this will create a generic tuple since we don't have
|
||||||
|
// the actual KeyedJaggedTensor constructor available in tests
|
||||||
|
auto itree = itreeUnflatten(flats, spec);
|
||||||
|
EXPECT_TRUE(itree.isTuple());
|
||||||
|
EXPECT_EQ(itree.toTupleRef().elements().size(), 6);
|
||||||
|
|
||||||
|
// Verify the values match what we put in
|
||||||
|
for (size_t i = 0; i < 6; i++) {
|
||||||
|
EXPECT_EQ(itree.toTupleRef().elements()[i], flats[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify spec has correct number of children and structure
|
||||||
|
EXPECT_EQ(spec.children().size(), 6);
|
||||||
|
EXPECT_EQ(spec.numIValues(), 6);
|
||||||
|
EXPECT_FALSE(spec.isIValue());
|
||||||
|
EXPECT_EQ(
|
||||||
|
spec.uniformName(), "torchrec.sparse.jagged_tensor.KeyedJaggedTensor");
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ITreeTest, KeyedJaggedTensorNodeRegistration) {
|
||||||
|
// Test that KeyedJaggedTensor pytree node is properly registered
|
||||||
|
|
||||||
|
// Verify the KeyedJaggedTensor node is in the registry by attempting
|
||||||
|
// to load a spec that references it
|
||||||
|
auto jsonSpec = R"(
|
||||||
|
[
|
||||||
|
1,
|
||||||
|
{
|
||||||
|
"type": "torchrec.sparse.jagged_tensor.KeyedJaggedTensor",
|
||||||
|
"context": "[\"key1\", \"key2\"]",
|
||||||
|
"children_spec": [
|
||||||
|
{
|
||||||
|
"type": null,
|
||||||
|
"context": null,
|
||||||
|
"children_spec": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": null,
|
||||||
|
"context": null,
|
||||||
|
"children_spec": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": null,
|
||||||
|
"context": null,
|
||||||
|
"children_spec": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": null,
|
||||||
|
"context": null,
|
||||||
|
"children_spec": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": null,
|
||||||
|
"context": null,
|
||||||
|
"children_spec": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": null,
|
||||||
|
"context": null,
|
||||||
|
"children_spec": []
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)";
|
||||||
|
|
||||||
|
auto [graph, valuePtrs] = makeValues(6);
|
||||||
|
|
||||||
|
// This should not throw - if KeyedJaggedTensor wasn't registered,
|
||||||
|
// we'd get an exception about "Unknown pytree node type"
|
||||||
|
EXPECT_NO_THROW({
|
||||||
|
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
|
||||||
|
|
||||||
|
// Verify the spec loaded correctly
|
||||||
|
EXPECT_FALSE(spec.isIValue());
|
||||||
|
EXPECT_EQ(
|
||||||
|
spec.uniformName(), "torchrec.sparse.jagged_tensor.KeyedJaggedTensor");
|
||||||
|
EXPECT_EQ(spec.children().size(), 6);
|
||||||
|
EXPECT_EQ(spec.numIValues(), 6);
|
||||||
|
|
||||||
|
// Verify context is parsed correctly
|
||||||
|
EXPECT_FALSE(spec.context().is_null());
|
||||||
|
EXPECT_TRUE(spec.context().is_array());
|
||||||
|
EXPECT_EQ(spec.context().size(), 2);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(ITreeTest, JaggedTensorNodeRegistration) {
|
||||||
|
// Test that JaggedTensor pytree node is also properly registered
|
||||||
|
|
||||||
|
auto jsonSpec = R"(
|
||||||
|
[
|
||||||
|
1,
|
||||||
|
{
|
||||||
|
"type": "torchrec.sparse.jagged_tensor.JaggedTensor",
|
||||||
|
"context": "null",
|
||||||
|
"children_spec": [
|
||||||
|
{
|
||||||
|
"type": null,
|
||||||
|
"context": null,
|
||||||
|
"children_spec": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": null,
|
||||||
|
"context": null,
|
||||||
|
"children_spec": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": null,
|
||||||
|
"context": null,
|
||||||
|
"children_spec": []
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": null,
|
||||||
|
"context": null,
|
||||||
|
"children_spec": []
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
)";
|
||||||
|
|
||||||
|
auto [graph, valuePtrs] = makeValues(4);
|
||||||
|
|
||||||
|
// This should not throw - if JaggedTensor wasn't registered,
|
||||||
|
// we'd get an exception about "Unknown pytree node type"
|
||||||
|
EXPECT_NO_THROW({
|
||||||
|
const auto spec = itreeSpecLoads(jsonSpec, valuePtrs);
|
||||||
|
|
||||||
|
// Verify the spec loaded correctly
|
||||||
|
EXPECT_FALSE(spec.isIValue());
|
||||||
|
EXPECT_EQ(spec.uniformName(), "torchrec.sparse.jagged_tensor.JaggedTensor");
|
||||||
|
EXPECT_EQ(spec.children().size(), 4);
|
||||||
|
EXPECT_EQ(spec.numIValues(), 4);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
} // namespace torch::nativert::detail
|
} // namespace torch::nativert::detail
|
||||||
|
|
|
||||||
|
|
@ -172,6 +172,148 @@ class PytreeNodeRegistry {
|
||||||
registerNode(
|
registerNode(
|
||||||
"torch.fx.immutable_collections.immutable_dict",
|
"torch.fx.immutable_collections.immutable_dict",
|
||||||
getNodeDef("builtins.dict"));
|
getNodeDef("builtins.dict"));
|
||||||
|
// Register JaggedTensor pytree node
|
||||||
|
registerNode(
|
||||||
|
"torchrec.sparse.jagged_tensor.JaggedTensor",
|
||||||
|
NodeDef{
|
||||||
|
[](const c10::IValue& nested,
|
||||||
|
const ITreeSpec& spec,
|
||||||
|
std::vector<c10::IValue>& ivalues) {
|
||||||
|
// JaggedTensor has 4 fields: _values, _weights, _lengths,
|
||||||
|
// _offsets All fields are optional torch.Tensor except _values
|
||||||
|
TORCH_CHECK(nested.isObject(), "Expected JaggedTensor object");
|
||||||
|
const auto& obj = nested.toObjectRef();
|
||||||
|
|
||||||
|
// Extract the tensor fields in order: _values, _weights,
|
||||||
|
// _lengths, _offsets
|
||||||
|
TORCH_CHECK(
|
||||||
|
spec.children().size() == 4,
|
||||||
|
"JaggedTensor should have 4 children");
|
||||||
|
|
||||||
|
// Flatten each tensor field
|
||||||
|
itreeFlatten(obj.getAttr("_values"), spec.children(0), ivalues);
|
||||||
|
itreeFlatten(obj.getAttr("_weights"), spec.children(1), ivalues);
|
||||||
|
itreeFlatten(obj.getAttr("_lengths"), spec.children(2), ivalues);
|
||||||
|
itreeFlatten(obj.getAttr("_offsets"), spec.children(3), ivalues);
|
||||||
|
},
|
||||||
|
[](std::vector<c10::IValue> flats,
|
||||||
|
const nlohmann::json& obj) -> c10::IValue {
|
||||||
|
// Reconstruct JaggedTensor from flattened tensors
|
||||||
|
// This is a simplified reconstruction - in practice would need
|
||||||
|
// to call the actual JaggedTensor constructor
|
||||||
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(obj.is_null());
|
||||||
|
TORCH_CHECK(
|
||||||
|
flats.size() == 4, "JaggedTensor expects 4 tensor fields");
|
||||||
|
|
||||||
|
// Return a generic tuple for now - actual implementation would
|
||||||
|
// need to construct the JaggedTensor custom class
|
||||||
|
return c10::ivalue::Tuple::create(std::move(flats));
|
||||||
|
},
|
||||||
|
[](ITreeMapNoReturnFn fn,
|
||||||
|
const c10::IValue& nested,
|
||||||
|
const ITreeSpec& spec) {
|
||||||
|
TORCH_CHECK(nested.isObject(), "Expected JaggedTensor object");
|
||||||
|
const auto& obj = nested.toObjectRef();
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
spec.children().size() == 4,
|
||||||
|
"JaggedTensor should have 4 children");
|
||||||
|
|
||||||
|
// Apply function to each tensor field
|
||||||
|
ivalueApply(fn, obj.getAttr("_values"), spec.children(0));
|
||||||
|
ivalueApply(fn, obj.getAttr("_weights"), spec.children(1));
|
||||||
|
ivalueApply(fn, obj.getAttr("_lengths"), spec.children(2));
|
||||||
|
ivalueApply(fn, obj.getAttr("_offsets"), spec.children(3));
|
||||||
|
}});
|
||||||
|
|
||||||
|
// Register KeyedJaggedTensor pytree node
|
||||||
|
registerNode(
|
||||||
|
"torchrec.sparse.jagged_tensor.KeyedJaggedTensor",
|
||||||
|
NodeDef{
|
||||||
|
[](const c10::IValue& nested,
|
||||||
|
const ITreeSpec& spec,
|
||||||
|
std::vector<c10::IValue>& ivalues) {
|
||||||
|
// KeyedJaggedTensor has 6 tensor fields plus keys context
|
||||||
|
// Fields: _values, _weights, _lengths, _offsets,
|
||||||
|
// _stride_per_key_per_rank, _inverse_indices tensor
|
||||||
|
TORCH_CHECK(
|
||||||
|
nested.isObject(), "Expected KeyedJaggedTensor object");
|
||||||
|
const auto& obj = nested.toObjectRef();
|
||||||
|
|
||||||
|
// Extract the tensor fields in order
|
||||||
|
TORCH_CHECK(
|
||||||
|
spec.children().size() == 6,
|
||||||
|
"KeyedJaggedTensor should have 6 children");
|
||||||
|
|
||||||
|
// Flatten each tensor field
|
||||||
|
itreeFlatten(obj.getAttr("_values"), spec.children(0), ivalues);
|
||||||
|
itreeFlatten(obj.getAttr("_weights"), spec.children(1), ivalues);
|
||||||
|
itreeFlatten(obj.getAttr("_lengths"), spec.children(2), ivalues);
|
||||||
|
itreeFlatten(obj.getAttr("_offsets"), spec.children(3), ivalues);
|
||||||
|
itreeFlatten(
|
||||||
|
obj.getAttr("_stride_per_key_per_rank"),
|
||||||
|
spec.children(4),
|
||||||
|
ivalues);
|
||||||
|
// For _inverse_indices, we need to extract the tensor part
|
||||||
|
// (second element of tuple)
|
||||||
|
auto inverse_indices = obj.getAttr("_inverse_indices");
|
||||||
|
if (!inverse_indices.isNone()) {
|
||||||
|
auto tuple = inverse_indices.toTuple();
|
||||||
|
itreeFlatten(tuple->elements()[1], spec.children(5), ivalues);
|
||||||
|
} else {
|
||||||
|
// Handle None case by adding a null tensor
|
||||||
|
itreeFlatten(c10::IValue(), spec.children(5), ivalues);
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[](std::vector<c10::IValue> flats,
|
||||||
|
const nlohmann::json& obj) -> c10::IValue {
|
||||||
|
// Reconstruct KeyedJaggedTensor from flattened tensors and keys
|
||||||
|
// context
|
||||||
|
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(!obj.is_null());
|
||||||
|
TORCH_CHECK(
|
||||||
|
flats.size() == 6,
|
||||||
|
"KeyedJaggedTensor expects 6 tensor fields");
|
||||||
|
|
||||||
|
// The context should contain the keys list
|
||||||
|
// Return a generic tuple for now - actual implementation would
|
||||||
|
// need to construct the KeyedJaggedTensor custom class
|
||||||
|
return c10::ivalue::Tuple::create(std::move(flats));
|
||||||
|
},
|
||||||
|
[](ITreeMapNoReturnFn fn,
|
||||||
|
const c10::IValue& nested,
|
||||||
|
const ITreeSpec& spec) {
|
||||||
|
TORCH_CHECK(
|
||||||
|
nested.isObject(), "Expected KeyedJaggedTensor object");
|
||||||
|
const auto& obj = nested.toObjectRef();
|
||||||
|
|
||||||
|
TORCH_CHECK(
|
||||||
|
spec.children().size() == 6,
|
||||||
|
"KeyedJaggedTensor should have 6 children");
|
||||||
|
|
||||||
|
// Apply function to each tensor field
|
||||||
|
ivalueApply(fn, obj.getAttr("_values"), spec.children(0));
|
||||||
|
ivalueApply(fn, obj.getAttr("_weights"), spec.children(1));
|
||||||
|
ivalueApply(fn, obj.getAttr("_lengths"), spec.children(2));
|
||||||
|
ivalueApply(fn, obj.getAttr("_offsets"), spec.children(3));
|
||||||
|
ivalueApply(
|
||||||
|
fn,
|
||||||
|
obj.getAttr("_stride_per_key_per_rank"),
|
||||||
|
spec.children(4));
|
||||||
|
// For _inverse_indices, we need to apply to the tensor part
|
||||||
|
// (second element of tuple)
|
||||||
|
auto inverse_indices = obj.getAttr("_inverse_indices");
|
||||||
|
if (!inverse_indices.isNone()) {
|
||||||
|
auto tuple = inverse_indices.toTuple();
|
||||||
|
ivalueApply(fn, tuple->elements()[1], spec.children(5));
|
||||||
|
} else {
|
||||||
|
// Handle None case
|
||||||
|
ivalueApply(fn, c10::IValue(), spec.children(5));
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[](std::string_view context) {
|
||||||
|
// Context contains the keys list as JSON
|
||||||
|
return nlohmann::json::parse(context);
|
||||||
|
}});
|
||||||
}
|
}
|
||||||
bool hasNodeDef(std::string_view typeName) const {
|
bool hasNodeDef(std::string_view typeName) const {
|
||||||
return registry_.find(std::string{typeName}) != registry_.end();
|
return registry_.find(std::string{typeName}) != registry_.end();
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user