pytorch/docs/source/torch.rst
xiaobing.zhang 9ba6a768de Add op bitwise_or (#31559)
Summary:
ezyang ,  this PR add bitwise_or operator as https://github.com/pytorch/pytorch/pull/31104 .
Benchmark script :
```
import timeit
import torch
torch.manual_seed(1)

for n, t in [(10, 100000),(1000, 10000)]:
    print('__or__ (a.numel() == {}) for {} times'.format(n, t))
    for device in ('cpu', 'cuda'):
        for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64'):
            print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t')
            print(timeit.timeit(f'a | b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}"); b = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}")', number=t))

for n, t in [(10, 100000),(1000, 10000)]:
    print('__ior__ (a.numel() == {}) for {} times'.format(n, t))
    for device in ('cpu', 'cuda'):
        for dtype in ('torch.int8', 'torch.uint8', 'torch.int16', 'torch.int32', 'torch.int64'):
            print(f'device: {device}, dtype: {dtype}, {t} times', end='\t\t')
            print(timeit.timeit(f'a | b\nif "{device}" == "cuda": torch.cuda.synchronize()', setup=f'import torch; a = torch.randint(0, 10, ({n},), dtype = {dtype}, device="{device}"); b = torch.tensor(5, dtype = {dtype}, device="{device}")', number=t))
```
Device: **Tesla P100, skx-8180**
Cuda verison: **9.0.176**

Before:
```
__or__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times            0.17616272252053022
device: cpu, dtype: torch.uint8, 100000 times           0.17148233391344547
device: cpu, dtype: torch.int16, 100000 times           0.17616403382271528
device: cpu, dtype: torch.int32, 100000 times           0.17717823758721352
device: cpu, dtype: torch.int64, 100000 times           0.1801931718364358
device: cuda, dtype: torch.int8, 100000 times           1.270583058707416
device: cuda, dtype: torch.uint8, 100000 times          1.2636413089931011
device: cuda, dtype: torch.int16, 100000 times          1.2839747751131654
device: cuda, dtype: torch.int32, 100000 times          1.2548385225236416
device: cuda, dtype: torch.int64, 100000 times          1.2650810535997152
__or__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times             0.031136621721088886
device: cpu, dtype: torch.uint8, 10000 times            0.030786747112870216
device: cpu, dtype: torch.int16, 10000 times            0.02391665056347847
device: cpu, dtype: torch.int32, 10000 times            0.024147341027855873
device: cpu, dtype: torch.int64, 10000 times            0.024414129555225372
device: cuda, dtype: torch.int8, 10000 times            0.12741921469569206
device: cuda, dtype: torch.uint8, 10000 times           0.1249831635504961
device: cuda, dtype: torch.int16, 10000 times           0.1283819805830717
device: cuda, dtype: torch.int32, 10000 times           0.12591975275427103
device: cuda, dtype: torch.int64, 10000 times           0.12655890546739101
__ior__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times            0.3908365070819855
device: cpu, dtype: torch.uint8, 100000 times           0.38267823681235313
device: cpu, dtype: torch.int16, 100000 times           0.38239253498613834
device: cpu, dtype: torch.int32, 100000 times           0.3817988149821758
device: cpu, dtype: torch.int64, 100000 times           0.3901665909215808
device: cuda, dtype: torch.int8, 100000 times           1.4211318120360374
device: cuda, dtype: torch.uint8, 100000 times          1.4215159295126796
device: cuda, dtype: torch.int16, 100000 times          1.4307750314474106
device: cuda, dtype: torch.int32, 100000 times          1.4123614141717553
device: cuda, dtype: torch.int64, 100000 times          1.4480243818834424
__ior__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times             0.06468924414366484
device: cpu, dtype: torch.uint8, 10000 times            0.06442475505173206
device: cpu, dtype: torch.int16, 10000 times            0.05267547257244587
device: cpu, dtype: torch.int32, 10000 times            0.05286940559744835
device: cpu, dtype: torch.int64, 10000 times            0.06211103219538927
device: cuda, dtype: torch.int8, 10000 times            0.15332304500043392
device: cuda, dtype: torch.uint8, 10000 times           0.15353196952492
device: cuda, dtype: torch.int16, 10000 times           0.15300503931939602
device: cuda, dtype: torch.int32, 10000 times           0.15274472255259752
device: cuda, dtype: torch.int64, 10000 times           0.1512152962386608
```
After:
```
__or__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times            0.2465507509186864
device: cpu, dtype: torch.uint8, 100000 times           0.2472386620938778
device: cpu, dtype: torch.int16, 100000 times           0.2469814233481884
device: cpu, dtype: torch.int32, 100000 times           0.2535214088857174
device: cpu, dtype: torch.int64, 100000 times           0.24855613708496094
device: cuda, dtype: torch.int8, 100000 times           1.4351346511393785
device: cuda, dtype: torch.uint8, 100000 times          1.4434308474883437
device: cuda, dtype: torch.int16, 100000 times          1.4520929995924234
device: cuda, dtype: torch.int32, 100000 times          1.4456610176712275
device: cuda, dtype: torch.int64, 100000 times          1.4580101007595658
__or__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times             0.029985425993800163
device: cpu, dtype: torch.uint8, 10000 times            0.03024935908615589
device: cpu, dtype: torch.int16, 10000 times            0.026356655173003674
device: cpu, dtype: torch.int32, 10000 times            0.027377349324524403
device: cpu, dtype: torch.int64, 10000 times            0.029163731262087822
device: cuda, dtype: torch.int8, 10000 times            0.14540370367467403
device: cuda, dtype: torch.uint8, 10000 times           0.1456305105239153
device: cuda, dtype: torch.int16, 10000 times           0.1450125053524971
device: cuda, dtype: torch.int32, 10000 times           0.1472016740590334
device: cuda, dtype: torch.int64, 10000 times           0.14709716010838747
__ior__ (a.numel() == 10) for 100000 times
device: cpu, dtype: torch.int8, 100000 times            0.27195510920137167
device: cpu, dtype: torch.uint8, 100000 times           0.2692424338310957
device: cpu, dtype: torch.int16, 100000 times           0.27726674638688564
device: cpu, dtype: torch.int32, 100000 times           0.2815811652690172
device: cpu, dtype: torch.int64, 100000 times           0.2852728571742773
device: cuda, dtype: torch.int8, 100000 times           1.4743850827217102
device: cuda, dtype: torch.uint8, 100000 times          1.4766502184793353
device: cuda, dtype: torch.int16, 100000 times          1.4774163831025362
device: cuda, dtype: torch.int32, 100000 times          1.4749693805351853
device: cuda, dtype: torch.int64, 100000 times          1.5772947426885366
__ior__ (a.numel() == 1000) for 10000 times
device: cpu, dtype: torch.int8, 10000 times             0.03614502027630806
device: cpu, dtype: torch.uint8, 10000 times            0.03619729354977608
device: cpu, dtype: torch.int16, 10000 times            0.0319912089034915
device: cpu, dtype: torch.int32, 10000 times            0.03319283854216337
device: cpu, dtype: torch.int64, 10000 times            0.0343862259760499
device: cuda, dtype: torch.int8, 10000 times            0.1581476852297783
device: cuda, dtype: torch.uint8, 10000 times           0.15974601730704308
device: cuda, dtype: torch.int16, 10000 times           0.15957212820649147
device: cuda, dtype: torch.int32, 10000 times           0.16002820804715157
device: cuda, dtype: torch.int64, 10000 times           0.16129320487380028
```

Fix  https://github.com/pytorch/pytorch/issues/24511, https://github.com/pytorch/pytorch/issues/24515, https://github.com/pytorch/pytorch/issues/24658, https://github.com/pytorch/pytorch/issues/24662.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31559

Differential Revision: D19315875

Pulled By: ezyang

fbshipit-source-id: 4a3ca88fdafbeb796079687e676228111eb44aad
2020-01-08 15:06:30 -08:00

383 lines
9.8 KiB
ReStructuredText

torch
===================================
.. automodule:: torch
Tensors
----------------------------------
.. autofunction:: is_tensor
.. autofunction:: is_storage
.. autofunction:: is_floating_point
.. autofunction:: set_default_dtype
.. autofunction:: get_default_dtype
.. autofunction:: set_default_tensor_type
.. autofunction:: numel
.. autofunction:: set_printoptions
.. autofunction:: set_flush_denormal
.. _tensor-creation-ops:
Creation Ops
~~~~~~~~~~~~~~~~~~~~~~
.. note::
Random sampling creation ops are listed under :ref:`random-sampling` and
include:
:func:`torch.rand`
:func:`torch.rand_like`
:func:`torch.randn`
:func:`torch.randn_like`
:func:`torch.randint`
:func:`torch.randint_like`
:func:`torch.randperm`
You may also use :func:`torch.empty` with the :ref:`inplace-random-sampling`
methods to create :class:`torch.Tensor` s with values sampled from a broader
range of distributions.
.. autofunction:: tensor
.. autofunction:: sparse_coo_tensor
.. autofunction:: as_tensor
.. autofunction:: as_strided
.. autofunction:: from_numpy
.. autofunction:: zeros
.. autofunction:: zeros_like
.. autofunction:: ones
.. autofunction:: ones_like
.. autofunction:: arange
.. autofunction:: range
.. autofunction:: linspace
.. autofunction:: logspace
.. autofunction:: eye
.. autofunction:: empty
.. autofunction:: empty_like
.. autofunction:: empty_strided
.. autofunction:: full
.. autofunction:: full_like
.. autofunction:: quantize_per_tensor
.. autofunction:: quantize_per_channel
Indexing, Slicing, Joining, Mutating Ops
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: cat
.. autofunction:: chunk
.. autofunction:: gather
.. autofunction:: index_select
.. autofunction:: masked_select
.. autofunction:: narrow
.. autofunction:: nonzero
.. autofunction:: reshape
.. autofunction:: split
.. autofunction:: squeeze
.. autofunction:: stack
.. autofunction:: t
.. autofunction:: take
.. autofunction:: transpose
.. autofunction:: unbind
.. autofunction:: unsqueeze
.. autofunction:: where
.. _generators:
Generators
----------------------------------
.. autoclass:: torch._C.Generator
:members:
.. _random-sampling:
Random sampling
----------------------------------
.. autofunction:: seed
.. autofunction:: manual_seed
.. autofunction:: initial_seed
.. autofunction:: get_rng_state
.. autofunction:: set_rng_state
.. autoattribute:: torch.default_generator
:annotation: Returns the default CPU torch.Generator
.. The following doesn't actually seem to exist.
https://github.com/pytorch/pytorch/issues/27780
.. autoattribute:: torch.cuda.default_generators
:annotation: If cuda is available, returns a tuple of default CUDA torch.Generator-s.
The number of CUDA torch.Generator-s returned is equal to the number of
GPUs available in the system.
.. autofunction:: bernoulli
.. autofunction:: multinomial
.. autofunction:: normal
.. autofunction:: poisson
.. autofunction:: rand
.. autofunction:: rand_like
.. autofunction:: randint
.. autofunction:: randint_like
.. autofunction:: randn
.. autofunction:: randn_like
.. autofunction:: randperm
.. _inplace-random-sampling:
In-place random sampling
~~~~~~~~~~~~~~~~~~~~~~~~
There are a few more in-place random sampling functions defined on Tensors as well. Click through to refer to their documentation:
- :func:`torch.Tensor.bernoulli_` - in-place version of :func:`torch.bernoulli`
- :func:`torch.Tensor.cauchy_` - numbers drawn from the Cauchy distribution
- :func:`torch.Tensor.exponential_` - numbers drawn from the exponential distribution
- :func:`torch.Tensor.geometric_` - elements drawn from the geometric distribution
- :func:`torch.Tensor.log_normal_` - samples from the log-normal distribution
- :func:`torch.Tensor.normal_` - in-place version of :func:`torch.normal`
- :func:`torch.Tensor.random_` - numbers sampled from the discrete uniform distribution
- :func:`torch.Tensor.uniform_` - numbers sampled from the continuous uniform distribution
Quasi-random sampling
~~~~~~~~~~~~~~~~~~~~~
.. autoclass:: torch.quasirandom.SobolEngine
:members:
:exclude-members: MAXBIT, MAXDIM
:undoc-members:
Serialization
----------------------------------
.. autofunction:: save
.. autofunction:: load
Parallelism
----------------------------------
.. autofunction:: get_num_threads
.. autofunction:: set_num_threads
.. autofunction:: get_num_interop_threads
.. autofunction:: set_num_interop_threads
Locally disabling gradient computation
--------------------------------------
The context managers :func:`torch.no_grad`, :func:`torch.enable_grad`, and
:func:`torch.set_grad_enabled` are helpful for locally disabling and enabling
gradient computation. See :ref:`locally-disable-grad` for more details on
their usage. These context managers are thread local, so they won't
work if you send work to another thread using the ``threading`` module, etc.
Examples::
>>> x = torch.zeros(1, requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
False
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
... y = x * 2
>>> y.requires_grad
False
>>> torch.set_grad_enabled(True) # this can also be used as a function
>>> y = x * 2
>>> y.requires_grad
True
>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
False
Math operations
----------------------------------
Pointwise Ops
~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: abs
.. autofunction:: acos
.. autofunction:: add
.. autofunction:: addcdiv
.. autofunction:: addcmul
.. autofunction:: angle
.. autofunction:: asin
.. autofunction:: atan
.. autofunction:: atan2
.. autofunction:: bitwise_not
.. autofunction:: bitwise_and
.. autofunction:: bitwise_or
.. autofunction:: bitwise_xor
.. autofunction:: ceil
.. autofunction:: clamp
.. autofunction:: conj
.. autofunction:: cos
.. autofunction:: cosh
.. autofunction:: div
.. autofunction:: digamma
.. autofunction:: erf
.. autofunction:: erfc
.. autofunction:: erfinv
.. autofunction:: exp
.. autofunction:: expm1
.. autofunction:: floor
.. autofunction:: floor_divide
.. autofunction:: fmod
.. autofunction:: frac
.. autofunction:: imag
.. autofunction:: lerp
.. autofunction:: lgamma
.. autofunction:: log
.. autofunction:: log10
.. autofunction:: log1p
.. autofunction:: log2
.. autofunction:: logical_and
.. autofunction:: logical_not
.. autofunction:: logical_or
.. autofunction:: logical_xor
.. autofunction:: mul
.. autofunction:: mvlgamma
.. autofunction:: neg
.. autofunction:: polygamma
.. autofunction:: pow
.. autofunction:: real
.. autofunction:: reciprocal
.. autofunction:: remainder
.. autofunction:: round
.. autofunction:: rsqrt
.. autofunction:: sigmoid
.. autofunction:: sign
.. autofunction:: sin
.. autofunction:: sinh
.. autofunction:: sqrt
.. autofunction:: square
.. autofunction:: tan
.. autofunction:: tanh
.. autofunction:: trunc
Reduction Ops
~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: argmax
.. autofunction:: argmin
.. autofunction:: dist
.. autofunction:: logsumexp
.. autofunction:: mean
.. autofunction:: median
.. autofunction:: mode
.. autofunction:: norm
.. autofunction:: prod
.. autofunction:: std
.. autofunction:: std_mean
.. autofunction:: sum
.. autofunction:: unique
.. autofunction:: unique_consecutive
.. autofunction:: var
.. autofunction:: var_mean
Comparison Ops
~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: allclose
.. autofunction:: argsort
.. autofunction:: eq
.. autofunction:: equal
.. autofunction:: ge
.. autofunction:: gt
.. autofunction:: isfinite
.. autofunction:: isinf
.. autofunction:: isnan
.. autofunction:: kthvalue
.. autofunction:: le
.. autofunction:: lt
.. autofunction:: max
.. autofunction:: min
.. autofunction:: ne
.. autofunction:: sort
.. autofunction:: topk
Spectral Ops
~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: fft
.. autofunction:: ifft
.. autofunction:: rfft
.. autofunction:: irfft
.. autofunction:: stft
.. autofunction:: bartlett_window
.. autofunction:: blackman_window
.. autofunction:: hamming_window
.. autofunction:: hann_window
Other Operations
~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: bincount
.. autofunction:: broadcast_tensors
.. autofunction:: cartesian_prod
.. autofunction:: cdist
.. autofunction:: combinations
.. autofunction:: cross
.. autofunction:: cumprod
.. autofunction:: cumsum
.. autofunction:: diag
.. autofunction:: diag_embed
.. autofunction:: diagflat
.. autofunction:: diagonal
.. autofunction:: einsum
.. autofunction:: flatten
.. autofunction:: flip
.. autofunction:: rot90
.. autofunction:: histc
.. autofunction:: meshgrid
.. autofunction:: renorm
.. autofunction:: repeat_interleave
.. autofunction:: roll
.. autofunction:: tensordot
.. autofunction:: trace
.. autofunction:: tril
.. autofunction:: tril_indices
.. autofunction:: triu
.. autofunction:: triu_indices
BLAS and LAPACK Operations
~~~~~~~~~~~~~~~~~~~~~~~~~~~
.. autofunction:: addbmm
.. autofunction:: addmm
.. autofunction:: addmv
.. autofunction:: addr
.. autofunction:: baddbmm
.. autofunction:: bmm
.. autofunction:: chain_matmul
.. autofunction:: cholesky
.. autofunction:: cholesky_inverse
.. autofunction:: cholesky_solve
.. autofunction:: dot
.. autofunction:: eig
.. autofunction:: geqrf
.. autofunction:: ger
.. autofunction:: inverse
.. autofunction:: det
.. autofunction:: logdet
.. autofunction:: slogdet
.. autofunction:: lstsq
.. autofunction:: lu
.. autofunction:: lu_solve
.. autofunction:: lu_unpack
.. autofunction:: matmul
.. autofunction:: matrix_power
.. autofunction:: matrix_rank
.. autofunction:: mm
.. autofunction:: mv
.. autofunction:: orgqr
.. autofunction:: ormqr
.. autofunction:: pinverse
.. autofunction:: qr
.. autofunction:: solve
.. autofunction:: svd
.. autofunction:: symeig
.. autofunction:: trapz
.. autofunction:: triangular_solve
Utilities
----------------------------------
.. autofunction:: compiled_with_cxx11_abi
.. autofunction:: result_type
.. autofunction:: can_cast
.. autofunction:: promote_types