pytorch/torch/csrc/jit/runtime/symbolic_script.cpp
jjsjann123 1ec732bc46 Add fp16/fp32 autocasting to JIT/TorchScript (#63939)
Summary:
Adds mixed precision autocasting support between fp32/fp16 to torchscript/JIT. More in depth descriptoin can be found at [torch/csrc/jit/JIT-AUTOCAST.md](https://github.com/pytorch/pytorch/pull/63939/files#diff-1f1772aaa508841c5bb58b74ab98f49a1e577612cd9ea5c386c8714a75db830b)

This PR implemented an autocast optimization pass that inserts casting ops per AMP rule (torch/csrc/jit/passes/autocast.cpp), that mimics the behavior of eager autocast. The pass also takes into consideration the context of `torch.cuda.amp.autocast` and only inserts casting ops within the enabled context manager, giving feature parity as with eager amp autocast.

We currently provide JIT AMP autocast as a prototyping feature, so it is default off and could be turned on via `torch._C._jit_set_autocast_mode(True)`

The JIT support for autocast is subject to different constraints compared to the eager mode implementation (mostly related to the fact that TorchScript is statically typed), restriction on the user facing python code is described in doc torch/csrc/jit/JIT-AUTOCAST.md

This is a prototype, there are also implementation limitation that's necessary to keep this PR small and get something functioning quickly on upstream, so we can iterate on designs.

Few limitation/challenge that is not properly resolved in this PR:
1. Autocast inserts cast operation, which would have impact on scalar type of output tensor feeding downstream operations. We are not currently propagating the updated scalar types, this would give issues/wrong results on operations in promotion rules.

2. Backward for autodiff in JIT misses the casting of dgrad to input scalar type, as what autograd does in eager. This forces us to explicitly mark the casting operation for certain operations (e.g. binary ops), otherwise, we might be feeding dgrad with mismatch scalar type to input. This could potentially break gradient function consuming dgrad. (e.g. gemm backwards, which assumes grad_output to be of same scalar type as input')

3. `torch.autocast` api has an optional argument `dtype` which is not currently supported in the JIT autocast and we require a static value.

Credit goes mostly to:
tlemo
kevinstephano

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63939

Reviewed By: navahgar

Differential Revision: D31093381

Pulled By: eellison

fbshipit-source-id: da6e26c668c38b01e296f304507048d6c1794314
2021-10-27 12:11:36 -07:00

1650 lines
62 KiB
C++

#include <torch/csrc/jit/runtime/symbolic_script.h>
#include <torch/csrc/jit/frontend/ir_emitter.h>
#include <torch/csrc/jit/runtime/operator.h>
namespace torch {
namespace jit {
namespace {
std::mutex lock;
const std::vector<std::string> functions = {
R"(
#### HELPER FUNCTIONS ###
#### PREFIX: AD_ ###
#### SCHEMA NOT SAVED IN CACHE ###
def AD_unsqueeze_multiple(t,
dims: List[int],
n_dims: int):
seen = [False] * n_dims
for i in range(len(dims)):
seen[dims[i]] = True
for d in range(n_dims):
if seen[d]:
t = t.unsqueeze(d)
return t
def AD_sum_backward(grad,
sizes: List[int],
dims: List[int],
keepdim: bool):
if not keepdim and len(sizes) > 0:
if len(dims) == 1:
return grad.unsqueeze(dims[0]).expand(sizes)
else:
res = AD_unsqueeze_multiple(grad, dims, len(sizes))
return res.expand(sizes)
else:
return grad.expand(sizes)
def AD_logsumexp_backward(grad, self, result,
dim: List[int],
keepdim: bool):
if not keepdim and self.dim() != 0:
n_dims = len(self.size())
grad = AD_unsqueeze_multiple(grad, dim, n_dims)
result = AD_unsqueeze_multiple(result, dim, n_dims)
return grad * (self - result).exp()
def mean_0(self, *, dtype: Optional[int]):
self_size = self.size()
self_numel = self.numel()
self_scalar_type = self.dtype
def backward(grad_output):
return grad_output.expand(self_size).to(self_scalar_type) / self_numel, None
return torch.mean(self, dtype=dtype), backward
def mean_1(self,
dim: List[int],
keepdim: bool,
*,
dtype: Optional[int]):
self_size = self.size()
self_scalar_type = self.dtype
def backward(grad_output):
grad_self = AD_sum_backward(grad_output, self_size, dim, keepdim).to(self_scalar_type) / AD_safe_size(self_size, dim)
return grad_self, None, None, None
return torch.mean(self, dim, keepdim, dtype=dtype), backward
def logsumexp(self,
dim: List[int],
keepdim: bool):
result = torch.logsumexp(self, dim, keepdim)
self_dim = self.dim()
def backward(grad_output):
grad_self = AD_logsumexp_backward(grad_output, self, result, dim, keepdim)
return grad_self, None, None
return result, backward
def AD_bool_to_int(b: bool):
# FIXME: torchscript: int - bool
if b:
i = 1
else:
i = 0
return i
def AD_var_backward_0(grad, self, correction: int):
# FIXME: torchscript: div(float, float)
return grad * (self - self.mean()) * 2.0 / (self.numel() - correction)
def AD_safe_size(sizes: List[int],
dims: List[int]):
if len(sizes) == 0:
return 1
size = 1
for i in range(len(dims)):
d = dims[i]
size *= sizes[d]
return size
def AD_var_backward_1(grad,
self,
dim: List[int],
correction: int,
keepdim: bool):
if self.dim() == 0:
return AD_var_backward_0(grad, self, correction)
self_size = self.size()
if not keepdim and self.dim() > 1:
grad = AD_unsqueeze_multiple(grad, dim, len(self_size))
# FIXME: torchscript: div(float, float)
return grad * (self - self.mean(dim, True)) * 2.0 / (AD_safe_size(self_size, dim) - correction)
def AD_var_backward_2(grad,
self,
dim: Optional[List[int]],
correction: Optional[int],
keepdim: bool):
if correction is None:
correction = 1
if self.dim() == 0 or dim is None:
return AD_var_backward_0(grad, self, correction)
return AD_var_backward_1(grad, self, dim, correction, keepdim)
def std_0(self,
unbiased: bool=True):
std_out = torch.std(self, unbiased)
def backward(grad_output):
correction = AD_bool_to_int(unbiased)
grad_self = AD_var_backward_0(grad_output / (std_out * 2), self, correction)
return grad_self, None
return std_out, backward
def std_1(self,
dim: List[int],
unbiased: bool,
keepdim: bool):
std_out = torch.std(self, dim, unbiased, keepdim)
def backward(grad_output):
correction = AD_bool_to_int(unbiased)
grad_self = AD_var_backward_1(grad_output / (std_out * 2), self, dim, correction, keepdim)
return grad_self, None, None, None
return std_out, backward
def std_2(self,
dim: Optional[List[int]],
*,
correction: Optional[int],
keepdim: bool):
std_out = torch.std(self, dim, correction=correction, keepdim=keepdim)
def backward(grad_output):
grad_self = AD_var_backward_2(grad_output / (std_out * 2), self, dim, correction, keepdim)
return grad_self, None, None, None
return std_out, backward
def var_0(self,
unbiased: bool=True):
def backward(grad_output):
correction = AD_bool_to_int(unbiased)
grad_self = AD_var_backward_0(grad_output, self, correction)
return grad_self, None
return torch.var(self, unbiased), backward
def var_1(self,
dim: List[int],
unbiased: bool,
keepdim: bool):
def backward(grad_output):
correction = AD_bool_to_int(unbiased)
grad_self = AD_var_backward_1(grad_output, self, dim, correction, keepdim)
return grad_self, None, None, None
return torch.var(self, dim, unbiased, keepdim), backward
def var_2(self,
dim: Optional[List[int]],
*,
correction: Optional[int],
keepdim: bool):
def backward(grad_output):
grad_self = AD_var_backward_2(grad_output, self, dim, correction, keepdim)
return grad_self, None, None, None
return torch.var(self, dim, correction=correction, keepdim=keepdim), backward
def tanh(self):
output = torch.tanh(self)
def backward(grad_output):
return grad_output * (1 - output * output)
return output, backward
def AD_index_select_backward(grad,
dim: int,
indices,
sizes: List[int],
keepdim: bool):
if not keepdim and len(sizes) > 0:
grad = grad.unsqueeze(dim)
indices = indices.unsqueeze(dim)
# FIXME: torchscript: torch.zeros(sizes, grad.options())
return torch.zeros(sizes).to(grad).scatter_(dim, indices, grad)
# def topk(self,
# k: int,
# dim: int = -1,
# largest: bool = True,
# sorted: bool = True):
# result0, result1 = torch.topk(self, k, dim, largest, sorted)
# self_size = self.size()
# def backward(grad_output):
# grad_self = AD_index_select_backward(grad_output, dim, result1, self_size, True)
# return grad_self, None, None, None, None
# return result0, result1, backward
# def kthvalue(self,
# k: int,
# dim: int,
# keepdim: bool):
# result0, result1 = torch.kthvalue(self, k, dim, keepdim)
# self_size = self.size()
# def backward(grad_output):
# grad_self = AD_index_select_backward(grad_output, dim, result1, self_size, keepdim)
# return grad_self, None, None, None
# return result0, result1, backward
def AD_mm_backward_self(grad, mat2):
return grad.mm(mat2.t())
def AD_mm_backward_mat2(grad, self):
return self.t().mm(grad)
def mm(self, mat2):
def backward(grad_output):
grad_self = AD_mm_backward_self(grad_output, mat2)
grad_mat2 = AD_mm_backward_mat2(grad_output, self)
return grad_self, grad_mat2
return torch.mm(self, mat2), backward
def AD_permute_backward(grad,
fwd_dims: List[int]):
ndims = len(fwd_dims)
dims = [0] * ndims
for i in range(ndims):
dims[fwd_dims[i]] = i
return grad.permute(dims)
def permute(self,
dims: List[int]):
def backward(grad_output):
grad_self = AD_permute_backward(grad_output, dims)
return grad_self, None
return torch.permute(self, dims), backward
def AD_select_backward(grad,
input_sizes: List[int],
dim: int,
index: int):
# FIXME: torchscript: torch.zeros(sizes, grad.options())
grad_input = torch.zeros(input_sizes).to(grad)
grad_input.select(dim, index).copy_(grad)
return grad_input
# TODO: fix torch.zeros(sizes, grad.options()) before enabling select, topk, kthvalue
# def select(self,
# dim: int,
# index: int):
# self_size = self.size()
# def backward(grad_output):
# grad_self = AD_select_backward(grad_output, self_size, dim, index)
# return grad_self, None, None
# return torch.select(self, dim, index), backward
def AD_slice_backward(grad,
input_sizes: List[int],
dim: int,
start: int,
end: int,
step: int):
# FIXME: torchscript: torch.zeros(sizes, grad.options())
grad_input = torch.zeros(input_sizes).to(grad)
grad_input.slice(dim, start, end, step).copy_(grad)
return grad_input
# DON'T enable slice unless we can correctly handle view ops in graph executor.
# It triggers failure of TestJit.test_sample in test_distributions.py.
# def slice(self,
# dim: int=0,
# start: int=0,
# end: int=9223372036854775807,
# step: int=1):
# def backward(grad_output):
# grad_self = AD_slice_backward(grad_output, self.size(), dim, start, end, step)
# return grad_self, None, None, None, None
# return torch.slice(self, dim, start, end, step), backward
def AD_unsqueeze_to_0(self,
sizes: List[int]):
ndims = len(sizes)
for i in range(ndims):
if sizes[i] == 1:
self = self.unsqueeze(i)
return self
def AD_unsqueeze_to_1(self,
dim: int,
sizes: List[int]):
if len(sizes) > 0 and sizes[dim] == 1:
return self.unsqueeze(dim)
return self
def squeeze_0(self):
self_size = self.size()
def backward(grad_output):
grad_self = AD_unsqueeze_to_0(grad_output, self_size)
return grad_self
return torch.squeeze(self), backward
def squeeze_1(self,
dim: int):
self_size = self.size()
def backward(grad_output):
grad_self = AD_unsqueeze_to_1(grad_output, dim, self_size)
return grad_self, None
return torch.squeeze(self, dim), backward
def AD_infer_size(a: List[int],
b: List[int]):
dimsA = len(a)
dimsB = len(b)
ndim = dimsA if dimsA > dimsB else dimsB
expand_sizes = [0] * ndim
for i in range(ndim):
idx = - i + ndim - 1
sizeA = a[i] if dimsA + i >= 0 else 1
sizeB = b[i] if dimsB + i >= 0 else 1
# Assert sizeA == sizeB or sizeA == 1 or sizeB == 1
expand_sizes[i] = sizeB if sizeA == 1 else sizeA
return expand_sizes
def AD_bmm_backward_self(grad, mat2):
return grad.bmm(mat2.transpose(1, 2))
def AD_bmm_backward_mat2(grad, self):
return self.transpose(1, 2).bmm(grad)
def bmm(self, mat2):
def backward(grad_output):
grad_self = AD_bmm_backward_self(grad_output, mat2)
grad_mat2 = AD_bmm_backward_mat2(grad_output, self)
return grad_self, grad_mat2
return torch.bmm(self, mat2), backward
)",
R"(
def AD_mat_transpose(mat):
dim = mat.dim()
if dim == 1:
out = mat
elif dim == 2:
out = mat.t()
else:
dims = rangelist(dim)
dims[-1] = dim - 2
dims[-2] = dim - 1
out = mat.permute(dims)
return out
# In matmul backward case of [b, m, n] * [b, n, p] => [m, p],
# instead of doing [b, m, p] and then reduce to [m, p]
# whice potentially uses large intermediate of size b*m*p,
# we do [m, bn] * [bn, p] to avoid having the large
# intermediate, thus reduces max memory usage.
def AD_matmul_bw_special_fold(mat1, mat2):
mat1_transpose = AD_mat_transpose(mat1)
mat1_fold = mat1_transpose.reshape(-1, mat1_transpose.size()[-1])
mat2_fold = mat2.reshape(-1, mat2.size()[-1])
return mat1_fold.t().mm(mat2_fold)
def AD_matmul_bw_size(mat1, mat2,
out_size: List[int]):
dim1 = mat1.dim()
dim2 = mat2.dim()
dim_out = len(out_size)
if dim1 == 0 or dim2 == 0:
out = mat1 * mat2
elif dim_out == 2 and dim1 == dim2 and dim1 >=3:
out = AD_matmul_bw_special_fold(mat1, mat2)
elif dim_out == 1 and dim1 - dim2 == 1 and dim1 >= 3:
mat2_unsqueeze = mat2.unsqueeze(-1)
out = AD_matmul_bw_special_fold(mat1, mat2_unsqueeze)
out = out.squeeze(-1)
elif dim1 + dim2 == dim_out:
if dim2 == 1:
target_dim2 = 0
else:
target_dim2 = -2
out = torch.matmul(mat1.unsqueeze(dim1), mat2.unsqueeze(target_dim2))
elif dim_out == dim1 - dim2:
out = torch.matmul(mat1, mat2.unsqueeze(dim2)).squeeze(-1)
elif dim_out == dim2 - dim1:
out = torch.matmul(mat1.unsqueeze(-2), mat2).squeeze(-2)
else:
out = torch.matmul(mat1, mat2)
return out
def matmul(self, other):
def backward(grad_output):
self_size = self.size()
other_size = other.size()
grad_self = AD_matmul_bw_size(grad_output, AD_mat_transpose(other), self_size)._grad_sum_to_size(self_size)
grad_other = AD_matmul_bw_size(AD_mat_transpose(self), grad_output, other_size)._grad_sum_to_size(other_size)
return grad_self, grad_other
return torch.matmul(self, other), backward
def linear(input : Tensor,
weight : Tensor,
bias : Optional[Tensor]):
result = torch.linear(input, weight, bias)
def backward(grad_output):
if bias is not None:
grad_bias = grad_output._grad_sum_to_size(bias.size())
else:
grad_bias = None
weight_size = weight.size()
grad_input = torch.matmul(grad_output, weight)
grad_weight = torch.matmul(grad_output.reshape(-1, weight_size[0]).t(), input.reshape(-1, weight_size[1]))
# Note: calling unchecked_unwrap_optional is only safe, when we
# directly return grad_bias directly back to bias.
# Because in the case where `bias is None`, unwrapped
# grad_bias would just be pruned away.
return grad_input, grad_weight, grad_bias.unchecked_unwrap_optional
return result, backward
)",
R"(
def addcmul(self,
tensor1,
tensor2,
*,
value: number):
result = torch.addcmul(self, tensor1, tensor2, value=value)
self_size = torch._size_if_not_equal(self.size(), result.size())
tensor1_size = torch._size_if_not_equal(tensor1.size(), result.size())
tensor2_size = torch._size_if_not_equal(tensor2.size(), result.size())
def backward(grad_output):
grad = grad_output * value
grad_tensor1 = (grad * tensor2)._grad_sum_to_size(tensor1_size)
grad_tensor2 = (grad * tensor1)._grad_sum_to_size(tensor2_size)
return grad_output._grad_sum_to_size(self_size), grad_tensor1, grad_tensor2, None
return result, backward
def _autocast_to_full_precision(self, cuda_enabled : bool, cpu_enabled : bool):
self_dtype = self.dtype
def backward(grad_output):
return grad_output.to(self_dtype)
return torch._autocast_to_full_precision(self, cuda_enabled, cpu_enabled), backward
def _autocast_to_reduced_precision(self,
cuda_enabled : bool,
cpu_enabled : bool,
cuda_dtype : int,
cpu_dtype : int):
self_dtype = self.dtype
def backward(grad_output):
return grad_output.to(self_dtype)
return torch._autocast_to_reduced_precision(self, cuda_enabled, cpu_enabled, cuda_dtype, cpu_dtype), backward
def _dim_arange(like,
dim: int):
def backward(grad_output):
return None, None
return torch._dim_arange(like, dim), backward
def contiguous(self, *, memory_format: int=0):
def backward(grad_output):
return grad_output, None
return self.contiguous(memory_format=memory_format), backward
def dot(self, tensor):
def backward(grad_output):
return grad_output * tensor, grad_output * self
return torch.dot(self, tensor), backward
def erf(self):
def backward(grad_output):
# Precomputed constant C = 2.0 / math.sqrt(math.pi)
C = 1.1283791670955126
return C * torch.exp(- self * self) * grad_output
return torch.erf(self), backward
def expand(self,
size: List[int],
*,
implicit: bool=False):
result = torch.expand(self, size, implicit=implicit)
self_size = torch._size_if_not_equal(self.size(), result.size())
def backward(grad_output):
return grad_output._grad_sum_to_size(self_size), None, None
return result, backward
def expand_as(self, other):
result = torch.expand_as(self, other)
self_size = torch._size_if_not_equal(self.size(), result.size())
def backward(grad_output):
return grad_output._grad_sum_to_size(self_size), None
return result, backward
def full_like(self,
fill_value: float):
def backward(grad_output):
return None, None
return torch.full_like(self, fill_value, memory_format=1), backward
def lerp_0(self,
end,
weight: number):
result = torch.lerp(self, end, weight)
self_size = torch._size_if_not_equal(self.size(), result.size())
end_size = torch._size_if_not_equal(end.size(), result.size())
def backward(grad_output):
grad_self = (grad_output * (1 - float(weight)))._grad_sum_to_size(self_size)
grad_end = (grad_output * float(weight))._grad_sum_to_size(end_size)
return grad_self, grad_end, None
return result, backward
def lerp_1(self,
end,
weight):
result = torch.lerp(self, end, weight)
self_size = torch._size_if_not_equal(self.size(), result.size())
end_size = torch._size_if_not_equal(end.size(), result.size())
weight_size = torch._size_if_not_equal(weight.size(), result.size())
def backward(grad_output):
grad_self = (grad_output * (1 - weight))._grad_sum_to_size(self_size)
grad_end = (grad_output * weight)._grad_sum_to_size(end_size)
grad_weight = (grad_output * (end - self))._grad_sum_to_size(weight_size)
return grad_self, grad_end, grad_weight
return result, backward
def reshape(self,
shape: List[int]):
self_size = self.size()
def backward(grad_output):
return grad_output.reshape(self_size), None
return torch.reshape(self, shape), backward
def split(self,
split_size: int,
dim: int):
def backward(grad_outputs: List[Tensor]):
grad_self = torch.cat(grad_outputs, dim)
return grad_self, None, None
return torch.split(self, split_size, dim), backward
def split_with_sizes(self,
split_sizes: List[int],
dim: int):
def backward(grad_outputs: List[Tensor]):
size = len(grad_outputs)
grad_self = torch.cat(grad_outputs, dim)
return grad_self, None, None
return torch.split_with_sizes(self, split_sizes, dim), backward
def stack(tensors: List[Tensor],
dim: int=0):
def backward(grad_output):
grad_tensors = torch.unbind(grad_output, dim)
return grad_tensors, None
return torch.stack(tensors, dim), backward
def unbind(self,
dim: int):
def backward(grad_outputs: List[Tensor]):
grad_self = torch.stack(grad_outputs, dim)
return grad_self, None
return torch.unbind(self, dim), backward
def cat(tensors: List[Tensor],
dim: int):
size = len(tensors)
split_sizes = [0] * size
for i in range(size):
if tensors[i].size() != [0]:
split_sizes[i] = tensors[i].size()[dim]
def backward(grad_output):
grad_tensors = torch.split_with_sizes(grad_output, split_sizes, dim)
return grad_tensors, None
return torch.cat(tensors, dim), backward
def index(self,
indices: List[Tensor]):
def backward(grad_output):
grad_self = torch.zeros_like(self, memory_format=1).index_put_(indices, grad_output, True)
return grad_self, None
return torch.index(self, indices), backward
def meshgrid(tensors: List[Tensor]):
size = len(tensors)
sizes = [0] * size
for i in range(size):
if tensors[i].dim() != 0:
sizes[i] = tensors[i].size()[0]
def backward(grad_outputs: List[Tensor]):
grads_tensors = []
for i in range(size):
view_shape = [1] * size
if sizes[i] == 0:
view_shape[i] = 1
grads_tensors.append((grad_outputs[i]._grad_sum_to_size(view_shape)).reshape(()))
else:
view_shape[i] = sizes[i]
grads_tensors.append((grad_outputs[i]._grad_sum_to_size(view_shape)).reshape([sizes[i]]))
return grads_tensors
return torch.meshgrid(tensors), backward
def mv(self, vec):
def backward(grad_output):
return grad_output.ger(vec), self.t().mv(grad_output)
return torch.mv(self, vec), backward
def nonzero(self):
def backward(grad_output):
return None
return torch.nonzero(self), backward
def ones_like(self):
def backward(grad_output):
return None
return torch.ones_like(self, memory_format=1), backward
def pow_0(self,
exponent: number):
def backward(grad_output):
if float(exponent) == 0.0:
grad_self = torch.zeros_like(self, memory_format=1)
else:
grad_self = grad_output * exponent * torch.pow(self, float(exponent) - 1)
return grad_self, None
return torch.pow(self, exponent), backward
def pow_1(self, exponent):
result = torch.pow(self, exponent)
self_size = torch._size_if_not_equal(self.size(), result.size())
exponent_size = torch._size_if_not_equal(exponent.size(), result.size())
def backward(grad_output):
grad_self = torch.where(exponent == 0.0, torch.zeros_like(self, memory_format=1), grad_output * exponent * torch.pow(self, exponent - 1))._grad_sum_to_size(self_size)
grad_exponent = (grad_output * torch.pow(self, exponent) * torch.log(self))._grad_sum_to_size(exponent_size)
return grad_self, grad_exponent
return result, backward
def pow_2(self: number,
exponent):
def backward(grad_output):
grad_exponent = grad_output * torch.pow(self, exponent) * torch.log(float(self))
return None, grad_exponent
return torch.pow(self, exponent), backward
def rsub_0(self,
other,
alpha: number):
result = torch.rsub(self, other, alpha=alpha)
self_size = torch._size_if_not_equal(self.size(), result.size())
other_size = torch._size_if_not_equal(other.size(), result.size())
def backward(grad_output):
grad_self = (- grad_output * alpha)._grad_sum_to_size(self_size)
grad_other = (grad_output)._grad_sum_to_size(other_size)
return grad_self, grad_other, None
return result, backward
def rsub_1(self,
other: number,
alpha: number):
def backward(grad_output):
grad_self = (- grad_output * alpha)
return grad_self, None, None
return torch.rsub(self, other, alpha), backward
def sqrt(self):
result = torch.sqrt(self)
def backward(grad_output):
return grad_output / (2 * result)
return result, backward
def t(self):
def backward(grad_output):
return torch.t(grad_output)
return torch.t(self), backward
def to_0(self,
device: Optional[Device],
dtype: Optional[int],
non_blocking: bool,
copy: bool):
self_device = self.device
self_dtype = self.dtype
if device is not None:
result = self.to(device, dtype=dtype, non_blocking=non_blocking, copy=copy)
else:
result = self.to(dtype, non_blocking=non_blocking, copy=copy)
def backward(grad_output):
grad_self = grad_output.to(self_device, dtype=self_dtype, non_blocking=non_blocking, copy=copy)
return grad_self, None, None, None, None
return result, backward
def to_1(self,
dtype: int,
non_blocking: bool,
copy: bool):
self_dtype = self.dtype
def backward(grad_output):
grad_self = grad_output.to(self_dtype, non_blocking, copy)
return grad_self, None, None, None
return self.to(dtype=dtype, non_blocking=non_blocking, copy=copy), backward
def to_2(self,
other,
non_blocking: bool,
copy: bool):
def backward(grad_output):
grad_self = grad_output.to(self, non_blocking, copy)
return grad_self, None, None, None
return self.to(other, non_blocking=non_blocking, copy=copy), backward
def transpose(self,
dim0: int,
dim1: int):
def backward(grad_output):
return torch.transpose(grad_output, dim0, dim1), None, None
return torch.transpose(self, dim0, dim1), backward
def view(self,
size: List[int]):
self_size = self.size()
def backward(grad_output):
return grad_output.reshape(self_size), None
return torch.view(self, size), backward
)",
R"(
def AD_sizes_if_not_equal_multi_0(t1, t2, res):
return torch._size_if_not_equal(t1.size(), res.size()), torch._size_if_not_equal(t2.size(), res.size())
def mul_0(self, other):
result = self * other
self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
def backward(grad_output):
grad_self = (grad_output * other)._grad_sum_to_size(self_size)
grad_other = (grad_output * self)._grad_sum_to_size(other_size)
return grad_self, grad_other
return result, backward
def mul_1(self, other: number):
def backward(grad_output):
return grad_output * other, None
return self * other, backward
def div_0(self, other):
result = self / other
self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
def backward(grad_output):
grad_self = (grad_output / other)._grad_sum_to_size(self_size)
grad_other = (-grad_output * self / (other * other))._grad_sum_to_size(other_size)
return grad_self, grad_other
return result, backward
def div_1(self, other: number):
def backward(grad_output):
return grad_output / other, None
return self / other, backward
def div_2(self, other, *, rounding_mode: Optional[str]):
result = torch.div(self, other, rounding_mode=rounding_mode)
self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
def backward(grad_output):
if rounding_mode is None:
grad_self = (grad_output / other)._grad_sum_to_size(self_size)
grad_other = (-grad_output * self / (other * other))._grad_sum_to_size(other_size)
else:
grad_self = torch.zeros_like(self)
grad_other = torch.zeros_like(other)
return grad_self, grad_other, None
return result, backward
def div_3(self, other: number, *, rounding_mode: Optional[str]):
result = torch.div(self, other, rounding_mode=rounding_mode)
def backward(grad_output):
if rounding_mode is None:
grad_self = (grad_output / other)
else:
grad_self = torch.zeros_like(self, memory_format=1)
return grad_self, None, None
return result, backward
def max(self, other):
result = torch.max(self, other)
self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
def backward(grad_output):
grad_self = (grad_output * (self > other).type_as(grad_output))._grad_sum_to_size(self_size)
grad_other = (grad_output * (other > self).type_as(grad_output))._grad_sum_to_size(other_size)
return grad_self, grad_other
return result, backward
def min(self, other):
def backward(grad_output):
grad_self = (grad_output * (self < other).type_as(grad_output))._grad_sum_to_size(self.size())
grad_other = (grad_output * (other < self).type_as(grad_output))._grad_sum_to_size(other.size())
return grad_self, grad_other
return torch.min(self, other), backward
def sigmoid(self):
result = torch.sigmoid(self)
def backward(grad_output):
return (1 - result) * result * grad_output
return result, backward
# Share backward with threshold
def relu(self):
result = torch.relu(self)
def backward(grad_output):
return grad_output * (result > 0).type_as(result)
return result, backward
def relu6(self):
result = torch.relu6(self)
def backward(grad_output):
return grad_output * ((result > 0) & (result < 6.0))
return result, backward
def leaky_relu(self, negative_slope: number):
result = torch.leaky_relu(self, negative_slope)
def backward(grad_output):
return grad_output * torch.where(self > 0, 1.0, negative_slope).type_as(result), None
return result, backward
def gelu(self):
result = torch.gelu(self)
def backward(grad_output):
m_2_sqrtpi = 1.12837916709551257390
m_sqrt1_2 = 0.707106781186547524401
alpha = m_sqrt1_2
beta = m_2_sqrtpi * m_sqrt1_2 * 0.5
cdf = (torch.erf(self * m_sqrt1_2) + 1.0) * 0.5
pdf = beta * torch.exp(self * self * -0.5)
return grad_output * (cdf + self * pdf)
return result, backward
def hardswish(self):
result = torch.hardswish(self)
def backward(grad_output):
m = (self > 3.).type_as(result)
m = torch.where((self >= -3.) & (self <= 3.), self / 3. + .5, m)
return grad_output * m
return result, backward
def hardsigmoid(self):
result = torch.hardsigmoid(self)
def backward(grad_output):
m = (self > -3.) & (self < 3.)
lhs = grad_output * (1.0 / 6.0)
return torch.where(m, lhs, m.type_as(self))
return result, backward
def erfc(self):
def backward(grad_output):
# Precomputed constant C = -2.0 / math.sqrt(math.pi)
C = -1.1283791670955126
return C * torch.exp(-self * self) * grad_output
return torch.erfc(self), backward
def exp(self):
result = torch.exp(self)
def backward(grad_output):
return grad_output * result
return result, backward
def neg(self):
def backward(grad_output):
return grad_output.neg()
return torch.neg(self), backward
def where(condition, self, other):
result = torch.where(condition, self, other)
self_size, other_size = AD_sizes_if_not_equal_multi_0(self, other, result)
def backward(grad_output):
grad_self = (grad_output * condition.type_as(grad_output))._grad_sum_to_size(self_size)
grad_other = (grad_output * (condition.bitwise_not()).type_as(grad_output))._grad_sum_to_size(other_size)
return None, grad_self, grad_other
return result, backward
def type_as(self, other):
def backward(grad_output):
return grad_output.type_as(self), None
return torch.type_as(self, other), backward
def unsqueeze(self, dim: int):
def backward(grad_output):
return grad_output.squeeze(dim), None
return torch.unsqueeze(self, dim), backward
def abs(self):
def backward(grad_output):
return grad_output * self.sign()
return torch.abs(self), backward
def acos(self):
def backward(grad_output):
return grad_output * -((-self * self + 1).rsqrt())
return torch.acos(self), backward
def asin(self):
def backward(grad_output):
return grad_output * (-self * self + 1).rsqrt()
return torch.asin(self), backward
def atan(self):
def backward(grad_output):
return grad_output / (self * self + 1)
return torch.atan(self), backward
def ceil(self):
def backward(grad_output):
return torch.zeros_like(grad_output, memory_format=1)
return torch.ceil(self), backward
def cos(self):
def backward(grad_output):
return grad_output * -self.sin()
return torch.cos(self), backward
def cosh(self):
def backward(grad_output):
return grad_output * self.sinh()
return torch.cosh(self), backward
def expm1(self):
result = torch.expm1(self)
def backward(grad_output):
return grad_output * (result + 1)
return result, backward
def floor(self):
def backward(grad_output):
return torch.zeros_like(grad_output, memory_format=1)
return torch.floor(self), backward
def frac(self):
def backward(grad_output):
return grad_output
return torch.frac(self), backward
def log(self):
def backward(grad_output):
return grad_output.div(self)
return torch.log(self), backward
def log10(self):
def backward(grad_output):
return grad_output / (self * 2.3025850929940456)
return torch.log10(self), backward
def log1p(self):
def backward(grad_output):
return grad_output / (self + 1)
return torch.log1p(self), backward
def log2(self):
def backward(grad_output):
return grad_output / (self * 0.6931471805599453)
return torch.log2(self), backward
def rand_like(self, *, memory_format: Optional[int]):
def backward(grad_output):
return None
return torch.rand_like(self, memory_format=memory_format), backward
def reciprocal(self):
result = torch.reciprocal(self)
def backward(grad_output):
return -grad_output * result * result
return result, backward
def round(self):
def backward(grad_output):
return torch.zeros_like(grad_output, memory_format=1)
return torch.round(self), backward
def rsqrt(self):
result = torch.rsqrt(self)
def backward(grad_output):
return -grad_output * result * result * result / 2
return result, backward
def sin(self):
def backward(grad_output):
return grad_output * self.cos()
return torch.sin(self), backward
def sinh(self):
def backward(grad_output):
return grad_output * self.cosh()
return torch.sinh(self), backward
def tan(self):
result = torch.tan(self)
def backward(grad_output):
return grad_output * (1. + result * result)
return result, backward
def trunc(self):
def backward(grad_output):
return torch.zeros_like(grad_output, memory_format=1)
return torch.trunc(self), backward
def _grad_sum_to_size(self,
size: Optional[List[int]]):
result = torch._grad_sum_to_size(self, size)
self_size = torch._size_if_not_equal(self.size(), result.size())
def backward(grad_output):
if self_size is None:
grad_input = grad_output
else:
grad_input = grad_output.expand(self_size)
return grad_input, None
return result, backward
)",
R"(
def batch_norm(input : Tensor,
weight : Optional[Tensor],
bias : Optional[Tensor],
running_mean : Optional[Tensor],
running_var : Optional[Tensor],
training : bool,
momentum : float,
eps : float,
cudnn_enabled : bool):
output, save1, save2, reserve, impl_idx = torch._batch_norm_impl_index(
input, weight, bias, running_mean, running_var, training,
momentum, eps, cudnn_enabled)
has_weight = weight is not None
has_bias = bias is not None
def backward(grad_output):
dinput, dweight, dbias = torch._batch_norm_impl_index_backward(
impl_idx, input, grad_output, weight, running_mean, running_var,
save1, save2, training, eps, [True, has_weight, has_bias], reserve)
return dinput, dweight, dbias, None, None, None, None, None, None
return output, backward
def layer_norm_disabled(input : Tensor,
normalized_shape : List[int],
weight : Optional[Tensor],
bias : Optional[Tensor],
eps : float,
cudnn_enable : bool):
output, mean, rstd = torch.native_layer_norm(input, normalized_shape, weight, bias, eps)
def backward(grad_output):
output_mask = [True, weight is not None, bias is not None]
grad_input, grad_weight, grad_bias = torch.native_layer_norm_backward(grad_output, input, normalized_shape, mean, rstd, weight, bias, output_mask)
return grad_input, None, grad_weight, grad_bias, None, None
return output, backward
def AD_fused_dropout_backward(grad,
mask,
p1m: float):
p1r = 1. / p1m
grad_input = grad * (mask.type_as(grad) * p1r)
return grad_input
def dropout(input,
p: float,
train: bool):
use_cuda = input.is_cuda
# lowering is specialized for cuda because cuda fuser can efficiently fuse those operations
# for cpu backend, where fusions are disabled, a different lowering that is more efficient
# in the absence of fusion is used
p1m = 1. - p
if train:
if use_cuda:
mask = torch.rand_like(input, memory_format=1) < p1m
res = mask.type_as(input) * input * (1./p1m)
else:
mask = torch.empty_like(input, memory_format=1)
mask.bernoulli_(p1m)
res = mask * input / p1m
else:
p1m = 1.
res = input
mask = torch.empty_like(input, memory_format=1)
def backward(grad_output):
use_cuda = grad_output.is_cuda
if use_cuda:
grad_input = AD_fused_dropout_backward(grad_output, mask, p1m)
else:
grad_input = grad_output * mask / p1m
return grad_input, None, None
return res, backward
def embedding(weight,
indices,
padding_idx: int,
scale_grad_by_freq: bool,
sparse: bool):
weight_size_0 = weight.size()[0]
def backward(grad_output):
grad_weight = torch.embedding_backward(grad_output, indices, weight_size_0, padding_idx, scale_grad_by_freq, sparse)
return grad_weight, None, None, None, None
return torch.embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse), backward
def log_softmax(self, dim: int, dtype: Optional[int]):
result = torch.log_softmax(self, dim, dtype)
def backward(grad_output):
grad_self = torch._log_softmax_backward_data(grad_output, result, dim, self.dtype)
return grad_self, None, None
return result, backward
def nll_loss(self, target, weight: Optional[Tensor], reduction: int, ignore_index: int):
result, total_weight = torch.nll_loss_forward(self, target, weight, reduction, ignore_index)
def backward(grad):
return torch.nll_loss_backward(grad, self, target, weight, reduction, ignore_index, total_weight), None, None, None, None
return result, backward
def softmax(self, dim: int, dtype: Optional[int]):
result = torch.softmax(self, dim, dtype)
def backward(grad_output):
grad_self = torch._softmax_backward_data(grad_output, result, dim, self.dtype)
return grad_self, None, None
return result, backward
)",
R"(
def AD_adaptive_avg_pool3d_backward(grad,
self,
output_size: List[int]):
if output_size[0] == 1 and output_size[1] == 1 and output_size[2] == 1:
self_size = self.size()
grad_self = grad.expand(self.size()) / (self_size[-1] * self_size[-2] * self_size[-3])
else:
grad_self = torch._adaptive_avg_pool3d_backward(grad, self)
return grad_self
def AD_adaptive_avg_pool2d_backward(grad,
self,
output_size: List[int]):
if output_size[0] == 1 and output_size[1] == 1:
self_size = self.size()
grad_self = grad.expand(self.size()) / (self_size[-1] * self_size[-2])
else:
grad_self = torch._adaptive_avg_pool2d_backward(grad, self)
return grad_self
def AD_adaptive_avg_pool1d_backward(grad,
input,
output_size: List[int]):
output_size_2d = [1, output_size[0]]
grad_input = AD_adaptive_avg_pool2d_backward(grad.unsqueeze(2), input.unsqueeze(2), output_size_2d).squeeze(2)
return grad_input
def adaptive_avg_pool1d(self,
output_size: List[int]):
def backward(grad_output):
grad_self = AD_adaptive_avg_pool1d_backward(grad_output, self, output_size)
return grad_self, None
return torch.adaptive_avg_pool1d(self, output_size), backward
def adaptive_avg_pool2d(self,
output_size: List[int]):
def backward(grad_output):
# self is used in backward, no need to pass in its size explicitly
grad_self = AD_adaptive_avg_pool2d_backward(grad_output, self, output_size)
return grad_self, None
return torch.adaptive_avg_pool2d(self, output_size), backward
def adaptive_avg_pool3d(self,
output_size: List[int]):
def backward(grad_output):
grad_self = AD_adaptive_avg_pool3d_backward(grad_output, self, output_size)
return grad_self, None
return torch.adaptive_avg_pool3d(self, output_size), backward
def avg_pool2d(self,
kernel_size: List[int],
stride: List[int],
padding: List[int],
ceil_mode: bool,
count_include_pad: bool,
divisor_override: Optional[int]):
def backward(grad_output):
grad_self = torch.avg_pool2d_backward(grad_output, self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override)
return grad_self, None, None, None, None, None, None
return torch.avg_pool2d(self, kernel_size, stride, padding, ceil_mode, count_include_pad, divisor_override), backward
def max_pool2d(self,
kernel_size: List[int],
stride: List[int],
padding: List[int],
dilation: List[int],
ceil_mode: bool):
output, indices = torch.max_pool2d_with_indices(self, kernel_size, stride, padding, dilation, ceil_mode)
def backward(grad_output):
grad_self = torch.max_pool2d_with_indices_backward(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices)
return grad_self, None, None, None, None, None
return output, backward
def max_pool2d_with_indices(self,
kernel_size: List[int],
stride: List[int],
padding: List[int],
dilation: List[int],
ceil_mode: bool):
output, indices = torch.max_pool2d_with_indices(self, kernel_size, stride, padding, dilation, ceil_mode)
def backward(grad_output):
grad_self = torch.max_pool2d_with_indices_backward(grad_output, self, kernel_size, stride, padding, dilation, ceil_mode, indices)
return grad_self, None, None, None, None, None
return output, indices, backward
)",
R"(
def AD_sizes_if_not_equal_multi_1(t1, t2, res):
return torch._size_if_not_equal(t1.size(), res.size()), torch._size_if_not_equal(t2.size(), res.size())
def add_0(self,
other,
*,
alpha: number):
result = torch.add(self, other, alpha=alpha)
self_size, other_size = AD_sizes_if_not_equal_multi_1(self, other, result)
def backward(grad_output):
grad_other = (grad_output * alpha)._grad_sum_to_size(other_size)
grad_self = (grad_output)._grad_sum_to_size(self_size)
return grad_self, grad_other, None
return result, backward
def add_1(self,
other: number,
alpha: number):
def backward(grad_output):
return grad_output, None, None
return torch.add(self, other, alpha=alpha), backward
def sub_0(self,
other,
*,
alpha: number):
result = torch.sub(self, other, alpha=alpha)
self_size, other_size = AD_sizes_if_not_equal_multi_1(self, other, result)
def backward(grad_output):
grad_other = (-grad_output * alpha)._grad_sum_to_size(other_size)
grad_self = (grad_output)._grad_sum_to_size(self_size)
return grad_self, grad_other, None
return result , backward
def sub_1(self,
other: number,
alpha: number):
def backward(grad_output):
return grad_output, None, None
return torch.sub(self, other, alpha=alpha), backward
def threshold(self,
threshold: number,
value: number):
def backward(grad_output):
mask = (self >= threshold).type_as(self)
return grad_output * mask, None, None
return torch.threshold(self, threshold, value), backward
def softplus(self,
beta: number,
threshold: number):
result = torch.softplus(self, beta, threshold)
def backward(grad_output):
z = torch.exp(result * beta)
return torch.where((result * beta) > threshold, grad_output, grad_output * (z - 1.) / z), None, None
return result, backward
def fmod(self,
other: number):
def backward(grad_output):
return grad_output, None
return torch.fmod(self, other), backward
def remainder(self,
other: number):
def backward(grad_output):
return grad_output, None
return torch.remainder(self, other), backward
def addmm(self,
mat1,
mat2,
*,
beta: number,
alpha: number):
result = torch.addmm(self, mat1, mat2, beta=beta, alpha=alpha)
self_size = torch._size_if_not_equal(self.size(), result.size())
def backward(grad_output):
self_grad = (grad_output * beta)._grad_sum_to_size(self_size)
mat1_grad = grad_output.mm(mat2.t()) * alpha
mat2_grad = mat1.t().mm(grad_output) * alpha
return self_grad, mat1_grad, mat2_grad, None, None
return result, backward
# Comparison operators
def lt(self, other: number):
def backward(grad_output):
return None, None
return torch.lt(self, other), backward
def le(self, other: number):
def backward(grad_output):
return None, None
return torch.le(self, other), backward
def gt(self, other: number):
def backward(grad_output):
return None, None
return torch.gt(self, other), backward
def ge(self, other: number):
def backward(grad_output):
return None, None
return torch.ge(self, other), backward
def eq(self, other: number):
def backward(grad_output):
return None, None
return torch.eq(self, other), backward
def ne(self, other: number):
def backward(grad_output):
return None, None
return torch.ne(self, other), backward
def hardshrink(self, lambd: number):
def backward(grad_output):
mask = ((self > lambd) | (self < -lambd))
return grad_output * mask, None
return torch.hardshrink(self, lambd=lambd), backward
def hardtanh(self, min_val: number, max_val: number):
def backward(grad_output):
mask = ((self >= min_val) * (self <= max_val))
return grad_output * mask, None, None
return torch.hardtanh(self, min_val=min_val, max_val=max_val), backward
def clamp_1(self,
min: Optional[number],
max: Optional[number]):
def backward(grad_output):
if min is not None and max is not None:
mask = ((self >= float(min)) * (self <= float(max))).type_as(self)
return grad_output * mask, None, None
elif min is not None:
mask = (self >= float(min)).type_as(self)
return grad_output * mask, None, None
elif max is not None:
mask = (self <= float(max)).type_as(self)
return grad_output * mask, None, None
else: #min is None and max is None
return grad_output, None, None
return torch.clamp(self, min=min, max=max), backward
def clamp_2(self,
min: Optional[Tensor],
max: Optional[Tensor]):
def backward(grad_output):
if min is not None and max is not None:
mask = ((self >= min) * (self <= max)).type_as(self)
return grad_output * mask, None, None
elif min is not None:
mask = (self >= min).type_as(self)
return grad_output * mask, None, None
elif max is not None:
mask = (self <= max).type_as(self)
return grad_output * mask, None, None
else: #min is None and max is None
return grad_output, None, None
return torch.clamp(self, min=min, max=max), backward
)"};
std::unordered_map<std::string, GradientPair> schema_to_graphs;
// This map is a workaround to cache compiled gradient_pairs. Ideally this graph
// should be compiled only once and saved in Operator structure.
// This should be done along with merging into native_functions.yaml.
std::unordered_map<const FunctionSchema*, GradientPair> cached_gradient_pairs;
// CompilationUnit that holds all these Functions and keeps them alive.
CompilationUnit compilation_unit;
} // anonymous namespace
std::pair<std::shared_ptr<Graph>, Value*> extractClosure(Value* closure) {
TORCH_CHECK(
closure->node()->kind() == prim::TupleConstruct,
"closure must be a literal tuple construct");
Value* fn = closure->node()->inputs().at(0);
Value* context = closure->node()->inputs().at(1);
TORCH_CHECK(
fn->node()->kind() == prim::Closure,
"closure tuple must contain a prim::Closure");
return std::make_pair(fn->node()->g(attr::Subgraph), context);
}
Argument originalReturnType(const TupleTypePtr& tup) {
TORCH_CHECK(tup->elements().size() > 1);
if (tup->elements().size() == 2)
return Argument("", tup->elements().at(0));
std::vector<TypePtr> types = tup->elements().vec();
types.pop_back();
return Argument("", TupleType::create(std::move(types)));
}
// In torchscript AD formulas, we define {func_0, func_1, ...} as
// overloaded functions of `func`.
// Remove the suffix before adding the schema string to map
// schema_to_graphs.
std::string overloadedSchemaString(const FunctionSchema& schema) {
const auto& schema_name = schema.name();
auto pos = schema_name.find_last_of('_');
auto schema_name_suffix = schema_name.substr(pos + 1);
std::string schema_string = canonicalSchemaString(schema);
if (!schema_name_suffix.empty() &&
schema_name_suffix.find_first_not_of("0123456789") == std::string::npos) {
schema_string.replace(
schema_string.find(schema_name),
schema_name.length(),
schema_name.substr(0, pos));
}
return schema_string;
}
bool isHelperFunction(const std::string& method_name) {
std::string helper_prefix = "AD_";
return method_name.compare(0, helper_prefix.length(), helper_prefix) == 0;
}
void loadModule(const CompilationUnit& module) {
for (const auto& method : module.get_functions()) {
if (isHelperFunction(method->name()))
continue;
GradientPair pair;
pair.forward = toGraphFunction(*method).graph();
// lookup the backward function
Node* forward_tuple = pair.forward->outputs().at(0)->node();
if (forward_tuple->kind() != prim::TupleConstruct) {
throw ErrorReport(forward_tuple->sourceRange())
<< "gradient must return literal a tuple";
}
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
Value* context;
std::tie(pair.backward, context) =
extractClosure(forward_tuple->inputs().back());
// do surgery on the forward function to remove the closure tuple and
// replace it with the context variable:
// backward = (<lambda>, context_tuple)
// return original, backward
// -----
// return original, context_tuple
std::vector<Value*> new_inputs = forward_tuple->inputs().vec();
new_inputs.back() = context;
Value* new_tuple =
pair.forward->appendNode(pair.forward->createTuple(new_inputs))
->output();
pair.forward->eraseOutput(0);
pair.forward->registerOutput(new_tuple);
forward_tuple->destroy();
// derive schema from original function's schema:
const FunctionSchema& loaded_schema = method->getSchema();
FunctionSchema actual_schema(
Symbol::aten(loaded_schema.name()),
loaded_schema.overload_name(),
loaded_schema.arguments(),
{originalReturnType(new_tuple->type()->expect<TupleType>())});
// modify canonical string for function overloading
// prefer not to modify the schema name
auto schema_string = overloadedSchemaString(actual_schema);
schema_to_graphs[schema_string] = std::move(pair);
}
}
void loadFunctions() {
for (const std::string& str : functions) {
compilation_unit.define(c10::nullopt, str, nativeResolver(), nullptr);
}
loadModule(compilation_unit);
}
c10::optional<GradientPair> gradientInfoForSchema(
const FunctionSchema& schema) {
std::lock_guard<std::mutex> guard(lock);
if (schema_to_graphs.size() == 0) {
loadFunctions();
}
auto cache_it = cached_gradient_pairs.find(&schema);
if (cache_it != cached_gradient_pairs.end()) {
return cache_it->second;
} else {
auto schema_str = canonicalSchemaString(schema);
// For debugging AD change:
// std::cout << "Looking for " << schema_str << std::endl;
auto sym_script_it = schema_to_graphs.find(schema_str);
if (sym_script_it != schema_to_graphs.end()) {
cached_gradient_pairs.emplace_hint(
cache_it, &schema, sym_script_it->second);
return sym_script_it->second;
}
}
return c10::nullopt;
}
bool hasGradientInfoForSchema(const FunctionSchema& schema) {
return gradientInfoForSchema(schema).has_value();
}
} // namespace jit
} // namespace torch