diff --git a/tools/stats/upload_stats_lib.py b/tools/stats/upload_stats_lib.py index 126d78233d9..dd48e78ab7a 100644 --- a/tools/stats/upload_stats_lib.py +++ b/tools/stats/upload_stats_lib.py @@ -23,6 +23,8 @@ S3_RESOURCE = boto3.resource("s3") # NB: In CI, a flaky test is usually retried 3 times, then the test file would be rerun # 2 more times MAX_RETRY_IN_NON_DISABLED_MODE = 3 * 3 +# NB: Rockset has an upper limit of 5000 documents in one request +BATCH_SIZE = 5000 def _get_request_headers() -> Dict[str, str]: @@ -116,17 +118,29 @@ def download_gha_artifacts( def upload_to_rockset( - collection: str, docs: List[Any], workspace: str = "commons" + collection: str, + docs: List[Any], + workspace: str = "commons", + client: Any = None, ) -> None: - print(f"Writing {len(docs)} documents to Rockset") - client = rockset.RocksetClient( - host="api.usw2a1.rockset.com", api_key=os.environ["ROCKSET_API_KEY"] - ) - client.Documents.add_documents( - collection=collection, - data=docs, - workspace=workspace, - ) + if not client: + client = rockset.RocksetClient( + host="api.usw2a1.rockset.com", api_key=os.environ["ROCKSET_API_KEY"] + ) + + index = 0 + while index < len(docs): + from_index = index + to_index = min(from_index + BATCH_SIZE, len(docs)) + print(f"Writing {to_index - from_index} documents to Rockset") + + client.Documents.add_documents( + collection=collection, + data=docs[from_index:to_index], + workspace=workspace, + ) + index += BATCH_SIZE + print("Done!") diff --git a/tools/test/test_upload_stats_lib.py b/tools/test/test_upload_stats_lib.py index 979e62c769a..da7bc400c5e 100644 --- a/tools/test/test_upload_stats_lib.py +++ b/tools/test/test_upload_stats_lib.py @@ -4,7 +4,7 @@ import unittest from typing import Any, Dict from unittest import mock -from tools.stats.upload_stats_lib import emit_metric +from tools.stats.upload_stats_lib import BATCH_SIZE, emit_metric, upload_to_rockset # default values REPO = "some/repo" @@ -109,6 +109,38 @@ class TestUploadStats(unittest.TestCase): self.assertFalse(put_item_invoked) + def test_upload_to_rockset_batch_size(self) -> None: + cases = [ + { + "batch_size": BATCH_SIZE - 1, + "expected_number_of_requests": 1, + }, + { + "batch_size": BATCH_SIZE, + "expected_number_of_requests": 1, + }, + { + "batch_size": BATCH_SIZE + 1, + "expected_number_of_requests": 2, + }, + ] + + for case in cases: + mock_client = mock.Mock() + mock_client.Documents.add_documents.return_value = "OK" + + batch_size = case["batch_size"] + expected_number_of_requests = case["expected_number_of_requests"] + + docs = list(range(batch_size)) + upload_to_rockset( + collection="test", docs=docs, workspace="commons", client=mock_client + ) + self.assertEqual( + mock_client.Documents.add_documents.call_count, + expected_number_of_requests, + ) + if __name__ == "__main__": unittest.main()