mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[op-bench] check device attribute in user inputs
Summary: The device attribute in the op benchmark can only include 'cpu' or 'cuda'. So adding a check in this diff. Test Plan: buck run caffe2/benchmarks/operator_benchmark:benchmark_all_test -- --warmup_iterations 1 --iterations 1 Reviewed By: ngimel Differential Revision: D22538252 fbshipit-source-id: 3e5af72221fc056b8d867321ad22e35a2557b8c3
This commit is contained in:
parent
a0f110190c
commit
4ddf27ba48
|
|
@ -17,6 +17,7 @@ This module contains utilities for writing microbenchmark tests.
|
||||||
|
|
||||||
# Here are the reserved keywords in the benchmark suite
|
# Here are the reserved keywords in the benchmark suite
|
||||||
_reserved_keywords = {"probs", "total_samples", "tags"}
|
_reserved_keywords = {"probs", "total_samples", "tags"}
|
||||||
|
_supported_devices = {"cpu", "cuda"}
|
||||||
|
|
||||||
def shape_to_string(shape):
|
def shape_to_string(shape):
|
||||||
return ', '.join([str(x) for x in shape])
|
return ', '.join([str(x) for x in shape])
|
||||||
|
|
@ -108,6 +109,7 @@ def cross_product_configs(**configs):
|
||||||
({'M': 2}, {'N' : 4}),
|
({'M': 2}, {'N' : 4}),
|
||||||
({'M': 2}, {'N' : 5}))
|
({'M': 2}, {'N' : 5}))
|
||||||
"""
|
"""
|
||||||
|
_validate(configs)
|
||||||
configs_attrs_list = []
|
configs_attrs_list = []
|
||||||
for key, values in configs.items():
|
for key, values in configs.items():
|
||||||
tmp_results = [{key : value} for value in values]
|
tmp_results = [{key : value} for value in values]
|
||||||
|
|
@ -120,6 +122,13 @@ def cross_product_configs(**configs):
|
||||||
return generated_configs
|
return generated_configs
|
||||||
|
|
||||||
|
|
||||||
|
def _validate(configs):
|
||||||
|
""" Validate inputs from users."""
|
||||||
|
if 'device' in configs:
|
||||||
|
for v in configs['device']:
|
||||||
|
assert(v in _supported_devices), "Device needs to be a string."
|
||||||
|
|
||||||
|
|
||||||
def config_list(**configs):
|
def config_list(**configs):
|
||||||
""" Generate configs based on the list of input shapes.
|
""" Generate configs based on the list of input shapes.
|
||||||
This function will take input shapes specified in a list from user. Besides
|
This function will take input shapes specified in a list from user. Besides
|
||||||
|
|
@ -153,6 +162,8 @@ def config_list(**configs):
|
||||||
if any(attr not in configs for attr in reserved_names):
|
if any(attr not in configs for attr in reserved_names):
|
||||||
raise ValueError("Missing attrs in configs")
|
raise ValueError("Missing attrs in configs")
|
||||||
|
|
||||||
|
_validate(configs)
|
||||||
|
|
||||||
cross_configs = None
|
cross_configs = None
|
||||||
if 'cross_product_configs' in configs:
|
if 'cross_product_configs' in configs:
|
||||||
cross_configs = cross_product_configs(**configs['cross_product_configs'])
|
cross_configs = cross_product_configs(**configs['cross_product_configs'])
|
||||||
|
|
@ -318,14 +329,14 @@ def get_operator_range(chars_range):
|
||||||
if chars_range == 'None' or chars_range is None:
|
if chars_range == 'None' or chars_range is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if all(item not in chars_range for item in [',', '-']):
|
if all(item not in chars_range for item in [',', '-']):
|
||||||
raise ValueError("The correct format for operator_range is "
|
raise ValueError("The correct format for operator_range is "
|
||||||
"<start>-<end>, or <point>, <start>-<end>")
|
"<start>-<end>, or <point>, <start>-<end>")
|
||||||
|
|
||||||
ops_start_chars_set = set()
|
ops_start_chars_set = set()
|
||||||
ranges = chars_range.split(',')
|
ranges = chars_range.split(',')
|
||||||
for item in ranges:
|
for item in ranges:
|
||||||
if len(item) == 1:
|
if len(item) == 1:
|
||||||
ops_start_chars_set.add(item.lower())
|
ops_start_chars_set.add(item.lower())
|
||||||
continue
|
continue
|
||||||
start, end = item.split("-")
|
start, end = item.split("-")
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user