annotate torch.autograd.* modules (#45004)

Summary:
Fixes https://github.com/pytorch/pytorch/issues/44638

Pull Request resolved: https://github.com/pytorch/pytorch/pull/45004

Reviewed By: VitalyFedyunin

Differential Revision: D24113562

Pulled By: ezyang

fbshipit-source-id: a85018b7e08b2fe6cf2bc14a217eb418cb2b9de4
This commit is contained in:
Guilherme Leobas 2020-10-07 10:50:50 -07:00 committed by Facebook GitHub Bot
parent 83d2c9a232
commit 9679e1affc
9 changed files with 143 additions and 115 deletions

View File

@ -180,27 +180,6 @@ ignore_errors = True
[mypy-torch.utils.hipify.hipify_python]
ignore_errors = True
[mypy-torch.autograd._functions.tensor]
ignore_errors = True
[mypy-torch.autograd.function]
ignore_errors = True
[mypy-torch.autograd.functional]
ignore_errors = True
[mypy-torch.autograd.profiler]
ignore_errors = True
[mypy-torch.autograd.gradcheck]
ignore_errors = True
[mypy-torch.autograd.anomaly_mode]
ignore_errors = True
[mypy-torch.autograd.variable]
ignore_errors = True
[mypy-torch.nn.quantized.modules.batchnorm]
ignore_errors = True

View File

@ -93,6 +93,7 @@ def DisableTorchFunction(): ...
# Defined in torch/csrc/utils/tensor_layouts.cpp
strided : layout = ...
sparse_coo : layout = ...
_mkldnn : layout = ...
# Defined in torch/csrc/MemoryFormat.cpp
class memory_format: ...
@ -268,6 +269,10 @@ def import_ir_module_from_buffer(
class Graph:
...
# Defined in torch/csrc/jit/ir/ir.h
class Value:
...
# Defined in torch/aten/src/ATen/core/function_schema.h
class FunctionSchema:
...
@ -389,6 +394,7 @@ def _is_xnnpack_enabled() -> _bool: ... # THPModule_isEnabledXNNPACK
def _vmapmode_increment_nesting() -> _int: ... # THPModule_vmapmode_increment_nesting
def _vmapmode_decrement_nesting() -> _int: ... # THPModule_vmapmode_decrement_nesting
def _log_api_usage_once(str) -> None: ... # LogAPIUsageOnceFromPython
def _demangle(str) -> str: ... # c10::demangle
# Defined in `valgrind.h` and `callgrind.h` respecitively.
def valgrind_supported_platform() -> _bool: ... # NVALGRIND
@ -497,6 +503,10 @@ ${legacy_storage_base_hints}
# TODO: where
${legacy_class_hints}
# Defined in torch/csrc/autograd/python_engine.cpp
class _ImperativeEngine:
...
# Defined in torch/csrc/autograd/python_variable.cpp
class _TensorBase(object):
requires_grad: _bool

View File

@ -11,10 +11,14 @@ class ProfilerState(Enum):
class ProfilerConfig:
def __init__(self, state: ProfilerState, report_input_shapes: bool, profile_memory: bool) -> None: ...
def __init__(
self, state: ProfilerState,
report_input_shapes: bool,
profile_memory: bool,
with_stack: bool
) -> None: ...
...
class ProfilerEvent:
def cpu_elapsed_us(self, other: ProfilerEvent) -> float: ...
def cpu_memory_usage(self) -> int: ...

12
torch/_C/_functions.pyi Normal file
View File

@ -0,0 +1,12 @@
from torch import Tensor
from typing import AnyStr, List
class UndefinedGrad:
def __init__(self) -> None: ...
def __call__(self, inputs: List[Tensor]) -> List[Tensor]: ...
...
class DelayedError:
def __init__(self, msg: AnyStr, num_inputs: int) -> None: ...
def __call__(self, inputs: List[Tensor]) -> List[Tensor]: ...
...

View File

@ -1,11 +1,12 @@
import torch
import torch._C as _C
from torch._C import _functions
import torch.utils.hooks as hooks
from torch._six import with_metaclass
import functools
import warnings
from collections import OrderedDict
from typing import Any
from typing import Any, List, Optional
class _ContextMethodMixin(object):
@ -84,7 +85,8 @@ class BackwardCFunction(_C._FunctionBase, _ContextMethodMixin, _HookMixin):
_is_legacy = False
def apply(self, *args):
return self._forward_cls.backward(self, *args)
# _forward_cls is defined by derived class
return self._forward_cls.backward(self, *args) # type: ignore
class FunctionMeta(type):
@ -115,8 +117,8 @@ class FunctionMeta(type):
return super(FunctionMeta, cls).__init__(name, bases, attrs)
class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)):
# mypy doesn't understand `with_metaclass` from torch._six
class Function(with_metaclass(FunctionMeta, _C._FunctionBase, _ContextMethodMixin, _HookMixin)): # type: ignore
r"""Records operation history and defines formulas for differentiating ops.
See the Note on extending the autograd engine for more details on how to use
@ -227,7 +229,7 @@ def once_differentiable(fn):
if not isinstance(outputs, tuple):
outputs = (outputs,)
err_fn = torch._C._functions.DelayedError(
err_fn = _functions.DelayedError(
b"trying to differentiate twice a function that was marked"
b"with @once_differentiable", len(outputs))
@ -330,7 +332,7 @@ def _unflatten(input, proto):
# unflatten a list or tuple input into a nested list/tuple structure
# specified by proto
def unflatten_helper(input, proto):
res = []
res: List[Optional[torch.Tensor]] = []
if hasattr(proto, "_jit_wrap"):
return proto._jit_wrap(input)
if not isinstance(proto, (list, tuple)):
@ -379,16 +381,16 @@ class NestedIOFunction(Function):
del self._to_save_nested
return result
def backward(self, *gradients: Any) -> Any:
def backward(self, *gradients: Any) -> Any: # type: ignore
nested_gradients = _unflatten(gradients, self._nested_output)
result = self.backward_extended(*nested_gradients)
result = self.backward_extended(*nested_gradients) # type: ignore
return tuple(_iter_None_tensors(result))
__call__ = _do_forward
def forward(self, *args: Any) -> Any:
def forward(self, *args: Any) -> Any: # type: ignore
nested_tensors = _map_tensor_data(self._nested_input)
result = self.forward_extended(*nested_tensors)
result = self.forward_extended(*nested_tensors) # type: ignore
del self._nested_input
self._nested_output = result
return tuple(_iter_tensors(result))

View File

@ -1,4 +1,5 @@
import torch
from typing import Tuple, List
# Utility functions
@ -131,8 +132,8 @@ def _autograd_grad(outputs, inputs, grad_outputs=None, create_graph=False, retai
assert isinstance(grad_outputs, tuple)
assert len(outputs) == len(grad_outputs)
new_outputs = tuple()
new_grad_outputs = tuple()
new_outputs: Tuple[torch.Tensor, ...] = tuple()
new_grad_outputs: Tuple[torch.Tensor, ...] = tuple()
for out, grad_out in zip(outputs, grad_outputs):
if out is not None and out.requires_grad:
new_outputs += (out,)
@ -153,7 +154,7 @@ def _fill_in_zeros(grads, refs, strict, create_graph, stage):
if stage not in ["back", "back_trick", "double_back", "double_back_trick"]:
raise RuntimeError("Invalid stage argument '{}' to _fill_in_zeros".format(stage))
res = tuple()
res: Tuple[torch.Tensor, ...] = tuple()
for i, grads_i in enumerate(grads):
if grads_i is None:
if strict:
@ -427,10 +428,11 @@ def jacobian(func, inputs, create_graph=False, strict=False):
"jacobian")
_check_requires_grad(outputs, "outputs", strict=strict)
jacobian = tuple()
jacobian: Tuple[torch.Tensor, ...] = tuple()
for i, out in enumerate(outputs):
jac_i = tuple([] for _ in range(len(inputs)))
# mypy complains that expression and variable have different types due to the empty list
jac_i: Tuple[List[torch.Tensor]] = tuple([] for _ in range(len(inputs))) # type: ignore
for j in range(out.nelement()):
vj = _autograd_grad((out.reshape(-1)[j],), inputs,
retain_graph=True, create_graph=create_graph)

View File

@ -5,7 +5,7 @@ import torch.testing
from torch.overrides import is_tensor_like
from itertools import product
import warnings
from typing import Callable, Union, Optional
from typing import Callable, Union, Optional, Iterable, List
def zero_gradients(x):
if isinstance(x, torch.Tensor):
@ -29,15 +29,16 @@ def make_jacobian(input, num_out):
lambda x: x is not None, (make_jacobian(elem, num_out) for elem in input)))
if not jacobians:
return None
return type(input)(jacobians)
return type(input)(jacobians) # type: ignore
else:
return None
def iter_tensors(x, only_requiring_grad=False):
def iter_tensors(x: Union[torch.Tensor, Iterable[torch.Tensor]], only_requiring_grad: bool = False) -> Iterable[torch.Tensor]:
if is_tensor_like(x):
if x.requires_grad or not only_requiring_grad:
yield x
# mypy doesn't narrow type of `x` to torch.Tensor
if x.requires_grad or not only_requiring_grad: # type: ignore
yield x # type: ignore
elif isinstance(x, container_abcs.Iterable) and not isinstance(x, str):
for elem in x:
for result in iter_tensors(elem, only_requiring_grad):
@ -137,7 +138,7 @@ def get_numerical_jacobian(fn, input, target=None, eps=1e-3, grad_out=1.0):
indices = x_indices[i].tolist() + list(x_idx)
d_idx = sum(indices[k] * x_stride[k] for k in range(len(x_size)))
update_jacobians(x_value, x_idx, d_tensor, d_idx)
elif x_tensor.layout == torch._mkldnn:
elif x_tensor.layout == torch._mkldnn: # type: ignore
# Use .data here to get around the version check
x_tensor = x_tensor.data
if len(input) != 1:
@ -163,7 +164,7 @@ def get_analytical_jacobian(input, output, nondet_tol=0.0, grad_out=1.0):
if output.is_sparse:
raise ValueError('Sparse output is not supported at gradcheck yet. '
'Please call to_dense() on the output of fn for gradcheck.')
if output.layout == torch._mkldnn:
if output.layout == torch._mkldnn: # type: ignore
raise ValueError('MKLDNN output is not supported at gradcheck yet. '
'Please call to_dense() on the output of fn for gradcheck.')
diff_input_list = list(iter_tensors(input, True))
@ -303,13 +304,13 @@ def gradcheck(
content = inp._values() if inp.is_sparse else inp
# TODO: To cover more problematic cases, replace stride = 0 check with
# "any overlap in memory" once we have a proper function to check it.
if content.layout is not torch._mkldnn and \
not all(st > 0 or sz <= 1 for st, sz in zip(content.stride(), content.size())):
raise RuntimeError(
'The {}th input has a dimension with stride 0. gradcheck only '
'supports inputs that are non-overlapping to be able to '
'compute the numerical gradients correctly. You should call '
'.contiguous on the input before passing it to gradcheck.')
if content.layout is not torch._mkldnn: # type: ignore
if not all(st > 0 or sz <= 1 for st, sz in zip(content.stride(), content.size())):
raise RuntimeError(
'The {}th input has a dimension with stride 0. gradcheck only '
'supports inputs that are non-overlapping to be able to '
'compute the numerical gradients correctly. You should call '
'.contiguous on the input before passing it to gradcheck.')
any_input_requiring_grad = True
inp.retain_grad()
if not any_input_requiring_grad:
@ -403,30 +404,30 @@ def gradcheck(
# check if the backward multiplies by grad_output
output = _differentiable_outputs(func(*tupled_inputs))
if any([o.requires_grad for o in output]):
diff_input_list = list(iter_tensors(tupled_inputs, True))
diff_input_list: List[torch.Tensor] = list(iter_tensors(tupled_inputs, True))
if not diff_input_list:
raise RuntimeError("no Tensors requiring grad found in input")
grads_input = torch.autograd.grad(output, diff_input_list,
[torch.zeros_like(o, memory_format=torch.legacy_contiguous_format) for o in output],
allow_unused=True)
for gi, i in zip(grads_input, diff_input_list):
for gi, di in zip(grads_input, diff_input_list):
if gi is None:
continue
if isinstance(gi, torch.Tensor) and gi.layout != torch.strided:
if gi.layout != i.layout:
return fail_test('grad is incorrect layout (' + str(gi.layout) + ' is not ' + str(i.layout) + ')')
if gi.layout != di.layout:
return fail_test('grad is incorrect layout (' + str(gi.layout) + ' is not ' + str(di.layout) + ')')
if gi.layout == torch.sparse_coo:
if gi.sparse_dim() != i.sparse_dim():
if gi.sparse_dim() != di.sparse_dim():
return fail_test('grad is sparse tensor, but has incorrect sparse_dim')
if gi.dense_dim() != i.dense_dim():
if gi.dense_dim() != di.dense_dim():
return fail_test('grad is sparse tensor, but has incorrect dense_dim')
gi = gi.to_dense()
i = i.to_dense()
di = di.to_dense()
if not gi.eq(0).all():
return fail_test('backward not multiplied by grad_output')
if gi.dtype != i.dtype or gi.device != i.device or gi.is_sparse != i.is_sparse:
if gi.dtype != di.dtype or gi.device != di.device or gi.is_sparse != di.is_sparse:
return fail_test("grad is incorrect type")
if gi.size() != i.size():
if gi.size() != di.size():
return fail_test('grad is incorrect size')
if check_undefined_grad:

View File

@ -6,6 +6,8 @@ from torch.futures import Future
from collections import defaultdict, namedtuple
from operator import attrgetter
from typing import List, Dict, Tuple, Optional
try:
# Available in Python >= 3.2
from contextlib import ContextDecorator
@ -13,6 +15,13 @@ except ImportError:
import functools
class ContextDecorator(object): # type: ignore[no-redef]
def __enter__(self):
raise NotImplementedError
def __exit__(self, exc_type, exc_val, exc_tb):
raise NotImplementedError
def __call__(self, func):
@functools.wraps(func)
def wrapped(*args, **kwargs):
@ -78,13 +87,13 @@ class EventList(list):
# Algorithm has O(N * log(N)) complexity where N is number of
# intervals
for thread_id, thread_events in threads:
thread_events = sorted(
thread_events_ = sorted(
thread_events,
key=lambda event: [event.cpu_interval.start, -event.cpu_interval.end],
)
current_events = []
current_events: List[FunctionEvent] = []
cur_end = 0
for event in thread_events:
for event in thread_events_:
while len(current_events) > 0:
parent = current_events[-1]
if event.cpu_interval.start >= parent.cpu_interval.end or \
@ -253,7 +262,7 @@ class EventList(list):
An EventList containing FunctionEventAvg objects.
"""
self.populate_cpu_children()
stats = defaultdict(FunctionEventAvg)
stats: Dict[Tuple[int, Tuple[int, int]], FunctionEventAvg] = defaultdict(FunctionEventAvg)
def get_key(event, group_by_input_shapes, group_by_stack_n):
key = [str(event.key), str(event.node_id)]
@ -413,6 +422,7 @@ class profile(object):
def table(self, sort_by=None, row_limit=100, header=None, top_level_events_only=False):
self._check_finish()
assert self.function_events is not None
return self.function_events.table(
sort_by=sort_by, row_limit=row_limit, header=header,
top_level_events_only=top_level_events_only
@ -421,16 +431,19 @@ class profile(object):
def export_chrome_trace(self, path):
self._check_finish()
assert self.function_events is not None
return self.function_events.export_chrome_trace(path)
export_chrome_trace.__doc__ = EventList.export_chrome_trace.__doc__
def key_averages(self, group_by_input_shape=False, group_by_stack_n=0):
self._check_finish()
assert self.function_events is not None
return self.function_events.key_averages(group_by_input_shape, group_by_stack_n)
key_averages.__doc__ = EventList.key_averages.__doc__
def total_average(self):
self._check_finish()
assert self.function_events is not None
return self.function_events.total_average()
total_average.__doc__ = EventList.total_average.__doc__
@ -440,6 +453,7 @@ class profile(object):
all self times across all the events.
"""
self._check_finish()
assert self.function_events is not None
return self.function_events.self_cpu_time_total
@ -694,11 +708,11 @@ class FormattedTimesMixin(object):
@property
def cpu_time(self):
return 0.0 if self.count == 0 else 1.0 * self.cpu_time_total / self.count
return 0.0 if self.count == 0 else 1.0 * self.cpu_time_total / self.count # type: ignore
@property
def cuda_time(self):
return 0.0 if self.count == 0 else 1.0 * self.cuda_time_total / self.count
return 0.0 if self.count == 0 else 1.0 * self.cuda_time_total / self.count # type: ignore
class Interval(object):
@ -719,24 +733,24 @@ class FunctionEvent(FormattedTimesMixin):
self, id, node_id, name, thread, cpu_start, cpu_end, fwd_thread=None, input_shapes=None,
stack=None, scope=0, cpu_memory_usage=0, cuda_memory_usage=0, is_async=False,
is_remote=True, sequence_nr=-1):
self.id = id
self.node_id = node_id
self.name = name
self.cpu_interval = Interval(cpu_start, cpu_end)
self.thread = thread
self.fwd_thread = fwd_thread
self.kernels = []
self.count = 1
self.cpu_children = []
self.cpu_parent = None
self.input_shapes = input_shapes
self.stack = stack
self.scope = scope
self.cpu_memory_usage = cpu_memory_usage
self.cuda_memory_usage = cuda_memory_usage
self.is_async = is_async
self.is_remote = is_remote
self.sequence_nr = sequence_nr
self.id: int = id
self.node_id: int = node_id
self.name: str = name
self.cpu_interval: Interval = Interval(cpu_start, cpu_end)
self.thread: int = thread
self.fwd_thread: Optional[int] = fwd_thread
self.kernels: List[Kernel] = []
self.count: int = 1
self.cpu_children: List[FunctionEvent] = []
self.cpu_parent: Optional[FunctionEvent] = None
self.input_shapes: Tuple[int, ...] = input_shapes
self.stack: List = stack
self.scope: int = scope
self.cpu_memory_usage: int = cpu_memory_usage
self.cuda_memory_usage: int = cuda_memory_usage
self.is_async: bool = is_async
self.is_remote: bool = is_remote
self.sequence_nr: int = sequence_nr
def append_kernel(self, name, device, start, end):
self.kernels.append(Kernel(name, device, Interval(start, end)))
@ -830,24 +844,24 @@ class FunctionEvent(FormattedTimesMixin):
class FunctionEventAvg(FormattedTimesMixin):
"""Used to average stats over multiple FunctionEvent objects."""
def __init__(self):
self.key = None
self.count = 0
self.node_id = 0
self.is_async = False
self.is_remote = False
self.cpu_time_total = 0
self.cuda_time_total = 0
self.self_cpu_time_total = 0
self.self_cuda_time_total = 0
self.input_shapes = None
self.stack = None
self.scope = None
self.cpu_memory_usage = 0
self.cuda_memory_usage = 0
self.self_cpu_memory_usage = 0
self.self_cuda_memory_usage = 0
self.cpu_children = None
self.cpu_parent = None
self.key: Optional[str] = None
self.count: int = 0
self.node_id: int = 0
self.is_async: bool = False
self.is_remote: bool = False
self.cpu_time_total: int = 0
self.cuda_time_total: int = 0
self.self_cpu_time_total: int = 0
self.self_cuda_time_total: int = 0
self.input_shapes: Optional[List[List[int]]] = None
self.stack: Optional[List] = None
self.scope: Optional[int] = None
self.cpu_memory_usage: int = 0
self.cuda_memory_usage: int = 0
self.self_cpu_memory_usage: int = 0
self.self_cuda_memory_usage: int = 0
self.cpu_children: Optional[List[FunctionEvent]] = None
self.cpu_parent: Optional[FunctionEvent] = None
def add(self, other):
if self.key is None:
@ -950,6 +964,7 @@ def parse_event_records(thread_records):
# and the CPU time of the cuda start event for the device
def adjusted_time(cuda_record, cuda_records_map):
assert cuda_record.device() != -1
assert start_record is not None
cuda_time_0 = cuda_records_map[(cuda_record.node_id(), cuda_record.device())]
return cuda_time_0.cuda_elapsed_us(cuda_record) + start_record.cpu_elapsed_us(cuda_time_0)
@ -1102,6 +1117,8 @@ def parse_nvprof_trace(path):
for row in conn.execute(marker_query):
unique.see(row['marker_id'])
evt = FunctionEvent(id=row['marker_id'],
node_id=0, # missing a node_id when calling FunctionEvent. This is just to ensure
# that pytorch doesn't crash when creating a FunctionEvent() object
name=strings[row['name']],
cpu_start=row['start_time'],
cpu_end=row['end_time'],
@ -1215,15 +1232,15 @@ def build_table(
# Have to use a list because nonlocal is Py3 only...
SPACING_SIZE = 2
row_format = [""]
header_sep = [""]
line_length = [-SPACING_SIZE]
row_format_lst = [""]
header_sep_lst = [""]
line_length_lst = [-SPACING_SIZE]
MAX_STACK_ENTRY = 5
def add_column(padding, text_dir='>'):
row_format[0] += '{: ' + text_dir + str(padding) + '}' + (' ' * SPACING_SIZE)
header_sep[0] += '-' * padding + (' ' * SPACING_SIZE)
line_length[0] += padding + SPACING_SIZE
row_format_lst[0] += '{: ' + text_dir + str(padding) + '}' + (' ' * SPACING_SIZE)
header_sep_lst[0] += '-' * padding + (' ' * SPACING_SIZE)
line_length_lst[0] += padding + SPACING_SIZE
add_column(name_column_width)
for _ in headers[1:]:
@ -1237,10 +1254,10 @@ def build_table(
headers.append('Source Location')
add_column(src_column_width, text_dir='<')
row_format = row_format[0]
header_sep = header_sep[0]
line_length = line_length[0]
add_column = None
row_format = row_format_lst[0]
header_sep = header_sep_lst[0]
line_length = line_length_lst[0]
add_column = None # type: ignore
# Have to use a list because nonlocal is Py3 only...
result = []

View File

@ -7,9 +7,10 @@ class VariableMeta(type):
return isinstance(other, torch.Tensor)
class Variable(with_metaclass(VariableMeta, torch._C._LegacyVariableBase)):
# mypy doesn't understand torch._six.with_metaclass
class Variable(with_metaclass(VariableMeta, torch._C._LegacyVariableBase)): # type: ignore
pass
from torch._C import _ImperativeEngine as ImperativeEngine
Variable._execution_engine = ImperativeEngine()
Variable._execution_engine = ImperativeEngine() # type: ignore