[op-bench] check device attribute in user inputs

Summary: The device attribute in the op benchmark can only include 'cpu' or 'cuda'. So adding a check in this diff.

Test Plan: buck run caffe2/benchmarks/operator_benchmark:benchmark_all_test -- --warmup_iterations 1 --iterations 1

Reviewed By: ngimel

Differential Revision: D22538252

fbshipit-source-id: 3e5af72221fc056b8d867321ad22e35a2557b8c3
This commit is contained in:
Mingzhe Li 2020-07-14 17:15:38 -07:00 committed by Facebook GitHub Bot
parent a0f110190c
commit 4ddf27ba48

View File

@ -17,6 +17,7 @@ This module contains utilities for writing microbenchmark tests.
# Here are the reserved keywords in the benchmark suite # Here are the reserved keywords in the benchmark suite
_reserved_keywords = {"probs", "total_samples", "tags"} _reserved_keywords = {"probs", "total_samples", "tags"}
_supported_devices = {"cpu", "cuda"}
def shape_to_string(shape): def shape_to_string(shape):
return ', '.join([str(x) for x in shape]) return ', '.join([str(x) for x in shape])
@ -108,6 +109,7 @@ def cross_product_configs(**configs):
({'M': 2}, {'N' : 4}), ({'M': 2}, {'N' : 4}),
({'M': 2}, {'N' : 5})) ({'M': 2}, {'N' : 5}))
""" """
_validate(configs)
configs_attrs_list = [] configs_attrs_list = []
for key, values in configs.items(): for key, values in configs.items():
tmp_results = [{key : value} for value in values] tmp_results = [{key : value} for value in values]
@ -120,6 +122,13 @@ def cross_product_configs(**configs):
return generated_configs return generated_configs
def _validate(configs):
""" Validate inputs from users."""
if 'device' in configs:
for v in configs['device']:
assert(v in _supported_devices), "Device needs to be a string."
def config_list(**configs): def config_list(**configs):
""" Generate configs based on the list of input shapes. """ Generate configs based on the list of input shapes.
This function will take input shapes specified in a list from user. Besides This function will take input shapes specified in a list from user. Besides
@ -153,6 +162,8 @@ def config_list(**configs):
if any(attr not in configs for attr in reserved_names): if any(attr not in configs for attr in reserved_names):
raise ValueError("Missing attrs in configs") raise ValueError("Missing attrs in configs")
_validate(configs)
cross_configs = None cross_configs = None
if 'cross_product_configs' in configs: if 'cross_product_configs' in configs:
cross_configs = cross_product_configs(**configs['cross_product_configs']) cross_configs = cross_product_configs(**configs['cross_product_configs'])
@ -318,14 +329,14 @@ def get_operator_range(chars_range):
if chars_range == 'None' or chars_range is None: if chars_range == 'None' or chars_range is None:
return None return None
if all(item not in chars_range for item in [',', '-']): if all(item not in chars_range for item in [',', '-']):
raise ValueError("The correct format for operator_range is " raise ValueError("The correct format for operator_range is "
"<start>-<end>, or <point>, <start>-<end>") "<start>-<end>, or <point>, <start>-<end>")
ops_start_chars_set = set() ops_start_chars_set = set()
ranges = chars_range.split(',') ranges = chars_range.split(',')
for item in ranges: for item in ranges:
if len(item) == 1: if len(item) == 1:
ops_start_chars_set.add(item.lower()) ops_start_chars_set.add(item.lower())
continue continue
start, end = item.split("-") start, end = item.split("-")