mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Add stack trace to Inductor IR nodes if inductor.config.trace.provenance_tracing=True (#158576)
Summary:
- Split `create_mapping` to `create_mapping_pre_post_grad_nodes` and ` create_node_mapping_kernel_to_post_grad`
- Store a mapping from pre_grad graph node names to stack traces in `_inductor_pre_grad_node_stack_trace`
- Add `stack_traces` member to ir.Node and add it to the string representation of ir.Node
- When we create an IR node, if `inductor.config.trace.provenance_tracing=True`, we populate `stack_traces` from `origins`. The nodes in `origins` are post_grad graph nodes. If a node has `node.stack_trace`, we store the stack_trace directly. This is particularly important for backward graph nodes because they don't have a mapping to pre-grad graph nodes. If a node doesn't have `.stack_trace ` (such as `linear`-> `addmm` nodes), we use the stack trace of the pre_grad graph nodes that it maps to.
- A post grad graph node might not have stack trace if it correspond to multiple pre grad graph nodes, e.g. [GroupLinearFusion](a00442421a/torch/_inductor/fx_passes/group_batch_fusion.py (L299))
Example:
```
scheduling ExternKernelOut(
python_kernel_name='extern_kernels.mm',
name=buf0,
layout=FixedLayout('cuda:0', torch.float32, size=[8, 16], stride=[16, 1]),
inputs=[InputBuffer(name='arg2_1', layout=FixedLayout('cuda:0', torch.float32, size=[8, 10], stride=[10, 1])), ReinterpretView(
StorageBox(
ConstantBuffer(name='fc1_weight', layout=FixedLayout('cuda:0', torch.float32, size=[16, 10], stride=[10, 1]))
),
FixedLayout('cuda:0', torch.float32, size=[10, 16], stride=[1, 10]),
origins=OrderedSet([mm_default_1]),
stack_traces = {,
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/7b4b7a52e15abb17/scripts/shangdiy/__aot__/aot#link-tree/scripts/shangdiy/aot.py", line 29, in forward,
x = self.fc1(x),
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/7b4b7a52e15abb17/scripts/shangdiy/__aot__/aot#link-tree/torch/nn/modules/linear.py", line 125, in forward,
return F.linear(input, self.weight, self.bias),
}
)],
constant_args=(),
kwargs={},
output_view=None,
python_kernel_name=extern_kernels.mm,
cpp_kernel_name=at::mm_out,
ordered_kwargs_for_cpp_kernel=(),
op_overload=None,
arg_properties=[{}, {}],
allarg_properties={},
kwarg_properties=None,
unbacked_bindings={},
mutation_outputs=[],
origin_node=mm_default_1,
origins=OrderedSet([mm_default_1]),
stack_traces = {,
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/7b4b7a52e15abb17/scripts/shangdiy/__aot__/aot#link-tree/scripts/shangdiy/aot.py", line 29, in forward,
x = self.fc1(x),
File "/data/users/shangdiy/fbsource/buck-out/v2/gen/fbcode/7b4b7a52e15abb17/scripts/shangdiy/__aot__/aot#link-tree/torch/nn/modules/linear.py", line 125, in forward,
return F.linear(input, self.weight, self.bias),
}
)
```
Test Plan:
```
buck2 run mode/dev-nosan fbcode//caffe2/test/inductor:provenance_tracing
```
Rollback Plan:
Differential Revision: D78365534
Pull Request resolved: https://github.com/pytorch/pytorch/pull/158576
Approved by: https://github.com/angelayi
This commit is contained in:
parent
86dbc0ef67
commit
1e86fa2e5b
|
|
@ -10,7 +10,10 @@ from pathlib import Path
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch._inductor import config
|
from torch._inductor import config
|
||||||
from torch._inductor.debug import create_node_mapping
|
from torch._inductor.debug import (
|
||||||
|
create_mapping_pre_post_grad_nodes,
|
||||||
|
create_node_mapping_kernel_to_post_grad,
|
||||||
|
)
|
||||||
from torch._inductor.test_case import run_tests, TestCase
|
from torch._inductor.test_case import run_tests, TestCase
|
||||||
from torch.testing._internal.inductor_utils import HAS_GPU
|
from torch.testing._internal.inductor_utils import HAS_GPU
|
||||||
from torch.testing._internal.triton_utils import requires_cuda
|
from torch.testing._internal.triton_utils import requires_cuda
|
||||||
|
|
@ -386,11 +389,17 @@ class TestProvenanceTracingNodeMapping(TestCase):
|
||||||
"triton_poi_fused_addmm_relu_sigmoid_0": ["relu", "add_tensor"]
|
"triton_poi_fused_addmm_relu_sigmoid_0": ["relu", "add_tensor"]
|
||||||
}
|
}
|
||||||
|
|
||||||
result = create_node_mapping(
|
result = create_mapping_pre_post_grad_nodes(
|
||||||
pre_grad_graph_id,
|
pre_grad_graph_id,
|
||||||
post_to_pre_grad_nodes_json,
|
post_to_pre_grad_nodes_json,
|
||||||
triton_kernel_to_post_grad_json,
|
|
||||||
)
|
)
|
||||||
|
result = {
|
||||||
|
**result,
|
||||||
|
**create_node_mapping_kernel_to_post_grad(
|
||||||
|
triton_kernel_to_post_grad_json,
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
result,
|
result,
|
||||||
{
|
{
|
||||||
|
|
|
||||||
|
|
@ -1033,17 +1033,13 @@ def _compile_fx_inner(
|
||||||
provenance_info = torch._inductor.debug.dump_inductor_provenance_info()
|
provenance_info = torch._inductor.debug.dump_inductor_provenance_info()
|
||||||
# provenance_info might be None if trace.provenance_tracking is not set
|
# provenance_info might be None if trace.provenance_tracking is not set
|
||||||
if provenance_info:
|
if provenance_info:
|
||||||
(
|
|
||||||
_,
|
|
||||||
node_mappings,
|
|
||||||
) = provenance_info
|
|
||||||
trace_structured(
|
trace_structured(
|
||||||
"artifact",
|
"artifact",
|
||||||
metadata_fn=lambda: {
|
metadata_fn=lambda: {
|
||||||
"name": "inductor_provenance_tracking_node_mappings",
|
"name": "inductor_provenance_tracking_node_mappings",
|
||||||
"encoding": "json",
|
"encoding": "json",
|
||||||
},
|
},
|
||||||
payload_fn=lambda: json.dumps(node_mappings),
|
payload_fn=lambda: json.dumps(provenance_info),
|
||||||
)
|
)
|
||||||
|
|
||||||
# This message is for printing overview information of inductor mm counts, shapes,etc after lowering
|
# This message is for printing overview information of inductor mm counts, shapes,etc after lowering
|
||||||
|
|
@ -1299,8 +1295,13 @@ class _InProcessFxCompile(FxCompile):
|
||||||
},
|
},
|
||||||
payload_fn=lambda: json.dumps(provenance_tracking_json),
|
payload_fn=lambda: json.dumps(provenance_tracking_json),
|
||||||
)
|
)
|
||||||
|
from torch._inductor.debug import create_mapping_pre_post_grad_nodes
|
||||||
|
|
||||||
torch._inductor.debug._inductor_post_to_pre_grad_nodes = (
|
torch._inductor.debug._inductor_post_to_pre_grad_nodes = (
|
||||||
provenance_tracking_json
|
create_mapping_pre_post_grad_nodes(
|
||||||
|
torch._inductor.debug._pre_grad_graph_id,
|
||||||
|
provenance_tracking_json,
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
metrics_context = get_metrics_context()
|
metrics_context = get_metrics_context()
|
||||||
|
|
@ -2174,6 +2175,13 @@ def compile_fx(
|
||||||
)
|
)
|
||||||
torch._inductor.debug._pre_grad_graph_id = id(model_.graph)
|
torch._inductor.debug._pre_grad_graph_id = id(model_.graph)
|
||||||
|
|
||||||
|
if config.trace.provenance_tracking:
|
||||||
|
for node in model_.graph.nodes:
|
||||||
|
if node.stack_trace:
|
||||||
|
torch._inductor.debug._inductor_pre_grad_node_stack_trace[
|
||||||
|
node.name
|
||||||
|
] = node.stack_trace
|
||||||
|
|
||||||
model_ = _recursive_pre_grad_passes(model_, example_inputs_)
|
model_ = _recursive_pre_grad_passes(model_, example_inputs_)
|
||||||
trace_structured(
|
trace_structured(
|
||||||
"artifact",
|
"artifact",
|
||||||
|
|
|
||||||
|
|
@ -316,6 +316,7 @@ def enable_aot_logging() -> Iterator[None]:
|
||||||
_inductor_post_to_pre_grad_nodes: dict[str, Any] = {}
|
_inductor_post_to_pre_grad_nodes: dict[str, Any] = {}
|
||||||
_inductor_triton_kernel_to_post_grad_node_info: dict[str, Any] = {}
|
_inductor_triton_kernel_to_post_grad_node_info: dict[str, Any] = {}
|
||||||
_pre_grad_graph_id: Optional[int] = None
|
_pre_grad_graph_id: Optional[int] = None
|
||||||
|
_inductor_pre_grad_node_stack_trace: dict[str, str] = {}
|
||||||
|
|
||||||
|
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
|
|
@ -701,23 +702,18 @@ class TensorMetadataHolder:
|
||||||
save_args_cnt = itertools.count()
|
save_args_cnt = itertools.count()
|
||||||
|
|
||||||
|
|
||||||
def create_node_mapping(
|
def create_mapping_pre_post_grad_nodes(
|
||||||
pre_grad_graph_id: int,
|
pre_grad_graph_id: Optional[int],
|
||||||
post_to_pre_grad_nodes_json: dict[str, Any],
|
post_to_pre_grad_nodes_json: dict[str, Any],
|
||||||
triton_kernel_to_post_grad_json: dict[str, Any],
|
|
||||||
) -> dict[str, dict[str, Any]]:
|
) -> dict[str, dict[str, Any]]:
|
||||||
"""Create bidirectional mappings between:
|
|
||||||
|
|
||||||
- pre_grad graph nodes and post_grad graph code nodes, and vice versa
|
|
||||||
- triton kernel name and post_grad graph code nodes, and vice versa
|
|
||||||
"""
|
"""
|
||||||
|
Create bidirectional mappings between pre_grad graph nodes
|
||||||
|
and post_grad graph code nodes, and vice versa.
|
||||||
|
"""
|
||||||
# return a dummy dict if there's any error
|
# return a dummy dict if there's any error
|
||||||
empty_return: dict[str, dict[str, Any]] = {
|
empty_return: dict[str, dict[str, Any]] = {
|
||||||
"preToPost": {},
|
"preToPost": {},
|
||||||
"postToPre": {},
|
"postToPre": {},
|
||||||
"cppCodeToPost": {},
|
|
||||||
"postToCppCode": {},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.info("Creating node mappings for provenance tracking")
|
log.info("Creating node mappings for provenance tracking")
|
||||||
|
|
@ -726,12 +722,6 @@ def create_node_mapping(
|
||||||
log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict")
|
log.error("Provenance tacking error: post_to_pre_grad_nodes_json is not a dict")
|
||||||
return empty_return
|
return empty_return
|
||||||
|
|
||||||
if not isinstance(triton_kernel_to_post_grad_json, dict):
|
|
||||||
log.error(
|
|
||||||
"Provenance tacking error: triton_kernel_to_post_grad_json is not a dict"
|
|
||||||
)
|
|
||||||
return empty_return
|
|
||||||
|
|
||||||
if not isinstance(pre_grad_graph_id, int):
|
if not isinstance(pre_grad_graph_id, int):
|
||||||
log.error("Provenance tacking error: pre_grad_graph_id is not an int")
|
log.error("Provenance tacking error: pre_grad_graph_id is not an int")
|
||||||
return empty_return
|
return empty_return
|
||||||
|
|
@ -739,17 +729,7 @@ def create_node_mapping(
|
||||||
pre_to_post: dict[str, Any] = collections.defaultdict(OrderedSet)
|
pre_to_post: dict[str, Any] = collections.defaultdict(OrderedSet)
|
||||||
post_to_pre: dict[str, Any] = collections.defaultdict(OrderedSet)
|
post_to_pre: dict[str, Any] = collections.defaultdict(OrderedSet)
|
||||||
|
|
||||||
post_to_cpp_code: dict[str, Any] = collections.defaultdict(OrderedSet)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for outer_key, node_array in triton_kernel_to_post_grad_json.items():
|
|
||||||
if not isinstance(node_array, list):
|
|
||||||
log.error(
|
|
||||||
"Provenance tacking error: triton_kernel_to_post_grad_json value is not a list"
|
|
||||||
)
|
|
||||||
return empty_return
|
|
||||||
for curr_node in node_array:
|
|
||||||
post_to_cpp_code[curr_node].add(outer_key)
|
|
||||||
|
|
||||||
def check_format(node: dict[str, Any]) -> bool:
|
def check_format(node: dict[str, Any]) -> bool:
|
||||||
if not isinstance(node, dict):
|
if not isinstance(node, dict):
|
||||||
|
|
@ -799,10 +779,61 @@ def create_node_mapping(
|
||||||
# convert to list because set is not JSON serializable
|
# convert to list because set is not JSON serializable
|
||||||
convert_sets_to_lists(pre_to_post)
|
convert_sets_to_lists(pre_to_post)
|
||||||
convert_sets_to_lists(post_to_pre)
|
convert_sets_to_lists(post_to_pre)
|
||||||
convert_sets_to_lists(post_to_cpp_code)
|
|
||||||
return {
|
return {
|
||||||
"preToPost": pre_to_post,
|
"preToPost": pre_to_post,
|
||||||
"postToPre": post_to_pre,
|
"postToPre": post_to_pre,
|
||||||
|
}
|
||||||
|
except Exception as e:
|
||||||
|
# Since this is just logging code, it should never interfere with regular
|
||||||
|
# program execution, so we use this try-except to guard against any error
|
||||||
|
log.error("Unexpected error in create_node_mapping: %s", e)
|
||||||
|
log.error("post_to_pre_grad_nodes_json: %s", post_to_pre_grad_nodes_json)
|
||||||
|
log.error("pre_grad_graph_id: %s", pre_grad_graph_id)
|
||||||
|
log.error(traceback.format_exc())
|
||||||
|
return empty_return
|
||||||
|
|
||||||
|
|
||||||
|
def create_node_mapping_kernel_to_post_grad(
|
||||||
|
triton_kernel_to_post_grad_json: dict[str, Any],
|
||||||
|
) -> dict[str, dict[str, Any]]:
|
||||||
|
"""Create bidirectional mappings between triton kernel name and post_grad
|
||||||
|
graph code nodes, and vice versa.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# return a dummy dict if there's any error
|
||||||
|
empty_return: dict[str, dict[str, Any]] = {
|
||||||
|
"cppCodeToPost": {},
|
||||||
|
"postToCppCode": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
log.info("Creating node mappings for provenance tracking")
|
||||||
|
|
||||||
|
if not isinstance(triton_kernel_to_post_grad_json, dict):
|
||||||
|
log.error(
|
||||||
|
"Provenance tacking error: triton_kernel_to_post_grad_json is not a dict"
|
||||||
|
)
|
||||||
|
return empty_return
|
||||||
|
|
||||||
|
post_to_cpp_code: dict[str, Any] = collections.defaultdict(OrderedSet)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for outer_key, node_array in triton_kernel_to_post_grad_json.items():
|
||||||
|
if not isinstance(node_array, list):
|
||||||
|
log.error(
|
||||||
|
"Provenance tacking error: triton_kernel_to_post_grad_json value is not a list"
|
||||||
|
)
|
||||||
|
return empty_return
|
||||||
|
for curr_node in node_array:
|
||||||
|
post_to_cpp_code[curr_node].add(outer_key)
|
||||||
|
|
||||||
|
def convert_sets_to_lists(d: dict[str, Any]) -> None:
|
||||||
|
for key in d:
|
||||||
|
d[key] = list(d[key])
|
||||||
|
d = dict(d)
|
||||||
|
|
||||||
|
# convert to list because set is not JSON serializable
|
||||||
|
convert_sets_to_lists(post_to_cpp_code)
|
||||||
|
return {
|
||||||
"cppCodeToPost": triton_kernel_to_post_grad_json,
|
"cppCodeToPost": triton_kernel_to_post_grad_json,
|
||||||
"postToCppCode": post_to_cpp_code,
|
"postToCppCode": post_to_cpp_code,
|
||||||
}
|
}
|
||||||
|
|
@ -810,37 +841,38 @@ def create_node_mapping(
|
||||||
# Since this is just logging code, it should never interfere with regular
|
# Since this is just logging code, it should never interfere with regular
|
||||||
# program execution, so we use this try-except to guard against any error
|
# program execution, so we use this try-except to guard against any error
|
||||||
log.error("Unexpected error in create_node_mapping: %s", e)
|
log.error("Unexpected error in create_node_mapping: %s", e)
|
||||||
log.error("post_to_pre_grad_nodes_json: %s", post_to_pre_grad_nodes_json)
|
|
||||||
log.error(
|
log.error(
|
||||||
"triton_kernel_to_post_grad_json: %s", triton_kernel_to_post_grad_json
|
"triton_kernel_to_post_grad_json: %s", triton_kernel_to_post_grad_json
|
||||||
)
|
)
|
||||||
log.error("pre_grad_graph_id: %s", pre_grad_graph_id)
|
|
||||||
log.error(traceback.format_exc())
|
log.error(traceback.format_exc())
|
||||||
return empty_return
|
return empty_return
|
||||||
|
|
||||||
|
|
||||||
def dump_inductor_provenance_info(
|
def dump_inductor_provenance_info(
|
||||||
filename: str = "inductor_generated_kernel_to_post_grad_nodes.json",
|
filename: str = "inductor_generated_kernel_to_post_grad_nodes.json",
|
||||||
) -> tuple[dict[str, list[str]], dict[str, Any]]:
|
) -> dict[str, Any]:
|
||||||
global _pre_grad_graph_id
|
global _pre_grad_graph_id
|
||||||
global _inductor_post_to_pre_grad_nodes
|
global _inductor_post_to_pre_grad_nodes
|
||||||
global _inductor_triton_kernel_to_post_grad_node_info
|
global _inductor_triton_kernel_to_post_grad_node_info
|
||||||
debug_info = _inductor_triton_kernel_to_post_grad_node_info.copy()
|
|
||||||
if config.trace.enabled:
|
if config.trace.enabled:
|
||||||
with V.debug.fopen(filename, "w") as fd:
|
with V.debug.fopen(filename, "w") as fd:
|
||||||
log.info("Writing provenance tracing debugging info to %s", fd.name)
|
log.info("Writing provenance tracing debugging info to %s", fd.name)
|
||||||
json.dump(debug_info, fd)
|
json.dump(_inductor_triton_kernel_to_post_grad_node_info, fd)
|
||||||
node_mapping = {}
|
node_mapping = {}
|
||||||
if _pre_grad_graph_id:
|
if _pre_grad_graph_id:
|
||||||
node_mapping = create_node_mapping(
|
node_mapping_kernel = create_node_mapping_kernel_to_post_grad(
|
||||||
_pre_grad_graph_id, _inductor_post_to_pre_grad_nodes, debug_info
|
_inductor_triton_kernel_to_post_grad_node_info
|
||||||
)
|
)
|
||||||
|
node_mapping = {
|
||||||
|
**_inductor_post_to_pre_grad_nodes,
|
||||||
|
**node_mapping_kernel,
|
||||||
|
}
|
||||||
if config.trace.enabled:
|
if config.trace.enabled:
|
||||||
with V.debug.fopen(
|
with V.debug.fopen(
|
||||||
"inductor_provenance_tracking_node_mappings.json", "w"
|
"inductor_provenance_tracking_node_mappings.json", "w"
|
||||||
) as fd:
|
) as fd:
|
||||||
json.dump(node_mapping, fd)
|
json.dump(node_mapping, fd)
|
||||||
return debug_info, node_mapping
|
return node_mapping
|
||||||
|
|
||||||
|
|
||||||
def set_kernel_post_grad_provenance_tracing(
|
def set_kernel_post_grad_provenance_tracing(
|
||||||
|
|
|
||||||
|
|
@ -541,12 +541,23 @@ def get_symbolic_inputs(inputs: Sequence[IRNode]) -> list[Expr]:
|
||||||
|
|
||||||
|
|
||||||
class IRNode:
|
class IRNode:
|
||||||
|
"""Base class for all intermediate representation (IR) nodes in TorchInductor.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
This is an abstract base class. Most methods raise NotImplementedError
|
||||||
|
and must be overridden by concrete subclasses.
|
||||||
|
"""
|
||||||
|
|
||||||
_current_origins: ClassVar[OrderedSet[Any]] = OrderedSet()
|
_current_origins: ClassVar[OrderedSet[Any]] = OrderedSet()
|
||||||
|
|
||||||
# NB: These are kinda weird,
|
# NB: These are kinda weird,
|
||||||
origins: OrderedSet[Any] = dataclasses.field(init=False)
|
origins: OrderedSet[Any] = dataclasses.field(init=False)
|
||||||
|
# traces back to where the IRNode is created in Inductor
|
||||||
traceback: Optional[list[str]] = dataclasses.field(init=False)
|
traceback: Optional[list[str]] = dataclasses.field(init=False)
|
||||||
origin_node: Optional[torch.fx.Node] = dataclasses.field(init=False)
|
origin_node: Optional[torch.fx.Node] = dataclasses.field(init=False)
|
||||||
|
# trace backs to user model code
|
||||||
|
# a single IRNode could correspond to multiple lines of code
|
||||||
|
stack_traces: dict[str, str] = dataclasses.field(init=False)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@contextlib.contextmanager
|
@contextlib.contextmanager
|
||||||
|
|
@ -578,12 +589,41 @@ class IRNode:
|
||||||
object.__setattr__(self, attr, value)
|
object.__setattr__(self, attr, value)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
self._post_init_setattr("origins", OrderedSet(self._current_origins))
|
origins = OrderedSet(self._current_origins)
|
||||||
|
self._post_init_setattr("origins", origins)
|
||||||
self._post_init_setattr(
|
self._post_init_setattr(
|
||||||
"traceback", traceback.format_stack() if config.debug_ir_traceback else None
|
"traceback", traceback.format_stack() if config.debug_ir_traceback else None
|
||||||
)
|
)
|
||||||
self._post_init_setattr("origin_node", None)
|
self._post_init_setattr("origin_node", None)
|
||||||
|
|
||||||
|
# Group nodes by their stack traces to deduplicate
|
||||||
|
nodes_to_stack_trace = {}
|
||||||
|
if config.trace.provenance_tracking:
|
||||||
|
for node in origins:
|
||||||
|
if node.stack_trace:
|
||||||
|
# nodes in the backward graph don't have mapping to pre_grad_graph
|
||||||
|
nodes_to_stack_trace["post_grad+" + node.name] = node.stack_trace
|
||||||
|
else:
|
||||||
|
if (
|
||||||
|
"postToPre"
|
||||||
|
not in torch._inductor.debug._inductor_post_to_pre_grad_nodes
|
||||||
|
):
|
||||||
|
continue
|
||||||
|
node_names = torch._inductor.debug._inductor_post_to_pre_grad_nodes[
|
||||||
|
"postToPre"
|
||||||
|
].get(node.name, None)
|
||||||
|
if node_names:
|
||||||
|
for node_name in node_names:
|
||||||
|
stack_trace = torch._inductor.debug._inductor_pre_grad_node_stack_trace.get(
|
||||||
|
node_name, None
|
||||||
|
)
|
||||||
|
if stack_trace:
|
||||||
|
nodes_to_stack_trace["pre_grad+" + node_name] = (
|
||||||
|
stack_trace
|
||||||
|
)
|
||||||
|
|
||||||
|
self._post_init_setattr("stack_traces", nodes_to_stack_trace)
|
||||||
|
|
||||||
def get_read_names(self) -> OrderedSet[str]:
|
def get_read_names(self) -> OrderedSet[str]:
|
||||||
return OrderedSet(dep.name for dep in self.get_reads())
|
return OrderedSet(dep.name for dep in self.get_reads())
|
||||||
|
|
||||||
|
|
@ -601,7 +641,15 @@ class IRNode:
|
||||||
if shorten and len(origins) > 64:
|
if shorten and len(origins) > 64:
|
||||||
# this can get *very* long
|
# this can get *very* long
|
||||||
origins = f"{origins[:61]}..."
|
origins = f"{origins[:61]}..."
|
||||||
return [origins]
|
if not self.stack_traces:
|
||||||
|
return [origins]
|
||||||
|
|
||||||
|
stack_trace_str = []
|
||||||
|
for stack_trace in self.stack_traces.values():
|
||||||
|
stack_trace_str.append("stack_traces = {{")
|
||||||
|
stack_trace_str += stack_trace.split("\n")
|
||||||
|
stack_trace_str.append("}")
|
||||||
|
return [origins] + stack_trace_str
|
||||||
|
|
||||||
def str_helper(
|
def str_helper(
|
||||||
self, lines: Sequence[object], shorten: bool = True, multiline: bool = True
|
self, lines: Sequence[object], shorten: bool = True, multiline: bool = True
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user