mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
This will be the last disruptive functorch internals change. Why are we moving these files? - As a part of rationalizing functorch we are moving the code in functorch/_src to torch/_functorch - This is so that we can offer the functorch APIs as native PyTorch APIs (coming soon) and resolve some internal build issues. Why are we moving all of these files at once? - It's better to break developers all at once rather than many times Test Plan: - wait for tests Pull Request resolved: https://github.com/pytorch/pytorch/pull/90091 Approved by: https://github.com/anijain2305, https://github.com/ezyang
104 lines
2.3 KiB
Python
104 lines
2.3 KiB
Python
import torch
|
|
import torch.fx as fx
|
|
from functorch import make_fx
|
|
from torch.profiler import profile, ProfilerActivity
|
|
|
|
from torch._functorch.compile_utils import fx_graph_cse
|
|
|
|
def profile_it(f, inp):
|
|
for _ in range(5):
|
|
f(inp)
|
|
|
|
itr = 5
|
|
with profile(activities=[ProfilerActivity.CUDA], record_shapes=True) as prof:
|
|
for _ in range(itr):
|
|
f(inp)
|
|
|
|
timing = prof.key_averages()
|
|
cuda_time_total = 0
|
|
for e in timing:
|
|
cuda_time_total = cuda_time_total + e.cuda_time_total
|
|
return cuda_time_total / itr
|
|
|
|
def profile_function(name, f, inp):
|
|
fx_g = make_fx(f)(inp)
|
|
|
|
new_g = fx_graph_cse(fx_g.graph)
|
|
new_g = fx.GraphModule(fx_g, new_g)
|
|
# do not benchmark against the scripted version because script already does some CSE
|
|
# script_f = torch.jit.script(fx_g)
|
|
# script_g = torch.jit.script(new_g)
|
|
# avg_cuda_time_f = profile_it(script_f, inp)
|
|
# avg_cuda_time_g = profile_it(script_g, inp)
|
|
avg_cuda_time_f = profile_it(fx_g, inp)
|
|
avg_cuda_time_g = profile_it(new_g, inp)
|
|
num_node_decrease = len(fx_g.graph.nodes) - len(new_g.graph.nodes)
|
|
|
|
print(f"{name}, {avg_cuda_time_f}, {avg_cuda_time_g}, {num_node_decrease}, {len(fx_g.graph.nodes)}")
|
|
|
|
g_gpu = torch.Generator(device='cuda')
|
|
g_gpu.manual_seed(2147483647)
|
|
inp = torch.randn(2**20, device='cuda', generator=g_gpu)
|
|
|
|
def f1(x):
|
|
return x.cos().cos()
|
|
|
|
profile_function("f1", f1, inp)
|
|
|
|
def fsum(x):
|
|
a = x.sum()
|
|
b = x.sum()
|
|
c = x.sum()
|
|
d = x.sum()
|
|
return a + b + c + d
|
|
|
|
profile_function("fsum", fsum, inp)
|
|
|
|
def fconcat(x):
|
|
a = torch.cat((x, x))
|
|
b = torch.cat((x, x))
|
|
return a + b
|
|
profile_function("fconcat", fconcat, inp)
|
|
|
|
def fsum2(x):
|
|
a = x.sum()
|
|
for _ in range(30):
|
|
a = a + x.sum()
|
|
return a
|
|
|
|
profile_function("fsum2", fsum2, inp)
|
|
|
|
def fsummulti(x):
|
|
a = 0
|
|
for _ in range(3):
|
|
a = a + x.sum()
|
|
a = a * x.sum()
|
|
return a
|
|
|
|
profile_function("fsummulti", fsummulti, inp)
|
|
|
|
def fsummulti2(x):
|
|
a = 0
|
|
for _ in range(30):
|
|
a = a + x.sum()
|
|
a = a * x.sum()
|
|
return a
|
|
|
|
profile_function("fsummulti2", fsummulti2, inp)
|
|
|
|
def fcos(x):
|
|
a = 0
|
|
for _ in range(3):
|
|
a = a + x.cos()
|
|
return a
|
|
|
|
profile_function("fcos", fcos, inp)
|
|
|
|
def fcos2(x):
|
|
a = 0
|
|
for _ in range(30):
|
|
a = a + x.cos()
|
|
return a
|
|
|
|
profile_function("fcos2", fcos2, inp)
|