pytorch/torch/csrc/jit/codegen/cuda
Ryan Spring f499ab9cef Implement Tanh Gelu Approximation (#61439)
Summary:
1. Implements https://github.com/pytorch/pytorch/issues/39853
2. Adds approximate boolean flag to Gelu
3. Enables Tanh Gelu approximation
4. Adds double backward support for Gelu
5. Enable Tanh Gelu in NvFuser

```
def gelu(x, approximate : str = 'none'):
    if approximate == 'tanh':
        # sqrt(2/pi) = 0.7978845608028654
        return 0.5 * x * (1.0 + torch.tanh(0.7978845608028654 * (x + 0.044715 * torch.pow(x, 3.0))))
    else:
        return x * normcdf(x)
```

Linking XLA PR - https://github.com/pytorch/xla/pull/3039

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61439

Reviewed By: mikaylagawarecki

Differential Revision: D33744717

Pulled By: jbschlosser

fbshipit-source-id: d64532a562ed53247bb4fa52bb16722634d5c187
(cherry picked from commit 4713dd9cca)
2022-01-28 16:59:09 +00:00
..
docs
ops Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
runtime Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
scheduler [warning] Fix TORCH_INTERNAL_ASSERT calls missing condition to check 1/x (#71711) 2022-01-25 15:45:21 +00:00
tools
arith.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
arith.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
codegen.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
codegen.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
compute_at_map.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
compute_at_map.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
compute_at.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
compute_at.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
disjoint_set.h
dispatch.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
dispatch.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
evaluator_common.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
evaluator_common.h Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
executor_kernel_arg.cpp [pytorch][aten][cuda] move CUDAGeneratorImpl.h to ATen/cuda (#70650) 2022-01-10 22:27:04 -08:00
executor_kernel_arg.h [pytorch][aten][cuda] move CUDAGeneratorImpl.h to ATen/cuda (#70650) 2022-01-10 22:27:04 -08:00
executor_launch_params.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
executor_launch_params.h
executor_utils.cpp [pytorch][aten][cuda] move CUDAGeneratorImpl.h to ATen/cuda (#70650) 2022-01-10 22:27:04 -08:00
executor_utils.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
executor.cpp Codegen: TraceType only includes operators being registered (#68691) 2022-01-02 13:09:19 -08:00
executor.h Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
expr_evaluator.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
expr_evaluator.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
fusion_segmenter.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
fusion_segmenter.h Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
fusion.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
fusion.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
graph_fuser.cpp Allow single node fusion for nvfuser (#70000) 2021-12-23 17:07:57 -08:00
index_compute.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
index_compute.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
index_reference_replay.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
index_reference_replay.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
instrumentation.cpp Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
instrumentation.h Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
interface.cpp Allow single node fusion for nvfuser (#70000) 2021-12-23 17:07:57 -08:00
interface.h Allow single node fusion for nvfuser (#70000) 2021-12-23 17:07:57 -08:00
ir_all_nodes.h
ir_base_nodes.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
ir_base_nodes.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
ir_cloner.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
ir_cloner.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
ir_graphviz.cpp
ir_graphviz.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
ir_interface_nodes.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
ir_internal_nodes.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
ir_iostream.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
ir_iostream.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
ir_nodes.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
ir_printer.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
ir_utils.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
ir_utils.h Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
iter_visitor.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
iter_visitor.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
kernel_cache.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
kernel_cache.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
kernel_expr_evaluator.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
kernel_expr_evaluator.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
kernel_ir_builder.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
kernel_ir_builder.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
kernel_ir_printer.cpp
kernel_ir_printer.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
kernel_ir.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
kernel_ir.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
kernel.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
kernel.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
lower_alias_memory.cpp Fix usages of TORCH_CHECK/_INTERNAL_ASSERT without condition (#71879) 2022-01-27 04:20:55 +00:00
lower_alias_memory.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
lower_allocation.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
lower_allocation.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
lower_expr_sort.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
lower_expr_sort.h
lower_index.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
lower_index.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
lower_insert_syncs.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
lower_insert_syncs.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
lower_loops.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
lower_loops.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
lower_magic_zero.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
lower_magic_zero.h Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
lower_misaligned_vectorization.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
lower_misaligned_vectorization.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
lower_predicate.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
lower_predicate.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
lower_shift.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
lower_shift.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
lower_thread_predicate.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
lower_thread_predicate.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
lower_trivial_reductions.cpp
lower_trivial_reductions.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
lower_unroll.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
lower_unroll.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
lower_utils.cpp [warning] Fix TORCH_INTERNAL_ASSERT calls missing condition to check 1/x (#71711) 2022-01-25 15:45:21 +00:00
lower_utils.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
lower_validation.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
lower_validation.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
lower_warp_reduce.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
lower_warp_reduce.h Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
lower2device.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
lower2device.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
manager.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
manager.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
mutator.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
mutator.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
non_divisible_split.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
non_divisible_split.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
parallel_dimension_map.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
parallel_dimension_map.h Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
parallel_type_bitmap.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
parallel_type_bitmap.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
parser.cpp Implement Tanh Gelu Approximation (#61439) 2022-01-28 16:59:09 +00:00
parser.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
partial_split_map.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
partial_split_map.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
partition.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
partition.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
predicate_compute.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
predicate_compute.h Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
reference_tensor.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
register_interface.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
root_domain_map.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
root_domain_map.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
tensor_view.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
transform_iter.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
transform_iter.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
transform_replay.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
transform_replay.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
transform_rfactor.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
transform_rfactor.h Remove WindowsTorchApiMacro.h in favor of Export.h (#69585) 2021-12-09 17:30:09 -08:00
transform_view.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
transform_view.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
type_inference.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
type_inference.h Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
type_promotion.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
type_promotion.h Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
type.cpp Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
type.h Nvfuser code bump 12 5 (#69964) 2021-12-16 08:28:54 -08:00
utils.cpp Nvfuser code bump 11 5 (#67943) 2021-11-17 01:22:17 -08:00
utils.h Codegen: TraceType only includes operators being registered (#68691) 2022-01-02 13:09:19 -08:00