Commit Graph

41 Commits

Author SHA1 Message Date
PyTorch MergeBot
24bdd03d23 Revert "Reify view_func() closures as ViewFuncs (#118404)"
This reverts commit d5a6762263.

Reverted https://github.com/pytorch/pytorch/pull/118404 on behalf of https://github.com/DanilBaibak due to Broken trunk ([comment](https://github.com/pytorch/pytorch/pull/118404#issuecomment-1938600260))
2024-02-12 12:38:51 +00:00
Joel Schlosser
d5a6762263 Reify view_func() closures as ViewFuncs (#118404)
Replaces `view_func()` closures with a reified `ViewFunc` data structure. Codegen generates a `ViewFunc` subclass for each view op (e.g. `NarrowViewFunc`) containing state needed to reconstruct the view. The `ViewFunc` API allows for querying and hot-swapping any `SymInt`s or `Tensors` in the state through `get_symints()` / `get_tensors()` / `clone_and_set()`, which will be essential for fake-ification later on.

```cpp
/// Base class for view functions, providing reapplication of a view on a new base.
/// Each view op should get a codegenerated subclass of this class containing
/// any state needed to reconstruct the view. The class also provides convenience
/// accessors for saved SymInts / tensor state. This is useful for e.g. fake-ification,
/// where we want to use symbolic values or fake tensors instead.
struct TORCH_API ViewFunc {
  virtual ~ViewFunc() {}
  /// Returns any SymInts in the saved state.
  virtual std::vector<c10::SymInt> get_symints() const { return {}; }
  /// Returns the number of SymInts in the saved state.
  virtual size_t num_symints() const { return 0; }
  /// Returns any tensors in the saved state.
  virtual std::vector<at::Tensor> get_tensors() const { return {}; }
  /// Returns the number of tensors in the saved state.
  virtual size_t num_tensors() const { return 0; }
  /// Reapplies the view on the given base using the saved state.
  virtual at::Tensor operator()(const at::Tensor&) const = 0;
  /// Returns a clone of this ViewFunc, optionally with the specified saved state.
  virtual std::unique_ptr<ViewFunc> clone_and_set(
      std::optional<std::vector<c10::SymInt>> = c10::nullopt,
      std::optional<std::vector<at::Tensor>> = c10::nullopt) const = 0;

protected:
  /// Sets the values of any SymInts in the saved state. The input vector size must
  /// match the number of SymInts in the saved state (i.e. the size of the list
  /// returned by get_symints()).
  virtual void set_symints(std::vector<c10::SymInt>) {}
  /// Sets the values of any Tensors in the saved state. The input vector size must
  /// match the number of Tensors in the saved state (i.e. the size of the list
  /// returned by get_tensors()).
  virtual void set_tensors(std::vector<at::Tensor>) {}
};
```

New codegen files:
* `torch/csrc/autograd/generated/ViewFunc.h`
* `torch/csrc/autograd/generated/ViewFuncs.cpp`

The templates for these also contains impls for `ChainedViewFunc` and `ErroringViewFunc` which are used in a few places within autograd.

Example codegen for `slice.Tensor`:
```cpp
// torch/csrc/autograd/generated/ViewFuncs.h
#define SLICE_TENSOR_VIEW_FUNC_AVAILABLE
struct SliceTensorViewFunc : public torch::autograd::ViewFunc {
  SliceTensorViewFunc(int64_t dim, c10::optional<c10::SymInt> start, c10::optional<c10::SymInt> end, c10::SymInt step) : dim(dim), start(start), end(end), step(step)
  {};
  virtual ~SliceTensorViewFunc() override {};
  virtual std::vector<c10::SymInt> get_symints() const override;
  virtual size_t num_symints() const override;
  virtual std::vector<at::Tensor> get_tensors() const override;
  virtual size_t num_tensors() const override;
  virtual at::Tensor operator()(const at::Tensor&) const override;
  virtual std::unique_ptr<ViewFunc> clone_and_set(
      std::optional<std::vector<c10::SymInt>> = c10::nullopt,
      std::optional<std::vector<at::Tensor>> = c10::nullopt) const override;

protected:
  virtual void set_symints(std::vector<c10::SymInt>) override;
  virtual void set_tensors(std::vector<at::Tensor>) override;

private:
  int64_t dim;
  c10::optional<c10::SymInt> start;
  c10::optional<c10::SymInt> end;
  c10::SymInt step;
};
...

// torch/csrc/autograd/generated/ViewFuncs.cpp
std::vector<c10::SymInt> SliceTensorViewFunc::get_symints() const {
  ::std::vector<c10::SymInt> symints;
  symints.reserve((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
  if(start.has_value()) symints.insert(symints.end(), *(start));
  if(end.has_value()) symints.insert(symints.end(), *(end));
  symints.push_back(step);
  return symints;
}

size_t SliceTensorViewFunc::num_symints() const {
  return static_cast<size_t>((start.has_value() ? 1 : 0) + (end.has_value() ? 1 : 0) + 1);
}

void SliceTensorViewFunc::set_symints(std::vector<c10::SymInt> symints) {
  TORCH_INTERNAL_ASSERT(symints.size() == num_symints());
  auto i = 0;
  if(start.has_value()) start = symints[i];
  i += (start.has_value() ? 1 : 0);
  if(end.has_value()) end = symints[i];
  i += (end.has_value() ? 1 : 0);
  step = symints[i];
}

std::vector<at::Tensor> SliceTensorViewFunc::get_tensors() const {
  ::std::vector<at::Tensor> tensors;
  return tensors;
}

size_t SliceTensorViewFunc::num_tensors() const {
  return static_cast<size_t>(0);
}

void SliceTensorViewFunc::set_tensors(std::vector<at::Tensor> tensors) {
  TORCH_INTERNAL_ASSERT(tensors.size() == num_tensors());

}

at::Tensor SliceTensorViewFunc::operator()(const at::Tensor& input_base) const {
  return at::_ops::slice_Tensor::call(input_base, dim, start, end, step);
}

std::unique_ptr<ViewFunc> SliceTensorViewFunc::clone_and_set(
    std::optional<std::vector<c10::SymInt>> symints,
    std::optional<std::vector<at::Tensor>> tensors) const {
  auto output = std::make_unique<SliceTensorViewFunc>(dim, start, end, step);
  if (symints.has_value()) {
    output->set_symints(std::move(*(symints)));
  }
  if (tensors.has_value()) {
    output->set_tensors(std::move(*(tensors)));
  }
  return output;
}
```

The `_view_func()` / `_view_func_unsafe()` methods now accept two additional (optional) args for `symint_visitor_fn` / `tensor_visitor_fn`. If these are defined, they are expected to be python callables that operate on a single SymInt / tensor and return a new one. This allows for the hot-swapping needed during fake-ification.

For testing, there are extensive pre-existing tests, and I added a test to ensure that hot-swapping functions correctly.
```sh
python test/test_autograd.py -k test_view_func_replay
python test/test_ops.py -k test_view_replay
```
Pull Request resolved: https://github.com/pytorch/pytorch/pull/118404
Approved by: https://github.com/ezyang
2024-02-09 18:51:36 +00:00
Jason Ansel
26d29d9639 [Compiled Autograd] Support CopySlices and CopyBackwards (#105809)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/105809
Approved by: https://github.com/albanD
2023-07-28 21:42:51 +00:00
albanD
26054c1607 beef up inplace/view note on copy slices (#89856)
Follow up doc update from https://github.com/pytorch/pytorch/pull/89812
Pull Request resolved: https://github.com/pytorch/pytorch/pull/89856
Approved by: https://github.com/ezyang, https://github.com/soulitzer
2022-11-30 18:35:52 +00:00
Michael Suo
30fb2c4aba [lint] autoformat test/cpp and torch/csrc
Let's have some fun.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/78828

Approved by: https://github.com/ezyang
2022-06-11 21:11:16 +00:00
Peter Bell
b2e79ed5ec Remove WindowsTorchApiMacro.h in favor of Export.h (#69585)
Summary:
Follow up to https://github.com/pytorch/pytorch/issues/68095

This also changes the files from the ATen folder to include c10's `Export.h` instead since they can't ever be exporting `TORCH_PYTHON_API`.

cc pietern mrshenli pritamdamania87 zhaojuanmao satgera rohan-varma gqchen aazzolini osalpekar jiayisuse SciPioneer H-Huang

Pull Request resolved: https://github.com/pytorch/pytorch/pull/69585

Reviewed By: mrshenli

Differential Revision: D32958594

Pulled By: albanD

fbshipit-source-id: 1ec7ef63764573fa2b486928955e3a1172150061
2021-12-09 17:30:09 -08:00
Peter Bell
5407108533 CopyBackward: Remove redundant src_device and unnecessary copy=True (#60025)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/60025

`to` already copies unconditionally if `src.device() != options.device()` so
specifying the copy argument is unnecessary.

`src.device()` is also completely equivalent to `src.options().device()` so
storing both is redundant.

Test Plan: Imported from OSS

Reviewed By: zou3519

Differential Revision: D29698627

Pulled By: albanD

fbshipit-source-id: eb091d39b71db688e6bcbb33a227c01b94b432bb
2021-07-15 09:48:03 -07:00
Erjia Guan
00d432a1ed Remove optional for veiw_fn during View Tracking (#50067)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50067

Fixes #49257

Using the `Callgrind` to test the performance.
```python
import torch
import timeit
from torch.utils.benchmark import Timer

timer = Timer("x.view({100, 5, 20});", setup="torch::Tensor x = torch::ones({10, 10, 100});", language="c++", timer=timeit.default_timer)
res = timer.collect_callgrind(number=10)
```
### Nightly
```python
torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7f7949138c40>
x.view({100, 5, 20});
setup: torch::Tensor x = torch::ones({10, 10, 100});
                           All          Noisy symbols removed
    Instructions:        42310                      42310
    Baseline:                0                          0
10 runs per measurement, 1 thread
Warning: PyTorch was not built with debug symbols.
         Source information may be limited. Rebuild with
         REL_WITH_DEB_INFO=1 for more detailed results.
```
### Current
```python
<torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.CallgrindStats object at 0x7f78f271a580>
x.view({100, 5, 20});
setup: torch::Tensor x = torch::ones({10, 10, 100});
                           All          Noisy symbols removed
    Instructions:        42480                      42480
    Baseline:                0                          0
10 runs per measurement, 1 thread
Warning: PyTorch was not built with debug symbols.
         Source information may be limited. Rebuild with
         REL_WITH_DEB_INFO=1 for more detailed results.
```
### Compare
There are 170 instructions reduced
```python
torch.utils.benchmark.utils.valgrind_wrapper.timer_interface.FunctionCounts object at 0x7f7941b7a7c0>
    970  ???:torch::autograd::as_view(at::Tensor const&, at::Tensor const&, bool, bool, std::function<at::Tensor (at::Tensor const&)>, torch::autograd::CreationMeta, bool)
    240  ???:torch::autograd::ViewInfo::~ViewInfo()
    180  ???:torch::autograd::ViewInfo::ViewInfo(at::Tensor, std::function<at::Tensor (at::Tensor const&)>)
    130  ???:torch::autograd::make_variable_differentiable_view(at::Tensor const&, c10::optional<torch::autograd::ViewInfo>, c10::optional<torch::autograd::ViewInfo>, torch::autograd::CreationMeta, bool)
    105  /tmp/benchmark_utils_jit_build_69e2f1710544485588feeca0719a3a57/timer_cpp_4435526292782672407/timer_src.cpp:main
    100  ???:std::function<at::Tensor (at::Tensor const&)>::function(std::function<at::Tensor (at::Tensor const&)> const&)
     70  ???:torch::autograd::DifferentiableViewMeta::~DifferentiableViewMeta()
     70  ???:torch::autograd::DifferentiableViewMeta::DifferentiableViewMeta(c10::TensorImpl*, c10::optional<torch::autograd::ViewInfo>, c10::optional<torch::autograd::ViewInfo>, torch::autograd::CreationMeta)
   -100  ???:c10::optional_base<torch::autograd::ViewInfo>::optional_base(c10::optional_base<torch::autograd::ViewInfo>&&)
   -105  /tmp/benchmark_utils_jit_build_2e75f38b553e42eba00523a86ad9aa05/timer_cpp_3360771523810516633/timer_src.cpp:main
   -120  ???:torch::autograd::ViewInfo::ViewInfo(at::Tensor, c10::optional<std::function<at::Tensor (at::Tensor const&)> >)
   -210  ???:c10::optional_base<std::function<at::Tensor (at::Tensor const&)> >::~optional_base()
   -240  ???:c10::optional_base<torch::autograd::ViewInfo>::~optional_base()
   -920  ???:torch::autograd::as_view(at::Tensor const&, at::Tensor const&, bool, bool, c10::optional<std::function<at::Tensor (at::Tensor const&)> >, torch::autograd::CreationMeta, bool)
```

Test Plan: Imported from OSS

Reviewed By: albanD

Differential Revision: D25900495

Pulled By: ejguan

fbshipit-source-id: dedd30e69db6b48601a18ae98d6b28faeae30d90
2021-01-15 08:29:28 -08:00
Ailing Zhang
30dd0b74fd Save view_fn for inplace update on view tensors (#36073)
Summary:
This PR enables inplace updates on view Tensors for tensor types(XLA) that doesn't support as_strided.
(See Notes inside PR)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36073

Reviewed By: yf225

Differential Revision: D20994282

Pulled By: ailzhang

fbshipit-source-id: 83eeccb297b242f9822f08ad110a7045d7055639
2020-04-15 20:11:27 -07:00
Edward Yang
7b4eddede9 Delete toType(const DeprecatedTypeProperties&, ...) (#25332)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/25332

This method makes reference to a deprecated class, we now delete it.
This deletion was somewhat involved.  Pre-existing use sites of
toType:

- Tensor::cpu()/cuda()/hip()
- native::type_as
- SummaryOps: toType(CPU(kDouble)) translated into to(kDouble) as weights
  is an input argument and therefore assumed to be on CPU already.  Similar
  for CUDA.
- TensorTransformations: toType(CUDA(kLong)) translated into cuda(), as
  the inputs are actually already the correct dtype, and this translation is just to move them to CUDA
- Adjusted native_test to take TensorOptions instead of
  DeprecatedTypeProperties, killing toType along the way in favor of to
- Some tests for toType with UndefinedType which I just deleted
- CopyBackwards stores TensorOptions now instead of
  DeprecatedTypeProperties
ghstack-source-id: 89177526

Test Plan: sandcastle and ossci

Differential Revision: D17096824

fbshipit-source-id: 964e5a073b9d37594e911d8bca98c9eab5766826
2019-08-29 16:20:18 -07:00
mal
e7a9b0d62f Rename torch::autograd::Function to torch::autograd::Node
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/23269

Test Plan: Imported from OSS

Differential Revision: D16454878

fbshipit-source-id: b1e840fc2d3901955280d141e5ad6efd5e9d66af
2019-07-23 20:52:22 -07:00
Karl Ostmo
8f0603b128 C++ changes toward libtorch and libcaffe2 unification (#19554)
Summary:
* adds TORCH_API and AT_CUDA_API in places
* refactor code generation Python logic to separate
  caffe2/torch outputs
* fix hip and asan
* remove profiler_cuda from hip
* fix gcc warnings for enums
* Fix PythonOp::Kind
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19554

Differential Revision: D15082727

Pulled By: kostmo

fbshipit-source-id: 83a8a99717f025ab44b29608848928d76b3147a4
2019-04-26 01:38:10 -07:00
Roy Li
422b01e788 Replace more usages of Type with DeprecatedTypeProperties (#19093)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/19093
ghimport-source-id: a82e3dce912a173b42a6a7e35eb1302d9f334e03

Differential Revision: D14865520

Pulled By: li-roy

fbshipit-source-id: b1a8bf32f87920ce8d82f990d670477bc79d0ca7
2019-04-11 17:02:05 -07:00
Edward Yang
517c7c9861 Canonicalize all includes in PyTorch. (#14849)
Summary:
Anywhere we used #include "foo.h", we now say #include <foo.h>
Paths are adjusted to be rooted out of aten/src, torch/lib, or
the root level directory.

I modified CMakeLists.txt by hand to remove TH and THC from
the include paths.

I used the following script to do the canonicalization:

```
  import subprocess
  import re
  import os.path

  files = subprocess.check_output(['git', 'ls-files']).decode('utf-8').rstrip().split('\n')
  for fn in files:
      if not any(fn.endswith(suff) for suff in ['.cu', '.cpp', '.in', '.h', '.hpp', '.cu', '.cuh', '.cc']):
          continue
      if not any(fn.startswith(pref) for pref in ["aten/", "torch/"]):
          continue
      with open(fn, 'r') as f:
          c = f.read()
      def fmt(p):
          return "#include <{}>".format(p)
      def repl(m):
          p = m.group(1)
          if p in ["dlfcn.h", "unistd.h", "nvrtc.h", "cuda.h", "cuda_runtime.h", "cstdint", "cudnn.h", "Python.h", "cusparse.h", "cuda_runtime_api.h", "cuda_fp16.h", "cublas_v2.h", "stdint.h", "curand_kernel.h"]:
              return fmt(p)
          if any(p.startswith(pref) for pref in ["torch/csrc", "c10/", "ATen/", "caffe2/", "TH/", "THC/", "Eigen/", "gtest/", "zdl/", "gloo/", "onnx/", "miopen/"]):
              return fmt(p)
          for root in ["aten/src", "torch/lib", ""]:
              for bad_root in [os.path.dirname(fn), "aten/src/TH", "aten/src/THC", "torch/csrc"]:
                  new_p = os.path.relpath(os.path.join(bad_root, p), root)
                  if not new_p.startswith("../") and (os.path.exists(os.path.join(root, new_p)) or os.path.exists(os.path.join(root, new_p + ".in"))):
                      return fmt(new_p)
          print("ERROR: ", fn, p)
          return m.group(0)
      new_c = re.sub(r'#include "([^"]+)"', repl, c)
      if new_c != c:
          print(fn)
          with open(fn, 'w') as f:
              f.write(new_c)
```

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/14849

Reviewed By: dzhulgakov

Differential Revision: D13363445

Pulled By: ezyang

fbshipit-source-id: 52361f878a672785f9306c9e9ab2513128092b68
2018-12-08 19:38:30 -08:00
Edward Yang
e5d56659ec Delete DeviceGuard(int64_t) constructor. (#13232)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13232

DeviceGuard should be device agnostic, which means that it shouldn't
assume that int64_t means select the CUDA device.

Reviewed By: gchanan

Differential Revision: D10858024

fbshipit-source-id: b40e8337e4046906fd8f83a95e6206367fb29dbe
2018-10-31 07:55:11 -07:00
Peter Goldsborough
6071389a90 Enable cppcoreguidelines checks in clang-tidy (#12959)
Summary:
Enables most of `cppcoreguidelines-*` checks for clang-tidy. Major fixes included:

- Uninitialized members,
- Use of `const_cast`,
- Use of raw `new`

ezyang apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/12959

Differential Revision: D11349285

Pulled By: goldsborough

fbshipit-source-id: 9e24d643787dfe7ede69f96223c8c0179bd1b2d6
2018-10-29 18:23:35 -07:00
Tongzhou Wang
46162ccdb9 Autograd indices/values and sparse_coo ctor (#13001)
Summary:
Reopen of #11253 after fixing bug in index_select
Pull Request resolved: https://github.com/pytorch/pytorch/pull/13001

Differential Revision: D10514987

Pulled By: SsnL

fbshipit-source-id: 399a83a1d3246877a3523baf99aaf1ce8066f33f
2018-10-24 10:00:22 -07:00
Yangqing Jia
713e706618 Move exception to C10 (#12354)
Summary:
There are still a few work to be done:

- Move logging and unify AT_WARN with LOG(ERROR).
- A few header files are still being plumbed through, need cleaning.
- caffe2::EnforceNotMet aliasing is not done yet.
- need to unify the macros. See c10/util/Exception.h

This is mainly a codemod and not causing functional changes. If you find your job failing and trace back to this diff, usually it can be fixed by the following approaches:

(1) add //caffe2/c10:c10 to your dependency (or transitive dependency).
(2) change objects such as at::Error, at::Optional to the c10 namespace.
(3) change functions to the c10 namespace. Especially, caffe2::MakeString is not overridden by the unified c10::str function. Nothing else changes.

Please kindly consider not reverting this diff - it involves multiple rounds of rebasing and the fix is usually simple. Contact jiayq@ or AI Platform Dev for details.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/12354

Reviewed By: orionr

Differential Revision: D10238910

Pulled By: Yangqing

fbshipit-source-id: 7794d5bf2797ab0ca6ebaccaa2f7ebbd50ff8f32
2018-10-15 13:33:18 -07:00
Sebastian Messmer
f51f15bb27 Update include paths for ATen/core (#10130)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/10130

Update some include paths to make them internally consistent

Reviewed By: ezyang

Differential Revision: D9119906

fbshipit-source-id: b44e5cab8e8e795ee18afe9ffc6caf1f2b413467
2018-08-03 11:57:02 -07:00
Peter Goldsborough
04939a4745 Match parameter names and = default (#9737)
Summary:
More clang tidy cleanups in `torch/csrc`. This time:

1. `hicpp-use-equals-default` recommends `= default` instead of `{}` for constructors/destructors. This is better practice because it expresses the intent better (https://stackoverflow.com/questions/6502828/what-does-default-mean-after-a-class-function-declaration)
2. `readability-inconsistent-declaration-parameter-name` enforces that parameter names in the declaration match parameter names in the definition. This is just generally useful and can prevent confusion and bugs.

Also updated my script a little bit.

apaszke ezyang
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9737

Differential Revision: D9069069

Pulled By: goldsborough

fbshipit-source-id: f7b3f3a4eb4c9fadc30425a153566d3b613a41ae
2018-07-30 14:10:00 -07:00
Peter Goldsborough
fd25a2a86c Remove virtual+override anti-pattern (#9335)
Summary:
I'm cramming through clang tidy emitted warnings. This PR addresses the `hi-cpp-override` check which warns that `virtual` + `override` is redundant, since `override` already  signifies that a function is overriding and thus virtual.

Where there was `virtual` + `override` I removed the `virtual`, where there was `virtual` and no `override` I removed `virtual` and added `override`.

ezyang apaszke
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9335

Differential Revision: D8807082

Pulled By: goldsborough

fbshipit-source-id: e0a261053f6540a22cc56ec160a24aa285af6319
2018-07-13 11:25:01 -07:00
Mary McBreen
483ae8cb5d Replaces const ref with && for apply (#9175)
Summary:
Addresses https://github.com/pytorch/pytorch/issues/5011
Tested with python test/test_autograd.py
Pull Request resolved: https://github.com/pytorch/pytorch/pull/9175

Reviewed By: zdevito

Differential Revision: D8736377

Pulled By: marymcbreen

fbshipit-source-id: ff86f427f7b2cf0cab5912e7f32812bd0f49a712
2018-07-12 08:31:59 -07:00
Peter Goldsborough
372d1d6735
Create ATen tensors via TensorOptions (#7869)
* Created TensorOptions

Storing the type in TensorOptions to solve the Variable problem

Created convenience creation functions for TensorOptions and added tests

Converted zeros to TensorOptions

Converted rand to TensorOptions

Fix codegen for TensorOptions and multiple arguments

Put TensorOptions convenience functions into torch namespace too

All factory functions except *_like support TensorOptions

Integrated with recent JIT changes

Support *_like functions

Fix in place modification

Some cleanups and fixes

Support sparse_coo_tensor

Fix bug in Type.cpp

Fix .empty calls in C++ API

Fix bug in Type.cpp

Trying to fix device placement

Make AutoGPU CPU compatible

Remove some auto_gpu.h uses

Fixing some headers

Fix some remaining CUDA/AutoGPU issues

Fix some AutoGPU uses

Fixes to dispatch_tensor_conversion

Reset version of new variables to zero

Implemented parsing device strings

Random fixes to tests

Self review cleanups

flake8

Undo changes to variable.{h,cpp} because they fail on gcc7.2

Add [cuda] tag to tensor_options_cuda.cpp

Move AutoGPU::set_index_from into .cpp file because Windows is stupid and sucks

Fix linker error in AutoGPU.cpp

Fix bad merge conflict in native_functions.yaml

Fixed caffe2/contrib/aten

Fix new window functions added to TensorFactories.cpp

* Removed torch::TensorOptions

Added code to generate wrapper functions for factory methods

Add implicit constructor from Backend to TensorOptions

Remove Var() from C++ API and use torch:: functions

Use torch:: functions more subtly in C++ API

Make AutoGPU::set_device more exception safe

Check status directly in DynamicCUDAHooksInterface

Rename AutoGPU to DeviceGuard

Removed set_requires_grad from python_variables.h and warn appropriately in Variable::set_requires_grad

remove python_default_init: self.type()

Add back original factory functions, but with deprecation warnings

Disable DeviceGuard for a couple functions in ATen

Remove print statement

Fix DeviceGuard construction from undefined tensor

Fixing CUDA device compiler issues

Moved as many methods as possible into header files

Dont generate python functions for deprecated factories

Remove merge conflict artefact

Fix tensor_options_cuda.cpp

Fix set_requires_grad not being checked

Fix tensor_new.h

TEMPORARILY put some methods in .cpp files to see if it solves issues on windows and mac

Fix bug in DeviceGuard.h

Missing includes

TEMPORARILY moving a few more methods into .cpp to see if it fixes windows

Fixing linker errors

* Fix up SummaryOps to use new factories

Undo device agnostic behavior of DeviceGuard

Use -1 instead of optional for default device index

Also move DeviceGuard methods into header

Fixes around device index after optional -> int32_t switch

Fix use of DeviceGuard in new_with_tensor_copy

Fix tensor_options.cpp

* Fix Type::copy(

* Remove test_non_float_params from ONNX tests

* Set requires_grad=False in ONNX tests that use ints

* Put layout/dtype/device on Tensor

* Post merge fixes

* Change behavior of DeviceGuard to match AutoGPU

* Fix C++ API integration tests

* Fix flip functions
2018-06-16 00:40:35 -07:00
Luca Antiga
396637cdd6 Python-free build of autograd + jit (#5356)
This PR adds the possibility to build the C++ parts of autograd and jit, with no dependency on Python.
The goal is to allow taking a PyTorch IR representation (a tree s-expr) and running it with provided inputs.

Prerequisite: build PyTorch so that codegen runs once.
Instructions:

cd tools/cpp_build
bash build_all.sh
This will build libtorchjit and torchjit_test in tools/cpp_build/build/torchjit-build. The latter basically runs the code in test_jit.cpp for now.

While writing the PR, it turned out that a few of Python.h includes were redundant. They were removed here (PyTorch tests still pass on my machine, we'll see CI).

* Introduce Python-free builds of autograd and jit

* Remove NO_PYTHON ifdef in functions/special
2018-03-08 15:13:10 -05:00
Peter Goldsborough
702a7f3864 Improve Function interface (#5221)
* Improve Function interface

* Undo tracer changes

* Fix bug in VariableType.set_history

* Rename function_counter and sequence_number to sequence_nr

* Clarify Function documentation

* Replace swap_next_edges with next_edges() getter

* Bring back set_gradient_edge

* Simplify special.cpp

* add_gradient_edge -> create_gradient_edge

* Add mutable getters for pre/post hooks

* Use make_variable with Edge

* Remove remove_gradient_edge in favor of detach_

* Fix documentation and remove create_gradient_edge friend method

* Canonicalize some includes
2018-02-21 16:37:52 -05:00
Edward Z. Yang
dc76db349e Delete a pile of dead code (#4295)
* Delete obsolete basic ops.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* More deletion.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Delete some unused utilities.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Delete dead apply_fn

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Delete CppFunction symbolic support.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Delete ForwardFunction

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

* Batchnorm is 'working'

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2018-01-04 09:21:54 -05:00
Edward Z. Yang
1c0fbd27a1
CuDNN bindings rewrite (into ATen) (#3666)
* Comprehensive rewrite of Torch CuDNN bindings / a bit of ATen infra

The executive summary is that this moves the torch/csrc/cudnn
library into ATen, adding a number of new cudnn_ methods to ATen
for batchnorm, convolution, affine grid generator and grid sampler.

ATen infra changes:

- TensorGeometry was moved to ATen
- TensorGeometry was modified to make its interface resemble that of
  Tensor; in particular, sizes is no longer a field, it's a method.
- AT_CUDA_ENABLED macro is set via ATen/Config.h header which is
  generated at cmake configure time.
  Fixes https://github.com/zdevito/ATen/issues/168
- Change AT_CUDA_ENABLED macro to be a function macro, so that we
  error if it is not defined
- Introduce a new TensorArg class, which is a Tensor plus a little
  metadata.  This helps us give good error messages when checking
  dimensions/shapes of tensors.
  Fixes https://github.com/zdevito/ATen/issues/169
- Also introduce a TensorGeometryArg class, for when you don't
  need the actual tensor data (which is most of the time.)
- Add ATen/Check.h, which contains a number of utility functions
  for testing shapes, types and devices of input tensors.  This
  will be particulary useful for native methods, which don't get
  code generated input testing code.  These functions take a
  'CheckedFrom' argument, at the moment just a string, which
  specifies some extra information about what function was
  doing the actual checking; this greatly improves error messages.
    - Many check functions take initializer lists, which let you
      test that all tensors have some property.  This API is
      peculiar, in that we IGNORE undefined tensors in this case.
      This is handled by filterDefined.
- Add AT_CUDNN_ENABLED macro
- CuDNN linking from ATen was improved; for example, we now actually
  add the CuDNN headers to our include path.
- Add some missing override specifiers to some methods
- We now actually build tests with CUDA functionality accessible
  (previously, AT_CUDA_ENABLED was not defined, meaning that
  the headers were missing all CUDA-only functionality.)
- Native functions now support giving explicit names to return
  outputs in yaml.  This makes it possible to hook into the NN
  autogenerated derivatives codepath using native functions.

CuDNN rewrite changes:

- torch/csrc/cudnn now uses ATen (rather than passing around
  THVoidTensor) and lives in ATen.  This lets us remove tensorPointer
  shenanigans.  The functions are exposed to ATen as native functions
  described in aten/src/ATen/cudnn/cuDNN.yaml
- ATen now builds and links against CuDNN when enabled.  The cmake
  package script was taken from Caffe2.
- Some header reorganization was done to help reduce dependencies
  on headers (this reorg is no longer used but I've kept it)
- Rename CHECK to CUDNN_CHECK
- Rip out old shape/type testing code in favor of modern ATen/Check.h
  interface using TensorArg.  In many cases, increase the robustness of
  the checking code.
- Change the inputs of the public facing functions, so that they can
  be bound by ATen
  - Delete THCState*; this is retrieved from the global ATen context
  - Delete cudnnHandle_t, this is retrieved from the global Handles.h
  - Delete cudnnDataType_t, this is retrieved from the Tensor type
  - Delete Convolution class, instead its constituent arguments are
    passed individually
- Change functions to return tensors, rather than take an appropriately
  sized output tensor as an input.
- Redo how transposed convolution / backward convolution is implemented
  (knock on effect of returning tensors).  Previously it was assumed
  that you would always pass an appropriately sized output tensor, but
  we don't want to do this anymore.  For backwards, we instead give
  the desired output tensor (input, really) size, because that is
  readily available.  For *transposed* convolution, however, we take
  output_padding, and otherwise do the shape calculation.
- Redo how legacy group convolution is implemented (knock on effect from
  porting cudnn to ATen.)  Previously, group convolution was implemented
  by manually constructing sizes and strides and then outputting
  appropriate, with macros switching between individual groups and
  all-at-once based on CuDNN version.  Now, the code looks exactly what
  you'd expect: there's a top-level wrapping function that supports
  group convolution no matter the version of CuDNN, and a low-level
  wrapper which supports only what CuDNN supports.  The top-level
  function conditions on CuDNN version, and invokes the low-level
  interface 1 or n times.
- There is now a debugging printer for tensor descriptors.
- Convolution struct is replaced with ConvolutionArgs, which is not
  part of the public API but is used internally to conveniently
  pass around all of the arguments needed for Convolution.
- Add some constexprs for well-known dimensions, reduce amount of
  magic numbers in code.
- Put 'deterministic' in to ConvParams.  Fixes #3659
- Lots more comments.
- Some pessimizations, in the name of code clarity:
  - The descriptors are initialized on every invocation of convolution
    forward/backward.  Previously, the descriptors were cached, so that
    you didn't have to initialize them again on backwards.  This is
    difficult to support in the ATen interface so I didn't support it.
  - Legacy group convolution initializes its workspace for *every* group
    it performs.  I did not feel motivated to fix this because the
    legacy codepath is already quite slow.
- Affine grid generator and grid sampler automatically call contiguous
  on their arguments as necessary.
- Batchnorm input checking is greatly beefed up, it now checks for
  the following input characteristics:
    - Definedness
    - GPU location
    - Type
    - Contiguity
    - Size

PyTorch binding code changes

- batchnorm now uses consistent var/data naming
- batchnorm and convolution make use of new ATen bindings
- Affine grid generator and grid sampler make use of ATen CuDNN
  bindings via derivatives.yaml.  This means I had to restructure
  the code a little, since the THNN bindings still go through
  a legacy Python class.
- I fixed some warnings:
  - s/friend class/friend struct/ on InterpreterStateImpl
  - Removed pessimizing move 'detached' in torch/csrc/autograd/variable.cpp
  - Removed unused pack_list on Scalar

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

GCC 4.8 buildfix

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Add TensorGeometry to ATen.h

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

CUDNN_CHECK

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Update TODO comment

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Delete return in cudnn_grid_sampler

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

s/cudnnSetStreamToCurrent/setCuDNNStreamToCurrent/g

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Don't allocate a new vector when filtering defined.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Remove Check overloads, convert to pass references.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>

Some more microbenchmarking.

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-11-30 23:06:58 -05:00
gchanan
9c498aa523
Implement Variable cpu() as an ATen method. (#3802) 2017-11-22 11:25:52 -05:00
gchanan
ee08120b46
Move Variable conversion methods to ATen. (#3762)
* Move Variable conversion methods to ATen.

* Add a test to ensure type conversions work through backwards.

* Fix VariableType copy for type conversions.

* Add comment about needing to handle device movement.

* Move back to opposite order for copy function params -- inplace views depend on it.

* Use is_available() rather than is_available.
2017-11-20 13:28:08 -05:00
Sam Gross
b09d66e60d
Fix a reference cycle when in-place ops on views save the output (#3679)
Previously, an in-place operation that saves its output (such as
relu/threshold) would create a reference cycle when applied to the a
view. There were two cycles created:

1) The cycle base.grad_fn.fn.input_.base
   base.grad_fn is a CopySlices
   base.grad_fn.fn is ThresholdBackward
   base.grad_fn.fn.input_ is a SavedVariable with base pointing to base

2) The cycle base.grad_fn.fn.input_.grad_fn.next_functions[0]
   base.grad_fn.fn.input_.grad_fn is AsStridedBackward
   and next_functions[0] points to base.grad_fn

Generally, we avoid cycles because the AD graph is mostly immutable. Two
notable exceptions are:

a) Variable.grad_fn can change to point to a new grad_fn
b) SavedVariables in a function can be set after the function is created

The first case is not a problem if grad_fns do not hold strong references
to Variables. Removing "base" from SavedVariable removes the strong ref.

For the second case, we need to avoid saving the grad_fn of outputs. We
were incorrectly saving the grad_fns of outputs when they were the
result of in-place ops on views.
2017-11-15 15:19:41 -05:00
Sam Gross
fde355f7d4
Allow in-place operations on views (#3384)
Allow in-place operations on views

Adds VariableViewImpl, a subclass of VariableImpl which has a pointer to
the base Variable on which it is a view. In-place operations on views
change the grad_fn of the base.

Note that in-place operations only work on views that are the first output of the function that created them. All C++/ATen implemented functions have this behavior, but it's possible to write Python-implemented autograd functions that do not. In-place operations on these view will raise an exception.

Fixes #3313
2017-11-06 18:19:56 -05:00
Adam Paszke
594f98ce16 Support multi-stage AutogradClosures 2017-09-05 17:48:55 -04:00
Adam Paszke
fa308b3183 Improve backward tracing 2017-09-05 17:48:55 -04:00
Zachary DeVito
43c944acbd Remove dead THPP code that has been replaced with ATen objects. (#2235)
THPP usage is now isolated in THD.
2017-07-29 08:07:41 +05:30
Trevor Killeen
c304d04fc6 Replace thpp::Tensor with ATen Tensor in autograd csrc (#2170) 2017-07-28 10:18:37 -04:00
albanD
c888857461 Conv double backward groups (#1993)
* add support for groups in double backward

* add tests for group in double backward

* fix lint

* separate some tests to reduce number of test cases

* remove redundant testing for different number of output channels
2017-07-13 00:41:14 -04:00
albanD
6cdcd9c603 Add Narrow function
clean error message and support non perfectly sized inputs
2017-06-17 11:11:48 -04:00
albanD
2f8d21a7f2 add contiguous function 2017-06-17 11:11:48 -04:00
albanD
462ab8a644 add Transpose View Expand C functions 2017-06-17 11:11:48 -04:00
Edward Z. Yang
3ada9da808 Make csrc -Werror clean. (#1795)
Primary things I had to fix:

- Suppress _XOPEN_SOURCE warnings by ensuring that Python.h is included
  first, because it always unconditionally defines this macro.

- Turn off strict aliasing, because Python 2 doesn't work with strict
  aliasing.

- Workaround setuptools bug, where it's incorrectly passing
  -Wstrict-prototypes to C++ compilers (where this doesn't make
  any sense)

To compile csrc with -Werror, run `CFLAGS="-Werror" python setup.py build_ext`

Signed-off-by: Edward Z. Yang <ezyang@fb.com>
2017-06-13 20:18:09 -04:00
Adam Paszke
702a2e3bc5 Make Variables not subclass Function anymore
Because of this Variables can no longer appear in the graph.
Every usage of a leaf Variable will leave an AccumulateGrad
function that has no outputs, but modifies var.grad as a side
effect.
2017-05-01 16:44:56 -04:00