mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix recompile reason logging (#148200)
for the following test case
```
@torch.compile(dynamic=False, backend=cnts)
def fn(x, y, z):
return x * y * z[0]
fn(1, torch.randn(1), {0: torch.randn(1)})
fn(2, torch.randn(2), {0: torch.randn(2)})
fn(3, torch.randn(3), {0: torch.randn(3)})
fn(4, torch.randn(4), {0: torch.randn(4)})
fn(5, torch.randn(5), {0: torch.randn(5)})
```
previously we would log
```
0/0: L['x'] == 1
0/0: L['x'] == 1
0/0: L['x'] == 1
0/0: L['x'] == 1
```
but after this change we now log
```
0/0: L['x'] == 1
0/1: L['x'] == 2
0/2: L['x'] == 3
0/3: L['x'] == 4
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/148200
Approved by: https://github.com/xmfan
This commit is contained in:
parent
40b3e4a358
commit
83ec7cdcd4
|
|
@ -945,7 +945,7 @@ def _compile(
|
|||
if is_recompilation(cache_size) and frame:
|
||||
reasons = get_and_maybe_log_recompilation_reasons(cache_entry, frame)
|
||||
recompile_reason = (
|
||||
"Unable to find recompilation reasons" if not reasons else reasons[-1]
|
||||
"Unable to find recompilation reasons" if not reasons else reasons[0]
|
||||
)
|
||||
metrics_context.update_outer({"recompile_reason": recompile_reason})
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user