mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Summary: This commit adds recording of a global flag allow_tf32 in all torch ops. Then I add a pattern that detects if a user is setting tf32 flag or not, recommend he/she to do so if is not set. Test Plan: Added a test in test_profiler.py Differential Revision: [D37799414](https://our.internmc.facebook.com/intern/diff/D37799414) Pull Request resolved: https://github.com/pytorch/pytorch/pull/81273 Approved by: https://github.com/robieta
291 lines
9.5 KiB
Python
291 lines
9.5 KiB
Python
from collections import deque
|
|
import re
|
|
from typing import Dict, List, Set
|
|
|
|
import torch
|
|
from torch.profiler import profile
|
|
from torch.profiler._utils import index_of_first_match
|
|
from torch._C._autograd import (_ProfilerEvent, _ExtraFields_TorchOp,
|
|
_ExtraFields_Backend, _ExtraFields_Allocation,
|
|
_ExtraFields_PyCCall, _ExtraFields_PyCall,
|
|
_EventType)
|
|
|
|
|
|
class Pattern:
|
|
'''
|
|
Base class for all patterns, subclass this class and implement match()
|
|
to define custom patterns.
|
|
|
|
In subclass, define description and skip property.
|
|
'''
|
|
|
|
def __init__(self, prof: profile):
|
|
self.prof = prof
|
|
self.description = "Please specify a description for pattern"
|
|
assert prof.profiler is not None and prof.profiler.kineto_results is not None
|
|
self.event_tree = prof.profiler.kineto_results.experimental_event_tree(
|
|
)
|
|
self.tid_root: Dict[int, List[_ProfilerEvent]] = {}
|
|
for event in self.event_tree:
|
|
self.tid_root.setdefault(event.start_tid, []).append(event)
|
|
|
|
@property
|
|
def skip(self):
|
|
return False
|
|
|
|
def report(self, event: _ProfilerEvent):
|
|
msg = f"{event.name()}\n{self.description}\n{source_code_location(event)}"
|
|
return msg
|
|
|
|
def eventTreeTraversal(self):
|
|
'''
|
|
Standard DFS traversal of the event tree.
|
|
Override this method to customize the traversal order.
|
|
'''
|
|
stack = deque(self.event_tree)
|
|
while stack:
|
|
curr_event = stack.pop()
|
|
yield curr_event
|
|
for child_event in curr_event.children:
|
|
stack.append(child_event)
|
|
|
|
def match(self, event: _ProfilerEvent):
|
|
'''
|
|
Return True if the event matches the pattern.
|
|
This method should be overriden in subclass.
|
|
'''
|
|
raise NotImplementedError
|
|
|
|
def matched_events(self):
|
|
if self.skip:
|
|
return []
|
|
matched_events = []
|
|
for event in self.eventTreeTraversal():
|
|
if self.match(event):
|
|
matched_events.append(event)
|
|
return matched_events
|
|
|
|
def root_of(self, event: _ProfilerEvent):
|
|
while event.parent:
|
|
event = event.parent
|
|
return event
|
|
|
|
def siblings_of(self, event: _ProfilerEvent):
|
|
if event.parent:
|
|
children = event.parent.children
|
|
else:
|
|
children = self.tid_root[event.start_tid]
|
|
index = children.index(event)
|
|
return children[:index], children[index + 1:]
|
|
|
|
def next_of(self, event: _ProfilerEvent):
|
|
_, next_events = self.siblings_of(event)
|
|
return next_events[0] if next_events else None
|
|
|
|
def prev_of(self, event: _ProfilerEvent):
|
|
prev_events, _ = self.siblings_of(event)
|
|
return prev_events[-1] if prev_events else None
|
|
|
|
|
|
# Patterns
|
|
|
|
|
|
class NamePattern(Pattern):
|
|
|
|
def __init__(self, prof: profile, name: str):
|
|
super().__init__(prof)
|
|
self.description = f"Matched Name Event: {name}"
|
|
self.name = name
|
|
|
|
def match(self, event: _ProfilerEvent):
|
|
return re.search(self.name, event.name()) is not None
|
|
|
|
|
|
class ExtraCUDACopyPattern(Pattern):
|
|
'''
|
|
This pattern identifies if we creates a constant tensor on CPU and immediately moves it to GPU.
|
|
example: torch.zeros((100, 100)).to("cuda")
|
|
|
|
Pattern:
|
|
build-in method |build-in method
|
|
... | aten::to
|
|
aten::fill_/aten::zero_ | aten::_to_copy
|
|
|
|
Algorithm:
|
|
We start at node aten::to, go parent events' previous events,
|
|
and check if we have a aten::fill_/aten::zero_ as we keep going down the tree.
|
|
We always select the last child in the children list when we go down the tree.
|
|
If at any step we failed, it is not a match.
|
|
'''
|
|
|
|
def __init__(self, prof: profile):
|
|
super().__init__(prof)
|
|
self.description = "Filled a CPU tensor and immediately moved it to GPU. Please initalize it on GPU."
|
|
self.init_ops = {
|
|
"aten::fill_", "aten::zero_", "aten::normal_", "aten::uniform_"
|
|
}
|
|
|
|
@property
|
|
def skip(self):
|
|
return self.prof.with_stack is False
|
|
|
|
def match(self, event):
|
|
# TODO: We should also check tensor identities
|
|
if event.name() != "aten::to":
|
|
return False
|
|
# Up one level
|
|
event = event.parent
|
|
if event is None:
|
|
return False
|
|
# Check if we have a aten::fill_ in previous leaf
|
|
event = self.prev_of(event)
|
|
if event is None:
|
|
return False
|
|
while event.children:
|
|
event = event.children[-1]
|
|
# aten::zero_ is a special optimzation case where fill_ is not called
|
|
if event.name() in self.init_ops:
|
|
return True
|
|
return event.name() in self.init_ops
|
|
# TODO: Check if tensor is reused
|
|
|
|
|
|
class ForLoopIndexingPattern(Pattern):
|
|
'''
|
|
This pattern identifies if we use a for loop to index a tensor that
|
|
can be vectorized.
|
|
example:
|
|
tensor = torch.empty((100, 100))
|
|
for i in range(100):
|
|
tensor[i] = i
|
|
|
|
Pattern:
|
|
aten::select | ... | aten::select | ... (Repeat)
|
|
|
|
Algorithm:
|
|
We start at node aten::select, and we check if we can find this alternating patterns.
|
|
We also keep a dictionary to avoid duplicate match in the for loop.
|
|
'''
|
|
|
|
def __init__(self, prof: profile):
|
|
super().__init__(prof)
|
|
self.description = "For loop indexing detected. Vectorization recommended."
|
|
self.visited: Set[int] = set()
|
|
|
|
def eventTreeTraversal(self):
|
|
'''
|
|
We need to use BFS traversal order to avoid duplicate match.
|
|
'''
|
|
stack = deque(self.event_tree)
|
|
while stack:
|
|
curr_event = stack.popleft()
|
|
yield curr_event
|
|
for child_event in curr_event.children:
|
|
stack.append(child_event)
|
|
|
|
def match(self, event: _ProfilerEvent):
|
|
if event.name() != "aten::select":
|
|
return False
|
|
if event.id in self.visited:
|
|
return False
|
|
repeat_count = 1
|
|
_, next = self.siblings_of(event)
|
|
if len(next) <= 1:
|
|
return False
|
|
|
|
# Custom event list matching
|
|
def same_ops(list1, list2):
|
|
if len(list1) != len(list2):
|
|
return False
|
|
for op1, op2 in zip(list1, list2):
|
|
if op1.name() != op2.name():
|
|
return False
|
|
return True
|
|
|
|
# Record the ops between two aten::select
|
|
next_select_idx = index_of_first_match(
|
|
next, lambda e: e.name() == "aten::select")
|
|
if next_select_idx is None:
|
|
return False
|
|
indexing_ops = [event] + next[:next_select_idx]
|
|
next = next[len(indexing_ops) - 1:]
|
|
for i in range(0, len(next), len(indexing_ops)):
|
|
if same_ops(indexing_ops, next[i:i + len(indexing_ops)]):
|
|
repeat_count += 1
|
|
self.visited.add(next[i].id)
|
|
else:
|
|
break
|
|
return repeat_count >= 10
|
|
|
|
|
|
class FP32MatMulPattern(Pattern):
|
|
|
|
def __init__(self, prof: profile):
|
|
super().__init__(prof)
|
|
self.description = (
|
|
"You are currently using GPU that supports TF32. "
|
|
"Please enable TF32 by setting 'torch.backends.cuda.matmul.allow_tf32 = True'"
|
|
)
|
|
|
|
@property
|
|
def skip(self):
|
|
has_tf32 = all(
|
|
int(arch[3:]) >= 80 for arch in torch.cuda.get_arch_list())
|
|
return has_tf32 is False
|
|
|
|
def match(self, event: _ProfilerEvent):
|
|
# If we saw this pattern once, we don't need to match it again
|
|
if event_type(event) != _EventType.TorchOp:
|
|
return False
|
|
assert isinstance(event.extra_fields, _ExtraFields_TorchOp)
|
|
if event.name() == "aten::mm":
|
|
if event.extra_fields.allow_tf32_cublas is False:
|
|
return True
|
|
return False
|
|
|
|
def report(self, event: _ProfilerEvent):
|
|
return self.description
|
|
|
|
|
|
def source_code_location(event: _ProfilerEvent):
|
|
while event:
|
|
if event_type(event) == _EventType.PyCall or event_type(
|
|
event) == _EventType.PyCCall:
|
|
assert isinstance(event.extra_fields,
|
|
_ExtraFields_PyCall) or isinstance(
|
|
event.extra_fields, _ExtraFields_PyCCall)
|
|
return f"{event.extra_fields.caller.file_name}:{event.extra_fields.caller.line_number}"
|
|
event = event.parent
|
|
return "No source code location found"
|
|
|
|
|
|
def report_all_anti_patterns(prof):
|
|
anti_patterns = [
|
|
ExtraCUDACopyPattern(prof),
|
|
ForLoopIndexingPattern(prof),
|
|
FP32MatMulPattern(prof)
|
|
]
|
|
reported = set()
|
|
print(f"{'-'*40}TorchTidy Report{'-'*40}")
|
|
for anti_pattern in anti_patterns:
|
|
for event in anti_pattern.matched_events():
|
|
report_msg = anti_pattern.report(event)
|
|
if report_msg not in reported:
|
|
print(report_msg)
|
|
reported.add(report_msg)
|
|
|
|
|
|
def event_type(event: _ProfilerEvent):
|
|
if isinstance(event.extra_fields, _ExtraFields_TorchOp):
|
|
return _EventType.TorchOp
|
|
elif isinstance(event.extra_fields, _ExtraFields_Backend):
|
|
return _EventType.Backend
|
|
elif isinstance(event.extra_fields, _ExtraFields_Allocation):
|
|
return _EventType.Allocation
|
|
elif isinstance(event.extra_fields, _ExtraFields_PyCall):
|
|
return _EventType.PyCall
|
|
elif isinstance(event.extra_fields, _ExtraFields_PyCCall):
|
|
return _EventType.PyCCall
|
|
else:
|
|
raise Exception("Unknown event type")
|