pytorch/torch/csrc/jit/codegen/onednn/interface.cpp
sanchitintel 974ad8fa6c Add BFloat16 dtype support for oneDNN Graph JIT fuser (#85591)
## BFloat16 dtype support for faster inference with TorchScript using oneDNN Graph

Intel Xeon Cooper Lake platform & beyond support the `AVX512_BF16` ISA, which is essentially native BFloat16 support.
oneDNN Graph delivers high inference performance with BFloat16 on such machines.

While oneDNN Graph can still be used with BFloat16 on older machines that lack `avx512_bf16` ISA but support `avx512bw`, `avx512vl` & `avx512dq` ISAs, the BF16 performance on these older machines will be significantly poorer (probably even poorer than Float32), as they lack native BF16 support.

Currently, [AMP support for eager mode & JIT mode is divergent in PyTorch](https://github.com/pytorch/pytorch/issues/75956).
So, for using oneDNN Graph with BFloat16, eager-mode AMP should be leveraged by turning off AMP for JIT mode, using `torch._C._jit_set_autocast_mode(False)` in python code, so as to avoid conflicts.

Please use the following environment variable to view JIT logs -
`PYTORCH_JIT_LOG_LEVEL=">>graph_helper:>>graph_fuser:>>kernel:>>interface"`

## Changes being made in this PR
1. This PR does NOT change the `oneDNN` commit or the `ideep` files. While the `ideep` commit is being updated, only files pertaining to oneDNN Graph are being updated. oneDNN Graph is being upgraded to version 0.5.2 (alpha patch release 2).
To put things into perspective, `ideep` is a git submodule of PyTorch. `oneDNN Graph` is a git submodule of `ideep` (`ideep/mkl-dnn`), and oneDNN is a git submodule of oneDNN Graph (`ideep/mkl-dnn/third_party/oneDNN`).
2. Unit-tests are being updated. We now use the [existing dtypes decorator](https://github.com/pytorch/pytorch/blob/master/torch/testing/_internal/common_device_type.py#L123-L131).
3. Suggestions made by @eellison in the [FP32 PR](https://github.com/pytorch/pytorch/pull/68111#pullrequestreview-896719477) are being incorporated/addressed -

| Action-item | Status |
| :---                                             |          ---: |
|checkInputCompatibility follow up | Fixed |
|the mayConvertScalarInputToTensor logic we can consider | Added type promotion code |
|fix up fixConvOptionalBias| The current approach seems correct |
|Use opinfo tests| using dtypes decorator. Will use `OpInfo` in a subsequent PR, if that'd be possible. Should we create a list of ops from opDB that are supported by oneDNN Graph, and add it to `common_methods_invocations.py`? |
|inferDevice torch_check call | not necessary now, perhaps, as only CPU is supported, for now? We'd add it by the beta release of oneDNN Graph, though, so that by then, users might be able to use other fusers with oneDNN Graph (NNC/TensorExpr are already compatible with the oneDNN Graph fuser). We can still add it, if you'd insist. |
|not checking shapes of input mkldnn tensor to llga guard | Those checks should not be present because oneDNN Graph may use blocked or channels-last layout, so those strides would be different. They're only skipped if an LLGA subgraph's output is input to another LLGA subgraph, which enables LLGA to choose an optimal layout between them. |
|fix test failures with respect to unsupported inputs | We'll address them with the upcoming release of oneDNN Graph beta version|

4. More PyTorch ops are being been mapped to oneDNN Graph

## Example of using oneDNN Graph with BFloat16

```python
# Assuming we have a model of the name 'model'

example_input = torch.rand(1, 3, 224, 224)

# enable oneDNN Graph
torch.jit.enable_onednn_fusion(True)
# Disable AMP for JIT
torch._C._jit_set_autocast_mode(False)
with torch.no_grad(), torch.cpu.amp.autocast():
    model = torch.jit.trace(model, (example_input))
    model = torch.jit.freeze(model)
     # 2 warm-ups (2 for tracing/scripting with an example, 3 without an example)
    model(example_input)
    model(example_input)

    # speedup would be observed in subsequent runs.
    model(example_input)
```

## TorchBench based Benchmarks
**URL:** https://github.com/sanchitintel/benchmark/tree/onednn_graph_benchmark (instructions present at URL).
**Batch-size(s):** TorchBench-default for each model
**Baseline :** PyTorch JIT OFI FP32
**Machine:** Intel(R) Xeon(R) Platinum 8371HC (Cooper Lake)
**Sockets used**: 1
**Number of cores on one socket**: 26
Intel OpenMP & tcmalloc were preloaded

#### Benchmark results with single thread
| name                                             | latency of PyTorch JIT OFI FP32 (s) |   Latency of oneDNN Graph BF16 (s) |   % change |
| :---                                             |          ---: |            ---: |       ---: |
| test_eval[alexnet-cpu-jit]                       |      1.063851 |        0.509820 |     -52.1% |
| test_eval[mnasnet1_0-cpu-jit]                    |      0.218435 |        0.107100 |     -51.0% |
| test_eval[mobilenet_v2-cpu-jit]                  |      0.114467 |        0.058359 |     -49.0% |
| test_eval[mobilenet_v3_large-cpu-jit]            |      0.233873 |        0.117614 |     -49.7% |
| test_eval[resnet18-cpu-jit]                      |      0.160584 |        0.075854 |     -52.8% |
| test_eval[resnet50-cpu-jit]                      |      1.652846 |        0.713373 |     -56.8% |
| test_eval[resnext50_32x4d-cpu-jit]               |      0.471174 |        0.209431 |     -55.6% |
|test_eval[shufflenet_v2_x1_0-cpu-jit] | 0.310306 | 0.167090 | -46.2% |
| test_eval[squeezenet1_1-cpu-jit]                 |      0.161247 |        0.045684 |     -71.7% |
| test_eval[timm_efficientnet-cpu-jit]             |      1.643772 |        0.800099 |     -51.3% |
| test_eval[timm_regnet-cpu-jit]                   |      5.732272 |        2.333417 |     -59.3% |
| test_eval[timm_resnest-cpu-jit]                  |      1.366464 |        0.715252 |     -47.7% |
| test_eval[timm_vision_transformer-cpu-jit]       |      0.508521 |        0.271598 |     -46.6% |
| test_eval[timm_vovnet-cpu-jit]                   |      2.756692 |        1.125033 |     -59.2% |
| test_eval[vgg16-cpu-jit]                         |      0.711533 |        0.312344 |     -56.1% |

#### Benchmark results with 26 threads:
| name                                             | latency of PyTorch JIT OFI FP32 (s) |   Latency of oneDNN Graph BF16 (s) |   % change |
| :---                                             |          ---: |            ---: |       ---: |
| test_eval[alexnet-cpu-jit]                       |      0.062871 |        0.034198 |     -45.6% |
| test_eval[mnasnet1_0-cpu-jit]                    |      0.022490 |        0.008172 |     -63.7% |
| test_eval[mobilenet_v2-cpu-jit]                  |      0.012730 |        0.005866 |     -53.9% |
| test_eval[mobilenet_v3_large-cpu-jit]            |      0.025948 |        0.010346 |     -60.1% |
| test_eval[resnet18-cpu-jit]                      |      0.011194 |        0.005726 |     -48.9% |
| test_eval[resnet50-cpu-jit]                      |      0.124662 |        0.045599 |     -63.4% |
| test_eval[resnext50_32x4d-cpu-jit]               |      0.034737 |        0.015214 |     -56.2% |
|test_eval[shufflenet_v2_x1_0-cpu-jit] | 0.028820 | 0.012517 | -56.6% |
| test_eval[squeezenet1_1-cpu-jit]                 |      0.012557 |        0.003876 |     -69.1% |
| test_eval[timm_efficientnet-cpu-jit]             |      0.203177 |        0.051879 |     -74.5% |
| test_eval[timm_regnet-cpu-jit]                   |      0.452050 |        0.151113 |     -66.6% |
| test_eval[timm_resnest-cpu-jit]                  |      0.117072 |        0.052848 |     -54.9% |
| test_eval[timm_vision_transformer-cpu-jit]       |      0.046048 |        0.023275 |     -49.5% |
| test_eval[timm_vovnet-cpu-jit]                   |      0.213187 |        0.077482 |     -63.7% |
| test_eval[vgg16-cpu-jit]                         |      0.044726 |        0.021998 |     -50.8% |

Pull Request resolved: https://github.com/pytorch/pytorch/pull/85591
Approved by: https://github.com/jgong5, https://github.com/frank-wei, https://github.com/chunyuan-w
2022-10-13 20:36:59 +00:00

182 lines
6.2 KiB
C++

#include <oneapi/dnnl/dnnl_graph.hpp>
#include <torch/csrc/jit/codegen/onednn/decompose_silu.h>
#include <torch/csrc/jit/codegen/onednn/defer_size_check.h>
#include <torch/csrc/jit/codegen/onednn/graph_fuser.h>
#include <torch/csrc/jit/codegen/onednn/guard_shape.h>
#include <torch/csrc/jit/codegen/onednn/interface.h>
#include <torch/csrc/jit/codegen/onednn/kernel.h>
#include <torch/csrc/jit/codegen/onednn/layout_propagation.h>
#include <torch/csrc/jit/codegen/onednn/prepare_binary.h>
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/decompose_ops.h>
#include <torch/csrc/jit/passes/pass_manager.h>
#include <torch/csrc/jit/passes/remove_mutation.h>
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/graph_executor.h>
#include <torch/csrc/jit/runtime/operator_options.h>
namespace torch {
namespace jit {
namespace fuser {
namespace onednn {
void fuseGraph(std::shared_ptr<Graph>& g) {
// Follow the process of the tensorexpr_fuser in profiling mode:
// Remove prim::profile nodes and embed the profile info directly in the
// IR in value types to avoid breaking the fusion patterns.
// Will add shape guard after LLGA optimization passes and
// wipe the tensor type information from the IR, so that it's not
// accidentally used by any other pass.
// We rely on the shape specialization and shape guard to ensure the validity
// of the cached compilation in the kernel, thus only support profiling mode.
// TODO: add check on oneDNNFusionGroup to ensure allShapesAreKnown on nodes
// to fuse: torch/csrc/jit/passes/tensorexpr_fuser.cpp: allShapesAreKnown
if (getProfilingMode()) {
GRAPH_DUMP(
"Before RemoveProfileNodesAndSpecializeTypes. Beginning of LLGA "
"optimization pass",
g);
RemoveProfileNodesAndSpecializeTypes(g);
GRAPH_DUMP(
"After RemoveProfileNodesAndSpecializeTypes. Before mutation removal",
g);
RemoveTensorMutation(g, [](Node* nodeToFunctionalize) {
static std::unordered_set<Symbol> supportedOps = {
aten::add_,
aten::mul_,
aten::tanh_,
aten::elu_,
aten::relu_,
aten::relu6_,
aten::gelu_,
aten::sqrt_,
aten::sigmoid_,
aten::hardtanh_,
aten::abs_,
aten::square_,
aten::pow_,
aten::leaky_relu_,
aten::round_,
aten::exp_,
aten::abs_,
aten::hardswish_,
aten::silu_};
return supportedOps.count(nodeToFunctionalize->kind()) != 0;
});
RemoveListMutation(g);
GRAPH_DUMP("After mutation removal. Before DecomposeSiluForLlga", g);
DecomposeSiluForLLGA(g);
GRAPH_DUMP("After DecomposeSiluForLlga. Before PrepareBinaryForLLGA", g);
PrepareBinaryForLLGA(g);
GRAPH_DUMP("After PrepareBinaryForLLGA. Before DeferSizeCheck", g);
DeferSizeCheck(g);
GRAPH_DUMP("After DeferSizeCheck. Before CreateLlgaSubgraphs", g);
CreateLlgaSubgraphs(g);
GRAPH_DUMP("After CreateLlgaSubgraphs. Before PropagateLayout", g);
PropagateLayout(g);
GRAPH_DUMP(
"After PropagateLayout. Before prepareFusionGroupAndGuardOutputs", g);
// Add shape guard for profiling mode and wipe the tensor type information
// from the IR
prepareFusionGroupAndGuardOutputs(g->block());
GRAPH_DUMP(
"After prepareFusionGroupAndGuardOutputs. Before "
"RemoveTensorTypeSpecializations",
g);
RemoveTensorTypeSpecializations(g);
GRAPH_DUMP(
"After RemoveTensorTypeSpecializations. End of LLGA optimization pass",
g);
}
}
} // namespace onednn
} // namespace fuser
Operation createLlgaKernel(const Node* node) {
auto kernel = std::make_shared<fuser::onednn::LlgaKernel>(node);
return [kernel](Stack* stack) {
RECORD_FUNCTION(kernel->debugName(), std::vector<c10::IValue>());
kernel->run(*stack);
return 0;
};
}
RegisterOperators oneDNNFusionGroupOp({
torch::jit::Operator(
prim::oneDNNFusionGroup,
createLlgaKernel,
AliasAnalysisKind::INTERNAL_SPECIAL_CASE),
});
// Currently, we convert some scalar inputs, such as the second argument of
// binary ops to a 1D tensor. Other scalar inputs are prim::Constant nodes.
// But if we have any scalar inputs to guard in the future, some logic here
// would have to be changed.
Operation createLlgaGuardKernel(const Node* node) {
return [node](Stack* stack) {
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("Guarding node: ", node->kind().toQualString());
#endif
std::vector<TypePtr> types = node->tys(attr::types);
const auto num_inputs = types.size();
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("num_inputs to guard: ", num_inputs);
#endif
for (size_t i = 0; i < num_inputs; i++) {
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("checking input ", i);
#endif
auto& input = peek(stack, i, num_inputs);
const c10::TensorTypePtr& guard_tensor_type =
types[i]->cast<TensorType>();
if (!input.isTensor()) {
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("input ", i, " is not a tensor, return false");
#endif
push(stack, IValue(false));
return;
}
const at::Tensor& tensor = input.toTensor();
// If input tensor is of mkldnn, it's originated from an upstream
// LLGA partition that has passed the check on input shapes.
// It is valid to continue here as long as the output shapes from
// oneDNN graph partitions are determined by the input shapes.
if (tensor.is_mkldnn()) {
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("input ", i, " is_mkldnn, continue");
#endif
continue;
}
if (!guard_tensor_type->matchTensor(tensor)) {
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("input ", i, " check failed, return false");
#endif
push(stack, IValue(false));
return;
}
}
#ifdef GRAPH_DEBUG_ENABLED
GRAPH_DEBUG("all check done, return true");
#endif
push(stack, IValue(true));
return;
};
}
RegisterOperators oneDNNGuardOp({
torch::jit::Operator(
prim::oneDNNFusionGuard,
createLlgaGuardKernel,
AliasAnalysisKind::FROM_SCHEMA),
});
} // namespace jit
} // namespace torch