mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
The `usort` config in `pyproject.toml` has no effect due to a typo. Fixing the typo make `usort` do more and generate the changes in the PR. Except `pyproject.toml`, all changes are generated by `lintrunner -a --take UFMT --all-files`. Pull Request resolved: https://github.com/pytorch/pytorch/pull/127125 Approved by: https://github.com/Skylion007 ghstack dependencies: #127122, #127123, #127124
96 lines
2.1 KiB
Python
96 lines
2.1 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the BSD-style license found in the
|
|
# LICENSE file in the root directory of this source tree.
|
|
|
|
import time
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from functorch import make_functional
|
|
from functorch.compile import nnc_jit
|
|
|
|
torch._C._jit_override_can_fuse_on_cpu(True)
|
|
|
|
|
|
def bench(f, iters=100, warmup=10):
|
|
for _ in range(warmup):
|
|
f()
|
|
begin = time.time()
|
|
for _ in range(iters):
|
|
f()
|
|
print(time.time() - begin)
|
|
|
|
|
|
class Foo(nn.Module):
|
|
def __init__(self, num_layers=3, features=100):
|
|
super().__init__()
|
|
mods = []
|
|
for _ in range(num_layers):
|
|
mods.append(nn.Linear(features, features, bias=False))
|
|
self.mod = nn.Sequential(*mods)
|
|
|
|
def forward(self, x):
|
|
return (self.mod(x) ** 2).sum()
|
|
|
|
|
|
batch_size = 16
|
|
features = 64
|
|
num_layers = 8
|
|
inp = torch.randn((batch_size, features))
|
|
|
|
mod = Foo(num_layers, features)
|
|
|
|
jit_mod = torch.jit.script(mod)
|
|
|
|
func_model, weights = make_functional(mod)
|
|
lr = 1.0
|
|
|
|
|
|
def functional_step(x, weights):
|
|
weights = [weight.detach().requires_grad_() for weight in weights]
|
|
out = func_model(weights, x)
|
|
out.backward()
|
|
new_weights = [weight - lr * weight.grad for weight in weights]
|
|
return out, new_weights
|
|
|
|
|
|
optim = torch.optim.SGD(
|
|
jit_mod.parameters(), lr=lr, momentum=0, dampening=0, weight_decay=0
|
|
)
|
|
|
|
|
|
def jit_step(x, weights):
|
|
optim.zero_grad()
|
|
loss = jit_mod(x)
|
|
loss.backward()
|
|
optim.step()
|
|
return loss, None
|
|
|
|
|
|
def train(train_step, weights):
|
|
torch.manual_seed(16)
|
|
train_step(inp, weights)
|
|
begin = time.time()
|
|
for itr in range(1000):
|
|
loss, weights = train_step(torch.randn(batch_size, features), weights)
|
|
if itr % 200 == 0:
|
|
print(f"Loss at {itr}: {loss}")
|
|
print("Time taken: ", time.time() - begin)
|
|
print()
|
|
|
|
|
|
grad_pt = functional_step
|
|
grad_nnc = nnc_jit(functional_step)
|
|
|
|
print("Starting PT training")
|
|
train(grad_pt, weights)
|
|
|
|
print("Starting NNC training")
|
|
train(grad_nnc, weights)
|
|
|
|
print("Starting JIT training")
|
|
train(jit_step, None)
|