pytorch/c10/util/env.h
Aswin John Mathews 63b180beed ROCm MIOpen NHWC Convolution support (#63617)
Summary:
- Added 2D-Convolution NHWC support
  - on ROCm 4.3, with `PYTORCH_MIOPEN_SUGGEST_NHWC=1` flag
  - May need to force MIOpen to search for solutions ( see examples below for flags )

**PYTORCH_MIOPEN_SUGGEST_NHWC Environment Flag**
MIOpen does not officially support NHWC yet, although convolution support has been added to tip-of-tree of MIOpen. This flag is intended to be a short-lived flag to explicitly turn on NHWC support until ROCm officially supports NHWC and performance is verified.

**Examples**
1. Example usage 1 : Run test on ROCm4.3
`PYTORCH_TEST_WITH_ROCM=1 PYTORCH_MIOPEN_SUGGEST_NHWC=1 MIOPEN_FIND_ENFORCE=4 MIOPEN_DEBUG_CONV_GEMM=0 MIOPEN_FIND_MODE=1 pytest test_nn.py -v -k "test_conv_cudnn_nhwc" `
2. Example usage 2: Run the following with `PYTORCH_MIOPEN_SUGGEST_NHWC=1` on ROCm4.3.
```
#!/usr/bin/env python3
import torch
model = torch.nn.Conv2d(8, 4, 3).cuda().half()
model = model.to(memory_format=torch.channels_last)
input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, requires_grad=True)
input = input.to(device="cuda", memory_format=torch.channels_last, dtype=torch.float16)

# should print True for is_contiguous(channels_last), and strides must match NHWC format
print(input.is_contiguous(memory_format=torch.channels_last), input.shape, input.stride() )

out = model(input)

# should print True for is_contiguous(channels_last), and strides must match NHWC format
print("Contiguous channel last :", out.is_contiguous(memory_format=torch.channels_last), " out shape :",  out.shape, "out stride :", out.stride() )
```

See https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html for more examples.

cc jeffdaily sunway513 jithunnair-amd ROCmSupport

Pull Request resolved: https://github.com/pytorch/pytorch/pull/63617

Reviewed By: saketh-are

Differential Revision: D30730800

Pulled By: ezyang

fbshipit-source-id: 61906a0f30be8299e6547d312ae6ac91cc7c3238
2021-09-10 08:06:32 -07:00

36 lines
835 B
C++

#pragma once
#include <c10/util/Exception.h>
#include <c10/util/Optional.h>
#include <cstring>
namespace c10 {
namespace utils {
// Reads an environment variable and returns
// - optional<true>, if set equal to "1"
// - optional<false>, if set equal to "0"
// - nullopt, otherwise
//
// NB:
// Issues a warning if the value of the environment variable is not 0 or 1.
inline optional<bool> check_env(const char* name) {
auto envar = std::getenv(name);
if (envar) {
if (strcmp(envar, "0") == 0) {
return false;
}
if (strcmp(envar, "1") == 0) {
return true;
}
TORCH_WARN(
"Ignoring invalid value for boolean flag ",
name,
": ",
envar,
"valid values are 0 or 1.");
}
return c10::nullopt;
}
} // namespace utils
} // namespace c10