From 4ddf27ba48ba3313a20d3316a8929cd42436ddbc Mon Sep 17 00:00:00 2001 From: Mingzhe Li Date: Tue, 14 Jul 2020 17:15:38 -0700 Subject: [PATCH] [op-bench] check device attribute in user inputs Summary: The device attribute in the op benchmark can only include 'cpu' or 'cuda'. So adding a check in this diff. Test Plan: buck run caffe2/benchmarks/operator_benchmark:benchmark_all_test -- --warmup_iterations 1 --iterations 1 Reviewed By: ngimel Differential Revision: D22538252 fbshipit-source-id: 3e5af72221fc056b8d867321ad22e35a2557b8c3 --- .../operator_benchmark/benchmark_utils.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/benchmarks/operator_benchmark/benchmark_utils.py b/benchmarks/operator_benchmark/benchmark_utils.py index 15d25dc8bb6..980967b6200 100644 --- a/benchmarks/operator_benchmark/benchmark_utils.py +++ b/benchmarks/operator_benchmark/benchmark_utils.py @@ -17,6 +17,7 @@ This module contains utilities for writing microbenchmark tests. # Here are the reserved keywords in the benchmark suite _reserved_keywords = {"probs", "total_samples", "tags"} +_supported_devices = {"cpu", "cuda"} def shape_to_string(shape): return ', '.join([str(x) for x in shape]) @@ -108,6 +109,7 @@ def cross_product_configs(**configs): ({'M': 2}, {'N' : 4}), ({'M': 2}, {'N' : 5})) """ + _validate(configs) configs_attrs_list = [] for key, values in configs.items(): tmp_results = [{key : value} for value in values] @@ -120,6 +122,13 @@ def cross_product_configs(**configs): return generated_configs +def _validate(configs): + """ Validate inputs from users.""" + if 'device' in configs: + for v in configs['device']: + assert(v in _supported_devices), "Device needs to be a string." + + def config_list(**configs): """ Generate configs based on the list of input shapes. This function will take input shapes specified in a list from user. Besides @@ -153,6 +162,8 @@ def config_list(**configs): if any(attr not in configs for attr in reserved_names): raise ValueError("Missing attrs in configs") + _validate(configs) + cross_configs = None if 'cross_product_configs' in configs: cross_configs = cross_product_configs(**configs['cross_product_configs']) @@ -318,14 +329,14 @@ def get_operator_range(chars_range): if chars_range == 'None' or chars_range is None: return None - if all(item not in chars_range for item in [',', '-']): - raise ValueError("The correct format for operator_range is " + if all(item not in chars_range for item in [',', '-']): + raise ValueError("The correct format for operator_range is " "-, or , -") ops_start_chars_set = set() ranges = chars_range.split(',') - for item in ranges: - if len(item) == 1: + for item in ranges: + if len(item) == 1: ops_start_chars_set.add(item.lower()) continue start, end = item.split("-")