[op-bench] check device attribute in user inputs

Summary: The device attribute in the op benchmark can only include 'cpu' or 'cuda'. So adding a check in this diff.

Test Plan: buck run caffe2/benchmarks/operator_benchmark:benchmark_all_test -- --warmup_iterations 1 --iterations 1

Reviewed By: ngimel

Differential Revision: D22538252

fbshipit-source-id: 3e5af72221fc056b8d867321ad22e35a2557b8c3
This commit is contained in:
Mingzhe Li 2020-07-14 17:15:38 -07:00 committed by Facebook GitHub Bot
parent a0f110190c
commit 4ddf27ba48

View File

@ -17,6 +17,7 @@ This module contains utilities for writing microbenchmark tests.
# Here are the reserved keywords in the benchmark suite
_reserved_keywords = {"probs", "total_samples", "tags"}
_supported_devices = {"cpu", "cuda"}
def shape_to_string(shape):
return ', '.join([str(x) for x in shape])
@ -108,6 +109,7 @@ def cross_product_configs(**configs):
({'M': 2}, {'N' : 4}),
({'M': 2}, {'N' : 5}))
"""
_validate(configs)
configs_attrs_list = []
for key, values in configs.items():
tmp_results = [{key : value} for value in values]
@ -120,6 +122,13 @@ def cross_product_configs(**configs):
return generated_configs
def _validate(configs):
""" Validate inputs from users."""
if 'device' in configs:
for v in configs['device']:
assert(v in _supported_devices), "Device needs to be a string."
def config_list(**configs):
""" Generate configs based on the list of input shapes.
This function will take input shapes specified in a list from user. Besides
@ -153,6 +162,8 @@ def config_list(**configs):
if any(attr not in configs for attr in reserved_names):
raise ValueError("Missing attrs in configs")
_validate(configs)
cross_configs = None
if 'cross_product_configs' in configs:
cross_configs = cross_product_configs(**configs['cross_product_configs'])