mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE] Add sharding data by default to metrics (#110035)
Extend metric library to allow setting global metrics on a process level which will always be emitted. Current use case for them is to include shard information every time a metric is emitted by run_test.py <!-- copilot:poem --> ### <samp>🤖 Generated by Copilot at 0cae92c</samp> > _`run_test` refactored_ > _Sharding metrics in Rockset_ > _Autumn of testing_ Pull Request resolved: https://github.com/pytorch/pytorch/pull/110035 Approved by: https://github.com/clee2000
This commit is contained in:
parent
d91492a7a4
commit
1277d0e834
|
|
@ -13,7 +13,7 @@ import sys
|
||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, cast, Dict, List, NamedTuple, Optional, Union
|
from typing import Any, cast, Dict, List, NamedTuple, Optional, Tuple, Union
|
||||||
|
|
||||||
import pkg_resources
|
import pkg_resources
|
||||||
|
|
||||||
|
|
@ -40,7 +40,7 @@ REPO_ROOT = pathlib.Path(__file__).resolve().parent.parent
|
||||||
# using tools/ to optimize test run.
|
# using tools/ to optimize test run.
|
||||||
sys.path.insert(0, str(REPO_ROOT))
|
sys.path.insert(0, str(REPO_ROOT))
|
||||||
from tools.stats.export_test_times import TEST_TIMES_FILE
|
from tools.stats.export_test_times import TEST_TIMES_FILE
|
||||||
from tools.stats.upload_metrics import emit_metric
|
from tools.stats.upload_metrics import add_global_metric, emit_metric
|
||||||
from tools.testing.target_determination.determinator import (
|
from tools.testing.target_determination.determinator import (
|
||||||
AggregatedHeuristics,
|
AggregatedHeuristics,
|
||||||
get_test_prioritizations,
|
get_test_prioritizations,
|
||||||
|
|
@ -1438,12 +1438,7 @@ def download_test_times(file: str = TEST_TIMES_FILE) -> Dict[str, float]:
|
||||||
return test_times_file["default"]["default"]
|
return test_times_file["default"]["default"]
|
||||||
|
|
||||||
|
|
||||||
def do_sharding(
|
def get_sharding_opts(options) -> Tuple[int, int]:
|
||||||
options,
|
|
||||||
selected_tests: List[str],
|
|
||||||
test_file_times: Dict[str, float],
|
|
||||||
sort_by_time: bool = True,
|
|
||||||
) -> List[ShardedTest]:
|
|
||||||
which_shard, num_shards = 1, 1
|
which_shard, num_shards = 1, 1
|
||||||
if options.shard:
|
if options.shard:
|
||||||
assert len(options.shard) == 2, "Unexpected shard format"
|
assert len(options.shard) == 2, "Unexpected shard format"
|
||||||
|
|
@ -1453,6 +1448,17 @@ def do_sharding(
|
||||||
which_shard <= num_shards
|
which_shard <= num_shards
|
||||||
), "Selected shard must be less than or equal to total number of shards"
|
), "Selected shard must be less than or equal to total number of shards"
|
||||||
|
|
||||||
|
return (which_shard, num_shards)
|
||||||
|
|
||||||
|
|
||||||
|
def do_sharding(
|
||||||
|
options,
|
||||||
|
selected_tests: List[str],
|
||||||
|
test_file_times: Dict[str, float],
|
||||||
|
sort_by_time: bool = True,
|
||||||
|
) -> List[ShardedTest]:
|
||||||
|
which_shard, num_shards = get_sharding_opts(options)
|
||||||
|
|
||||||
# Do sharding
|
# Do sharding
|
||||||
shards = calculate_shards(
|
shards = calculate_shards(
|
||||||
num_shards,
|
num_shards,
|
||||||
|
|
@ -1616,6 +1622,11 @@ def main():
|
||||||
|
|
||||||
options = parse_args()
|
options = parse_args()
|
||||||
|
|
||||||
|
# Include sharding info in all metrics
|
||||||
|
which_shard, num_shards = get_sharding_opts(options)
|
||||||
|
add_global_metric("shard", which_shard)
|
||||||
|
add_global_metric("num_shards", num_shards)
|
||||||
|
|
||||||
test_directory = str(REPO_ROOT / "test")
|
test_directory = str(REPO_ROOT / "test")
|
||||||
selected_tests = get_selected_tests(options)
|
selected_tests = get_selected_tests(options)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -59,6 +59,18 @@ class EnvVarMetric:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
global_metrics: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def add_global_metric(metric_name: str, metric_value: Any) -> None:
|
||||||
|
"""
|
||||||
|
Adds stats that should be emitted with every metric by the current process.
|
||||||
|
If the emit_metrics method specifies a metric with the same name, it will
|
||||||
|
overwrite this value.
|
||||||
|
"""
|
||||||
|
global_metrics[metric_name] = metric_value
|
||||||
|
|
||||||
|
|
||||||
def emit_metric(
|
def emit_metric(
|
||||||
metric_name: str,
|
metric_name: str,
|
||||||
metrics: Dict[str, Any],
|
metrics: Dict[str, Any],
|
||||||
|
|
@ -83,6 +95,10 @@ def emit_metric(
|
||||||
if metrics is None:
|
if metrics is None:
|
||||||
raise ValueError("You didn't ask to upload any metrics!")
|
raise ValueError("You didn't ask to upload any metrics!")
|
||||||
|
|
||||||
|
# Merge the given metrics with the global metrics, overwriting any duplicates
|
||||||
|
# with the given metrics.
|
||||||
|
metrics = {**global_metrics, **metrics}
|
||||||
|
|
||||||
# We use these env vars that to determine basic info about the workflow run.
|
# We use these env vars that to determine basic info about the workflow run.
|
||||||
# By using env vars, we don't have to pass this info around to every function.
|
# By using env vars, we don't have to pass this info around to every function.
|
||||||
# It also helps ensure that we only emit metrics during CI
|
# It also helps ensure that we only emit metrics during CI
|
||||||
|
|
|
||||||
|
|
@ -4,7 +4,7 @@ import unittest
|
||||||
from typing import Any, Dict
|
from typing import Any, Dict
|
||||||
from unittest import mock
|
from unittest import mock
|
||||||
|
|
||||||
from tools.stats.upload_metrics import emit_metric
|
from tools.stats.upload_metrics import add_global_metric, emit_metric
|
||||||
|
|
||||||
from tools.stats.upload_stats_lib import BATCH_SIZE, upload_to_rockset
|
from tools.stats.upload_stats_lib import BATCH_SIZE, upload_to_rockset
|
||||||
|
|
||||||
|
|
@ -85,6 +85,76 @@ class TestUploadStats(unittest.TestCase):
|
||||||
{**emit_should_include, **emitted_metric},
|
{**emit_should_include, **emitted_metric},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@mock.patch("boto3.Session.resource")
|
||||||
|
def test_when_global_metric_specified_then_it_emits_it(
|
||||||
|
self, mock_resource: Any
|
||||||
|
) -> None:
|
||||||
|
metric = {
|
||||||
|
"some_number": 123,
|
||||||
|
}
|
||||||
|
|
||||||
|
global_metric_name = "global_metric"
|
||||||
|
global_metric_value = "global_value"
|
||||||
|
|
||||||
|
add_global_metric(global_metric_name, global_metric_value)
|
||||||
|
|
||||||
|
emit_should_include = {
|
||||||
|
**metric,
|
||||||
|
global_metric_name: global_metric_value,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Preserve the metric emitted
|
||||||
|
emitted_metric: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def mock_put_item(Item: Dict[str, Any]) -> None:
|
||||||
|
nonlocal emitted_metric
|
||||||
|
emitted_metric = Item
|
||||||
|
|
||||||
|
mock_resource.return_value.Table.return_value.put_item = mock_put_item
|
||||||
|
|
||||||
|
emit_metric("metric_name", metric)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
emitted_metric,
|
||||||
|
{**emitted_metric, **emit_should_include},
|
||||||
|
)
|
||||||
|
|
||||||
|
@mock.patch("boto3.Session.resource")
|
||||||
|
def test_when_local_and_global_metric_specified_then_global_is_overridden(
|
||||||
|
self, mock_resource: Any
|
||||||
|
) -> None:
|
||||||
|
global_metric_name = "global_metric"
|
||||||
|
global_metric_value = "global_value"
|
||||||
|
local_override = "local_override"
|
||||||
|
|
||||||
|
add_global_metric(global_metric_name, global_metric_value)
|
||||||
|
|
||||||
|
metric = {
|
||||||
|
"some_number": 123,
|
||||||
|
global_metric_name: local_override,
|
||||||
|
}
|
||||||
|
|
||||||
|
emit_should_include = {
|
||||||
|
**metric,
|
||||||
|
global_metric_name: local_override,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Preserve the metric emitted
|
||||||
|
emitted_metric: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
def mock_put_item(Item: Dict[str, Any]) -> None:
|
||||||
|
nonlocal emitted_metric
|
||||||
|
emitted_metric = Item
|
||||||
|
|
||||||
|
mock_resource.return_value.Table.return_value.put_item = mock_put_item
|
||||||
|
|
||||||
|
emit_metric("metric_name", metric)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
emitted_metric,
|
||||||
|
{**emitted_metric, **emit_should_include},
|
||||||
|
)
|
||||||
|
|
||||||
@mock.patch("boto3.Session.resource")
|
@mock.patch("boto3.Session.resource")
|
||||||
def test_when_optional_envvar_set_to_actual_value_then_emit_vars_emits_it(
|
def test_when_optional_envvar_set_to_actual_value_then_emit_vars_emits_it(
|
||||||
self, mock_resource: Any
|
self, mock_resource: Any
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user