mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE] Enable ruff's UP rules and autoformat utils/ (#105424)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105424 Approved by: https://github.com/ezyang, https://github.com/malfet
This commit is contained in:
parent
91ab32e4b1
commit
abc1cadddb
|
|
@ -237,7 +237,7 @@ class Freezer:
|
|||
module_mangled_name = "__".join(module_qualname)
|
||||
c_name = "M_" + module_mangled_name
|
||||
|
||||
with open(path, "r") as src_file:
|
||||
with open(path) as src_file:
|
||||
co = self.compile_string(src_file.read())
|
||||
|
||||
bytecode = marshal.dumps(co)
|
||||
|
|
|
|||
|
|
@ -1,4 +1,3 @@
|
|||
# -*- coding: utf-8 -*-
|
||||
"""End-to-end example to test a PR for regressions:
|
||||
|
||||
$ python -m examples.end_to_end --pr 39850
|
||||
|
|
@ -111,7 +110,7 @@ _SUBPROCESS_CMD_TEMPLATE = (
|
|||
|
||||
def construct_stmt_and_label(pr, params):
|
||||
if pr == "39850":
|
||||
k0, k1, k2, dim = [params[i] for i in ["k0", "k1", "k2", "dim"]]
|
||||
k0, k1, k2, dim = (params[i] for i in ["k0", "k1", "k2", "dim"])
|
||||
state = np.random.RandomState(params["random_value"])
|
||||
topk_dim = state.randint(low=0, high=dim)
|
||||
dim_size = [k0, k1, k2][topk_dim]
|
||||
|
|
@ -291,7 +290,7 @@ def construct_table(results, device_str, test_variance):
|
|||
)
|
||||
|
||||
_, result_log_file = tempfile.mkstemp(suffix=".log")
|
||||
with open(result_log_file, "wt") as f:
|
||||
with open(result_log_file, "w") as f:
|
||||
f.write(f"{device_str}\n\n{column_labels}\n")
|
||||
print(f"\n{column_labels}\n[First twenty omitted (these tend to be noisy) ]")
|
||||
for key, (r_ref, r_pr), rel_diff in results:
|
||||
|
|
|
|||
|
|
@ -37,13 +37,13 @@ def run(n, stmt, fuzzer_cls):
|
|||
assert_dicts_equal(float_params, int_params)
|
||||
assert_dicts_equal(float_tensor_params["x"], int_tensor_params["x"])
|
||||
|
||||
float_measurement, int_measurement = [
|
||||
float_measurement, int_measurement = (
|
||||
Timer(
|
||||
stmt,
|
||||
globals=tensors,
|
||||
).blocked_autorange(min_run_time=_MEASURE_TIME)
|
||||
for tensors in (float_tensors, int_tensors)
|
||||
]
|
||||
)
|
||||
|
||||
descriptions = []
|
||||
for name in float_tensors:
|
||||
|
|
|
|||
|
|
@ -32,13 +32,13 @@ def run(n, stmt, fuzzer_cls):
|
|||
assert_dicts_equal(float_params, int_params)
|
||||
assert_dicts_equal(float_tensor_params["x"], int_tensor_params["x"])
|
||||
|
||||
float_measurement, int_measurement = [
|
||||
float_measurement, int_measurement = (
|
||||
Timer(
|
||||
stmt,
|
||||
globals=tensors,
|
||||
).blocked_autorange(min_run_time=_MEASURE_TIME)
|
||||
for tensors in (float_tensors, int_tensors)
|
||||
]
|
||||
)
|
||||
|
||||
descriptions = []
|
||||
for name in float_tensors:
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ def run_benchmark(name: str, function: object, dtype: torch.dtype, seed: int, de
|
|||
results = []
|
||||
for tensors, tensor_params, params in spectral_fuzzer.take(samples):
|
||||
shape = [params['k0'], params['k1'], params['k2']][:params['ndim']]
|
||||
str_shape = ' x '.join(["{:<4}".format(s) for s in shape])
|
||||
str_shape = ' x '.join([f"{s:<4}" for s in shape])
|
||||
sub_label = f"{str_shape} {'' if tensor_params['x']['is_contiguous'] else '(discontiguous)'}"
|
||||
for dim in _dim_options(params['ndim']):
|
||||
for nthreads in (1, 4, 16) if not cuda else (1,):
|
||||
|
|
|
|||
|
|
@ -325,7 +325,7 @@ def _make_temp_dir(prefix: Optional[str] = None, gc_dev_shm: bool = False) -> st
|
|||
if not os.path.exists(owner_file):
|
||||
continue
|
||||
|
||||
with open(owner_file, "rt") as f:
|
||||
with open(owner_file) as f:
|
||||
owner_pid = int(f.read())
|
||||
|
||||
if owner_pid == os.getpid():
|
||||
|
|
@ -349,7 +349,7 @@ def _make_temp_dir(prefix: Optional[str] = None, gc_dev_shm: bool = False) -> st
|
|||
os.makedirs(path, exist_ok=False)
|
||||
|
||||
if use_dev_shm:
|
||||
with open(os.path.join(path, "owner.pid"), "wt") as f:
|
||||
with open(os.path.join(path, "owner.pid"), "w") as f:
|
||||
f.write(str(os.getpid()))
|
||||
|
||||
return path
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@ def _compile_template(
|
|||
os.makedirs(build_dir, exist_ok=True)
|
||||
|
||||
src_path = os.path.join(build_dir, "timer_src.cpp")
|
||||
with open(src_path, "wt") as f:
|
||||
with open(src_path, "w") as f:
|
||||
f.write(src)
|
||||
|
||||
# `cpp_extension` has its own locking scheme, so we don't need our lock.
|
||||
|
|
@ -154,7 +154,7 @@ def _compile_template(
|
|||
|
||||
def compile_timeit_template(*, stmt: str, setup: str, global_setup: str) -> TimeitModuleType:
|
||||
template_path: str = os.path.join(SOURCE_ROOT, "timeit_template.cpp")
|
||||
with open(template_path, "rt") as f:
|
||||
with open(template_path) as f:
|
||||
src: str = f.read()
|
||||
|
||||
module = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=False)
|
||||
|
|
@ -164,7 +164,7 @@ def compile_timeit_template(*, stmt: str, setup: str, global_setup: str) -> Time
|
|||
|
||||
def compile_callgrind_template(*, stmt: str, setup: str, global_setup: str) -> str:
|
||||
template_path: str = os.path.join(SOURCE_ROOT, "valgrind_wrapper", "timer_callgrind_template.cpp")
|
||||
with open(template_path, "rt") as f:
|
||||
with open(template_path) as f:
|
||||
src: str = f.read()
|
||||
|
||||
target = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=True)
|
||||
|
|
|
|||
|
|
@ -28,7 +28,10 @@ else:
|
|||
CompletedProcessType = subprocess.CompletedProcess
|
||||
|
||||
|
||||
FunctionCount = NamedTuple("FunctionCount", [("count", int), ("function", str)])
|
||||
class FunctionCount(NamedTuple):
|
||||
# TODO(#105471): Rename the count field
|
||||
count: int # type: ignore[assignment]
|
||||
function: str
|
||||
|
||||
|
||||
@dataclasses.dataclass(repr=False, eq=False, frozen=True)
|
||||
|
|
@ -598,7 +601,7 @@ class _ValgrindWrapper:
|
|||
stderr=subprocess.STDOUT,
|
||||
**kwargs,
|
||||
)
|
||||
with open(stdout_stderr_log, "rt") as f:
|
||||
with open(stdout_stderr_log) as f:
|
||||
return invocation, f.read()
|
||||
finally:
|
||||
f_stdout_stderr.close()
|
||||
|
|
@ -612,7 +615,7 @@ class _ValgrindWrapper:
|
|||
)
|
||||
|
||||
script_file = os.path.join(working_dir, "timer_callgrind.py")
|
||||
with open(script_file, "wt") as f:
|
||||
with open(script_file, "w") as f:
|
||||
f.write(self._construct_script(
|
||||
task_spec,
|
||||
globals=GlobalsBridge(globals, data_dir),
|
||||
|
|
@ -652,7 +655,7 @@ class _ValgrindWrapper:
|
|||
if valgrind_invocation.returncode:
|
||||
error_report = ""
|
||||
if os.path.exists(error_log):
|
||||
with open(error_log, "rt") as f:
|
||||
with open(error_log) as f:
|
||||
error_report = f.read()
|
||||
if not error_report:
|
||||
error_report = "Unknown error.\n" + valgrind_invocation_output
|
||||
|
|
@ -724,7 +727,7 @@ class _ValgrindWrapper:
|
|||
fpath = f"{callgrind_out}.{i + 1}" # Callgrind one-indexes files.
|
||||
callgrind_out_contents: Optional[str] = None
|
||||
if retain_out_file:
|
||||
with open(fpath, "rt") as f:
|
||||
with open(fpath) as f:
|
||||
callgrind_out_contents = f.read()
|
||||
|
||||
return (
|
||||
|
|
|
|||
|
|
@ -16,7 +16,7 @@ def redirect_argv(new_argv):
|
|||
|
||||
def compiled_with_cuda(sysinfo):
|
||||
if sysinfo.cuda_compiled_version:
|
||||
return 'compiled w/ CUDA {}'.format(sysinfo.cuda_compiled_version)
|
||||
return f'compiled w/ CUDA {sysinfo.cuda_compiled_version}'
|
||||
return 'not compiled w/ CUDA'
|
||||
|
||||
|
||||
|
|
@ -59,7 +59,7 @@ def run_env_analysis():
|
|||
'debug_str': debug_str,
|
||||
'pytorch_version': info.torch_version,
|
||||
'cuda_compiled': compiled_with_cuda(info),
|
||||
'py_version': '{}.{}'.format(sys.version_info[0], sys.version_info[1]),
|
||||
'py_version': f'{sys.version_info[0]}.{sys.version_info[1]}',
|
||||
'cuda_runtime': cuda_avail,
|
||||
'pip_version': pip_version,
|
||||
'pip_list_output': pip_list_output,
|
||||
|
|
@ -138,7 +138,7 @@ def print_autograd_prof_summary(prof, mode, sortby='cpu_time', topk=15):
|
|||
|
||||
result = {
|
||||
'mode': mode,
|
||||
'description': 'top {} events sorted by {}'.format(topk, sortby),
|
||||
'description': f'top {topk} events sorted by {sortby}',
|
||||
'output': torch.autograd.profiler_util._build_table(topk_events),
|
||||
'cuda_warning': cuda_warning
|
||||
}
|
||||
|
|
|
|||
|
|
@ -261,11 +261,11 @@ def augment_many_model_functions_with_bundled_inputs(
|
|||
|
||||
|
||||
if input_list is not None and not isinstance(input_list, Sequence):
|
||||
raise TypeError("Error inputs for function {0} is not a Sequence".format(function_name))
|
||||
raise TypeError(f"Error inputs for function {function_name} is not a Sequence")
|
||||
|
||||
function_arg_types = [arg.type for arg in function.schema.arguments[1:]] # type: ignore[attr-defined]
|
||||
deflated_inputs_type: ListType = ListType(TupleType(function_arg_types))
|
||||
model._c._register_attribute("_bundled_inputs_deflated_{name}".format(name=function_name), deflated_inputs_type, [])
|
||||
model._c._register_attribute(f"_bundled_inputs_deflated_{function_name}", deflated_inputs_type, [])
|
||||
|
||||
if hasattr(model, "_generate_bundled_inputs_for_" + function_name):
|
||||
if input_list is not None:
|
||||
|
|
@ -290,7 +290,7 @@ def augment_many_model_functions_with_bundled_inputs(
|
|||
for inp_idx, args in enumerate(input_list):
|
||||
if not isinstance(args, Tuple) and not isinstance(args, List): # type: ignore[arg-type]
|
||||
raise TypeError(
|
||||
"Error bundled input for function {0} idx: {1} is not a Tuple or a List".format(function_name, inp_idx)
|
||||
f"Error bundled input for function {function_name} idx: {inp_idx} is not a Tuple or a List"
|
||||
)
|
||||
deflated_args = []
|
||||
parts.append("(")
|
||||
|
|
@ -314,7 +314,7 @@ def augment_many_model_functions_with_bundled_inputs(
|
|||
# Back-channel return this expr for debugging.
|
||||
if _receive_inflate_expr is not None:
|
||||
_receive_inflate_expr.append(expr)
|
||||
setattr(model, "_bundled_inputs_deflated_{name}".format(name=function_name), deflated_inputs)
|
||||
setattr(model, f"_bundled_inputs_deflated_{function_name}", deflated_inputs)
|
||||
definition = textwrap.dedent("""
|
||||
def _generate_bundled_inputs_for_{name}(self):
|
||||
deflated = self._bundled_inputs_deflated_{name}
|
||||
|
|
|
|||
|
|
@ -66,7 +66,7 @@ def _get_device_module(device="cuda"):
|
|||
return device_module
|
||||
|
||||
|
||||
class DefaultDeviceType(object):
|
||||
class DefaultDeviceType:
|
||||
r"""
|
||||
A class that manages the default device type for checkpointing.
|
||||
If no non-CPU tensors are present, the default device type will
|
||||
|
|
|
|||
|
|
@ -150,11 +150,11 @@ def _join_rocm_home(*paths) -> str:
|
|||
only once we need to get any ROCm-specific path.
|
||||
'''
|
||||
if ROCM_HOME is None:
|
||||
raise EnvironmentError('ROCM_HOME environment variable is not set. '
|
||||
'Please set it to your ROCm install root.')
|
||||
raise OSError('ROCM_HOME environment variable is not set. '
|
||||
'Please set it to your ROCm install root.')
|
||||
elif IS_WINDOWS:
|
||||
raise EnvironmentError('Building PyTorch extensions using '
|
||||
'ROCm and Windows is not supported.')
|
||||
raise OSError('Building PyTorch extensions using '
|
||||
'ROCm and Windows is not supported.')
|
||||
return os.path.join(ROCM_HOME, *paths)
|
||||
|
||||
|
||||
|
|
@ -264,7 +264,7 @@ def _maybe_write(filename, new_content):
|
|||
if it already had the right content (to avoid triggering recompile).
|
||||
'''
|
||||
if os.path.exists(filename):
|
||||
with open(filename, 'r') as f:
|
||||
with open(filename) as f:
|
||||
content = f.read()
|
||||
|
||||
if content == new_content:
|
||||
|
|
@ -2247,8 +2247,8 @@ def _join_cuda_home(*paths) -> str:
|
|||
only once we need to get any CUDA-specific path.
|
||||
'''
|
||||
if CUDA_HOME is None:
|
||||
raise EnvironmentError('CUDA_HOME environment variable is not set. '
|
||||
'Please set it to your CUDA install root.')
|
||||
raise OSError('CUDA_HOME environment variable is not set. '
|
||||
'Please set it to your CUDA install root.')
|
||||
return os.path.join(CUDA_HOME, *paths)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -37,7 +37,7 @@ def _pin_memory_loop(in_queue, out_queue, device_id, done_event, device):
|
|||
data = pin_memory(data, device)
|
||||
except Exception:
|
||||
data = ExceptionWrapper(
|
||||
where="in pin memory thread for device {}".format(device_id))
|
||||
where=f"in pin memory thread for device {device_id}")
|
||||
r = (idx, data)
|
||||
while not done_event.is_set():
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -76,13 +76,13 @@ class WorkerInfo:
|
|||
|
||||
def __setattr__(self, key, val):
|
||||
if self.__initialized:
|
||||
raise RuntimeError("Cannot assign attributes to {} objects".format(self.__class__.__name__))
|
||||
raise RuntimeError(f"Cannot assign attributes to {self.__class__.__name__} objects")
|
||||
return super().__setattr__(key, val)
|
||||
|
||||
def __repr__(self):
|
||||
items = []
|
||||
for k in self.__keys:
|
||||
items.append('{}={}'.format(k, getattr(self, k)))
|
||||
items.append(f'{k}={getattr(self, k)}')
|
||||
return '{}({})'.format(self.__class__.__name__, ', '.join(items))
|
||||
|
||||
|
||||
|
|
@ -252,7 +252,7 @@ def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
|
|||
fetcher = _DatasetKind.create_fetcher(dataset_kind, dataset, auto_collation, collate_fn, drop_last)
|
||||
except Exception:
|
||||
init_exception = ExceptionWrapper(
|
||||
where="in DataLoader worker process {}".format(worker_id))
|
||||
where=f"in DataLoader worker process {worker_id}")
|
||||
|
||||
# When using Iterable mode, some worker can exit earlier than others due
|
||||
# to the IterableDataset behaving differently for different workers.
|
||||
|
|
@ -318,7 +318,7 @@ def _worker_loop(dataset_kind, dataset, index_queue, data_queue, done_event,
|
|||
# `ExceptionWrapper` does the correct thing.
|
||||
# See NOTE [ Python Traceback Reference Cycle Problem ]
|
||||
data = ExceptionWrapper(
|
||||
where="in DataLoader worker process {}".format(worker_id))
|
||||
where=f"in DataLoader worker process {worker_id}")
|
||||
data_queue.put((idx, data))
|
||||
del data, idx, index, r # save memory
|
||||
except KeyboardInterrupt:
|
||||
|
|
|
|||
|
|
@ -604,7 +604,7 @@ class _BaseDataLoaderIter:
|
|||
self._base_seed = torch.empty((), dtype=torch.int64).random_(generator=loader.generator).item()
|
||||
self._persistent_workers = loader.persistent_workers
|
||||
self._num_yielded = 0
|
||||
self._profile_name = "enumerate(DataLoader)#{}.__next__".format(self.__class__.__name__)
|
||||
self._profile_name = f"enumerate(DataLoader)#{self.__class__.__name__}.__next__"
|
||||
|
||||
def __iter__(self) -> '_BaseDataLoaderIter':
|
||||
return self
|
||||
|
|
@ -1145,7 +1145,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
|||
self._mark_worker_as_unavailable(worker_id)
|
||||
if len(failed_workers) > 0:
|
||||
pids_str = ', '.join(str(w.pid) for w in failed_workers)
|
||||
raise RuntimeError('DataLoader worker (pid(s) {}) exited unexpectedly'.format(pids_str)) from e
|
||||
raise RuntimeError(f'DataLoader worker (pid(s) {pids_str}) exited unexpectedly') from e
|
||||
if isinstance(e, queue.Empty):
|
||||
return (False, None)
|
||||
import tempfile
|
||||
|
|
@ -1281,7 +1281,7 @@ class _MultiProcessingDataLoaderIter(_BaseDataLoaderIter):
|
|||
if success:
|
||||
return data
|
||||
else:
|
||||
raise RuntimeError('DataLoader timed out after {} seconds'.format(self._timeout))
|
||||
raise RuntimeError(f'DataLoader timed out after {self._timeout} seconds')
|
||||
elif self._pin_memory:
|
||||
while self._pin_memory_thread.is_alive():
|
||||
success, data = self._try_get_data()
|
||||
|
|
|
|||
|
|
@ -80,7 +80,7 @@ class non_deterministic:
|
|||
elif isinstance(arg, Callable): # type:ignore[arg-type]
|
||||
self.deterministic_fn = arg # type: ignore[assignment, misc]
|
||||
else:
|
||||
raise TypeError("{} can not be decorated by non_deterministic".format(arg))
|
||||
raise TypeError(f"{arg} can not be decorated by non_deterministic")
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
global _determinism
|
||||
|
|
|
|||
|
|
@ -234,7 +234,7 @@ class _DataPipeType:
|
|||
return issubtype(self.param, other.param)
|
||||
if isinstance(other, type):
|
||||
return issubtype(self.param, other)
|
||||
raise TypeError("Expected '_DataPipeType' or 'type', but found {}".format(type(other)))
|
||||
raise TypeError(f"Expected '_DataPipeType' or 'type', but found {type(other)}")
|
||||
|
||||
def issubtype_of_instance(self, other):
|
||||
return issubinstance(other, self.param)
|
||||
|
|
@ -279,13 +279,13 @@ class _DataPipeMeta(GenericMeta):
|
|||
@_tp_cache
|
||||
def _getitem_(self, params):
|
||||
if params is None:
|
||||
raise TypeError('{}[t]: t can not be None'.format(self.__name__))
|
||||
raise TypeError(f'{self.__name__}[t]: t can not be None')
|
||||
if isinstance(params, str):
|
||||
params = ForwardRef(params)
|
||||
if not isinstance(params, tuple):
|
||||
params = (params, )
|
||||
|
||||
msg = "{}[t]: t must be a type".format(self.__name__)
|
||||
msg = f"{self.__name__}[t]: t must be a type"
|
||||
params = tuple(_type_check(p, msg) for p in params)
|
||||
|
||||
if isinstance(self.type.param, _GenericAlias):
|
||||
|
|
@ -303,7 +303,7 @@ class _DataPipeMeta(GenericMeta):
|
|||
'__type_class__': True})
|
||||
|
||||
if len(params) > 1:
|
||||
raise TypeError('Too many parameters for {} actual {}, expected 1'.format(self, len(params)))
|
||||
raise TypeError(f'Too many parameters for {self} actual {len(params)}, expected 1')
|
||||
|
||||
t = _DataPipeType(params[0])
|
||||
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ def disable_capture():
|
|||
CaptureControl.disabled = True
|
||||
|
||||
|
||||
class CaptureControl():
|
||||
class CaptureControl:
|
||||
disabled = False
|
||||
|
||||
|
||||
|
|
@ -184,7 +184,7 @@ class CaptureA(CaptureF):
|
|||
return value
|
||||
|
||||
|
||||
class CaptureLikeMock():
|
||||
class CaptureLikeMock:
|
||||
def __init__(self, name):
|
||||
import unittest.mock as mock
|
||||
# TODO(VitalyFedyunin): Do not use provate function here, copy own implementation instead.
|
||||
|
|
@ -232,7 +232,7 @@ class CaptureVariableAssign(CaptureF):
|
|||
def __str__(self):
|
||||
variable = self.kwargs['variable']
|
||||
value = self.kwargs['value']
|
||||
return "{variable} = {value}".format(variable=variable, value=value)
|
||||
return f"{variable} = {value}"
|
||||
|
||||
def execute(self):
|
||||
self.kwargs['variable'].calculated_value = self.kwargs['value'].execute()
|
||||
|
|
@ -272,7 +272,7 @@ class CaptureGetItem(Capture):
|
|||
self.key = key
|
||||
|
||||
def __str__(self):
|
||||
return "%s[%s]" % (self.left, get_val(self.key))
|
||||
return f"{self.left}[{get_val(self.key)}]"
|
||||
|
||||
def execute(self):
|
||||
left = self.left.execute()
|
||||
|
|
@ -287,7 +287,7 @@ class CaptureSetItem(Capture):
|
|||
self.value = value
|
||||
|
||||
def __str__(self):
|
||||
return "%s[%s] = %s" % (self.left, get_val(self.key), self.value)
|
||||
return f"{self.left}[{get_val(self.key)}] = {self.value}"
|
||||
|
||||
def execute(self):
|
||||
left = self.left.execute()
|
||||
|
|
@ -302,7 +302,7 @@ class CaptureAdd(Capture):
|
|||
self.right = right
|
||||
|
||||
def __str__(self):
|
||||
return "%s + %s" % (self.left, self.right)
|
||||
return f"{self.left} + {self.right}"
|
||||
|
||||
def execute(self):
|
||||
return get_val(self.left) + get_val(self.right)
|
||||
|
|
@ -315,7 +315,7 @@ class CaptureMul(Capture):
|
|||
self.right = right
|
||||
|
||||
def __str__(self):
|
||||
return "%s * %s" % (self.left, self.right)
|
||||
return f"{self.left} * {self.right}"
|
||||
|
||||
def execute(self):
|
||||
return get_val(self.left) * get_val(self.right)
|
||||
|
|
@ -328,7 +328,7 @@ class CaptureSub(Capture):
|
|||
self.right = right
|
||||
|
||||
def __str__(self):
|
||||
return "%s - %s" % (self.left, self.right)
|
||||
return f"{self.left} - {self.right}"
|
||||
|
||||
def execute(self):
|
||||
return get_val(self.left) - get_val(self.right)
|
||||
|
|
@ -341,7 +341,7 @@ class CaptureGetAttr(Capture):
|
|||
self.name = name
|
||||
|
||||
def __str__(self):
|
||||
return "%s.%s" % (self.src, self.name)
|
||||
return f"{self.src}.{self.name}"
|
||||
|
||||
def execute(self):
|
||||
val = get_val(self.src)
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ class IterDataPipe(IterableDataset[T_co], metaclass=_IterDataPipeMeta):
|
|||
functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",))
|
||||
return function
|
||||
else:
|
||||
raise AttributeError("'{0}' object has no attribute '{1}".format(self.__class__.__name__, attribute_name))
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attribute_name}")
|
||||
|
||||
@classmethod
|
||||
def register_function(cls, function_name, function):
|
||||
|
|
@ -135,7 +135,7 @@ class IterDataPipe(IterableDataset[T_co], metaclass=_IterDataPipeMeta):
|
|||
@classmethod
|
||||
def register_datapipe_as_function(cls, function_name, cls_to_register, enable_df_api_tracing=False):
|
||||
if function_name in cls.functions:
|
||||
raise Exception("Unable to add DataPipe function name {} as it is already taken".format(function_name))
|
||||
raise Exception(f"Unable to add DataPipe function name {function_name} as it is already taken")
|
||||
|
||||
def class_function(cls, enable_df_api_tracing, source_dp, *args, **kwargs):
|
||||
result_pipe = cls(source_dp, *args, **kwargs)
|
||||
|
|
@ -265,7 +265,7 @@ class MapDataPipe(Dataset[T_co], metaclass=_DataPipeMeta):
|
|||
functools.update_wrapper(wrapper=function, wrapped=f, assigned=("__doc__",))
|
||||
return function
|
||||
else:
|
||||
raise AttributeError("'{0}' object has no attribute '{1}".format(self.__class__.__name__, attribute_name))
|
||||
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{attribute_name}")
|
||||
|
||||
@classmethod
|
||||
def register_function(cls, function_name, function):
|
||||
|
|
@ -274,7 +274,7 @@ class MapDataPipe(Dataset[T_co], metaclass=_DataPipeMeta):
|
|||
@classmethod
|
||||
def register_datapipe_as_function(cls, function_name, cls_to_register):
|
||||
if function_name in cls.functions:
|
||||
raise Exception("Unable to add DataPipe function name {} as it is already taken".format(function_name))
|
||||
raise Exception(f"Unable to add DataPipe function name {function_name} as it is already taken")
|
||||
|
||||
def class_function(cls, source_dp, *args, **kwargs):
|
||||
result_pipe = cls(source_dp, *args, **kwargs)
|
||||
|
|
@ -363,7 +363,7 @@ class _DataPipeSerializationWrapper:
|
|||
return len(self._datapipe)
|
||||
except Exception as e:
|
||||
raise TypeError(
|
||||
"{} instance doesn't have valid length".format(type(self).__name__)
|
||||
f"{type(self).__name__} instance doesn't have valid length"
|
||||
) from e
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ def gen_from_template(dir: str, template_name: str, output_name: str, replacemen
|
|||
template_path = os.path.join(dir, template_name)
|
||||
output_path = os.path.join(dir, output_name)
|
||||
|
||||
with open(template_path, "r") as f:
|
||||
with open(template_path) as f:
|
||||
content = f.read()
|
||||
for placeholder, lines, indentation in replacements:
|
||||
with open(output_path, "w") as f:
|
||||
|
|
|
|||
|
|
@ -126,7 +126,7 @@ class MapperIterDataPipe(IterDataPipe[T_co]):
|
|||
if isinstance(self.datapipe, Sized):
|
||||
return len(self.datapipe)
|
||||
raise TypeError(
|
||||
"{} instance doesn't have valid length".format(type(self).__name__)
|
||||
f"{type(self).__name__} instance doesn't have valid length"
|
||||
)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ class SamplerIterDataPipe(IterDataPipe[T_co]):
|
|||
# Dataset has been tested as `Sized`
|
||||
if isinstance(self.sampler, Sized):
|
||||
return len(self.sampler)
|
||||
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
|
||||
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
|
||||
|
||||
|
||||
@functional_datapipe('shuffle')
|
||||
|
|
@ -137,7 +137,7 @@ class ShufflerIterDataPipe(IterDataPipe[T_co]):
|
|||
def __len__(self) -> int:
|
||||
if isinstance(self.datapipe, Sized):
|
||||
return len(self.datapipe)
|
||||
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
|
||||
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
|
||||
|
||||
def reset(self) -> None:
|
||||
self._buffer = []
|
||||
|
|
|
|||
|
|
@ -56,7 +56,7 @@ class ConcaterIterDataPipe(IterDataPipe):
|
|||
if all(isinstance(dp, Sized) for dp in self.datapipes):
|
||||
return sum(len(dp) for dp in self.datapipes)
|
||||
else:
|
||||
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
|
||||
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
|
||||
|
||||
|
||||
@functional_datapipe('fork')
|
||||
|
|
@ -567,7 +567,7 @@ class MultiplexerIterDataPipe(IterDataPipe):
|
|||
if all(isinstance(dp, Sized) for dp in self.datapipes):
|
||||
return min(len(dp) for dp in self.datapipes) * len(self.datapipes)
|
||||
else:
|
||||
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
|
||||
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
|
||||
|
||||
def reset(self) -> None:
|
||||
self.buffer = []
|
||||
|
|
@ -627,4 +627,4 @@ class ZipperIterDataPipe(IterDataPipe[Tuple[T_co]]):
|
|||
if all(isinstance(dp, Sized) for dp in self.datapipes):
|
||||
return min(len(dp) for dp in self.datapipes)
|
||||
else:
|
||||
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
|
||||
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
|
||||
|
|
|
|||
|
|
@ -61,5 +61,5 @@ class FileListerIterDataPipe(IterDataPipe[str]):
|
|||
|
||||
def __len__(self):
|
||||
if self.length == -1:
|
||||
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
|
||||
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
|
||||
return self.length
|
||||
|
|
|
|||
|
|
@ -51,7 +51,7 @@ class FileOpenerIterDataPipe(IterDataPipe[Tuple[str, IOBase]]):
|
|||
self.encoding: Optional[str] = encoding
|
||||
|
||||
if self.mode not in ('b', 't', 'rb', 'rt', 'r'):
|
||||
raise ValueError("Invalid mode {}".format(mode))
|
||||
raise ValueError(f"Invalid mode {mode}")
|
||||
# TODO: enforce typing for each instance based on mode, otherwise
|
||||
# `argument_validation` with this DataPipe may be potentially broken
|
||||
|
||||
|
|
@ -68,5 +68,5 @@ class FileOpenerIterDataPipe(IterDataPipe[Tuple[str, IOBase]]):
|
|||
|
||||
def __len__(self):
|
||||
if self.length == -1:
|
||||
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
|
||||
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
|
||||
return self.length
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class BatcherIterDataPipe(IterDataPipe[DataChunk]):
|
|||
else:
|
||||
return (len(self.datapipe) + self.batch_size - 1) // self.batch_size
|
||||
else:
|
||||
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
|
||||
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
|
||||
|
||||
|
||||
@functional_datapipe('unbatch')
|
||||
|
|
|
|||
|
|
@ -62,4 +62,4 @@ class RoutedDecoderIterDataPipe(IterDataPipe[Tuple[str, Any]]):
|
|||
def __len__(self) -> int:
|
||||
if isinstance(self.datapipe, Sized):
|
||||
return len(self.datapipe)
|
||||
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
|
||||
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
|
||||
|
|
|
|||
|
|
@ -80,4 +80,4 @@ class ShardingFilterIterDataPipe(_ShardingIterDataPipe):
|
|||
if isinstance(self.source_datapipe, Sized):
|
||||
return len(self.source_datapipe) // self.num_of_instances +\
|
||||
(1 if (self.instance_id < len(self.source_datapipe) % self.num_of_instances) else 0)
|
||||
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
|
||||
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ class ConcaterMapDataPipe(MapDataPipe):
|
|||
return dp[index - offset]
|
||||
else:
|
||||
offset += len(dp)
|
||||
raise IndexError("Index {} is out of range.".format(index))
|
||||
raise IndexError(f"Index {index} is out of range.")
|
||||
|
||||
def __len__(self) -> int:
|
||||
return sum(len(dp) for dp in self.datapipes)
|
||||
|
|
|
|||
|
|
@ -64,4 +64,4 @@ class BatcherMapDataPipe(MapDataPipe[DataChunk]):
|
|||
else:
|
||||
return (len(self.datapipe) + self.batch_size - 1) // self.batch_size
|
||||
else:
|
||||
raise TypeError("{} instance doesn't have valid length".format(type(self).__name__))
|
||||
raise TypeError(f"{type(self).__name__} instance doesn't have valid length")
|
||||
|
|
|
|||
|
|
@ -305,7 +305,7 @@ class StreamWrapper:
|
|||
self.closed = False
|
||||
if parent_stream is not None:
|
||||
if not isinstance(parent_stream, StreamWrapper):
|
||||
raise RuntimeError('Parent stream should be StreamWrapper, {} was given'.format(type(parent_stream)))
|
||||
raise RuntimeError(f'Parent stream should be StreamWrapper, {type(parent_stream)} was given')
|
||||
parent_stream.child_counter += 1
|
||||
self.parent_stream = parent_stream
|
||||
if StreamWrapper.debug_unclosed_streams:
|
||||
|
|
|
|||
|
|
@ -137,7 +137,7 @@ class ImageHandler:
|
|||
- pilrgba: pil None rgba
|
||||
"""
|
||||
def __init__(self, imagespec):
|
||||
assert imagespec in list(imagespecs.keys()), "unknown image specification: {}".format(imagespec)
|
||||
assert imagespec in list(imagespecs.keys()), f"unknown image specification: {imagespec}"
|
||||
self.imagespec = imagespec.lower()
|
||||
|
||||
def __call__(self, extension, data):
|
||||
|
|
@ -167,14 +167,14 @@ class ImageHandler:
|
|||
return img
|
||||
elif atype == "numpy":
|
||||
result = np.asarray(img)
|
||||
assert result.dtype == np.uint8, "numpy image array should be type uint8, but got {}".format(result.dtype)
|
||||
assert result.dtype == np.uint8, f"numpy image array should be type uint8, but got {result.dtype}"
|
||||
if etype == "uint8":
|
||||
return result
|
||||
else:
|
||||
return result.astype("f") / 255.0
|
||||
elif atype == "torch":
|
||||
result = np.asarray(img)
|
||||
assert result.dtype == np.uint8, "numpy image array should be type uint8, but got {}".format(result.dtype)
|
||||
assert result.dtype == np.uint8, f"numpy image array should be type uint8, but got {result.dtype}"
|
||||
|
||||
if etype == "uint8":
|
||||
result = np.array(result.transpose(2, 0, 1))
|
||||
|
|
|
|||
|
|
@ -130,7 +130,7 @@ def traverse(datapipe: DataPipe, only_datapipe: Optional[bool] = None) -> DataPi
|
|||
# Add cache here to prevent infinite recursion on DataPipe
|
||||
def _traverse_helper(datapipe: DataPipe, only_datapipe: bool, cache: Set[int]) -> DataPipeGraph:
|
||||
if not isinstance(datapipe, (IterDataPipe, MapDataPipe)):
|
||||
raise RuntimeError("Expected `IterDataPipe` or `MapDataPipe`, but {} is found".format(type(datapipe)))
|
||||
raise RuntimeError(f"Expected `IterDataPipe` or `MapDataPipe`, but {type(datapipe)} is found")
|
||||
|
||||
dp_id = id(datapipe)
|
||||
if dp_id in cache:
|
||||
|
|
|
|||
|
|
@ -102,7 +102,7 @@ def from_dlpack(ext_tensor: Any) -> 'torch.Tensor':
|
|||
# device is either CUDA or ROCm, we need to pass the current
|
||||
# stream
|
||||
if device[0] in (DLDeviceType.kDLGPU, DLDeviceType.kDLROCM):
|
||||
stream = torch.cuda.current_stream('cuda:{}'.format(device[1]))
|
||||
stream = torch.cuda.current_stream(f'cuda:{device[1]}')
|
||||
# cuda_stream is the pointer to the stream and it is a public
|
||||
# attribute, but it is not documented
|
||||
# The array API specify that the default legacy stream must be passed
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ if os.path.isfile(rocm_version_h):
|
|||
RE_MINOR = re.compile(r"#define\s+ROCM_VERSION_MINOR\s+(\d+)")
|
||||
RE_PATCH = re.compile(r"#define\s+ROCM_VERSION_PATCH\s+(\d+)")
|
||||
major, minor, patch = 0, 0, 0
|
||||
for line in open(rocm_version_h, "r"):
|
||||
for line in open(rocm_version_h):
|
||||
match = RE_MAJOR.search(line)
|
||||
if match:
|
||||
major = int(match.group(1))
|
||||
|
|
|
|||
|
|
@ -219,13 +219,13 @@ def compute_stats(stats):
|
|||
unsupported_calls = {cuda_call for (cuda_call, _filepath) in stats["unsupported_calls"]}
|
||||
|
||||
# Print the number of unsupported calls
|
||||
print("Total number of unsupported CUDA function calls: {0:d}".format(len(unsupported_calls)))
|
||||
print(f"Total number of unsupported CUDA function calls: {len(unsupported_calls):d}")
|
||||
|
||||
# Print the list of unsupported calls
|
||||
print(", ".join(unsupported_calls))
|
||||
|
||||
# Print the number of kernel launches
|
||||
print("\nTotal number of replaced kernel launches: {0:d}".format(len(stats["kernel_launches"])))
|
||||
print("\nTotal number of replaced kernel launches: {:d}".format(len(stats["kernel_launches"])))
|
||||
|
||||
|
||||
def add_dim3(kernel_string, cuda_kernel):
|
||||
|
|
@ -254,8 +254,8 @@ def add_dim3(kernel_string, cuda_kernel):
|
|||
first_arg_clean = kernel_string[arg_locs[0]['start']:arg_locs[0]['end']].replace("\n", "").strip(" ")
|
||||
second_arg_clean = kernel_string[arg_locs[1]['start']:arg_locs[1]['end']].replace("\n", "").strip(" ")
|
||||
|
||||
first_arg_dim3 = "dim3({})".format(first_arg_clean)
|
||||
second_arg_dim3 = "dim3({})".format(second_arg_clean)
|
||||
first_arg_dim3 = f"dim3({first_arg_clean})"
|
||||
second_arg_dim3 = f"dim3({second_arg_clean})"
|
||||
|
||||
first_arg_raw_dim3 = first_arg_raw.replace(first_arg_clean, first_arg_dim3)
|
||||
second_arg_raw_dim3 = second_arg_raw.replace(second_arg_clean, second_arg_dim3)
|
||||
|
|
@ -269,7 +269,7 @@ RE_KERNEL_LAUNCH = re.compile(r'([ ]+)(detail?)::[ ]+\\\n[ ]+')
|
|||
def processKernelLaunches(string, stats):
|
||||
""" Replace the CUDA style Kernel launches with the HIP style kernel launches."""
|
||||
# Concat the namespace with the kernel names. (Find cleaner way of doing this later).
|
||||
string = RE_KERNEL_LAUNCH.sub(lambda inp: "{0}{1}::".format(inp.group(1), inp.group(2)), string)
|
||||
string = RE_KERNEL_LAUNCH.sub(lambda inp: f"{inp.group(1)}{inp.group(2)}::", string)
|
||||
|
||||
def grab_method_and_template(in_kernel):
|
||||
# The positions for relevant kernel components.
|
||||
|
|
@ -482,7 +482,7 @@ def replace_math_functions(input_string):
|
|||
"""
|
||||
output_string = input_string
|
||||
for func in MATH_TRANSPILATIONS:
|
||||
output_string = output_string.replace(r'{}('.format(func), '{}('.format(MATH_TRANSPILATIONS[func]))
|
||||
output_string = output_string.replace(fr'{func}(', f'{MATH_TRANSPILATIONS[func]}(')
|
||||
|
||||
return output_string
|
||||
|
||||
|
|
@ -531,7 +531,7 @@ def replace_extern_shared(input_string):
|
|||
"""
|
||||
output_string = input_string
|
||||
output_string = RE_EXTERN_SHARED.sub(
|
||||
lambda inp: "HIP_DYNAMIC_SHARED({0} {1}, {2})".format(
|
||||
lambda inp: "HIP_DYNAMIC_SHARED({} {}, {})".format(
|
||||
inp.group(1) or "", inp.group(2), inp.group(3)), output_string)
|
||||
|
||||
return output_string
|
||||
|
|
@ -657,7 +657,7 @@ def is_caffe2_gpu_file(rel_filepath):
|
|||
|
||||
|
||||
# Cribbed from https://stackoverflow.com/questions/42742810/speed-up-millions-of-regex-replacements-in-python-3/42789508#42789508
|
||||
class Trie():
|
||||
class Trie:
|
||||
"""Regex::Trie in Python. Creates a Trie out of a list of words. The trie can be exported to a Regex pattern.
|
||||
The corresponding Regex should match much faster than a simple Regex union."""
|
||||
|
||||
|
|
@ -750,7 +750,7 @@ for mapping in CUDA_TO_HIP_MAPPINGS:
|
|||
CAFFE2_TRIE.add(src)
|
||||
CAFFE2_MAP[src] = dst
|
||||
RE_CAFFE2_PREPROCESSOR = re.compile(CAFFE2_TRIE.pattern())
|
||||
RE_PYTORCH_PREPROCESSOR = re.compile(r'(?<=\W)({0})(?=\W)'.format(PYTORCH_TRIE.pattern()))
|
||||
RE_PYTORCH_PREPROCESSOR = re.compile(fr'(?<=\W)({PYTORCH_TRIE.pattern()})(?=\W)')
|
||||
|
||||
RE_QUOTE_HEADER = re.compile(r'#include "([^"]+)"')
|
||||
RE_ANGLE_HEADER = re.compile(r'#include <([^>]+)>')
|
||||
|
|
@ -789,7 +789,7 @@ def preprocessor(
|
|||
|
||||
rel_filepath = os.path.relpath(filepath, output_directory)
|
||||
|
||||
with open(fin_path, 'r', encoding='utf-8') as fin:
|
||||
with open(fin_path, encoding='utf-8') as fin:
|
||||
if fin.readline() == HIPIFY_C_BREADCRUMB:
|
||||
hipify_result.hipified_path = None
|
||||
hipify_result.status = "[ignored, input is hipified output]"
|
||||
|
|
@ -929,7 +929,7 @@ def preprocessor(
|
|||
|
||||
do_write = True
|
||||
if os.path.exists(fout_path):
|
||||
with open(fout_path, 'r', encoding='utf-8') as fout_old:
|
||||
with open(fout_path, encoding='utf-8') as fout_old:
|
||||
do_write = fout_old.read() != output_source
|
||||
if do_write:
|
||||
try:
|
||||
|
|
@ -956,7 +956,7 @@ def file_specific_replacement(filepath, search_string, replace_string, strict=Fa
|
|||
with openf(filepath, "r+") as f:
|
||||
contents = f.read()
|
||||
if strict:
|
||||
contents = re.sub(r'\b({0})\b'.format(re.escape(search_string)), lambda x: replace_string, contents)
|
||||
contents = re.sub(fr'\b({re.escape(search_string)})\b', lambda x: replace_string, contents)
|
||||
else:
|
||||
contents = contents.replace(search_string, replace_string)
|
||||
f.seek(0)
|
||||
|
|
@ -968,8 +968,8 @@ def file_add_header(filepath, header):
|
|||
with openf(filepath, "r+") as f:
|
||||
contents = f.read()
|
||||
if header[0] != "<" and header[-1] != ">":
|
||||
header = '"{0}"'.format(header)
|
||||
contents = ('#include {0} \n'.format(header)) + contents
|
||||
header = f'"{header}"'
|
||||
contents = (f'#include {header} \n') + contents
|
||||
f.seek(0)
|
||||
f.write(contents)
|
||||
f.truncate()
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ def extract_ir(filename: str) -> List[str]:
|
|||
pfx = None
|
||||
current = ""
|
||||
graphs = []
|
||||
with open(filename, "r") as f:
|
||||
with open(filename) as f:
|
||||
split_strs = f.read().split(BEGIN)
|
||||
for i, split_str in enumerate(split_strs):
|
||||
if i == 0:
|
||||
|
|
|
|||
|
|
@ -31,7 +31,7 @@ def optimize_for_mobile(
|
|||
"""
|
||||
if not isinstance(script_module, torch.jit.ScriptModule):
|
||||
raise TypeError(
|
||||
'Got {}, but ScriptModule is expected.'.format(type(script_module)))
|
||||
f'Got {type(script_module)}, but ScriptModule is expected.')
|
||||
|
||||
if optimization_blocklist is None:
|
||||
optimization_blocklist = set()
|
||||
|
|
@ -86,7 +86,7 @@ def generate_mobile_module_lints(script_module: torch.jit.ScriptModule):
|
|||
"""
|
||||
if not isinstance(script_module, torch.jit.ScriptModule):
|
||||
raise TypeError(
|
||||
'Got {}, but ScriptModule is expected.'.format(type(script_module)))
|
||||
f'Got {type(script_module)}, but ScriptModule is expected.')
|
||||
|
||||
lint_list = []
|
||||
|
||||
|
|
|
|||
|
|
@ -232,7 +232,7 @@ def _add_gradient_scope(shapes, blob_name_tracker, ops):
|
|||
|
||||
def f(name):
|
||||
if "_grad" in name:
|
||||
return "GRADIENTS/{}".format(name)
|
||||
return f"GRADIENTS/{name}"
|
||||
else:
|
||||
return name
|
||||
|
||||
|
|
@ -317,7 +317,7 @@ def _tf_device(device_option):
|
|||
):
|
||||
return "/cpu:*"
|
||||
if device_option.device_type == caffe2_pb2.CUDA:
|
||||
return "/gpu:{}".format(device_option.device_id)
|
||||
return f"/gpu:{device_option.device_id}"
|
||||
raise Exception("Unhandled device", device_option)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -62,7 +62,7 @@ def make_sprite(label_img, save_path):
|
|||
|
||||
def get_embedding_info(metadata, label_img, subdir, global_step, tag):
|
||||
info = EmbeddingInfo()
|
||||
info.tensor_name = "{}:{}".format(tag, str(global_step).zfill(5))
|
||||
info.tensor_name = f"{tag}:{str(global_step).zfill(5)}"
|
||||
info.tensor_path = _gfile_join(subdir, "tensors.tsv")
|
||||
if metadata is not None:
|
||||
info.metadata_path = _gfile_join(subdir, "metadata.tsv")
|
||||
|
|
|
|||
|
|
@ -275,7 +275,7 @@ def parse(graph, trace, args=None, omit_useless_nodes=True):
|
|||
parent_scope, attr_scope, attr_name
|
||||
)
|
||||
else:
|
||||
attr_to_scope[attr_key] = "__module.{}".format(attr_name)
|
||||
attr_to_scope[attr_key] = f"__module.{attr_name}"
|
||||
# We don't need classtype nodes; scope will provide this information
|
||||
if node.output().type().kind() != CLASSTYPE_KIND:
|
||||
node_py = NodePyOP(node)
|
||||
|
|
@ -286,7 +286,7 @@ def parse(graph, trace, args=None, omit_useless_nodes=True):
|
|||
|
||||
for i, node in enumerate(graph.outputs()): # Create sink nodes for output ops
|
||||
node_pyio = NodePyIO(node, "output")
|
||||
node_pyio.debugName = "output.{}".format(i + 1)
|
||||
node_pyio.debugName = f"output.{i + 1}"
|
||||
node_pyio.inputs = [node.debugName()]
|
||||
nodes_py.append(node_pyio)
|
||||
|
||||
|
|
@ -302,7 +302,7 @@ def parse(graph, trace, args=None, omit_useless_nodes=True):
|
|||
for name, module in trace.named_modules(prefix="__module"):
|
||||
mod_name = parse_traced_name(module)
|
||||
attr_name = name.split(".")[-1]
|
||||
alias_to_name[name] = "{}[{}]".format(mod_name, attr_name)
|
||||
alias_to_name[name] = f"{mod_name}[{attr_name}]"
|
||||
|
||||
for node in nodes_py.nodes_op:
|
||||
module_aliases = node.scopeName.split("/")
|
||||
|
|
|
|||
|
|
@ -953,7 +953,7 @@ class SummaryWriter:
|
|||
|
||||
# Maybe we should encode the tag so slashes don't trip us up?
|
||||
# I don't think this will mess us up, but better safe than sorry.
|
||||
subdir = "%s/%s" % (str(global_step).zfill(5), self._encode(tag))
|
||||
subdir = f"{str(global_step).zfill(5)}/{self._encode(tag)}"
|
||||
save_path = os.path.join(self._get_file_writer().get_logdir(), subdir)
|
||||
|
||||
fs = tf.io.gfile
|
||||
|
|
|
|||
|
|
@ -18,10 +18,10 @@ def format_time(time_us=None, time_ms=None, time_s=None):
|
|||
raise AssertionError("Shouldn't reach here :)")
|
||||
|
||||
if time_us >= US_IN_SECOND:
|
||||
return '{:.3f}s'.format(time_us / US_IN_SECOND)
|
||||
return f'{time_us / US_IN_SECOND:.3f}s'
|
||||
if time_us >= US_IN_MS:
|
||||
return '{:.3f}ms'.format(time_us / US_IN_MS)
|
||||
return '{:.3f}us'.format(time_us)
|
||||
return f'{time_us / US_IN_MS:.3f}ms'
|
||||
return f'{time_us:.3f}us'
|
||||
|
||||
|
||||
class ExecutionStats:
|
||||
|
|
@ -52,8 +52,8 @@ class ExecutionStats:
|
|||
def __str__(self):
|
||||
return '\n'.join([
|
||||
"Average latency per example: " + format_time(time_ms=self.latency_avg_ms),
|
||||
"Total number of iterations: {}".format(self.num_iters),
|
||||
"Total number of iterations per second (across all threads): {:.2f}".format(self.iters_per_second),
|
||||
f"Total number of iterations: {self.num_iters}",
|
||||
f"Total number of iterations per second (across all threads): {self.iters_per_second:.2f}",
|
||||
"Total time: " + format_time(time_s=self.total_time_seconds)
|
||||
])
|
||||
|
||||
|
|
|
|||
|
|
@ -220,29 +220,29 @@ def object_annotation(obj):
|
|||
if isinstance(obj, BASE_TYPES):
|
||||
return repr(obj)
|
||||
if type(obj).__name__ == 'function':
|
||||
return "function\n{}".format(obj.__name__)
|
||||
return f"function\n{obj.__name__}"
|
||||
elif isinstance(obj, types.MethodType):
|
||||
try:
|
||||
func_name = obj.__func__.__qualname__
|
||||
except AttributeError:
|
||||
func_name = "<anonymous>"
|
||||
return "instancemethod\n{}".format(func_name)
|
||||
return f"instancemethod\n{func_name}"
|
||||
elif isinstance(obj, list):
|
||||
return f"[{format_sequence(obj)}]"
|
||||
elif isinstance(obj, tuple):
|
||||
return f"({format_sequence(obj)})"
|
||||
elif isinstance(obj, dict):
|
||||
return "dict[{}]".format(len(obj))
|
||||
return f"dict[{len(obj)}]"
|
||||
elif isinstance(obj, types.ModuleType):
|
||||
return "module\n{}".format(obj.__name__)
|
||||
return f"module\n{obj.__name__}"
|
||||
elif isinstance(obj, type):
|
||||
return "type\n{}".format(obj.__name__)
|
||||
return f"type\n{obj.__name__}"
|
||||
elif isinstance(obj, weakref.ref):
|
||||
referent = obj()
|
||||
if referent is None:
|
||||
return "weakref (dead referent)"
|
||||
else:
|
||||
return "weakref to id 0x{:x}".format(id(referent))
|
||||
return f"weakref to id 0x{id(referent):x}"
|
||||
elif isinstance(obj, types.FrameType):
|
||||
filename = obj.f_code.co_filename
|
||||
if len(filename) > FRAME_FILENAME_LIMIT:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,6 @@ import weakref
|
|||
from weakref import ref
|
||||
from _weakrefset import _IterationGuard # type: ignore[attr-defined]
|
||||
from collections.abc import MutableMapping, Mapping
|
||||
from typing import Dict
|
||||
from torch import Tensor
|
||||
import collections.abc as _collections_abc
|
||||
|
||||
|
|
@ -83,7 +82,7 @@ class WeakIdRef(weakref.ref):
|
|||
|
||||
# This is directly adapted from cpython/Lib/weakref.py
|
||||
class WeakIdKeyDictionary(MutableMapping):
|
||||
data: Dict[WeakIdRef, object]
|
||||
data: dict[WeakIdRef, object]
|
||||
|
||||
def __init__(self, dict=None):
|
||||
self.data = {}
|
||||
|
|
@ -144,7 +143,7 @@ class WeakIdKeyDictionary(MutableMapping):
|
|||
return len(self.data) - len(self._pending_removals)
|
||||
|
||||
def __repr__(self):
|
||||
return "<%s at %#x>" % (self.__class__.__name__, id(self))
|
||||
return f"<{self.__class__.__name__} at {id(self):#x}>"
|
||||
|
||||
def __setitem__(self, key, value):
|
||||
self.data[WeakIdRef(key, self._remove)] = value # CHANGED
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user