pytorch/test/dynamo/test_torchrec.py
Edward Yang 88600e7d2e [RELAND] Force synced KJT to trace unbacked SymInt (#108960) (#109216)
Summary:

The basic concept behind this diff is to modify Dynamo's tracing behavior when it encounters a KeyedJaggedTensor that is synced (aka has `_length_per_key` and `_offset_per_key` populated). These fields are lists of integers; ordinarily, Dynamo will optimistically try to specialize on integers, however, for KJTs, we know that these integers will definitely vary from run-to-run. Furthermore, ordinarily, we would also specialize these integers if they are 0/1, but we will frequently expect features in KJTs to be 0/1.

The fix is to detect KJTs and treat these integers as *unbacked integers*. This is NOT a universally sound optimization: when treating these integers as unbacked, we never report them as equal to zero or one. In return, we always generate graphs that generalize no matter the length of values on features. This is enough to trace through APS sparse arch, torchrec_dlrm and some small split-cat examples.

The special integer behavior is triggered by a dynamically scoped `force_unspec_int_unbacked_size_like` variable on TracingContext, which we trigger when we wrap a KJT. There probably are other ways to do this, but this was simple and worked.

Test Plan:
```
buck2 test mode/dev-nosan //pytorch/benchmark/fb/test_gpu:run_test_gpu
```

from aakhundov

1. first build feed_lower_benchmark:
```
buck2 build --show-output mode/opt -c python.package_style=inplace -c fbcode.enable_gpu_sections=true -c fbcode.platform=platform010 -c fbcode.split-dwarf=true hpc/new/models/feed/benchmark:feed_lower_benchmark
```
2. then run the lowering of the model with it:
```
TORCHINDUCTOR_MAX_AUTOTUNE=1 TORCHINDUCTOR_UNIQUE_KERNEL_NAMES=1 TORCH_LOGS="output_code,graph_code" TORCH_COMPILE_DEBUG=1 ../buck-out/v2/gen/fbcode/79c6b019ee0f9469/hpc/new/models/feed/benchmark/__feed_lower_benchmark__/feed_lower_benchmark.par --load=manifold://ig_inference_model/tree/user/facebook/fblearner/predictor/960999465/60/gpu_lowering/input.predictor --skip-trt --skip-ait --sync-mode=0 --enable-aot-inductor --lower-presets="ig_stories" --gpu-trace
```
cf https://docs.google.com/document/d/1yD30xYrdmM8r2HTdmXnZTg0-MHVexfVrAa0294m1AUE/edit?pli=1#heading=h.qiv3fp7e6zg0

From torchrec: https://www.internalfb.com/intern/wiki/Torchrec/Development/Testing_production_models/

From ge0405
baseline (without your diff): f477293168
your diff: f477292363

```
buck2 test //caffe2/test/dynamo:test_dynamo_torchrec
buck2 run 'fbcode//mode/opt' fbcode//pytorch/benchmark/fb/test_gpu:run_test_gpu -- 'pytorch.benchmark.fb.test_gpu.test_gpu.TestBenchmarkFbGpu.test_train_blue_reels_vdd_v3_inductor_speedup'
```

Differential Revision: D49236757

Pull Request resolved: https://github.com/pytorch/pytorch/pull/109216
Approved by: https://github.com/voznesenskym
2023-09-18 14:39:44 +00:00

207 lines
6.6 KiB
Python

# Owner(s): ["module: dynamo"]
import sys
import unittest
from typing import Dict, List
import torch
import torch._dynamo.config
import torch._dynamo.test_case
from torch import nn
from torch._dynamo.test_case import TestCase
from torch._dynamo.testing import CompileCounter
from torch.testing._internal.common_utils import NoTest
try:
from torchrec.datasets.random import RandomRecDataset
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
HAS_TORCHREC = True
except ImportError:
HAS_TORCHREC = False
@torch._dynamo.config.patch(force_unspec_int_unbacked_size_like_on_torchrec_kjt=True)
class BucketizeMod(torch.nn.Module):
def __init__(self, feature_boundaries: Dict[str, List[float]]):
super().__init__()
self.bucket_w = torch.nn.ParameterDict()
self.boundaries_dict = {}
for key, boundaries in feature_boundaries.items():
self.bucket_w[key] = torch.nn.Parameter(
torch.empty([len(boundaries) + 1]).fill_(1.0),
requires_grad=True,
)
buf = torch.tensor(boundaries, requires_grad=False)
self.register_buffer(
f"{key}_boundaries",
buf,
persistent=False,
)
self.boundaries_dict[key] = buf
def forward(self, features: "KeyedJaggedTensor") -> "KeyedJaggedTensor":
weights_list = []
for key, boundaries in self.boundaries_dict.items():
jt = features[key]
bucketized = torch.bucketize(jt.weights(), boundaries)
# doesn't super matter I guess
# hashed = torch.ops.fb.index_hash(bucketized, seed=0, modulo=len(boundaries))
hashed = bucketized
weights = torch.gather(self.bucket_w[key], dim=0, index=hashed)
weights_list.append(weights)
return KeyedJaggedTensor(
keys=features.keys(),
values=features.values(),
weights=torch.cat(weights_list),
lengths=features.lengths(),
offsets=features.offsets(),
stride=features.stride(),
length_per_key=features.length_per_key(),
)
if not HAS_TORCHREC:
print("torchrec not available, skipping tests", file=sys.stderr)
TestCase = NoTest # noqa: F811
@unittest.skipIf(not HAS_TORCHREC, "these tests require torchrec")
class TorchRecTests(TestCase):
def test_pooled(self):
tables = [
(nn.EmbeddingBag(2000, 8), ["a0", "b0"]),
(nn.EmbeddingBag(2000, 8), ["a1", "b1"]),
(nn.EmbeddingBag(2000, 8), ["b2"]),
]
embedding_groups = {
"a": ["a0", "a1"],
"b": ["b0", "b1", "b2"],
}
counter = CompileCounter()
@torch.compile(backend=counter, fullgraph=True, dynamic=True)
def f(id_list_features: KeyedJaggedTensor):
id_list_jt_dict: Dict[str, JaggedTensor] = id_list_features.to_dict()
pooled_embeddings = {}
# TODO: run feature processor
for emb_module, feature_names in tables:
features_dict = id_list_jt_dict
for feature_name in feature_names:
f = features_dict[feature_name]
pooled_embeddings[feature_name] = emb_module(
f.values(), f.offsets()
)
pooled_embeddings_by_group = {}
for group_name, group_embedding_names in embedding_groups.items():
group_embeddings = [
pooled_embeddings[name] for name in group_embedding_names
]
pooled_embeddings_by_group[group_name] = torch.cat(
group_embeddings, dim=1
)
return pooled_embeddings_by_group
dataset = RandomRecDataset(
keys=["a0", "a1", "b0", "b1", "b2"],
batch_size=4,
hash_size=2000,
ids_per_feature=3,
num_dense=0,
)
di = iter(dataset)
# unsync should work
d1 = next(di).sparse_features.unsync()
d2 = next(di).sparse_features.unsync()
d3 = next(di).sparse_features.unsync()
r1 = f(d1)
r2 = f(d2)
r3 = f(d3)
self.assertEqual(counter.frame_count, 1)
counter.frame_count = 0
# sync should work too
d1 = next(di).sparse_features.sync()
d2 = next(di).sparse_features.sync()
d3 = next(di).sparse_features.sync()
r1 = f(d1)
r2 = f(d2)
r3 = f(d3)
self.assertEqual(counter.frame_count, 1)
# export only works with unsync
gm = torch._dynamo.export(f)(next(di).sparse_features.unsync()).graph_module
gm.print_readable()
self.assertEqual(gm(d1), r1)
self.assertEqual(gm(d2), r2)
self.assertEqual(gm(d3), r3)
def test_bucketize(self):
mod = BucketizeMod({"f1": [0.0, 0.5, 1.0]})
features = KeyedJaggedTensor.from_lengths_sync(
keys=["f1"],
values=torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]),
lengths=torch.tensor([2, 0, 1, 1, 1, 3]),
weights=torch.tensor([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7]),
).unsync()
def f(x):
# This is a trick to populate the computed cache and instruct
# ShapeEnv that they're all sizey
x.to_dict()
return mod(x)
torch._dynamo.export(f, aten_graph=True)(features).graph_module.print_readable()
@unittest.expectedFailure
def test_simple(self):
jag_tensor1 = KeyedJaggedTensor(
values=torch.tensor([3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
keys=["index_0", "index_1"],
lengths=torch.tensor([0, 0, 1, 1, 1, 3]),
).sync()
# ordinarily, this would trigger one specialization
self.assertEqual(jag_tensor1.length_per_key(), [1, 5])
counter = CompileCounter()
@torch._dynamo.optimize(counter, nopython=True)
def f(jag_tensor):
# The indexing here requires more symbolic reasoning
# and doesn't work right now
return jag_tensor["index_0"].values().sum()
f(jag_tensor1)
self.assertEqual(counter.frame_count, 1)
jag_tensor2 = KeyedJaggedTensor(
values=torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
keys=["index_0", "index_1"],
lengths=torch.tensor([2, 0, 1, 1, 1, 3]),
).sync()
f(jag_tensor2)
self.assertEqual(counter.frame_count, 1)
if __name__ == "__main__":
from torch._dynamo.test_case import run_tests
run_tests()