# Copyright (c) Facebook, Inc. and its affiliates. # All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import time import torch import torch.nn as nn from functorch import make_functional from functorch.compile import nnc_jit torch._C._jit_override_can_fuse_on_cpu(True) def bench(f, iters=100, warmup=10): for _ in range(warmup): f() begin = time.time() for _ in range(iters): f() print(time.time() - begin) class Foo(nn.Module): def __init__(self, num_layers=3, features=100): super().__init__() mods = [] for _ in range(num_layers): mods.append(nn.Linear(features, features, bias=False)) self.mod = nn.Sequential(*mods) def forward(self, x): return (self.mod(x) ** 2).sum() batch_size = 16 features = 64 num_layers = 8 inp = torch.randn((batch_size, features)) mod = Foo(num_layers, features) jit_mod = torch.jit.script(mod) func_model, weights = make_functional(mod) lr = 1.0 def functional_step(x, weights): weights = [weight.detach().requires_grad_() for weight in weights] out = func_model(weights, x) out.backward() new_weights = [weight - lr * weight.grad for weight in weights] return out, new_weights optim = torch.optim.SGD( jit_mod.parameters(), lr=lr, momentum=0, dampening=0, weight_decay=0 ) def jit_step(x, weights): optim.zero_grad() loss = jit_mod(x) loss.backward() optim.step() return loss, None def train(train_step, weights): torch.manual_seed(16) train_step(inp, weights) begin = time.time() for itr in range(1000): loss, weights = train_step(torch.randn(batch_size, features), weights) if itr % 200 == 0: print(f"Loss at {itr}: {loss}") print("Time taken: ", time.time() - begin) print() grad_pt = functional_step grad_nnc = nnc_jit(functional_step) print("Starting PT training") train(grad_pt, weights) print("Starting NNC training") train(grad_nnc, weights) print("Starting JIT training") train(jit_step, None)