mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: The basic concept behind this diff is to modify Dynamo's tracing behavior when it encounters a KeyedJaggedTensor that is synced (aka has `_length_per_key` and `_offset_per_key` populated). These fields are lists of integers; ordinarily, Dynamo will optimistically try to specialize on integers, however, for KJTs, we know that these integers will definitely vary from run-to-run. Furthermore, ordinarily, we would also specialize these integers if they are 0/1, but we will frequently expect features in KJTs to be 0/1. The fix is to detect KJTs and treat these integers as *unbacked integers*. This is NOT a universally sound optimization: when treating these integers as unbacked, we never report them as equal to zero or one. In return, we always generate graphs that generalize no matter the length of values on features. This is enough to trace through APS sparse arch, torchrec_dlrm and some small split-cat examples. The special integer behavior is triggered by a dynamically scoped `force_unspec_int_unbacked_size_like` variable on TracingContext, which we trigger when we wrap a KJT. There probably are other ways to do this, but this was simple and worked. Test Plan: ``` buck2 test mode/dev-nosan //pytorch/benchmark/fb/test_gpu:run_test_gpu ``` from aakhundov 1. first build feed_lower_benchmark: ``` buck2 build --show-output mode/opt -c python.package_style=inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.split-dwarf=true hpc/new/models/feed/benchmark:feed_lower_benchmark ``` 2. then run the lowering of the model with it: ``` TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 TORCH_LOGS="output_code,graph_code" TORCH_COMPILE_DEBUG=1 ../buck-out/v2/gen/fbcode/79c6b019ee0f9469/hpc/new/models/feed/benchmark/__feed_lower_benchmark__/feed_lower_benchmark.par --load=manifold://ig_inference_model/tree/user/facebook/fblearner/predictor/960999465/60/gpu_lowering/input.predictor --skip-trt --skip-ait --sync-mode=0 --enable-aot-inductor --lower-presets="ig_stories" --gpu-trace ``` cf https://docs.google.com/document/d/1yD30xYrdmM8r2HTdmXnZTg0-MHVexfVrAa0294m1AUE/edit?pli=1#heading=h.qiv3fp7e6zg0 From torchrec: https://www.internalfb.com/intern/wiki/Torchrec/Development/Testing_production_models/ From ge0405 baseline (without your diff): f477293168 your diff: f477292363 ``` buck2 test //caffe2/test/dynamo:test_dynamo_torchrec buck2 run 'fbcode//mode/opt' fbcode//pytorch/benchmark/fb/test_gpu:run_test_gpu -- 'pytorch.benchmark.fb.test_gpu.test_gpu.TestBenchmarkFbGpu.test_train_blue_reels_vdd_v3_inductor_speedup' ``` Differential Revision: D49236757 Pull Request resolved: https://github.com/pytorch/pytorch/pull/109216 Approved by: https://github.com/voznesenskym
207 lines
6.6 KiB
Python
207 lines
6.6 KiB
Python
# Owner(s): ["module: dynamo"]
|
|
import sys
|
|
import unittest
|
|
from typing import Dict, List
|
|
|
|
import torch
|
|
|
|
import torch._dynamo.config
|
|
import torch._dynamo.test_case
|
|
from torch import nn
|
|
from torch._dynamo.test_case import TestCase
|
|
from torch._dynamo.testing import CompileCounter
|
|
from torch.testing._internal.common_utils import NoTest
|
|
|
|
try:
|
|
from torchrec.datasets.random import RandomRecDataset
|
|
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
|
|
|
|
HAS_TORCHREC = True
|
|
except ImportError:
|
|
HAS_TORCHREC = False
|
|
|
|
|
|
@torch._dynamo.config.patch(force_unspec_int_unbacked_size_like_on_torchrec_kjt=True)
|
|
class BucketizeMod(torch.nn.Module):
|
|
def __init__(self, feature_boundaries: Dict[str, List[float]]):
|
|
super().__init__()
|
|
self.bucket_w = torch.nn.ParameterDict()
|
|
self.boundaries_dict = {}
|
|
for key, boundaries in feature_boundaries.items():
|
|
self.bucket_w[key] = torch.nn.Parameter(
|
|
torch.empty([len(boundaries) + 1]).fill_(1.0),
|
|
requires_grad=True,
|
|
)
|
|
buf = torch.tensor(boundaries, requires_grad=False)
|
|
self.register_buffer(
|
|
f"{key}_boundaries",
|
|
buf,
|
|
persistent=False,
|
|
)
|
|
self.boundaries_dict[key] = buf
|
|
|
|
def forward(self, features: "KeyedJaggedTensor") -> "KeyedJaggedTensor":
|
|
weights_list = []
|
|
for key, boundaries in self.boundaries_dict.items():
|
|
jt = features[key]
|
|
bucketized = torch.bucketize(jt.weights(), boundaries)
|
|
# doesn't super matter I guess
|
|
# hashed = torch.ops.fb.index_hash(bucketized, seed=0, modulo=len(boundaries))
|
|
hashed = bucketized
|
|
weights = torch.gather(self.bucket_w[key], dim=0, index=hashed)
|
|
weights_list.append(weights)
|
|
return KeyedJaggedTensor(
|
|
keys=features.keys(),
|
|
values=features.values(),
|
|
weights=torch.cat(weights_list),
|
|
lengths=features.lengths(),
|
|
offsets=features.offsets(),
|
|
stride=features.stride(),
|
|
length_per_key=features.length_per_key(),
|
|
)
|
|
|
|
|
|
if not HAS_TORCHREC:
|
|
print("torchrec not available, skipping tests", file=sys.stderr)
|
|
TestCase = NoTest # noqa: F811
|
|
|
|
|
|
@unittest.skipIf(not HAS_TORCHREC, "these tests require torchrec")
|
|
class TorchRecTests(TestCase):
|
|
def test_pooled(self):
|
|
tables = [
|
|
(nn.EmbeddingBag(2000, 8), ["a0", "b0"]),
|
|
(nn.EmbeddingBag(2000, 8), ["a1", "b1"]),
|
|
(nn.EmbeddingBag(2000, 8), ["b2"]),
|
|
]
|
|
|
|
embedding_groups = {
|
|
"a": ["a0", "a1"],
|
|
"b": ["b0", "b1", "b2"],
|
|
}
|
|
|
|
counter = CompileCounter()
|
|
|
|
@torch.compile(backend=counter, fullgraph=True, dynamic=True)
|
|
def f(id_list_features: KeyedJaggedTensor):
|
|
id_list_jt_dict: Dict[str, JaggedTensor] = id_list_features.to_dict()
|
|
pooled_embeddings = {}
|
|
# TODO: run feature processor
|
|
for emb_module, feature_names in tables:
|
|
features_dict = id_list_jt_dict
|
|
for feature_name in feature_names:
|
|
f = features_dict[feature_name]
|
|
pooled_embeddings[feature_name] = emb_module(
|
|
f.values(), f.offsets()
|
|
)
|
|
|
|
pooled_embeddings_by_group = {}
|
|
for group_name, group_embedding_names in embedding_groups.items():
|
|
group_embeddings = [
|
|
pooled_embeddings[name] for name in group_embedding_names
|
|
]
|
|
pooled_embeddings_by_group[group_name] = torch.cat(
|
|
group_embeddings, dim=1
|
|
)
|
|
|
|
return pooled_embeddings_by_group
|
|
|
|
dataset = RandomRecDataset(
|
|
keys=["a0", "a1", "b0", "b1", "b2"],
|
|
batch_size=4,
|
|
hash_size=2000,
|
|
ids_per_feature=3,
|
|
num_dense=0,
|
|
)
|
|
di = iter(dataset)
|
|
|
|
# unsync should work
|
|
|
|
d1 = next(di).sparse_features.unsync()
|
|
d2 = next(di).sparse_features.unsync()
|
|
d3 = next(di).sparse_features.unsync()
|
|
|
|
r1 = f(d1)
|
|
r2 = f(d2)
|
|
r3 = f(d3)
|
|
|
|
self.assertEqual(counter.frame_count, 1)
|
|
counter.frame_count = 0
|
|
|
|
# sync should work too
|
|
|
|
d1 = next(di).sparse_features.sync()
|
|
d2 = next(di).sparse_features.sync()
|
|
d3 = next(di).sparse_features.sync()
|
|
|
|
r1 = f(d1)
|
|
r2 = f(d2)
|
|
r3 = f(d3)
|
|
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
# export only works with unsync
|
|
|
|
gm = torch._dynamo.export(f)(next(di).sparse_features.unsync()).graph_module
|
|
gm.print_readable()
|
|
|
|
self.assertEqual(gm(d1), r1)
|
|
self.assertEqual(gm(d2), r2)
|
|
self.assertEqual(gm(d3), r3)
|
|
|
|
def test_bucketize(self):
|
|
mod = BucketizeMod({"f1": [0.0, 0.5, 1.0]})
|
|
features = KeyedJaggedTensor.from_lengths_sync(
|
|
keys=["f1"],
|
|
values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
|
|
lengths=torch.tensor([2, 0, 1, 1, 1, 3]),
|
|
weights=torch.tensor([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]),
|
|
).unsync()
|
|
|
|
def f(x):
|
|
# This is a trick to populate the computed cache and instruct
|
|
# ShapeEnv that they're all sizey
|
|
x.to_dict()
|
|
return mod(x)
|
|
|
|
torch._dynamo.export(f, aten_graph=True)(features).graph_module.print_readable()
|
|
|
|
@unittest.expectedFailure
|
|
def test_simple(self):
|
|
jag_tensor1 = KeyedJaggedTensor(
|
|
values=torch.tensor([3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
|
|
keys=["index_0", "index_1"],
|
|
lengths=torch.tensor([0, 0, 1, 1, 1, 3]),
|
|
).sync()
|
|
|
|
# ordinarily, this would trigger one specialization
|
|
self.assertEqual(jag_tensor1.length_per_key(), [1, 5])
|
|
|
|
counter = CompileCounter()
|
|
|
|
@torch._dynamo.optimize(counter, nopython=True)
|
|
def f(jag_tensor):
|
|
# The indexing here requires more symbolic reasoning
|
|
# and doesn't work right now
|
|
return jag_tensor["index_0"].values().sum()
|
|
|
|
f(jag_tensor1)
|
|
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
jag_tensor2 = KeyedJaggedTensor(
|
|
values=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
|
|
keys=["index_0", "index_1"],
|
|
lengths=torch.tensor([2, 0, 1, 1, 1, 3]),
|
|
).sync()
|
|
|
|
f(jag_tensor2)
|
|
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|