mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
rework device type filter rule (#55753)
Summary: Currently common_device_type generates device-specific test based on vague rules. see https://github.com/pytorch/pytorch/issues/55707. This should fix https://github.com/pytorch/pytorch/issues/55707 # Changes included This PR changes the rule: 1. First user provided args (`except_for` and `only_for`) are processed to filter out desired device type from a ALL_AVAILABLE_LIST 2. Then environment variables are processed the exact same way. tests are generated based on the final filtered list. Pull Request resolved: https://github.com/pytorch/pytorch/pull/55753 Test Plan: CI Reviewed By: seemethere, ngimel Differential Revision: D27709192 Pulled By: walterddr fbshipit-source-id: 1d5378ef013b22a7fb9fdae24b486730b2e67401
This commit is contained in:
parent
dab1cdf7cb
commit
d0cd16899f
|
|
@ -427,6 +427,21 @@ def get_device_type_test_bases():
|
|||
device_type_test_bases = get_device_type_test_bases()
|
||||
|
||||
|
||||
def filter_desired_device_types(device_type_test_bases, except_for=None, only_for=None):
|
||||
# device type cannot appear in both except_for and only_for
|
||||
intersect = set(except_for if except_for else []) & set(only_for if only_for else [])
|
||||
assert not intersect, f"device ({intersect}) appeared in both except_for and only_for"
|
||||
|
||||
if except_for is not None:
|
||||
device_type_test_bases = filter(
|
||||
lambda x: x.device_type not in except_for, device_type_test_bases)
|
||||
if only_for is not None:
|
||||
device_type_test_bases = filter(
|
||||
lambda x: x.device_type in only_for, device_type_test_bases)
|
||||
|
||||
return list(device_type_test_bases)
|
||||
|
||||
|
||||
# Note [How to extend DeviceTypeTestBase to add new test device]
|
||||
# The following logic optionally allows downstream projects like pytorch/xla to
|
||||
# add more test devices.
|
||||
|
|
@ -474,31 +489,28 @@ def instantiate_device_type_tests(generic_test_class, scope, except_for=None, on
|
|||
# Acquires members names
|
||||
# See Note [Overriding methods in generic tests]
|
||||
generic_members = set(generic_test_class.__dict__.keys()) - set(empty_class.__dict__.keys())
|
||||
generic_tests = [x for x in generic_members if x.startswith('test')]
|
||||
generic_tests = [x for x in generic_members if x.startswith('test')]\
|
||||
|
||||
def split_if_not_empty(x):
|
||||
# Filter out the device types based on user inputs
|
||||
desired_device_type_test_bases = filter_desired_device_types(device_type_test_bases,
|
||||
except_for, only_for)
|
||||
|
||||
def split_if_not_empty(x: str):
|
||||
return x.split(",") if len(x) != 0 else []
|
||||
|
||||
# Derive defaults from environment variables if available, default is still none
|
||||
# Filter out the device types based on environment variables if available
|
||||
# Usage:
|
||||
# export PYTORCH_TESTING_DEVICE_ONLY_FOR=cuda,cpu
|
||||
# export PYTORCH_TESTING_DEVICE_EXCEPT_FOR=xla
|
||||
if only_for is None:
|
||||
only_for = split_if_not_empty(os.getenv("PYTORCH_TESTING_DEVICE_ONLY_FOR", ''))
|
||||
env_only_for = split_if_not_empty(os.getenv("PYTORCH_TESTING_DEVICE_ONLY_FOR", ''))
|
||||
env_except_for = split_if_not_empty(os.getenv("PYTORCH_TESTING_DEVICE_EXCEPT_FOR", ''))
|
||||
|
||||
desired_device_type_test_bases = filter_desired_device_types(desired_device_type_test_bases,
|
||||
env_except_for, env_only_for)
|
||||
|
||||
if except_for is None:
|
||||
except_for = split_if_not_empty(os.getenv("PYTORCH_TESTING_DEVICE_EXCEPT_FOR", ''))
|
||||
|
||||
# Creates device-specific test cases
|
||||
for base in device_type_test_bases:
|
||||
# Skips bases listed in except_for
|
||||
if except_for and only_for:
|
||||
assert base.device_type not in except_for or base.device_type not in only_for,\
|
||||
"same device cannot appear in except_for and only_for"
|
||||
if except_for and base.device_type in except_for:
|
||||
continue
|
||||
if only_for and base.device_type not in only_for:
|
||||
continue
|
||||
for base in desired_device_type_test_bases:
|
||||
# Special-case for ROCm testing -- only test for 'cuda' i.e. ROCm device by default
|
||||
# The except_for and only_for cases were already checked above. At this point we only need to check 'cuda'.
|
||||
if TEST_WITH_ROCM and base.device_type != 'cuda':
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user