mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE]: Apply PERF401 autofixes from ruff (#140980)
* Automatically applies ruff rule 401. Turns loops into equivalent list comprehensions which are faster and do not leak the scope of the loop variables. * list comprehensions not only often have better typing, but are 50+% faster than for loops on overhead. They also preserve length information etc and are better for the interpreter to optimize. * Manually went back and made mypy happy after the change. * Also fixed style lints in files covered by flake8 but not by pyfmt Pull Request resolved: https://github.com/pytorch/pytorch/pull/140980 Approved by: https://github.com/justinchuby, https://github.com/malfet
This commit is contained in:
parent
8d708090c0
commit
12e95aa4ee
|
|
@ -67,8 +67,7 @@ def pretty_print_buckets(buckets: List[Bucket], bucket_bytes_cap: int):
|
||||||
for idx, bucket in enumerate(reversed(buckets)):
|
for idx, bucket in enumerate(reversed(buckets)):
|
||||||
if len(bucket.params) > 0:
|
if len(bucket.params) > 0:
|
||||||
rows.append((idx, bucket.size, bucket.params[0]))
|
rows.append((idx, bucket.size, bucket.params[0]))
|
||||||
for param in bucket.params[1:]:
|
rows.extend((None, None, param) for param in bucket.params[1:])
|
||||||
rows.append((None, None, param))
|
|
||||||
if bucket.opcount_increased_to_capture_external_output > 0:
|
if bucket.opcount_increased_to_capture_external_output > 0:
|
||||||
extended_buckets.append(
|
extended_buckets.append(
|
||||||
(
|
(
|
||||||
|
|
|
||||||
|
|
@ -1140,11 +1140,12 @@ def update_offsets(instructions) -> None:
|
||||||
|
|
||||||
def debug_bytes(*args) -> str:
|
def debug_bytes(*args) -> str:
|
||||||
index = range(max(map(len, args)))
|
index = range(max(map(len, args)))
|
||||||
result = []
|
result = [
|
||||||
for arg in (
|
" ".join(f"{x:03}" for x in arg)
|
||||||
[index] + list(args) + [[int(a != b) for a, b in zip(args[-1], args[-2])]]
|
for arg in [index]
|
||||||
):
|
+ list(args)
|
||||||
result.append(" ".join(f"{x:03}" for x in arg))
|
+ [[int(a != b) for a, b in zip(args[-1], args[-2])]]
|
||||||
|
]
|
||||||
|
|
||||||
return "bytes mismatch\n" + "\n".join(result)
|
return "bytes mismatch\n" + "\n".join(result)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -478,10 +478,8 @@ class AutogradCompilerInstance:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_all_nodes(args):
|
def get_all_nodes(args):
|
||||||
nodes = []
|
# filter out non-Node args, like None
|
||||||
for n in args:
|
nodes = [n for n in args if type(n) is torch.fx.Node]
|
||||||
if type(n) is torch.fx.Node: # filter out non-Node args, like None
|
|
||||||
nodes.append(n)
|
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
@ -671,13 +669,15 @@ class AutogradCompilerInstance:
|
||||||
input_nodes_and_users = []
|
input_nodes_and_users = []
|
||||||
input_nodes_and_users.extend(list(input_nodes))
|
input_nodes_and_users.extend(list(input_nodes))
|
||||||
for input_node in input_nodes:
|
for input_node in input_nodes:
|
||||||
for user in list(input_node.users.keys()):
|
input_nodes_and_users.extend(
|
||||||
|
user
|
||||||
|
for user in list(input_node.users.keys())
|
||||||
if not (
|
if not (
|
||||||
user.op == "call_function"
|
user.op == "call_function"
|
||||||
and user.target == call_hook
|
and user.target == call_hook
|
||||||
and node.kwargs.get("hook_type", None) == "post_hook"
|
and node.kwargs.get("hook_type", None) == "post_hook"
|
||||||
):
|
)
|
||||||
input_nodes_and_users.append(user)
|
)
|
||||||
|
|
||||||
arg = max(input_nodes_and_users) # last input users
|
arg = max(input_nodes_and_users) # last input users
|
||||||
if (
|
if (
|
||||||
|
|
|
||||||
|
|
@ -2526,7 +2526,6 @@ def make_torch_function_mode_stack_guard(intial_stack):
|
||||||
|
|
||||||
|
|
||||||
def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope):
|
def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope):
|
||||||
duplicate_tensors = []
|
|
||||||
global_scope = dict(guard_manager.global_scope)
|
global_scope = dict(guard_manager.global_scope)
|
||||||
ids_to_source = collections.defaultdict(list)
|
ids_to_source = collections.defaultdict(list)
|
||||||
for tensor_source in guard_manager.no_tensor_aliasing_sources: # type: ignore[attr-defined]
|
for tensor_source in guard_manager.no_tensor_aliasing_sources: # type: ignore[attr-defined]
|
||||||
|
|
@ -2534,9 +2533,9 @@ def recompilation_reason_for_no_tensor_aliasing_guard(guard_manager, scope):
|
||||||
tensor_id = id(eval(tensor_source, global_scope, scope))
|
tensor_id = id(eval(tensor_source, global_scope, scope))
|
||||||
ids_to_source[tensor_id].append(tensor_source)
|
ids_to_source[tensor_id].append(tensor_source)
|
||||||
|
|
||||||
for key in ids_to_source:
|
duplicate_tensors = [
|
||||||
if len(ids_to_source[key]) > 1:
|
f"{ids_to_source[key]}" for key in ids_to_source if len(ids_to_source[key]) > 1
|
||||||
duplicate_tensors.append(f"{ids_to_source[key]}")
|
]
|
||||||
|
|
||||||
reason = ", ".join(duplicate_tensors)
|
reason = ", ".join(duplicate_tensors)
|
||||||
return [f"Duplicate tensors found: {reason}"]
|
return [f"Duplicate tensors found: {reason}"]
|
||||||
|
|
|
||||||
|
|
@ -1463,9 +1463,7 @@ class OutputGraph:
|
||||||
return compiled_fn
|
return compiled_fn
|
||||||
|
|
||||||
def example_inputs(self) -> List[torch.Tensor]:
|
def example_inputs(self) -> List[torch.Tensor]:
|
||||||
result = []
|
result = [arg.example for arg in self.graphargs]
|
||||||
for arg in self.graphargs:
|
|
||||||
result.append(arg.example)
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def remove_unused_graphargs(self) -> None:
|
def remove_unused_graphargs(self) -> None:
|
||||||
|
|
|
||||||
|
|
@ -252,8 +252,7 @@ def _filter_iter(l1, l2, cond):
|
||||||
def _load_tuple_and_call(tup):
|
def _load_tuple_and_call(tup):
|
||||||
insts: List[Instruction] = []
|
insts: List[Instruction] = []
|
||||||
_initial_push_null(insts)
|
_initial_push_null(insts)
|
||||||
for val in tup:
|
insts.extend(create_instruction("LOAD_CONST", argval=val) for val in tup)
|
||||||
insts.append(create_instruction("LOAD_CONST", argval=val))
|
|
||||||
insts.extend(create_call_function(len(tup), False))
|
insts.extend(create_call_function(len(tup), False))
|
||||||
return insts
|
return insts
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -95,9 +95,7 @@ def collect_results(
|
||||||
results.append(buffers)
|
results.append(buffers)
|
||||||
for example in example_inputs:
|
for example in example_inputs:
|
||||||
if isinstance(example, (tuple, list)):
|
if isinstance(example, (tuple, list)):
|
||||||
for inp in example:
|
results.extend(inp.grad for inp in example if isinstance(inp, torch.Tensor))
|
||||||
if isinstance(inp, torch.Tensor):
|
|
||||||
results.append(inp.grad)
|
|
||||||
else:
|
else:
|
||||||
if isinstance(example, torch.Tensor):
|
if isinstance(example, torch.Tensor):
|
||||||
results.append(example.grad)
|
results.append(example.grad)
|
||||||
|
|
|
||||||
|
|
@ -1535,9 +1535,10 @@ def checkpoint_params(gm):
|
||||||
rng_state = torch.clone(torch.random.get_rng_state())
|
rng_state = torch.clone(torch.random.get_rng_state())
|
||||||
if torch.cuda.is_available():
|
if torch.cuda.is_available():
|
||||||
cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
|
cuda_rng_state = torch.clone(torch.cuda.get_rng_state())
|
||||||
saved_state = []
|
saved_state = [
|
||||||
for param in itertools.chain(gm.parameters(), gm.buffers()):
|
(param, param._version, torch.clone(param))
|
||||||
saved_state.append((param, param._version, torch.clone(param)))
|
for param in itertools.chain(gm.parameters(), gm.buffers())
|
||||||
|
]
|
||||||
|
|
||||||
def restore():
|
def restore():
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
|
|
|
||||||
|
|
@ -1765,10 +1765,9 @@ class BuiltinVariable(VariableTracker):
|
||||||
# tracked fakes to produce incorrect guards. This is sound because the TensorVariable
|
# tracked fakes to produce incorrect guards. This is sound because the TensorVariable
|
||||||
# coming out of set_() below will be a new one, and get
|
# coming out of set_() below will be a new one, and get
|
||||||
# installed in tracked fakes.
|
# installed in tracked fakes.
|
||||||
to_remove = []
|
to_remove = [
|
||||||
for tf in tx.output.tracked_fakes:
|
tf for tf in tx.output.tracked_fakes if tf.source == obj.source
|
||||||
if tf.source == obj.source:
|
]
|
||||||
to_remove.append(tf)
|
|
||||||
for tf in to_remove:
|
for tf in to_remove:
|
||||||
tx.output.tracked_fakes.remove(tf)
|
tx.output.tracked_fakes.remove(tf)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1026,17 +1026,16 @@ class SDPAKernelVariable(ContextWrappingVariable):
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _backends_to_nodes(tx, backends):
|
def _backends_to_nodes(tx, backends):
|
||||||
nodes = []
|
# convert to/from string in order to bake the backend into FX graph
|
||||||
for backend in backends:
|
nodes = [
|
||||||
# convert to/from string in order to bake the backend into FX graph
|
tx.output.create_node(
|
||||||
nodes.append(
|
"call_function",
|
||||||
tx.output.create_node(
|
torch.nn.attention._backend_from_string,
|
||||||
"call_function",
|
(backend.name,),
|
||||||
torch.nn.attention._backend_from_string,
|
{},
|
||||||
(backend.name,),
|
|
||||||
{},
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for backend in backends
|
||||||
|
]
|
||||||
return nodes
|
return nodes
|
||||||
|
|
||||||
def enter(self, tx):
|
def enter(self, tx):
|
||||||
|
|
|
||||||
|
|
@ -48,9 +48,9 @@ class ItertoolsVariable(VariableTracker):
|
||||||
and all(arg.has_unpack_var_sequence(tx) for arg in args)
|
and all(arg.has_unpack_var_sequence(tx) for arg in args)
|
||||||
):
|
):
|
||||||
seqs = [arg.unpack_var_sequence(tx) for arg in args]
|
seqs = [arg.unpack_var_sequence(tx) for arg in args]
|
||||||
items = []
|
items = [
|
||||||
for item in itertools.product(*seqs):
|
variables.TupleVariable(list(item)) for item in itertools.product(*seqs)
|
||||||
items.append(variables.TupleVariable(list(item)))
|
]
|
||||||
return variables.ListIteratorVariable(
|
return variables.ListIteratorVariable(
|
||||||
items, mutation_type=ValueMutationNew()
|
items, mutation_type=ValueMutationNew()
|
||||||
)
|
)
|
||||||
|
|
|
||||||
|
|
@ -776,14 +776,13 @@ class UserDefinedObjectVariable(UserDefinedVariable):
|
||||||
):
|
):
|
||||||
assert self.source # OrderedDict, dict subtypes must always have source
|
assert self.source # OrderedDict, dict subtypes must always have source
|
||||||
assert not (args or kwargs)
|
assert not (args or kwargs)
|
||||||
items = []
|
|
||||||
keys = self.call_method(tx, "keys", [], {})
|
keys = self.call_method(tx, "keys", [], {})
|
||||||
for key in keys.force_unpack_var_sequence(tx):
|
items = [
|
||||||
items.append(
|
TupleVariable(
|
||||||
TupleVariable(
|
[key, self.odict_getitem(tx, key)],
|
||||||
[key, self.odict_getitem(tx, key)],
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for key in keys.force_unpack_var_sequence(tx)
|
||||||
|
]
|
||||||
tx.output.guard_on_key_order.add(self.source.name())
|
tx.output.guard_on_key_order.add(self.source.name())
|
||||||
return TupleVariable(items)
|
return TupleVariable(items)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -833,9 +833,7 @@ class TS2FXGraphConverter:
|
||||||
self._convert_prim_iterator(node)
|
self._convert_prim_iterator(node)
|
||||||
|
|
||||||
def _convert_prim_iterator(self, node: torch._C.Node):
|
def _convert_prim_iterator(self, node: torch._C.Node):
|
||||||
output_list = []
|
output_list = [self.get_fx_value_by_ir_value(inp) for inp in node.inputs()]
|
||||||
for inp in node.inputs():
|
|
||||||
output_list.append(self.get_fx_value_by_ir_value(inp))
|
|
||||||
|
|
||||||
output_name = node.output().debugName()
|
output_name = node.output().debugName()
|
||||||
self.name_to_node[output_name] = output_list
|
self.name_to_node[output_name] = output_list
|
||||||
|
|
|
||||||
|
|
@ -7,9 +7,8 @@ class StaticForLoop(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
ret = []
|
# constant
|
||||||
for i in range(10): # constant
|
ret = [i + x for i in range(10)]
|
||||||
ret.append(i + x)
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
example_args = (torch.randn(3, 2),)
|
example_args = (torch.randn(3, 2),)
|
||||||
|
|
|
||||||
|
|
@ -1780,9 +1780,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
# Outputs: convert to a single `output` node.
|
# Outputs: convert to a single `output` node.
|
||||||
outputs = []
|
outputs = [self.deserialize_graph_output(output) for output in serialized_graph.outputs]
|
||||||
for output in serialized_graph.outputs:
|
|
||||||
outputs.append(self.deserialize_graph_output(output))
|
|
||||||
|
|
||||||
if serialized_graph.is_single_tensor_return:
|
if serialized_graph.is_single_tensor_return:
|
||||||
assert len(outputs) == 1
|
assert len(outputs) == 1
|
||||||
|
|
@ -2149,9 +2147,7 @@ class GraphModuleDeserializer(metaclass=Final):
|
||||||
if len(value) == 0:
|
if len(value) == 0:
|
||||||
return []
|
return []
|
||||||
elif typ_ == "as_tensors":
|
elif typ_ == "as_tensors":
|
||||||
result = []
|
result = [self.serialized_name_to_node[arg.name] for arg in value]
|
||||||
for arg in value:
|
|
||||||
result.append(self.serialized_name_to_node[arg.name])
|
|
||||||
return result
|
return result
|
||||||
elif typ_ in ("as_ints", "as_floats", "as_bools", "as_strings"):
|
elif typ_ in ("as_ints", "as_floats", "as_bools", "as_strings"):
|
||||||
# convert from serialized.python.types.List to python list
|
# convert from serialized.python.types.List to python list
|
||||||
|
|
|
||||||
|
|
@ -1306,8 +1306,9 @@ def merge_view_inputs(
|
||||||
mutated_input_info[inpt_idx].mutates_data
|
mutated_input_info[inpt_idx].mutates_data
|
||||||
for inpt_idx in aliased_input_indices
|
for inpt_idx in aliased_input_indices
|
||||||
):
|
):
|
||||||
for curr_idx in aliased_input_indices:
|
other_args.extend(
|
||||||
other_args.append(fwd_inputs[curr_idx])
|
fwd_inputs[curr_idx] for curr_idx in aliased_input_indices
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Here, we attempt to do a more complicated check to detect false aliasing
|
# Here, we attempt to do a more complicated check to detect false aliasing
|
||||||
|
|
@ -1320,8 +1321,9 @@ def merge_view_inputs(
|
||||||
fwd_inputs, aliased_input_indices
|
fwd_inputs, aliased_input_indices
|
||||||
)
|
)
|
||||||
if len(aliased_input_indices_no_false_sharing) <= 1:
|
if len(aliased_input_indices_no_false_sharing) <= 1:
|
||||||
for curr_idx in aliased_input_indices:
|
other_args.extend(
|
||||||
other_args.append(fwd_inputs[curr_idx])
|
fwd_inputs[curr_idx] for curr_idx in aliased_input_indices
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# We detected an input that was mutated, AND aliases with another input.
|
# We detected an input that was mutated, AND aliases with another input.
|
||||||
|
|
|
||||||
|
|
@ -217,9 +217,7 @@ def trace_map(proxy_mode, func_overload, f, xs, pos_args):
|
||||||
|
|
||||||
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
|
@map_impl.py_impl(DispatchKey.CompositeExplicitAutograd)
|
||||||
def map_dense(f, xs, pos_args):
|
def map_dense(f, xs, pos_args):
|
||||||
pytrees = []
|
pytrees = [f(*inp, *pos_args) for inp in _unstack_pytree(xs)]
|
||||||
for inp in _unstack_pytree(xs):
|
|
||||||
pytrees.append(f(*inp, *pos_args))
|
|
||||||
return _stack_pytree(pytrees)
|
return _stack_pytree(pytrees)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -543,8 +543,7 @@ def analyze_kernel_mutations(
|
||||||
)
|
)
|
||||||
stack.extend(arg for arg, mutated in zip(op.args, mutations) if mutated)
|
stack.extend(arg for arg, mutated in zip(op.args, mutations) if mutated)
|
||||||
else:
|
else:
|
||||||
for idx in MUTATION_OPS.get(op.name, []):
|
stack.extend(op.args[idx] for idx in MUTATION_OPS.get(op.name, []))
|
||||||
stack.append(op.args[idx])
|
|
||||||
|
|
||||||
# The following is an iterative DFS algorithm
|
# The following is an iterative DFS algorithm
|
||||||
mutated = [False] * num_args
|
mutated = [False] * num_args
|
||||||
|
|
|
||||||
|
|
@ -388,9 +388,7 @@ def _unstack_pytree(xs):
|
||||||
|
|
||||||
a = zip(*flat_xs)
|
a = zip(*flat_xs)
|
||||||
|
|
||||||
pytrees = []
|
pytrees = [pytree.tree_unflatten(tuple, inspec) for tuple in a]
|
||||||
for tuple in a:
|
|
||||||
pytrees.append(pytree.tree_unflatten(tuple, inspec))
|
|
||||||
return pytrees
|
return pytrees
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -2102,13 +2102,14 @@ class AotCodeCompiler:
|
||||||
aot_constants = struct.pack("qq", consts_size + 8, magic_number)
|
aot_constants = struct.pack("qq", consts_size + 8, magic_number)
|
||||||
|
|
||||||
consts_o = _compile_consts(aot_constants, sys.platform)
|
consts_o = _compile_consts(aot_constants, sys.platform)
|
||||||
kernels_o = []
|
|
||||||
gpu_codecache: Union[ROCmCodeCache, CUDACodeCache] = (
|
gpu_codecache: Union[ROCmCodeCache, CUDACodeCache] = (
|
||||||
ROCmCodeCache() if torch.version.hip else CUDACodeCache()
|
ROCmCodeCache() if torch.version.hip else CUDACodeCache()
|
||||||
)
|
)
|
||||||
for entry in gpu_codecache.cache.values():
|
kernels_o = [
|
||||||
if entry.output_path.endswith(".o"):
|
entry.output_path
|
||||||
kernels_o.append(entry.output_path)
|
for entry in gpu_codecache.cache.values()
|
||||||
|
if entry.output_path.endswith(".o")
|
||||||
|
]
|
||||||
kernels_o = " ".join(kernels_o)
|
kernels_o = " ".join(kernels_o)
|
||||||
|
|
||||||
output_name, output_dir = get_name_and_dir_from_output_file_path(output_so)
|
output_name, output_dir = get_name_and_dir_from_output_file_path(output_so)
|
||||||
|
|
|
||||||
|
|
@ -4800,10 +4800,9 @@ class LoopLevel:
|
||||||
|
|
||||||
def split_with_tiling(self, depth, factor):
|
def split_with_tiling(self, depth, factor):
|
||||||
def clone_inner():
|
def clone_inner():
|
||||||
inner = []
|
inner: List[LoopLevel] = []
|
||||||
if self.inner:
|
if self.inner:
|
||||||
for loop in self.inner:
|
inner.extend(loop.clone() for loop in self.inner)
|
||||||
inner.append(loop.clone())
|
|
||||||
return inner
|
return inner
|
||||||
|
|
||||||
def do_split_with_tiling():
|
def do_split_with_tiling():
|
||||||
|
|
|
||||||
|
|
@ -150,18 +150,16 @@ class DebugPrinterManager:
|
||||||
# get the list of args_to_print_or_save
|
# get the list of args_to_print_or_save
|
||||||
# TODO: Find a more reliable way to detect kernel args types to print for extern kernel calls
|
# TODO: Find a more reliable way to detect kernel args types to print for extern kernel calls
|
||||||
if kernel_type == "extern":
|
if kernel_type == "extern":
|
||||||
args_to_print_or_save_extern = []
|
args_to_print_or_save_extern = [
|
||||||
for arg in args_to_print_or_save:
|
arg for arg in args_to_print_or_save if arg.startswith(("buf", "arg"))
|
||||||
if arg.startswith(("buf", "arg")):
|
]
|
||||||
args_to_print_or_save_extern.append(arg)
|
|
||||||
self.args_to_print_or_save = args_to_print_or_save_extern
|
self.args_to_print_or_save = args_to_print_or_save_extern
|
||||||
elif kernel_type == "cpp":
|
elif kernel_type == "cpp":
|
||||||
args_to_print_or_save_cpp = []
|
args_to_print_or_save_cpp = [
|
||||||
for arg in args_to_print_or_save:
|
f"convert_arrayref_tensor_to_tensor({arg})"
|
||||||
if arg.startswith(("buf", "arg")):
|
for arg in args_to_print_or_save
|
||||||
args_to_print_or_save_cpp.append(
|
if arg.startswith(("buf", "arg"))
|
||||||
f"convert_arrayref_tensor_to_tensor({arg})"
|
]
|
||||||
)
|
|
||||||
self.args_to_print_or_save = args_to_print_or_save_cpp
|
self.args_to_print_or_save = args_to_print_or_save_cpp
|
||||||
else:
|
else:
|
||||||
self.args_to_print_or_save = args_to_print_or_save
|
self.args_to_print_or_save = args_to_print_or_save
|
||||||
|
|
|
||||||
|
|
@ -1392,19 +1392,19 @@ class HalideKernel(SIMDKernel):
|
||||||
result.append((call_str, arg))
|
result.append((call_str, arg))
|
||||||
if isinstance(arg, TensorArg):
|
if isinstance(arg, TensorArg):
|
||||||
assert arg.offset == 0 and arg.alias_of is None
|
assert arg.offset == 0 and arg.alias_of is None
|
||||||
for alias in self.buffer_aliases.get(arg.name, ()):
|
result.extend(
|
||||||
result.append(
|
(
|
||||||
(
|
None,
|
||||||
None,
|
TensorArg(
|
||||||
TensorArg(
|
alias,
|
||||||
alias,
|
arg.buffer,
|
||||||
arg.buffer,
|
arg.dtype,
|
||||||
arg.dtype,
|
arg.offset,
|
||||||
arg.offset,
|
alias_of=arg.name,
|
||||||
alias_of=arg.name,
|
),
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for alias in self.buffer_aliases.get(arg.name, ())
|
||||||
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
def halide_kernel_meta(self) -> HalideMeta:
|
def halide_kernel_meta(self) -> HalideMeta:
|
||||||
|
|
|
||||||
|
|
@ -72,10 +72,11 @@ def get_all_call_args(call_args_list, arg_types_list):
|
||||||
|
|
||||||
|
|
||||||
def get_numel_argdefs(kernel):
|
def get_numel_argdefs(kernel):
|
||||||
numel_argdefs = []
|
numel_argdefs = [
|
||||||
for tree in kernel.range_trees:
|
f"{tree.prefix}numel"
|
||||||
if tree.prefix != "r" or kernel.inside_reduction:
|
for tree in kernel.range_trees
|
||||||
numel_argdefs.append(f"{tree.prefix}numel")
|
if tree.prefix != "r" or kernel.inside_reduction
|
||||||
|
]
|
||||||
|
|
||||||
return numel_argdefs
|
return numel_argdefs
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -76,7 +76,7 @@ def _default_custom_combo_kernel_horizontal_partition(
|
||||||
tilings = [node_info_map[n][1] for n in nodes]
|
tilings = [node_info_map[n][1] for n in nodes]
|
||||||
|
|
||||||
max_dims = max(len(t) for t in tilings)
|
max_dims = max(len(t) for t in tilings)
|
||||||
nodes_per_ndim = []
|
nodes_per_ndim: List[List[BaseSchedulerNode]] = []
|
||||||
for i in range(2, max_dims + 1):
|
for i in range(2, max_dims + 1):
|
||||||
group_per_dim = [n for n, t in zip(nodes, tilings) if len(t) == i]
|
group_per_dim = [n for n, t in zip(nodes, tilings) if len(t) == i]
|
||||||
reduction = [
|
reduction = [
|
||||||
|
|
@ -111,12 +111,11 @@ def _default_custom_combo_kernel_horizontal_partition(
|
||||||
len(large_pointwise),
|
len(large_pointwise),
|
||||||
)
|
)
|
||||||
not_reduction = [n for n in not_reduction if n not in large_pointwise]
|
not_reduction = [n for n in not_reduction if n not in large_pointwise]
|
||||||
for node in large_pointwise:
|
nodes_per_ndim.extend([node] for node in large_pointwise)
|
||||||
nodes_per_ndim.append([node])
|
|
||||||
|
|
||||||
for g in (not_reduction, short_reduction, long_reduction):
|
nodes_per_ndim.extend(
|
||||||
if g:
|
g for g in (not_reduction, short_reduction, long_reduction) if g
|
||||||
nodes_per_ndim.append(g)
|
)
|
||||||
|
|
||||||
assert sum(len(p) for p in nodes_per_ndim) == len(nodes)
|
assert sum(len(p) for p in nodes_per_ndim) == len(nodes)
|
||||||
return nodes_per_ndim
|
return nodes_per_ndim
|
||||||
|
|
|
||||||
|
|
@ -288,8 +288,7 @@ def get_compiler_version_info(compiler: str) -> str:
|
||||||
|
|
||||||
# =============================== cpp builder ===============================
|
# =============================== cpp builder ===============================
|
||||||
def _append_list(dest_list: List[str], src_list: List[str]) -> None:
|
def _append_list(dest_list: List[str], src_list: List[str]) -> None:
|
||||||
for item in src_list:
|
dest_list.extend(copy.deepcopy(item) for item in src_list)
|
||||||
dest_list.append(copy.deepcopy(item))
|
|
||||||
|
|
||||||
|
|
||||||
def _remove_duplication_in_list(orig_list: List[str]) -> List[str]:
|
def _remove_duplication_in_list(orig_list: List[str]) -> List[str]:
|
||||||
|
|
@ -742,12 +741,11 @@ def _setup_standard_sys_libs(
|
||||||
|
|
||||||
|
|
||||||
def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> Tuple[List[str], List[str]]:
|
def _get_build_args_of_chosen_isa(vec_isa: VecISA) -> Tuple[List[str], List[str]]:
|
||||||
macros = []
|
macros: List[str] = []
|
||||||
build_flags = []
|
build_flags: List[str] = []
|
||||||
if vec_isa != invalid_vec_isa:
|
if vec_isa != invalid_vec_isa:
|
||||||
# Add Windows support later.
|
# Add Windows support later.
|
||||||
for x in vec_isa.build_macro():
|
macros.extend(copy.deepcopy(x) for x in vec_isa.build_macro())
|
||||||
macros.append(copy.deepcopy(x))
|
|
||||||
|
|
||||||
build_flags = [vec_isa.build_arch_flags()]
|
build_flags = [vec_isa.build_arch_flags()]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -398,9 +398,11 @@ def valid_vec_isa_list() -> List[VecISA]:
|
||||||
arch value is x86_64 on Linux, and the value is AMD64 on Windows.
|
arch value is x86_64 on Linux, and the value is AMD64 on Windows.
|
||||||
"""
|
"""
|
||||||
_cpu_supported_x86_isa = x86_isa_checker()
|
_cpu_supported_x86_isa = x86_isa_checker()
|
||||||
for isa in supported_vec_isa_list:
|
isa_list.extend(
|
||||||
if all(flag in _cpu_supported_x86_isa for flag in str(isa).split()) and isa:
|
isa
|
||||||
isa_list.append(isa)
|
for isa in supported_vec_isa_list
|
||||||
|
if all(flag in _cpu_supported_x86_isa for flag in str(isa).split()) and isa
|
||||||
|
)
|
||||||
|
|
||||||
return isa_list
|
return isa_list
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -621,13 +621,12 @@ class CUDAWarmupNode:
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_non_cudagraph_inps() -> List[weakref.ReferenceType[UntypedStorage]]:
|
def get_non_cudagraph_inps() -> List[weakref.ReferenceType[UntypedStorage]]:
|
||||||
non_cudagraph_inps = []
|
non_cudagraph_inps = [
|
||||||
for t in itertools.chain(new_inputs, self.wrapped_function.constants):
|
weakref.ref(t.untyped_storage())
|
||||||
if (
|
for t in itertools.chain(new_inputs, self.wrapped_function.constants)
|
||||||
isinstance(t, torch.Tensor)
|
if isinstance(t, torch.Tensor)
|
||||||
and t.untyped_storage().data_ptr() not in existing_path_data_ptrs
|
and t.untyped_storage().data_ptr() not in existing_path_data_ptrs
|
||||||
):
|
]
|
||||||
non_cudagraph_inps.append(weakref.ref(t.untyped_storage()))
|
|
||||||
return non_cudagraph_inps
|
return non_cudagraph_inps
|
||||||
|
|
||||||
non_cudagraph_inps_storages = get_non_cudagraph_inps()
|
non_cudagraph_inps_storages = get_non_cudagraph_inps()
|
||||||
|
|
@ -1709,12 +1708,10 @@ def get_block_addrs(pool_id: Tuple[int, int], live_only: bool = True) -> List[in
|
||||||
|
|
||||||
|
|
||||||
def format_tb(frames: List[Any]) -> str:
|
def format_tb(frames: List[Any]) -> str:
|
||||||
formatted_traceback = []
|
formatted_traceback = [
|
||||||
|
traceback.FrameSummary(entry["filename"], entry["line"], entry["name"])
|
||||||
for entry in frames:
|
for entry in frames
|
||||||
formatted_traceback.append(
|
]
|
||||||
traceback.FrameSummary(entry["filename"], entry["line"], entry["name"])
|
|
||||||
)
|
|
||||||
|
|
||||||
return "".join(traceback.format_list(formatted_traceback))
|
return "".join(traceback.format_list(formatted_traceback))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -127,9 +127,7 @@ class MemoryDep(Dep):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
stride_to_index = {s: i for i, s in enumerate(self_strides)}
|
stride_to_index = {s: i for i, s in enumerate(self_strides)}
|
||||||
order = []
|
order = [stride_to_index[s] for s in other_strides]
|
||||||
for s in other_strides:
|
|
||||||
order.append(stride_to_index[s])
|
|
||||||
|
|
||||||
assert set(order) == set(range(0, self.num_vars))
|
assert set(order) == set(range(0, self.num_vars))
|
||||||
return order
|
return order
|
||||||
|
|
@ -547,9 +545,7 @@ def var_builder(prefix: str) -> Tuple[VarRanges, Callable[[sympy.Expr], sympy.Sy
|
||||||
|
|
||||||
def index_vars_no_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str):
|
def index_vars_no_squeeze(*argsizes: Tuple[sympy.Expr, ...], prefix: str):
|
||||||
var_ranges, add_var = var_builder(prefix)
|
var_ranges, add_var = var_builder(prefix)
|
||||||
args: List[List[sympy.Symbol]] = []
|
args: List[List[sympy.Symbol]] = [list(map(add_var, size)) for size in argsizes]
|
||||||
for size in argsizes:
|
|
||||||
args.append(list(map(add_var, size)))
|
|
||||||
return args, var_ranges
|
return args, var_ranges
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1120,9 +1120,9 @@ class UnbindCatRemover(SplitCatSimplifier):
|
||||||
return
|
return
|
||||||
# we need to check if the getitem indices from unbind are consecutive and all go to the same cat node
|
# we need to check if the getitem indices from unbind are consecutive and all go to the same cat node
|
||||||
# before we do the unbind remove, otherwise it will hit the error when we unbind part of them
|
# before we do the unbind remove, otherwise it will hit the error when we unbind part of them
|
||||||
getitem_indices = []
|
getitem_indices = [
|
||||||
for getitem_node in unbind_node.users.keys():
|
getitem_node.args[1] for getitem_node in unbind_node.users.keys()
|
||||||
getitem_indices.append(getitem_node.args[1])
|
]
|
||||||
if not is_sorted_and_consecutive(getitem_indices) or len( # type: ignore[arg-type]
|
if not is_sorted_and_consecutive(getitem_indices) or len( # type: ignore[arg-type]
|
||||||
getitem_indices
|
getitem_indices
|
||||||
) != len(
|
) != len(
|
||||||
|
|
@ -1497,9 +1497,8 @@ def merge_getitem_cat(match: Match, split_sections: List[int], dim: int):
|
||||||
):
|
):
|
||||||
continue
|
continue
|
||||||
# find the index of getitems to be cated/stacked
|
# find the index of getitems to be cated/stacked
|
||||||
indices = []
|
# type: ignore[union-attr]
|
||||||
for arg in cat_user.args[0]: # type: ignore[union-attr]
|
indices = [arg.args[1] for arg in cat_user.args[0]] # type: ignore[union-attr]
|
||||||
indices.append(arg.args[1]) # type: ignore[union-attr]
|
|
||||||
# the gettitems to be merged must be consecutive, otherwise
|
# the gettitems to be merged must be consecutive, otherwise
|
||||||
# returned sliced tensor could be wrong
|
# returned sliced tensor could be wrong
|
||||||
if not is_sorted_and_consecutive(indices):
|
if not is_sorted_and_consecutive(indices):
|
||||||
|
|
|
||||||
|
|
@ -1657,15 +1657,14 @@ class GraphLowering(torch.fx.Interpreter):
|
||||||
new_unbacked_defs |= op.get_unbacked_symbol_defs()
|
new_unbacked_defs |= op.get_unbacked_symbol_defs()
|
||||||
|
|
||||||
def format_new_defs() -> str:
|
def format_new_defs() -> str:
|
||||||
r = []
|
r = [
|
||||||
for buf in self.buffers[buffer_watermark:]:
|
f"unbacked_symbol_defs={buf.get_unbacked_symbol_defs()} in:\n{buf}\n"
|
||||||
r.append(
|
for buf in self.buffers[buffer_watermark:]
|
||||||
f"unbacked_symbol_defs={buf.get_unbacked_symbol_defs()} in:\n{buf}\n"
|
]
|
||||||
)
|
r.extend(
|
||||||
for op in self.operations[operation_watermark:]:
|
f"unbacked_symbol_defs={op.get_unbacked_symbol_defs()} in:\n{op}\n"
|
||||||
r.append(
|
for op in self.operations[operation_watermark:]
|
||||||
f"unbacked_symbol_defs={op.get_unbacked_symbol_defs()} in:\n{op}\n"
|
)
|
||||||
)
|
|
||||||
return "***\n".join(r)
|
return "***\n".join(r)
|
||||||
|
|
||||||
if n.op != "placeholder":
|
if n.op != "placeholder":
|
||||||
|
|
|
||||||
|
|
@ -5484,13 +5484,14 @@ class UserDefinedTritonKernel(ExternKernel):
|
||||||
|
|
||||||
kernel = kernel_side_table.get_kernel(self.kernel_idx)
|
kernel = kernel_side_table.get_kernel(self.kernel_idx)
|
||||||
configs = []
|
configs = []
|
||||||
restore_value_args = []
|
restore_value_args: List[str] = []
|
||||||
if isinstance(kernel, Autotuner):
|
if isinstance(kernel, Autotuner):
|
||||||
# https://github.com/triton-lang/triton/pull/5083
|
# https://github.com/triton-lang/triton/pull/5083
|
||||||
# changes kernel.restore_idx to kernel.restore_value
|
# changes kernel.restore_idx to kernel.restore_value
|
||||||
if hasattr(kernel, "restore_idx"):
|
if hasattr(kernel, "restore_idx"):
|
||||||
for i in kernel.restore_idx:
|
restore_value_args.extend(
|
||||||
restore_value_args.append(kernel.fn.arg_names[i])
|
kernel.fn.arg_names[i] for i in kernel.restore_idx
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
assert hasattr(kernel, "restore_value")
|
assert hasattr(kernel, "restore_value")
|
||||||
restore_value_args.extend(kernel.restore_value)
|
restore_value_args.extend(kernel.restore_value)
|
||||||
|
|
|
||||||
|
|
@ -1691,9 +1691,7 @@ def split_with_sizes(x, sizes, dim=0):
|
||||||
def unbind(x, dim=0):
|
def unbind(x, dim=0):
|
||||||
dim = _validate_dim(x, dim, 0)
|
dim = _validate_dim(x, dim, 0)
|
||||||
x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim])
|
x_size = V.graph.sizevars.evaluate_static_shape(x.get_size()[dim])
|
||||||
result = []
|
result = [select(x, dim, i) for i in range(x_size)]
|
||||||
for i in range(x_size):
|
|
||||||
result.append(select(x, dim, i))
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -87,8 +87,7 @@ def _prepare_convolution_fusion_create(
|
||||||
weight_size = []
|
weight_size = []
|
||||||
weight_size.append(prepacked_weight_size[1] * groups)
|
weight_size.append(prepacked_weight_size[1] * groups)
|
||||||
weight_size.append(prepacked_weight_size[0] / groups)
|
weight_size.append(prepacked_weight_size[0] / groups)
|
||||||
for d in range(2, dim):
|
weight_size.extend(prepacked_weight_size[d] for d in range(2, dim))
|
||||||
weight_size.append(prepacked_weight_size[d])
|
|
||||||
else:
|
else:
|
||||||
weight_size = prepacked_weight.transpose(0, 1).size()
|
weight_size = prepacked_weight.transpose(0, 1).size()
|
||||||
return weight_size
|
return weight_size
|
||||||
|
|
|
||||||
|
|
@ -952,9 +952,10 @@ class PatternPrettyPrinter:
|
||||||
assert hasattr(obj, "pretty_print")
|
assert hasattr(obj, "pretty_print")
|
||||||
out_str = obj.pretty_print(pp=pp)
|
out_str = obj.pretty_print(pp=pp)
|
||||||
|
|
||||||
output = []
|
output = [
|
||||||
for key in pp.memoized_objs_names:
|
f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}"
|
||||||
output.append(f"{pp.memoized_objs_names[key]} = {pp.memoized_objs_pp[key]}")
|
for key in pp.memoized_objs_names
|
||||||
|
]
|
||||||
|
|
||||||
output.append(f"{output_name} = {out_str}")
|
output.append(f"{output_name} = {out_str}")
|
||||||
|
|
||||||
|
|
@ -1361,9 +1362,7 @@ def register_replacement(
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def normalize_args(**kwargs: Any) -> List[Any]:
|
def normalize_args(**kwargs: Any) -> List[Any]:
|
||||||
args = []
|
args = [kwargs.pop(name) for name in argnames_static]
|
||||||
for name in argnames_static:
|
|
||||||
args.append(kwargs.pop(name))
|
|
||||||
for i in range(1, len(kwargs) + 1):
|
for i in range(1, len(kwargs) + 1):
|
||||||
if f"tangents_{i}" not in kwargs:
|
if f"tangents_{i}" not in kwargs:
|
||||||
break
|
break
|
||||||
|
|
|
||||||
|
|
@ -120,7 +120,7 @@ def autotune_hints_to_configs(
|
||||||
configs to try.
|
configs to try.
|
||||||
"""
|
"""
|
||||||
xyz_options: Tuple[Tuple[int, Optional[int], Optional[int]], ...]
|
xyz_options: Tuple[Tuple[int, Optional[int], Optional[int]], ...]
|
||||||
configs = []
|
configs: List[Config] = []
|
||||||
warp_size = device_props.warp_size
|
warp_size = device_props.warp_size
|
||||||
# CPU target has no concept of "warp"
|
# CPU target has no concept of "warp"
|
||||||
if warp_size is None:
|
if warp_size is None:
|
||||||
|
|
@ -138,16 +138,16 @@ def autotune_hints_to_configs(
|
||||||
(1, block_size // 4, 1),
|
(1, block_size // 4, 1),
|
||||||
(1, 1, block_size // 4),
|
(1, 1, block_size // 4),
|
||||||
)
|
)
|
||||||
for xyz in xyz_options:
|
configs.extend(
|
||||||
configs.append(
|
triton_config(
|
||||||
triton_config(
|
size_hints,
|
||||||
size_hints,
|
*xyz,
|
||||||
*xyz,
|
num_elements_per_warp=(
|
||||||
num_elements_per_warp=(
|
device_props.warp_size if device_props.warp_size else 32
|
||||||
device_props.warp_size if device_props.warp_size else 32
|
),
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for xyz in xyz_options
|
||||||
|
)
|
||||||
|
|
||||||
return configs
|
return configs
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -308,9 +308,13 @@ class BaseSchedulerNode:
|
||||||
dep = deps.pop()
|
dep = deps.pop()
|
||||||
used_names.add(dep)
|
used_names.add(dep)
|
||||||
if V.graph.name_to_buffer.get(dep):
|
if V.graph.name_to_buffer.get(dep):
|
||||||
for alias in V.graph.name_to_buffer[dep].get_inputs_that_alias_output():
|
deps.extend(
|
||||||
if alias not in used_names:
|
alias
|
||||||
deps.append(alias)
|
for alias in V.graph.name_to_buffer[
|
||||||
|
dep
|
||||||
|
].get_inputs_that_alias_output()
|
||||||
|
if alias not in used_names
|
||||||
|
)
|
||||||
return used_names
|
return used_names
|
||||||
|
|
||||||
def prune_deps(self) -> None:
|
def prune_deps(self) -> None:
|
||||||
|
|
@ -3324,10 +3328,11 @@ class Scheduler:
|
||||||
node1 = node2
|
node1 = node2
|
||||||
node2 = tmp
|
node2 = tmp
|
||||||
|
|
||||||
deps = []
|
deps = [
|
||||||
for dep in node1.read_writes.reads | node1.read_writes.writes:
|
dep
|
||||||
if dep in node2.read_writes.reads or dep in node2.read_writes.writes:
|
for dep in node1.read_writes.reads | node1.read_writes.writes
|
||||||
deps.append(dep)
|
if dep in node2.read_writes.reads or dep in node2.read_writes.writes
|
||||||
|
]
|
||||||
|
|
||||||
return sum(self.dep_size_hint(dep) for dep in deps)
|
return sum(self.dep_size_hint(dep) for dep in deps)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -176,14 +176,12 @@ def derived_types(
|
||||||
]
|
]
|
||||||
|
|
||||||
if list_base:
|
if list_base:
|
||||||
for seq_typ in derived_seq_types(base_type):
|
result.extend((seq_typ, f"{cpp_type}[]") for seq_typ in derived_seq_types(base_type)) # type: ignore[valid-type]
|
||||||
result.append((seq_typ, f"{cpp_type}[]")) # type: ignore[valid-type]
|
|
||||||
if optional_base_list:
|
if optional_base_list:
|
||||||
for seq_typ in derived_seq_types(typing.Optional[base_type]):
|
result.extend((seq_typ, f"{cpp_type}?[]") for seq_typ in derived_seq_types(typing.Optional[base_type])) # type: ignore[valid-type]
|
||||||
result.append((seq_typ, f"{cpp_type}?[]")) # type: ignore[valid-type]
|
|
||||||
if optional_list_base:
|
if optional_list_base:
|
||||||
for seq_typ in derived_seq_types(base_type): # type: ignore[valid-type]
|
# type: ignore[valid-type]
|
||||||
result.append((typing.Optional[seq_typ], f"{cpp_type}[]?")) # type: ignore[valid-type]
|
result.extend((typing.Optional[seq_typ], f"{cpp_type}[]?") for seq_typ in derived_seq_types(base_type)) # type: ignore[valid-type]
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -279,17 +279,17 @@ def requires_set_python_module() -> bool:
|
||||||
|
|
||||||
def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs):
|
def handle_dispatch_mode(curr_mode, op_overload, *args, **kwargs):
|
||||||
assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode)
|
assert isinstance(curr_mode, torch.utils._python_dispatch.TorchDispatchMode)
|
||||||
overload_types = []
|
|
||||||
args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values()))
|
args_flattened, _ = torch.utils._pytree.tree_flatten((args, kwargs.values()))
|
||||||
for a in args_flattened:
|
# TODO: need to double check the semantics of the "types" argument to torch_dispatch.
|
||||||
# TODO: need to double check the semantics of the "types" argument to torch_dispatch.
|
# It's generated in PyInterpreter.cpp, but seems to be generated in two places,
|
||||||
# It's generated in PyInterpreter.cpp, but seems to be generated in two places,
|
# where in one case we only include tensors with the python key, and in another
|
||||||
# where in one case we only include tensors with the python key, and in another
|
# we include **all** tensors.
|
||||||
# we include **all** tensors.
|
overload_types = [
|
||||||
if isinstance(a, torch.Tensor) and torch._C._dispatch_keys(a).has(
|
type(a)
|
||||||
torch._C.DispatchKey.Python
|
for a in args_flattened
|
||||||
):
|
if isinstance(a, torch.Tensor)
|
||||||
overload_types.append(type(a))
|
and torch._C._dispatch_keys(a).has(torch._C.DispatchKey.Python)
|
||||||
|
]
|
||||||
# TODO: check that I got these args correct (in C++, we pass in "0000"??)
|
# TODO: check that I got these args correct (in C++, we pass in "0000"??)
|
||||||
|
|
||||||
return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
|
return curr_mode.__torch_dispatch__(op_overload, overload_types, args, kwargs)
|
||||||
|
|
|
||||||
|
|
@ -43,15 +43,14 @@ def dump_file(filename: str) -> None:
|
||||||
|
|
||||||
|
|
||||||
def from_traceback(tb: Sequence[traceback.FrameSummary]) -> List[Dict[str, Any]]:
|
def from_traceback(tb: Sequence[traceback.FrameSummary]) -> List[Dict[str, Any]]:
|
||||||
r = []
|
# dict naming convention here coincides with
|
||||||
for frame in tb:
|
# python/combined_traceback.cpp
|
||||||
# dict naming convention here coincides with
|
r = [
|
||||||
# python/combined_traceback.cpp
|
{
|
||||||
r.append(
|
"line": frame.lineno,
|
||||||
{
|
"name": frame.name,
|
||||||
"line": frame.lineno,
|
"filename": intern_string(frame.filename),
|
||||||
"name": frame.name,
|
}
|
||||||
"filename": intern_string(frame.filename),
|
for frame in tb
|
||||||
}
|
]
|
||||||
)
|
|
||||||
return r
|
return r
|
||||||
|
|
|
||||||
|
|
@ -312,10 +312,11 @@ def _make_prim(
|
||||||
prim_autograd_impl.impl(name, _autograd_impl)
|
prim_autograd_impl.impl(name, _autograd_impl)
|
||||||
prim_meta_impl.impl(name, meta)
|
prim_meta_impl.impl(name, meta)
|
||||||
else:
|
else:
|
||||||
mutates_args = []
|
mutates_args = [
|
||||||
for arg in cpp_schema.arguments:
|
arg.name
|
||||||
if arg.alias_info is not None and arg.alias_info.is_write:
|
for arg in cpp_schema.arguments
|
||||||
mutates_args.append(arg.name)
|
if arg.alias_info is not None and arg.alias_info.is_write
|
||||||
|
]
|
||||||
prim_def = torch.library.custom_op(
|
prim_def = torch.library.custom_op(
|
||||||
"prims::" + name,
|
"prims::" + name,
|
||||||
_prim_impl,
|
_prim_impl,
|
||||||
|
|
|
||||||
|
|
@ -3055,9 +3055,7 @@ def chunk(a: TensorLikeType, chunks: int, dim: int = 0) -> Tuple[TensorLikeType,
|
||||||
full_chunks = math.floor(length / chunk_size)
|
full_chunks = math.floor(length / chunk_size)
|
||||||
tail_chunk_size = length % chunk_size
|
tail_chunk_size = length % chunk_size
|
||||||
|
|
||||||
result = []
|
result = [narrow(a, dim, i * chunk_size, chunk_size) for i in range(full_chunks)]
|
||||||
for i in range(full_chunks):
|
|
||||||
result.append(narrow(a, dim, i * chunk_size, chunk_size))
|
|
||||||
|
|
||||||
if tail_chunk_size != 0:
|
if tail_chunk_size != 0:
|
||||||
result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size))
|
result.append(narrow(a, dim, full_chunks * chunk_size, tail_chunk_size))
|
||||||
|
|
|
||||||
|
|
@ -585,14 +585,13 @@ def has_meta(func):
|
||||||
lambda func: is_builtin(func) and "foreach" in func.name() and has_meta(func)
|
lambda func: is_builtin(func) and "foreach" in func.name() and has_meta(func)
|
||||||
)
|
)
|
||||||
def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs):
|
def foreach_run_and_map_input_device(fake_mode, func, *args, **kwargs):
|
||||||
tensor_lists = []
|
tensor_lists = [
|
||||||
for arg in itertools.chain(args, kwargs.values()):
|
arg
|
||||||
if (
|
for arg in itertools.chain(args, kwargs.values())
|
||||||
isinstance(arg, (list, tuple))
|
if isinstance(arg, (list, tuple))
|
||||||
and len(arg)
|
and len(arg)
|
||||||
and isinstance(arg[0], torch.Tensor)
|
and isinstance(arg[0], torch.Tensor)
|
||||||
):
|
]
|
||||||
tensor_lists.append(arg)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with in_kernel_invocation_manager(fake_mode):
|
with in_kernel_invocation_manager(fake_mode):
|
||||||
|
|
|
||||||
|
|
@ -171,8 +171,7 @@ def get_plain_tensors(
|
||||||
continue
|
continue
|
||||||
|
|
||||||
inner_keys, _ = curr.__tensor_flatten__()
|
inner_keys, _ = curr.__tensor_flatten__()
|
||||||
for key in reversed(inner_keys):
|
todo.extend(getattr(curr, key) for key in reversed(inner_keys))
|
||||||
todo.append(getattr(curr, key))
|
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
@ -1629,13 +1628,12 @@ class FakeTensorMode(TorchDispatchMode):
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(output, tuple):
|
if isinstance(output, tuple):
|
||||||
output_infos = []
|
output_infos = [
|
||||||
for out_elem in output:
|
self._get_output_info_for_cache_entry(
|
||||||
output_infos.append(
|
state, key, func, args, kwargs, out_elem
|
||||||
self._get_output_info_for_cache_entry(
|
|
||||||
state, key, func, args, kwargs, out_elem
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for out_elem in output
|
||||||
|
]
|
||||||
return _DispatchCacheEntry(
|
return _DispatchCacheEntry(
|
||||||
output_infos=tuple(output_infos), is_output_tuple=True
|
output_infos=tuple(output_infos), is_output_tuple=True
|
||||||
)
|
)
|
||||||
|
|
@ -1727,17 +1725,16 @@ class FakeTensorMode(TorchDispatchMode):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if entry.is_output_tuple:
|
if entry.is_output_tuple:
|
||||||
outputs = []
|
outputs = [
|
||||||
for output_info in entry.output_infos:
|
self._get_output_tensor_from_cache_entry(
|
||||||
outputs.append(
|
state,
|
||||||
self._get_output_tensor_from_cache_entry(
|
output_info,
|
||||||
state,
|
key,
|
||||||
output_info,
|
func,
|
||||||
key,
|
args,
|
||||||
func,
|
|
||||||
args,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for output_info in entry.output_infos
|
||||||
|
]
|
||||||
return tuple(outputs)
|
return tuple(outputs)
|
||||||
else:
|
else:
|
||||||
return self._get_output_tensor_from_cache_entry(
|
return self._get_output_tensor_from_cache_entry(
|
||||||
|
|
|
||||||
|
|
@ -389,17 +389,17 @@ class LSTM(torch.nn.Module):
|
||||||
**factory_kwargs,
|
**factory_kwargs,
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
for layer in range(1, num_layers):
|
layers.extend(
|
||||||
layers.append(
|
_LSTMLayer(
|
||||||
_LSTMLayer(
|
self.hidden_size,
|
||||||
self.hidden_size,
|
self.hidden_size,
|
||||||
self.hidden_size,
|
self.bias,
|
||||||
self.bias,
|
batch_first=False,
|
||||||
batch_first=False,
|
bidirectional=self.bidirectional,
|
||||||
bidirectional=self.bidirectional,
|
**factory_kwargs,
|
||||||
**factory_kwargs,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for layer in range(1, num_layers)
|
||||||
|
)
|
||||||
self.layers = torch.nn.ModuleList(layers)
|
self.layers = torch.nn.ModuleList(layers)
|
||||||
|
|
||||||
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
|
def forward(self, x: Tensor, hidden: Optional[Tuple[Tensor, Tensor]] = None):
|
||||||
|
|
|
||||||
|
|
@ -32,8 +32,7 @@ def _reverse_repeat_padding(padding: List[int]) -> List[int]:
|
||||||
_reversed_padding_repeated_twice: List[int] = []
|
_reversed_padding_repeated_twice: List[int] = []
|
||||||
N = len(padding)
|
N = len(padding)
|
||||||
for idx in range(N):
|
for idx in range(N):
|
||||||
for _ in range(2):
|
_reversed_padding_repeated_twice.extend(padding[N - idx - 1] for _ in range(2))
|
||||||
_reversed_padding_repeated_twice.append(padding[N - idx - 1])
|
|
||||||
return _reversed_padding_repeated_twice
|
return _reversed_padding_repeated_twice
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -568,9 +568,9 @@ def create_one_transformed_and_logged_copy_of_subgraph(
|
||||||
and len(arg)
|
and len(arg)
|
||||||
and isinstance(arg[0], Node)
|
and isinstance(arg[0], Node)
|
||||||
):
|
):
|
||||||
for inner_arg in arg:
|
new_args.extend(
|
||||||
if isinstance(inner_arg, Node):
|
inner_arg for inner_arg in arg if isinstance(inner_arg, Node)
|
||||||
new_args.append(inner_arg)
|
)
|
||||||
|
|
||||||
new_kwargs = {}
|
new_kwargs = {}
|
||||||
for name, old_kwarg in first_node.kwargs.items():
|
for name, old_kwarg in first_node.kwargs.items():
|
||||||
|
|
|
||||||
|
|
@ -312,10 +312,7 @@ def get_arg_indices_of_inputs_to_log(node: Node) -> List[int]:
|
||||||
node.target in (torch.add, torch.ops.quantized.add, operator.add)
|
node.target in (torch.add, torch.ops.quantized.add, operator.add)
|
||||||
or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul)
|
or node.target in (torch.mul, torch.ops.quantized.mul, operator.mul)
|
||||||
):
|
):
|
||||||
result = []
|
result = [i for i in range(2) if type(node.args[i]) == Node]
|
||||||
for i in range(2):
|
|
||||||
if type(node.args[i]) == Node:
|
|
||||||
result.append(i)
|
|
||||||
return result
|
return result
|
||||||
return [0]
|
return [0]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -191,8 +191,7 @@ def expand_groups_in_paired_modules_list(paired_modules_list):
|
||||||
elif len(group) == 2:
|
elif len(group) == 2:
|
||||||
new_list.append(group)
|
new_list.append(group)
|
||||||
elif len(group) > 2:
|
elif len(group) > 2:
|
||||||
for i in range(len(group) - 1):
|
new_list.extend([group[i], group[i + 1]] for i in range(len(group) - 1))
|
||||||
new_list.append([group[i], group[i + 1]])
|
|
||||||
|
|
||||||
return new_list
|
return new_list
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -155,14 +155,14 @@ def _get_binary_op_configs(
|
||||||
(op_with_quantized_bop_scalar_variant, torch.relu),
|
(op_with_quantized_bop_scalar_variant, torch.relu),
|
||||||
op_with_quantized_bop_scalar_variant,
|
op_with_quantized_bop_scalar_variant,
|
||||||
]
|
]
|
||||||
for bop_pattern in bop_patterns:
|
binary_op_configs.extend(
|
||||||
binary_op_configs.append(
|
BackendPatternConfig(bop_pattern)
|
||||||
BackendPatternConfig(bop_pattern)
|
.set_dtype_configs(dtype_configs) # noqa: E131
|
||||||
.set_dtype_configs(dtype_configs) # noqa: E131
|
._set_num_tensor_args_to_observation_type(
|
||||||
._set_num_tensor_args_to_observation_type(
|
num_tensor_args_to_observation_type_mapping
|
||||||
num_tensor_args_to_observation_type_mapping
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for bop_pattern in bop_patterns
|
||||||
|
)
|
||||||
# matmul
|
# matmul
|
||||||
binary_op_configs.append(
|
binary_op_configs.append(
|
||||||
BackendPatternConfig(torch.matmul).set_dtype_configs(
|
BackendPatternConfig(torch.matmul).set_dtype_configs(
|
||||||
|
|
@ -502,7 +502,6 @@ def _get_ln_configs(dtype_configs: List[DTypeConfig]) -> List[BackendPatternConf
|
||||||
def _get_default_op_configs(
|
def _get_default_op_configs(
|
||||||
dtype_configs: List[DTypeConfig],
|
dtype_configs: List[DTypeConfig],
|
||||||
) -> List[BackendPatternConfig]:
|
) -> List[BackendPatternConfig]:
|
||||||
configs = []
|
|
||||||
default_ops = [
|
default_ops = [
|
||||||
torch.nn.ELU,
|
torch.nn.ELU,
|
||||||
torch.nn.LeakyReLU,
|
torch.nn.LeakyReLU,
|
||||||
|
|
@ -517,14 +516,14 @@ def _get_default_op_configs(
|
||||||
torch.nn.functional.leaky_relu,
|
torch.nn.functional.leaky_relu,
|
||||||
torch.nn.functional.dropout,
|
torch.nn.functional.dropout,
|
||||||
]
|
]
|
||||||
for op in default_ops:
|
configs = [
|
||||||
configs.append(
|
BackendPatternConfig(op)
|
||||||
BackendPatternConfig(op)
|
.set_observation_type(
|
||||||
.set_observation_type(
|
ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
|
||||||
ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT
|
) # noqa: E131
|
||||||
) # noqa: E131
|
.set_dtype_configs(dtype_configs)
|
||||||
.set_dtype_configs(dtype_configs)
|
for op in default_ops
|
||||||
)
|
]
|
||||||
|
|
||||||
configs.append(
|
configs.append(
|
||||||
BackendPatternConfig(torch.nn.functional.group_norm)
|
BackendPatternConfig(torch.nn.functional.group_norm)
|
||||||
|
|
|
||||||
|
|
@ -159,14 +159,14 @@ def get_binary_op_configs():
|
||||||
# TODO: remove when functionalization is supported in pt2_mode
|
# TODO: remove when functionalization is supported in pt2_mode
|
||||||
(op_with_quantized_bop_scalar_variant, torch.ops.aten.relu_.default),
|
(op_with_quantized_bop_scalar_variant, torch.ops.aten.relu_.default),
|
||||||
]
|
]
|
||||||
for bop_pattern in bop_patterns:
|
binary_op_configs.extend(
|
||||||
binary_op_configs.append(
|
BackendPatternConfig(bop_pattern)
|
||||||
BackendPatternConfig(bop_pattern)
|
.set_dtype_configs(dtype_configs) # noqa: E131
|
||||||
.set_dtype_configs(dtype_configs) # noqa: E131
|
._set_num_tensor_args_to_observation_type(
|
||||||
._set_num_tensor_args_to_observation_type(
|
num_tensor_args_to_observation_type_mapping
|
||||||
num_tensor_args_to_observation_type_mapping
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for bop_pattern in bop_patterns
|
||||||
|
)
|
||||||
|
|
||||||
return binary_op_configs
|
return binary_op_configs
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -325,14 +325,14 @@ def _get_binary_ops_configs() -> List[BackendPatternConfig]:
|
||||||
(op, torch.relu),
|
(op, torch.relu),
|
||||||
op,
|
op,
|
||||||
]
|
]
|
||||||
for bop_pattern in bop_patterns:
|
binary_op_configs.extend(
|
||||||
binary_op_configs.append(
|
BackendPatternConfig(bop_pattern)
|
||||||
BackendPatternConfig(bop_pattern)
|
.set_dtype_configs(dtype_configs) # noqa: E131
|
||||||
.set_dtype_configs(dtype_configs) # noqa: E131
|
._set_num_tensor_args_to_observation_type(
|
||||||
._set_num_tensor_args_to_observation_type(
|
num_tensor_args_to_observation_type_mapping
|
||||||
num_tensor_args_to_observation_type_mapping
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for bop_pattern in bop_patterns
|
||||||
|
)
|
||||||
return binary_op_configs
|
return binary_op_configs
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -385,13 +385,12 @@ def _get_share_qparams_ops_configs() -> List[BackendPatternConfig]:
|
||||||
"squeeze_",
|
"squeeze_",
|
||||||
"leaky_relu",
|
"leaky_relu",
|
||||||
]
|
]
|
||||||
share_qparams_op_configs: List[BackendPatternConfig] = []
|
share_qparams_op_configs: List[BackendPatternConfig] = [
|
||||||
for op in share_qparams_ops:
|
BackendPatternConfig(op)
|
||||||
share_qparams_op_configs.append(
|
.set_observation_type(observation_type) # noqa: E131
|
||||||
BackendPatternConfig(op)
|
.set_dtype_configs(dtype_configs)
|
||||||
.set_observation_type(observation_type) # noqa: E131
|
for op in share_qparams_ops
|
||||||
.set_dtype_configs(dtype_configs)
|
]
|
||||||
)
|
|
||||||
return share_qparams_op_configs
|
return share_qparams_op_configs
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,18 +36,20 @@ def get_pattern_to_dtype_configs(
|
||||||
|
|
||||||
|
|
||||||
def get_qat_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]:
|
def get_qat_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]:
|
||||||
qat_module_classes = []
|
qat_module_classes = [
|
||||||
for config in backend_config.configs:
|
config.qat_module
|
||||||
if config.qat_module is not None:
|
for config in backend_config.configs
|
||||||
qat_module_classes.append(config.qat_module)
|
if config.qat_module is not None
|
||||||
|
]
|
||||||
return tuple(set(qat_module_classes))
|
return tuple(set(qat_module_classes))
|
||||||
|
|
||||||
|
|
||||||
def get_fused_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]:
|
def get_fused_module_classes(backend_config: BackendConfig) -> Tuple[type, ...]:
|
||||||
fused_module_classes = []
|
fused_module_classes = [
|
||||||
for config in backend_config.configs:
|
config.fused_module
|
||||||
if config.fused_module is not None:
|
for config in backend_config.configs
|
||||||
fused_module_classes.append(config.fused_module)
|
if config.fused_module is not None
|
||||||
|
]
|
||||||
return tuple(set(fused_module_classes))
|
return tuple(set(fused_module_classes))
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -92,9 +92,7 @@ def _fuse_modules_helper(
|
||||||
additional_fuser_method_mapping = fuse_custom_config_dict.get(
|
additional_fuser_method_mapping = fuse_custom_config_dict.get(
|
||||||
"additional_fuser_method_mapping", {}
|
"additional_fuser_method_mapping", {}
|
||||||
)
|
)
|
||||||
mod_list = []
|
mod_list = [_get_module(model, item) for item in modules_to_fuse]
|
||||||
for item in modules_to_fuse:
|
|
||||||
mod_list.append(_get_module(model, item))
|
|
||||||
|
|
||||||
# Fuse list of modules
|
# Fuse list of modules
|
||||||
new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping)
|
new_mod_list = fuser_func(mod_list, is_qat, additional_fuser_method_mapping)
|
||||||
|
|
|
||||||
|
|
@ -266,9 +266,7 @@ def _get_valid_patterns(op_pattern):
|
||||||
"""
|
"""
|
||||||
result: List[Any]
|
result: List[Any]
|
||||||
if isinstance(op_pattern, (tuple, list)):
|
if isinstance(op_pattern, (tuple, list)):
|
||||||
sub_combs = []
|
sub_combs = [_get_valid_patterns(sub_pattern) for sub_pattern in op_pattern]
|
||||||
for sub_pattern in op_pattern:
|
|
||||||
sub_combs.append(_get_valid_patterns(sub_pattern))
|
|
||||||
result = list(itertools.product(*sub_combs))
|
result = list(itertools.product(*sub_combs))
|
||||||
else:
|
else:
|
||||||
result = [op_pattern, MatchAllNode]
|
result = [op_pattern, MatchAllNode]
|
||||||
|
|
|
||||||
|
|
@ -527,8 +527,9 @@ class ModelReportVisualizer:
|
||||||
# gather the x_data and multiple y_data
|
# gather the x_data and multiple y_data
|
||||||
# calculate the number of channels
|
# calculate the number of channels
|
||||||
num_channels: int = max(row[self.CHANNEL_NUM_INDEX] for row in table) + 1
|
num_channels: int = max(row[self.CHANNEL_NUM_INDEX] for row in table) + 1
|
||||||
for channel in range(num_channels):
|
y_data.extend(
|
||||||
y_data.append([]) # separate data list per channel
|
[] for channel in range(num_channels)
|
||||||
|
) # separate data list per channel
|
||||||
|
|
||||||
for table_row_num, row in enumerate(table):
|
for table_row_num, row in enumerate(table):
|
||||||
# get x_value to append
|
# get x_value to append
|
||||||
|
|
|
||||||
|
|
@ -1116,10 +1116,9 @@ def convert(
|
||||||
# for dynamic quant ops or weight only quant ops
|
# for dynamic quant ops or weight only quant ops
|
||||||
_run_weight_observers(model, backend_config)
|
_run_weight_observers(model, backend_config)
|
||||||
|
|
||||||
graph_inputs: List[str] = []
|
graph_inputs: List[str] = [
|
||||||
for node in model.graph.nodes:
|
node.name for node in model.graph.nodes if node.op == "placeholder"
|
||||||
if node.op == "placeholder":
|
]
|
||||||
graph_inputs.append(node.name)
|
|
||||||
|
|
||||||
# additional state to override inputs to be quantized, if specified
|
# additional state to override inputs to be quantized, if specified
|
||||||
# by the user
|
# by the user
|
||||||
|
|
|
||||||
|
|
@ -78,8 +78,7 @@ class DefaultFuseHandler(FuseHandler):
|
||||||
n, *args = pattern
|
n, *args = pattern
|
||||||
modules: List[torch.nn.Module] = []
|
modules: List[torch.nn.Module] = []
|
||||||
modules.append(get_modules(n))
|
modules.append(get_modules(n))
|
||||||
for a in args:
|
modules.extend(get_modules(a) for a in args)
|
||||||
modules.append(get_modules(a))
|
|
||||||
return tuple(modules)
|
return tuple(modules)
|
||||||
else:
|
else:
|
||||||
n = pattern
|
n = pattern
|
||||||
|
|
@ -111,9 +110,7 @@ class DefaultFuseHandler(FuseHandler):
|
||||||
# as input
|
# as input
|
||||||
fused_module = fuser_method(is_qat, *matched_modules)
|
fused_module = fuser_method(is_qat, *matched_modules)
|
||||||
setattr(named_modules[module_parent_name], module_name, fused_module)
|
setattr(named_modules[module_parent_name], module_name, fused_module)
|
||||||
extra_args = []
|
extra_args = [load_arg(input) for input in extra_inputs]
|
||||||
for input in extra_inputs:
|
|
||||||
extra_args.append(load_arg(input))
|
|
||||||
node = fused_graph.node_copy(root_node, load_arg)
|
node = fused_graph.node_copy(root_node, load_arg)
|
||||||
args = list(node.args)
|
args = list(node.args)
|
||||||
args.extend(extra_args)
|
args.extend(extra_args)
|
||||||
|
|
|
||||||
|
|
@ -1219,13 +1219,12 @@ def _maybe_insert_observers_before_graph_output(
|
||||||
else:
|
else:
|
||||||
return maybe_node
|
return maybe_node
|
||||||
elif isinstance(maybe_node, (list, tuple)):
|
elif isinstance(maybe_node, (list, tuple)):
|
||||||
results = []
|
results = [
|
||||||
for inner_node in maybe_node:
|
_recursive_maybe_replace_node_with_obs(
|
||||||
results.append(
|
inner_node, model, named_modules, graph
|
||||||
_recursive_maybe_replace_node_with_obs(
|
|
||||||
inner_node, model, named_modules, graph
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for inner_node in maybe_node
|
||||||
|
]
|
||||||
if isinstance(maybe_node, list):
|
if isinstance(maybe_node, list):
|
||||||
return results
|
return results
|
||||||
else:
|
else:
|
||||||
|
|
@ -1244,11 +1243,10 @@ def _maybe_insert_observers_before_graph_output(
|
||||||
"Unhandled type for returned node:", maybe_node
|
"Unhandled type for returned node:", maybe_node
|
||||||
)
|
)
|
||||||
|
|
||||||
new_args = []
|
new_args = [
|
||||||
for old_arg in graph_output_node.args:
|
_recursive_maybe_replace_node_with_obs(old_arg, model, named_modules, graph)
|
||||||
new_args.append(
|
for old_arg in graph_output_node.args
|
||||||
_recursive_maybe_replace_node_with_obs(old_arg, model, named_modules, graph)
|
]
|
||||||
)
|
|
||||||
|
|
||||||
graph_output_node.args = tuple(new_args) # type: ignore[assignment]
|
graph_output_node.args = tuple(new_args) # type: ignore[assignment]
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,7 +1,7 @@
|
||||||
# mypy: allow-untyped-defs
|
# mypy: allow-untyped-defs
|
||||||
import itertools
|
import itertools
|
||||||
import operator
|
import operator
|
||||||
from typing import Any, Callable, List, Optional, OrderedDict, Set
|
from typing import Any, Callable, List, Optional, OrderedDict, Sequence, Set
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch.fx import Node
|
from torch.fx import Node
|
||||||
|
|
@ -58,7 +58,7 @@ def update_equivalent_types_dict(customized_equivalent_types=None):
|
||||||
_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict()
|
_EQUIVALENT_TYPES_DICT = _create_equivalent_types_dict()
|
||||||
|
|
||||||
|
|
||||||
def _partitions_sequential(partitions: List[SourcePartition]):
|
def _partitions_sequential(partitions: Sequence[SourcePartition]):
|
||||||
prev_partition = None
|
prev_partition = None
|
||||||
for partition in partitions:
|
for partition in partitions:
|
||||||
if prev_partition is not None and not check_subgraphs_connected(
|
if prev_partition is not None and not check_subgraphs_connected(
|
||||||
|
|
@ -108,8 +108,9 @@ def find_sequential_partitions(
|
||||||
|
|
||||||
typed_partitions_list = list(typed_partitions.values())
|
typed_partitions_list = list(typed_partitions.values())
|
||||||
fusion_candidates = itertools.product(*typed_partitions_list)
|
fusion_candidates = itertools.product(*typed_partitions_list)
|
||||||
fused_partitions = []
|
fused_partitions = [
|
||||||
for candidate in fusion_candidates:
|
candidate
|
||||||
if _partitions_sequential(candidate): # type: ignore[arg-type]
|
for candidate in fusion_candidates
|
||||||
fused_partitions.append(candidate)
|
if _partitions_sequential(candidate)
|
||||||
|
]
|
||||||
return fused_partitions
|
return fused_partitions
|
||||||
|
|
|
||||||
|
|
@ -93,10 +93,10 @@ def _get_supported_symmetric_config_and_operators() -> List[OperatorConfig]:
|
||||||
get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
|
get_symmetric_quantization_config(is_per_channel=True, is_qat=True),
|
||||||
]:
|
]:
|
||||||
ops = _supported_symmetric_quantized_operators()
|
ops = _supported_symmetric_quantized_operators()
|
||||||
for pattern_list in ops.values():
|
supported_config_and_operators.extend(
|
||||||
supported_config_and_operators.append(
|
OperatorConfig(quantization_config, pattern_list)
|
||||||
OperatorConfig(quantization_config, pattern_list)
|
for pattern_list in ops.values()
|
||||||
)
|
)
|
||||||
return copy.deepcopy(supported_config_and_operators)
|
return copy.deepcopy(supported_config_and_operators)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -55,10 +55,11 @@ def _allocate_jacobians_with_inputs(
|
||||||
# of `t.numel` and width of `numel_output`. Otherwise, for each tensor, returns
|
# of `t.numel` and width of `numel_output`. Otherwise, for each tensor, returns
|
||||||
# a 1-d tensor with size `(t.numel,)`. Each new tensor will be strided and have
|
# a 1-d tensor with size `(t.numel,)`. Each new tensor will be strided and have
|
||||||
# the same dtype and device as those of the corresponding input.
|
# the same dtype and device as those of the corresponding input.
|
||||||
out: List[torch.Tensor] = []
|
out: List[torch.Tensor] = [
|
||||||
for t in input_tensors:
|
t.new_zeros((t.numel(), numel_output), layout=torch.strided)
|
||||||
if _is_float_or_complex_tensor(t) and t.requires_grad:
|
for t in input_tensors
|
||||||
out.append(t.new_zeros((t.numel(), numel_output), layout=torch.strided))
|
if _is_float_or_complex_tensor(t) and t.requires_grad
|
||||||
|
]
|
||||||
return tuple(out)
|
return tuple(out)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -69,11 +70,12 @@ def _allocate_jacobians_with_outputs(
|
||||||
# in `output_tensors`, returns a new zero-filled tensor with height of `dim` and
|
# in `output_tensors`, returns a new zero-filled tensor with height of `dim` and
|
||||||
# width of `t.numel`. Otherwise, for each tensor, returns a 1-d tensor with size
|
# width of `t.numel`. Otherwise, for each tensor, returns a 1-d tensor with size
|
||||||
# (t.numel,).
|
# (t.numel,).
|
||||||
out: List[torch.Tensor] = []
|
|
||||||
options = {"dtype": dtype, "device": device, "layout": torch.strided}
|
options = {"dtype": dtype, "device": device, "layout": torch.strided}
|
||||||
for t in output_tensors:
|
out: List[torch.Tensor] = [
|
||||||
if _is_float_or_complex_tensor(t):
|
t.new_zeros((numel_input, t.numel()), **options)
|
||||||
out.append(t.new_zeros((numel_input, t.numel()), **options))
|
for t in output_tensors
|
||||||
|
if _is_float_or_complex_tensor(t)
|
||||||
|
]
|
||||||
return tuple(out)
|
return tuple(out)
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -904,10 +906,10 @@ def _compute_analytical_jacobian_rows(
|
||||||
def _get_analytical_vjps_wrt_specific_output(
|
def _get_analytical_vjps_wrt_specific_output(
|
||||||
vjp_fn, sample_output, v
|
vjp_fn, sample_output, v
|
||||||
) -> List[List[Optional[torch.Tensor]]]:
|
) -> List[List[Optional[torch.Tensor]]]:
|
||||||
vjps: List[List[Optional[torch.Tensor]]] = []
|
|
||||||
grad_inputs = vjp_fn(v.reshape(sample_output.shape))
|
grad_inputs = vjp_fn(v.reshape(sample_output.shape))
|
||||||
for vjp in grad_inputs:
|
vjps: List[List[Optional[torch.Tensor]]] = [
|
||||||
vjps.append([vjp.clone() if isinstance(vjp, torch.Tensor) else None])
|
[vjp.clone() if isinstance(vjp, torch.Tensor) else None] for vjp in grad_inputs
|
||||||
|
]
|
||||||
return vjps
|
return vjps
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -802,9 +802,7 @@ def _register_logging_hooks_on_whole_graph(
|
||||||
log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}"
|
log_str = f"Executing: {node} with grad_outputs: {grad_outputs_str}"
|
||||||
log.debug(log_str)
|
log.debug(log_str)
|
||||||
|
|
||||||
handles = []
|
handles = [node.register_prehook(prehook) for node in iter_graph(grad_fns)]
|
||||||
for node in iter_graph(grad_fns):
|
|
||||||
handles.append(node.register_prehook(prehook))
|
|
||||||
|
|
||||||
def unregister_hooks() -> None:
|
def unregister_hooks() -> None:
|
||||||
for handle in handles:
|
for handle in handles:
|
||||||
|
|
|
||||||
|
|
@ -856,10 +856,9 @@ def _build_table(
|
||||||
flops_column_width = DEFAULT_COLUMN_WIDTH
|
flops_column_width = DEFAULT_COLUMN_WIDTH
|
||||||
|
|
||||||
src_column_width = None
|
src_column_width = None
|
||||||
stacks = []
|
stacks = [
|
||||||
for evt in events:
|
evt.stack for evt in events if evt.stack is not None and len(evt.stack) > 0
|
||||||
if evt.stack is not None and len(evt.stack) > 0:
|
]
|
||||||
stacks.append(evt.stack)
|
|
||||||
has_stack = len(stacks) > 0
|
has_stack = len(stacks) > 0
|
||||||
if has_stack:
|
if has_stack:
|
||||||
src_column_width = (
|
src_column_width = (
|
||||||
|
|
@ -947,10 +946,7 @@ def _build_table(
|
||||||
|
|
||||||
if with_flops:
|
if with_flops:
|
||||||
# Auto-scaling of flops header
|
# Auto-scaling of flops header
|
||||||
raw_flops = []
|
raw_flops = [evt.flops for evt in events if evt.flops > 0]
|
||||||
for evt in events:
|
|
||||||
if evt.flops > 0:
|
|
||||||
raw_flops.append(evt.flops)
|
|
||||||
if len(raw_flops) != 0:
|
if len(raw_flops) != 0:
|
||||||
(flops_scale, flops_header) = auto_scale_flops(min(raw_flops))
|
(flops_scale, flops_header) = auto_scale_flops(min(raw_flops))
|
||||||
headers.append(f"Total {flops_header}")
|
headers.append(f"Total {flops_header}")
|
||||||
|
|
|
||||||
|
|
@ -322,9 +322,7 @@ def _lazy_init():
|
||||||
# However, we must not let any *other* threads in!
|
# However, we must not let any *other* threads in!
|
||||||
_tls.is_initializing = True
|
_tls.is_initializing = True
|
||||||
|
|
||||||
for calls in _lazy_seed_tracker.get_calls():
|
_queued_calls.extend(calls for calls in _lazy_seed_tracker.get_calls() if calls)
|
||||||
if calls:
|
|
||||||
_queued_calls.append(calls)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for queued_call, orig_traceback in _queued_calls:
|
for queued_call, orig_traceback in _queued_calls:
|
||||||
|
|
|
||||||
|
|
@ -44,9 +44,7 @@ def get_rng_state(device: Union[int, str, torch.device] = "cuda") -> Tensor:
|
||||||
|
|
||||||
def get_rng_state_all() -> List[Tensor]:
|
def get_rng_state_all() -> List[Tensor]:
|
||||||
r"""Return a list of ByteTensor representing the random number states of all devices."""
|
r"""Return a list of ByteTensor representing the random number states of all devices."""
|
||||||
results = []
|
results = [get_rng_state(i) for i in range(device_count())]
|
||||||
for i in range(device_count()):
|
|
||||||
results.append(get_rng_state(i))
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -31,8 +31,9 @@ class ShardedOptimizer(optim.Optimizer):
|
||||||
tensors: List[Tensor] = []
|
tensors: List[Tensor] = []
|
||||||
for value in named_params.values():
|
for value in named_params.values():
|
||||||
if isinstance(value, ShardedTensor):
|
if isinstance(value, ShardedTensor):
|
||||||
for local_shard in value.local_shards():
|
tensors.extend(
|
||||||
tensors.append(local_shard.tensor)
|
local_shard.tensor for local_shard in value.local_shards()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
tensors.append(value)
|
tensors.append(value)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -99,9 +99,10 @@ def sharded_type_as(args, kwargs, pg):
|
||||||
tensor = args[1]
|
tensor = args[1]
|
||||||
if isinstance(tensor, ShardedTensor):
|
if isinstance(tensor, ShardedTensor):
|
||||||
tensor = tensor.local_tensor()
|
tensor = tensor.local_tensor()
|
||||||
new_local_shards = []
|
new_local_shards = [
|
||||||
for shard in st.local_shards():
|
Shard(shard.tensor.type_as(tensor), shard.metadata)
|
||||||
new_local_shards.append(Shard(shard.tensor.type_as(tensor), shard.metadata))
|
for shard in st.local_shards()
|
||||||
|
]
|
||||||
st_meta = copy.deepcopy(st._metadata)
|
st_meta = copy.deepcopy(st._metadata)
|
||||||
st_meta.tensor_properties.dtype = tensor.dtype
|
st_meta.tensor_properties.dtype = tensor.dtype
|
||||||
return new_local_shards, st_meta
|
return new_local_shards, st_meta
|
||||||
|
|
|
||||||
|
|
@ -191,9 +191,9 @@ def reshard_local_shard(
|
||||||
)
|
)
|
||||||
|
|
||||||
# Compute expected size
|
# Compute expected size
|
||||||
input_split_sizes = []
|
input_split_sizes = [
|
||||||
for metadata in shards_metadata:
|
metadata.shard_sizes[reshard_dim] for metadata in shards_metadata
|
||||||
input_split_sizes.append(metadata.shard_sizes[reshard_dim])
|
]
|
||||||
rearrange_input = any(ranks[i] > ranks[i + 1] for i in range(len(ranks) - 1))
|
rearrange_input = any(ranks[i] > ranks[i + 1] for i in range(len(ranks) - 1))
|
||||||
|
|
||||||
if rearrange_input:
|
if rearrange_input:
|
||||||
|
|
|
||||||
|
|
@ -142,8 +142,7 @@ def _run_trainer(emb_rref_list, rank):
|
||||||
)
|
)
|
||||||
|
|
||||||
# model.parameters() only includes local parameters.
|
# model.parameters() only includes local parameters.
|
||||||
for param in model.parameters():
|
model_parameter_rrefs.extend(RRef(param) for param in model.parameters())
|
||||||
model_parameter_rrefs.append(RRef(param))
|
|
||||||
|
|
||||||
# Setup distributed optimizer
|
# Setup distributed optimizer
|
||||||
opt = DistributedOptimizer(optim.SGD, model_parameter_rrefs, lr=0.05)
|
opt = DistributedOptimizer(optim.SGD, model_parameter_rrefs, lr=0.05)
|
||||||
|
|
|
||||||
|
|
@ -207,8 +207,10 @@ def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan:
|
||||||
if isinstance(obj, DTensor):
|
if isinstance(obj, DTensor):
|
||||||
requests.append(_create_write_items_for_dtensor(fqn, obj))
|
requests.append(_create_write_items_for_dtensor(fqn, obj))
|
||||||
elif isinstance(obj, ShardedTensor):
|
elif isinstance(obj, ShardedTensor):
|
||||||
for shard_md in obj.metadata().shards_metadata:
|
requests.extend(
|
||||||
requests.append(_create_write_item_for_shard(fqn, obj, shard_md))
|
_create_write_item_for_shard(fqn, obj, shard_md)
|
||||||
|
for shard_md in obj.metadata().shards_metadata
|
||||||
|
)
|
||||||
elif isinstance(obj, torch.Tensor):
|
elif isinstance(obj, torch.Tensor):
|
||||||
requests.append(_create_write_item_for_tensor(fqn, obj))
|
requests.append(_create_write_item_for_tensor(fqn, obj))
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -1371,9 +1371,9 @@ def _get_all_pg_configs() -> List[Dict[str, Any]]:
|
||||||
Return the pg configuration of all the process groups.
|
Return the pg configuration of all the process groups.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
config_info: List[Dict[str, Any]] = []
|
config_info: List[Dict[str, Any]] = [
|
||||||
for pg in _world.pg_map.keys():
|
_get_pg_config(pg) for pg in _world.pg_map.keys()
|
||||||
config_info.append(_get_pg_config(pg))
|
]
|
||||||
return config_info
|
return config_info
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -2508,9 +2508,7 @@ def _coalescing_manager(
|
||||||
# - coalesced `reduce_scatter_tensor`
|
# - coalesced `reduce_scatter_tensor`
|
||||||
op0 = op_list[0].op
|
op0 = op_list[0].op
|
||||||
if op0 == all_reduce:
|
if op0 == all_reduce:
|
||||||
tensors = []
|
tensors = [op.tensor for op in op_list]
|
||||||
for op in op_list:
|
|
||||||
tensors.append(op.tensor)
|
|
||||||
all_reduce_opts = AllreduceCoalescedOptions()
|
all_reduce_opts = AllreduceCoalescedOptions()
|
||||||
all_reduce_opts.reduceOp = not_none(op_list[0].redop)
|
all_reduce_opts.reduceOp = not_none(op_list[0].redop)
|
||||||
work = group.allreduce_coalesced(tensors, all_reduce_opts)
|
work = group.allreduce_coalesced(tensors, all_reduce_opts)
|
||||||
|
|
|
||||||
|
|
@ -568,10 +568,11 @@ def _root_pre_forward(
|
||||||
state._needs_buffer_dtype_restore_check = False
|
state._needs_buffer_dtype_restore_check = False
|
||||||
|
|
||||||
if state.forward_prefetch:
|
if state.forward_prefetch:
|
||||||
handles = []
|
handles = [
|
||||||
for fsdp_state in state._all_fsdp_states:
|
fsdp_state._handle
|
||||||
if fsdp_state._handle:
|
for fsdp_state in state._all_fsdp_states
|
||||||
handles.append(fsdp_state._handle)
|
if fsdp_state._handle
|
||||||
|
]
|
||||||
for handle in handles:
|
for handle in handles:
|
||||||
handle._needs_pre_forward_unshard = True
|
handle._needs_pre_forward_unshard = True
|
||||||
handle._prefetched = False
|
handle._prefetched = False
|
||||||
|
|
|
||||||
|
|
@ -110,9 +110,9 @@ def _create_module_with_interface(
|
||||||
|
|
||||||
|
|
||||||
def _param_rrefs(module_rref, recurse) -> List[rpc.RRef[Parameter]]:
|
def _param_rrefs(module_rref, recurse) -> List[rpc.RRef[Parameter]]:
|
||||||
ret: List[rpc.RRef[Parameter]] = []
|
ret: List[rpc.RRef[Parameter]] = [
|
||||||
for param in module_rref.local_value().parameters(recurse):
|
rpc.RRef(param) for param in module_rref.local_value().parameters(recurse)
|
||||||
ret.append(rpc.RRef(param))
|
]
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -146,9 +146,7 @@ class _NamedOptimizer(optim.Optimizer):
|
||||||
|
|
||||||
ret_groups = []
|
ret_groups = []
|
||||||
for group in param_groups:
|
for group in param_groups:
|
||||||
param_keys = []
|
param_keys = [self.ordered_param_keys[param] for param in group["params"]]
|
||||||
for param in group["params"]:
|
|
||||||
param_keys.append(self.ordered_param_keys[param])
|
|
||||||
ret_group = {"params": sorted(param_keys)}
|
ret_group = {"params": sorted(param_keys)}
|
||||||
for k, v in group.items():
|
for k, v in group.items():
|
||||||
if k != "params":
|
if k != "params":
|
||||||
|
|
|
||||||
|
|
@ -245,13 +245,12 @@ class DistributedOptimizer:
|
||||||
else _local_optimizer_step
|
else _local_optimizer_step
|
||||||
)
|
)
|
||||||
|
|
||||||
rpc_futs = []
|
rpc_futs = [
|
||||||
for optimizer in self.remote_optimizers:
|
rpc.rpc_async(
|
||||||
rpc_futs.append(
|
optimizer.owner(),
|
||||||
rpc.rpc_async(
|
optimizer_step_func,
|
||||||
optimizer.owner(),
|
args=(optimizer, context_id),
|
||||||
optimizer_step_func,
|
|
||||||
args=(optimizer, context_id),
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for optimizer in self.remote_optimizers
|
||||||
|
]
|
||||||
_wait_for_all(rpc_futs)
|
_wait_for_all(rpc_futs)
|
||||||
|
|
|
||||||
|
|
@ -781,15 +781,15 @@ class ZeroRedundancyOptimizer(Optimizer, Joinable):
|
||||||
self.process_group, rank
|
self.process_group, rank
|
||||||
)
|
)
|
||||||
for param_group in param_groups:
|
for param_group in param_groups:
|
||||||
for param in param_group["params"]:
|
handles.extend(
|
||||||
handles.append(
|
dist.broadcast(
|
||||||
dist.broadcast(
|
tensor=param.data,
|
||||||
tensor=param.data,
|
src=global_rank,
|
||||||
src=global_rank,
|
group=self.process_group,
|
||||||
group=self.process_group,
|
async_op=True,
|
||||||
async_op=True,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
|
for param in param_group["params"]
|
||||||
|
)
|
||||||
return handles
|
return handles
|
||||||
|
|
||||||
def _sync_params(self):
|
def _sync_params(self):
|
||||||
|
|
|
||||||
|
|
@ -224,9 +224,7 @@ def _shard_dict_of_args(
|
||||||
for chunk_idx in range(real_num_chunks):
|
for chunk_idx in range(real_num_chunks):
|
||||||
chunk_args = {}
|
chunk_args = {}
|
||||||
for key, arg in args_sharded_replicated.items():
|
for key, arg in args_sharded_replicated.items():
|
||||||
arg_single_chunk = []
|
arg_single_chunk = [v_flat[chunk_idx] for v_flat in arg]
|
||||||
for v_flat in arg:
|
|
||||||
arg_single_chunk.append(v_flat[chunk_idx])
|
|
||||||
chunk_args[key] = arg_single_chunk
|
chunk_args[key] = arg_single_chunk
|
||||||
chunks_flat.append(chunk_args)
|
chunks_flat.append(chunk_args)
|
||||||
|
|
||||||
|
|
@ -340,9 +338,10 @@ def split_args_kwargs_into_chunks(
|
||||||
f"{len(args_split_dict)}, {len(kwargs_split)}"
|
f"{len(args_split_dict)}, {len(kwargs_split)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
args_split = []
|
args_split = [
|
||||||
for chunk_args in args_split_dict:
|
tuple(chunk_args[i] for i in range(len(chunk_args)))
|
||||||
args_split.append(tuple(chunk_args[i] for i in range(len(chunk_args))))
|
for chunk_args in args_split_dict
|
||||||
|
]
|
||||||
|
|
||||||
return args_split, kwargs_split
|
return args_split, kwargs_split
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1702,16 +1702,14 @@ class ScheduleLoopedBFS(PipelineScheduleMulti):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store the list of operations used for that rank
|
# Store the list of operations used for that rank
|
||||||
rank_ops: List[Optional[_Action]] = []
|
|
||||||
# Pre-padding, rank starts with no-ops based on the warmup.
|
# Pre-padding, rank starts with no-ops based on the warmup.
|
||||||
for _ in range(rank):
|
rank_ops: List[Optional[_Action]] = [None for _ in range(rank)]
|
||||||
rank_ops.append(None)
|
|
||||||
|
|
||||||
for stage_index in stage_indices:
|
for stage_index in stage_indices:
|
||||||
for mb_index in range(self._n_microbatches):
|
rank_ops.extend(
|
||||||
rank_ops.append(
|
_Action(stage_index, _ComputationType.FORWARD, mb_index)
|
||||||
_Action(stage_index, _ComputationType.FORWARD, mb_index)
|
for mb_index in range(self._n_microbatches)
|
||||||
)
|
)
|
||||||
|
|
||||||
# wait for the first backward to trickle up
|
# wait for the first backward to trickle up
|
||||||
# which is 2 for every hop away
|
# which is 2 for every hop away
|
||||||
|
|
@ -1719,10 +1717,10 @@ class ScheduleLoopedBFS(PipelineScheduleMulti):
|
||||||
rank_ops.extend([None] * post_warmup_ops)
|
rank_ops.extend([None] * post_warmup_ops)
|
||||||
|
|
||||||
for stage_index in reversed(stage_indices):
|
for stage_index in reversed(stage_indices):
|
||||||
for mb_index in reversed(range(self._n_microbatches)):
|
rank_ops.extend(
|
||||||
rank_ops.append(
|
_Action(stage_index, _ComputationType.FULL_BACKWARD, mb_index)
|
||||||
_Action(stage_index, _ComputationType.FULL_BACKWARD, mb_index)
|
for mb_index in reversed(range(self._n_microbatches))
|
||||||
)
|
)
|
||||||
return rank_ops
|
return rank_ops
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -1744,10 +1742,8 @@ def _get_1f1b_rank_ops(
|
||||||
weight_stage_mb_index: Dict[int, int] = defaultdict(int)
|
weight_stage_mb_index: Dict[int, int] = defaultdict(int)
|
||||||
|
|
||||||
# Store the list of operations used for that rank
|
# Store the list of operations used for that rank
|
||||||
rank_ops: List[Optional[_Action]] = []
|
|
||||||
# Pre-padding, rank starts with no-ops based on the warmup.
|
# Pre-padding, rank starts with no-ops based on the warmup.
|
||||||
for _ in range(rank):
|
rank_ops: List[Optional[_Action]] = [None for _ in range(rank)]
|
||||||
rank_ops.append(None)
|
|
||||||
# These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
|
# These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
|
||||||
# when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
|
# when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
|
||||||
# Formula:
|
# Formula:
|
||||||
|
|
|
||||||
|
|
@ -195,8 +195,7 @@ def fill_empty_tensor_to_shards(
|
||||||
size if idx != shard_dim else 0 for idx, size in enumerate(tensor_size)
|
size if idx != shard_dim else 0 for idx, size in enumerate(tensor_size)
|
||||||
]
|
]
|
||||||
tensor = shards[0].new_zeros(tensor_size)
|
tensor = shards[0].new_zeros(tensor_size)
|
||||||
for _ in range(num_empty_tensors):
|
shards.extend(tensor for _ in range(num_empty_tensors))
|
||||||
shards.append(tensor)
|
|
||||||
return shards
|
return shards
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -167,9 +167,7 @@ def gen_einsum_strategies(
|
||||||
# (i.e. for Shard, tensor dim size must > mesh size)
|
# (i.e. for Shard, tensor dim size must > mesh size)
|
||||||
all_strategies = []
|
all_strategies = []
|
||||||
for strategy_comb in strategy_combs:
|
for strategy_comb in strategy_combs:
|
||||||
spec_list = []
|
spec_list = [DTensorSpec(mesh, tuple(specs)) for specs in zip(*strategy_comb)]
|
||||||
for specs in zip(*strategy_comb):
|
|
||||||
spec_list.append(DTensorSpec(mesh, tuple(specs)))
|
|
||||||
strat = PlacementStrategy(output_specs=spec_list[0], input_specs=spec_list[1:])
|
strat = PlacementStrategy(output_specs=spec_list[0], input_specs=spec_list[1:])
|
||||||
all_strategies.append(strat)
|
all_strategies.append(strat)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -43,18 +43,17 @@ def default_strategy(mesh: DeviceMesh, op_schema: OpSchema) -> StrategyType:
|
||||||
# Default strategy by default just propagate the first input strategy
|
# Default strategy by default just propagate the first input strategy
|
||||||
select_strategy = op_schema.args_schema[0]
|
select_strategy = op_schema.args_schema[0]
|
||||||
assert isinstance(select_strategy, OpStrategy)
|
assert isinstance(select_strategy, OpStrategy)
|
||||||
default_strategy = []
|
# we create new DTensorSpecs even for default strategy to assure that
|
||||||
for strategy in select_strategy.strategies:
|
# the tensor metas are distinct between the arguments and outputs
|
||||||
# we create new DTensorSpecs even for default strategy to assure that
|
default_strategy = [
|
||||||
# the tensor metas are distinct between the arguments and outputs
|
PlacementStrategy(
|
||||||
default_strategy.append(
|
output_specs=DTensorSpec(
|
||||||
PlacementStrategy(
|
mesh=strategy.output_spec.mesh,
|
||||||
output_specs=DTensorSpec(
|
placements=strategy.output_spec.placements,
|
||||||
mesh=strategy.output_spec.mesh,
|
|
||||||
placements=strategy.output_spec.placements,
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
for strategy in select_strategy.strategies
|
||||||
|
]
|
||||||
return OpStrategy(default_strategy)
|
return OpStrategy(default_strategy)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -209,9 +209,10 @@ def map_placements_after_broadcast(
|
||||||
def generate_redistribute_costs(
|
def generate_redistribute_costs(
|
||||||
src_strategy: OpStrategy, dst_spec: DTensorSpec
|
src_strategy: OpStrategy, dst_spec: DTensorSpec
|
||||||
) -> List[float]:
|
) -> List[float]:
|
||||||
redistribute_costs: List[float] = []
|
redistribute_costs: List[float] = [
|
||||||
for strat in src_strategy.strategies:
|
redistribute_cost(strat.output_spec, dst_spec)
|
||||||
redistribute_costs.append(redistribute_cost(strat.output_spec, dst_spec))
|
for strat in src_strategy.strategies
|
||||||
|
]
|
||||||
|
|
||||||
return redistribute_costs
|
return redistribute_costs
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -152,9 +152,11 @@ class InterpreterModule(torch.nn.Module):
|
||||||
# the keys in the kwarg dict.
|
# the keys in the kwarg dict.
|
||||||
arg_list = list(args)
|
arg_list = list(args)
|
||||||
kwarg_names = self.arg_names[len(arg_list) :]
|
kwarg_names = self.arg_names[len(arg_list) :]
|
||||||
for kwarg_name in kwarg_names:
|
arg_list.extend(
|
||||||
if kwarg_name in kwargs:
|
kwargs[kwarg_name]
|
||||||
arg_list.append(kwargs[kwarg_name])
|
for kwarg_name in kwarg_names
|
||||||
|
if kwarg_name in kwargs
|
||||||
|
)
|
||||||
|
|
||||||
# Assert that the kwargs passed in exactly match the positional
|
# Assert that the kwargs passed in exactly match the positional
|
||||||
# arguments specified by the GraphModule. This should be
|
# arguments specified by the GraphModule. This should be
|
||||||
|
|
@ -933,9 +935,10 @@ class _ModuleFrame:
|
||||||
assert kwargs_spec.context is not None
|
assert kwargs_spec.context is not None
|
||||||
|
|
||||||
with self.graph.inserting_after(None):
|
with self.graph.inserting_after(None):
|
||||||
arg_nodes = []
|
arg_nodes = [
|
||||||
for idx in range(args_spec.num_children):
|
self.graph.placeholder(f"_positional_arg_{idx}")
|
||||||
arg_nodes.append(self.graph.placeholder(f"_positional_arg_{idx}"))
|
for idx in range(args_spec.num_children)
|
||||||
|
]
|
||||||
kwarg_nodes = {}
|
kwarg_nodes = {}
|
||||||
for name in kwargs_spec.context:
|
for name in kwargs_spec.context:
|
||||||
kwarg_nodes[name] = self.graph.placeholder(name)
|
kwarg_nodes[name] = self.graph.placeholder(name)
|
||||||
|
|
|
||||||
|
|
@ -1015,10 +1015,11 @@ class Partitioner:
|
||||||
# Keep tracking the partition pair of node pair
|
# Keep tracking the partition pair of node pair
|
||||||
partition_pair: List[Partition] = []
|
partition_pair: List[Partition] = []
|
||||||
# Collect all the op nodes from the graph
|
# Collect all the op nodes from the graph
|
||||||
op_nodes = []
|
op_nodes = [
|
||||||
for n in self.graph_module.graph.nodes:
|
n
|
||||||
if n.op not in {"placeholder", "get_attr", "output"}:
|
for n in self.graph_module.graph.nodes
|
||||||
op_nodes.append(n)
|
if n.op not in {"placeholder", "get_attr", "output"}
|
||||||
|
]
|
||||||
for node in op_nodes:
|
for node in op_nodes:
|
||||||
# Find which partition the current node belongs
|
# Find which partition the current node belongs
|
||||||
p0_index = self.node_to_partition[node]
|
p0_index = self.node_to_partition[node]
|
||||||
|
|
|
||||||
|
|
@ -648,11 +648,7 @@ def generate_reshape(constraint, counter):
|
||||||
|
|
||||||
# then there must be exactly one occurrence of dyn
|
# then there must be exactly one occurrence of dyn
|
||||||
else:
|
else:
|
||||||
new_target = []
|
new_target = [n for n in target if n != Dyn]
|
||||||
|
|
||||||
for n in target:
|
|
||||||
if n != Dyn:
|
|
||||||
new_target.append(n)
|
|
||||||
|
|
||||||
# tensor 1
|
# tensor 1
|
||||||
c3_tensor1 = Disj(
|
c3_tensor1 = Disj(
|
||||||
|
|
@ -886,10 +882,8 @@ def generate_all_int_dyn_dim_possibilities(my_list: List[DVar]):
|
||||||
neq_possibilities = [
|
neq_possibilities = [
|
||||||
BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list))
|
BinConstraintD(my_list[i], Dyn, op_neq) for i in range(len(my_list))
|
||||||
]
|
]
|
||||||
d_possibilities = []
|
|
||||||
|
|
||||||
for i in zip(eq_possibilities, neq_possibilities):
|
d_possibilities = [list(i) for i in zip(eq_possibilities, neq_possibilities)]
|
||||||
d_possibilities.append(list(i))
|
|
||||||
all_possibilities = list(itertools.product(*d_possibilities))
|
all_possibilities = list(itertools.product(*d_possibilities))
|
||||||
return all_possibilities
|
return all_possibilities
|
||||||
|
|
||||||
|
|
@ -1043,13 +1037,11 @@ def apply_padding(
|
||||||
|
|
||||||
assert len(simulate_padding + d1) == len(d2)
|
assert len(simulate_padding + d1) == len(d2)
|
||||||
|
|
||||||
broadcast_padding = []
|
|
||||||
|
|
||||||
# for every padding size, we also consider broadcasting
|
# for every padding size, we also consider broadcasting
|
||||||
for j in range(len(d2) - i):
|
broadcast_padding = [
|
||||||
broadcast_padding.append(
|
broadcast_dim(simulate_padding, d2, d11, d12, j, True)
|
||||||
broadcast_dim(simulate_padding, d2, d11, d12, j, True)
|
for j in range(len(d2) - i)
|
||||||
)
|
]
|
||||||
|
|
||||||
# we consider the possibilities for broadcasting for every dimension. Since we already
|
# we consider the possibilities for broadcasting for every dimension. Since we already
|
||||||
# padded d1, we do not consider it while broadcasting
|
# padded d1, we do not consider it while broadcasting
|
||||||
|
|
|
||||||
|
|
@ -302,11 +302,10 @@ def get_latency_of_partitioned_graph(
|
||||||
"""This function is to return all the partitions without parents
|
"""This function is to return all the partitions without parents
|
||||||
as the starting points of all the paths
|
as the starting points of all the paths
|
||||||
"""
|
"""
|
||||||
top_partitions = []
|
# If a partition has no parents, then it is a top partition
|
||||||
for partition in partitions:
|
top_partitions = [
|
||||||
# If a partition has no parents, then it is a top partition
|
partition for partition in partitions if len(partition.parents) == 0
|
||||||
if len(partition.parents) == 0:
|
]
|
||||||
top_partitions.append(partition)
|
|
||||||
return top_partitions
|
return top_partitions
|
||||||
|
|
||||||
top_partitions = get_top_partitions(partitions)
|
top_partitions = get_top_partitions(partitions)
|
||||||
|
|
|
||||||
|
|
@ -5201,10 +5201,9 @@ class ShapeEnv:
|
||||||
symints = {
|
symints = {
|
||||||
s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol)
|
s.node.expr for s in symints if isinstance(s.node.expr, sympy.Symbol)
|
||||||
}
|
}
|
||||||
guards = []
|
guards = [
|
||||||
for g in self.guards:
|
g for g in self.guards if all(s in symints for s in g.expr.free_symbols)
|
||||||
if all(s in symints for s in g.expr.free_symbols):
|
]
|
||||||
guards.append(g)
|
|
||||||
return guards
|
return guards
|
||||||
|
|
||||||
def bind_symbols(
|
def bind_symbols(
|
||||||
|
|
|
||||||
|
|
@ -220,10 +220,7 @@ class PassManager:
|
||||||
def remove_pass(self, _passes: List[str]):
|
def remove_pass(self, _passes: List[str]):
|
||||||
if _passes is None:
|
if _passes is None:
|
||||||
return
|
return
|
||||||
passes_left = []
|
passes_left = [ps for ps in self.passes if ps.__name__ not in _passes]
|
||||||
for ps in self.passes:
|
|
||||||
if ps.__name__ not in _passes:
|
|
||||||
passes_left.append(ps)
|
|
||||||
self.passes = passes_left
|
self.passes = passes_left
|
||||||
self._validated = False
|
self._validated = False
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -881,9 +881,7 @@ def infer_methods_to_compile(nn_module):
|
||||||
uniqued_methods.append(name)
|
uniqued_methods.append(name)
|
||||||
uniquer.add(name)
|
uniquer.add(name)
|
||||||
|
|
||||||
stubs = []
|
stubs = [make_stub_from_method(nn_module, method) for method in uniqued_methods]
|
||||||
for method in uniqued_methods:
|
|
||||||
stubs.append(make_stub_from_method(nn_module, method))
|
|
||||||
return overload_stubs + stubs
|
return overload_stubs + stubs
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -959,9 +957,10 @@ def interface_script(mod_interface, nn_module):
|
||||||
|
|
||||||
It is used to know which methods need to act as starting points for compilation.
|
It is used to know which methods need to act as starting points for compilation.
|
||||||
"""
|
"""
|
||||||
stubs = []
|
stubs = [
|
||||||
for method in mod_interface.getMethodNames():
|
make_stub_from_method(nn_module, method)
|
||||||
stubs.append(make_stub_from_method(nn_module, method))
|
for method in mod_interface.getMethodNames()
|
||||||
|
]
|
||||||
return stubs
|
return stubs
|
||||||
|
|
||||||
return create_script_module(nn_module, infer_interface_methods_to_compile)
|
return create_script_module(nn_module, infer_interface_methods_to_compile)
|
||||||
|
|
|
||||||
|
|
@ -1493,11 +1493,10 @@ def _get_overloads(obj):
|
||||||
_jit_internal.get_overload_no_implementation_error_message("function", obj)
|
_jit_internal.get_overload_no_implementation_error_message("function", obj)
|
||||||
)
|
)
|
||||||
|
|
||||||
compiled_fns = []
|
compiled_fns = [
|
||||||
for overload_fn in uncompiled_overloads:
|
_compile_function_with_overload(overload_fn, qual_name, obj)
|
||||||
compiled_fns.append(
|
for overload_fn in uncompiled_overloads
|
||||||
_compile_function_with_overload(overload_fn, qual_name, obj)
|
]
|
||||||
)
|
|
||||||
|
|
||||||
if existing_compiled_fns:
|
if existing_compiled_fns:
|
||||||
compiled_fns = existing_compiled_fns + compiled_fns
|
compiled_fns = existing_compiled_fns + compiled_fns
|
||||||
|
|
|
||||||
|
|
@ -77,9 +77,7 @@ def _lazy_init() -> None:
|
||||||
# However, we must not let any *other* threads in!
|
# However, we must not let any *other* threads in!
|
||||||
_tls.is_initializing = True
|
_tls.is_initializing = True
|
||||||
|
|
||||||
for calls in _lazy_seed_tracker.get_calls():
|
_queued_calls.extend(calls for calls in _lazy_seed_tracker.get_calls() if calls)
|
||||||
if calls:
|
|
||||||
_queued_calls.append(calls)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for queued_call, orig_traceback in _queued_calls:
|
for queued_call, orig_traceback in _queued_calls:
|
||||||
|
|
|
||||||
|
|
@ -23,8 +23,7 @@ class Broadcast(Function):
|
||||||
non_differentiables = []
|
non_differentiables = []
|
||||||
for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]):
|
for idx, input_requires_grad in enumerate(ctx.needs_input_grad[1:]):
|
||||||
if not input_requires_grad:
|
if not input_requires_grad:
|
||||||
for output in outputs:
|
non_differentiables.extend(output[idx] for output in outputs)
|
||||||
non_differentiables.append(output[idx])
|
|
||||||
ctx.mark_non_differentiable(*non_differentiables)
|
ctx.mark_non_differentiable(*non_differentiables)
|
||||||
return tuple([t for tensors in outputs for t in tensors])
|
return tuple([t for tensors in outputs for t in tensors])
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -142,9 +142,7 @@ def conv_backward(func, ctx, grad_output):
|
||||||
ctx.groups,
|
ctx.groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
kernel_size = []
|
kernel_size = [weight_shape[i] for i in range(2, conv_picker(func, 3, 4, 5))]
|
||||||
for i in range(2, conv_picker(func, 3, 4, 5)):
|
|
||||||
kernel_size.append(weight_shape[i])
|
|
||||||
|
|
||||||
batch_size = ctx.batch_size
|
batch_size = ctx.batch_size
|
||||||
results: List[Optional[torch.Tensor]] = []
|
results: List[Optional[torch.Tensor]] = []
|
||||||
|
|
|
||||||
|
|
@ -356,12 +356,14 @@ def quantized_args(
|
||||||
return descriptor and _is_value(arg) and _is_tuple_construct(arg)
|
return descriptor and _is_value(arg) and _is_tuple_construct(arg)
|
||||||
|
|
||||||
# Run regular symbolic function if none of the argument is QTensor.
|
# Run regular symbolic function if none of the argument is QTensor.
|
||||||
is_quantized = []
|
is_quantized: typing.List[bool] = []
|
||||||
for descriptor, arg in descriptor_args:
|
for descriptor, arg in descriptor_args:
|
||||||
# ListConstruct
|
# ListConstruct
|
||||||
if _is_packed_list(arg):
|
if _is_packed_list(arg):
|
||||||
for arg_input in arg.node().inputs():
|
is_quantized.extend(
|
||||||
is_quantized.append(_is_arg_quantized(descriptor, arg_input))
|
_is_arg_quantized(descriptor, arg_input)
|
||||||
|
for arg_input in arg.node().inputs()
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
is_quantized.append(_is_arg_quantized(descriptor, arg))
|
is_quantized.append(_is_arg_quantized(descriptor, arg))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -67,10 +67,9 @@ def col2im(
|
||||||
stride: Sequence[int],
|
stride: Sequence[int],
|
||||||
):
|
):
|
||||||
# convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]
|
# convert [i0, i1, ..., in] into [i0, i0, i1, i1, ..., in, in]
|
||||||
adjusted_padding = []
|
adjusted_padding: List[int] = []
|
||||||
for pad in padding:
|
for pad in padding:
|
||||||
for _ in range(2):
|
adjusted_padding.extend(pad for _ in range(2))
|
||||||
adjusted_padding.append(pad)
|
|
||||||
|
|
||||||
num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
|
num_dimensional_axis = symbolic_helper._get_tensor_sizes(output_size)[0]
|
||||||
if not adjusted_padding:
|
if not adjusted_padding:
|
||||||
|
|
|
||||||
|
|
@ -1397,8 +1397,6 @@ class GraphInfo:
|
||||||
original_outputs = list(graph.outputs())
|
original_outputs = list(graph.outputs())
|
||||||
original_inputs = list(graph.inputs())
|
original_inputs = list(graph.inputs())
|
||||||
|
|
||||||
new_outputs = []
|
|
||||||
|
|
||||||
def _process_bridge_value_for_lower(
|
def _process_bridge_value_for_lower(
|
||||||
graph: torch.Graph, bridge_value: torch.Value
|
graph: torch.Graph, bridge_value: torch.Value
|
||||||
) -> torch.Value:
|
) -> torch.Value:
|
||||||
|
|
@ -1416,9 +1414,9 @@ class GraphInfo:
|
||||||
graph, pivot, process_bridge_value_for_lower
|
graph, pivot, process_bridge_value_for_lower
|
||||||
)
|
)
|
||||||
|
|
||||||
for output in original_outputs:
|
new_outputs = [
|
||||||
if _produced_by(output, lower_nodes):
|
output for output in original_outputs if _produced_by(output, lower_nodes)
|
||||||
new_outputs.append(output)
|
]
|
||||||
for _ in enumerate(original_outputs):
|
for _ in enumerate(original_outputs):
|
||||||
graph.eraseOutput(0)
|
graph.eraseOutput(0)
|
||||||
for output in new_outputs:
|
for output in new_outputs:
|
||||||
|
|
|
||||||
|
|
@ -50,10 +50,11 @@ class DirectoryReader:
|
||||||
def get_all_records(
|
def get_all_records(
|
||||||
self,
|
self,
|
||||||
):
|
):
|
||||||
files = []
|
files = [
|
||||||
for filename in glob(f"{self.directory}/**", recursive=True):
|
filename[len(self.directory) + 1 :]
|
||||||
if not os.path.isdir(filename):
|
for filename in glob(f"{self.directory}/**", recursive=True)
|
||||||
files.append(filename[len(self.directory) + 1 :])
|
if not os.path.isdir(filename)
|
||||||
|
]
|
||||||
return files
|
return files
|
||||||
|
|
||||||
def serialization_id(
|
def serialization_id(
|
||||||
|
|
|
||||||
|
|
@ -53,16 +53,16 @@ class _ExtractModuleReferences(ast.NodeVisitor):
|
||||||
if hasattr(node.func, "id") and node.func.id == "__import__":
|
if hasattr(node.func, "id") and node.func.id == "__import__":
|
||||||
try:
|
try:
|
||||||
name = self._grab_node_str(node.args[0])
|
name = self._grab_node_str(node.args[0])
|
||||||
fromlist = []
|
fromlist: List[str] = []
|
||||||
level = 0
|
level = 0
|
||||||
if len(node.args) > 3:
|
if len(node.args) > 3:
|
||||||
for v in node.args[3].elts:
|
fromlist.extend(self._grab_node_str(v) for v in node.args[3].elts)
|
||||||
fromlist.append(self._grab_node_str(v))
|
|
||||||
elif hasattr(node, "keywords"):
|
elif hasattr(node, "keywords"):
|
||||||
for keyword in node.keywords:
|
for keyword in node.keywords:
|
||||||
if keyword.arg == "fromlist":
|
if keyword.arg == "fromlist":
|
||||||
for v in keyword.value.elts:
|
fromlist.extend(
|
||||||
fromlist.append(self._grab_node_str(v))
|
self._grab_node_str(v) for v in keyword.value.elts
|
||||||
|
)
|
||||||
if len(node.args) > 4:
|
if len(node.args) > 4:
|
||||||
level = self._grab_node_int(node.args[4])
|
level = self._grab_node_int(node.args[4])
|
||||||
elif hasattr(node, "keywords"):
|
elif hasattr(node, "keywords"):
|
||||||
|
|
|
||||||
|
|
@ -97,10 +97,9 @@ class Pattern:
|
||||||
def matched_events(self):
|
def matched_events(self):
|
||||||
if self.skip:
|
if self.skip:
|
||||||
return []
|
return []
|
||||||
matched_events = []
|
matched_events = [
|
||||||
for event in self.eventTreeTraversal():
|
event for event in self.eventTreeTraversal() if self.match(event)
|
||||||
if self.match(event):
|
]
|
||||||
matched_events.append(event)
|
|
||||||
return matched_events
|
return matched_events
|
||||||
|
|
||||||
def root_of(self, event: _ProfilerEvent):
|
def root_of(self, event: _ProfilerEvent):
|
||||||
|
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user