mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary:
- Added 2D-Convolution NHWC support
- on ROCm 4.3, with `PYTORCH_MIOPEN_SUGGEST_NHWC=1` flag
- May need to force MIOpen to search for solutions ( see examples below for flags )
**PYTORCH_MIOPEN_SUGGEST_NHWC Environment Flag**
MIOpen does not officially support NHWC yet, although convolution support has been added to tip-of-tree of MIOpen. This flag is intended to be a short-lived flag to explicitly turn on NHWC support until ROCm officially supports NHWC and performance is verified.
**Examples**
1. Example usage 1 : Run test on ROCm4.3
`PYTORCH_TEST_WITH_ROCM=1 PYTORCH_MIOPEN_SUGGEST_NHWC=1 MIOPEN_FIND_ENFORCE=4 MIOPEN_DEBUG_CONV_GEMM=0 MIOPEN_FIND_MODE=1 pytest test_nn.py -v -k "test_conv_cudnn_nhwc" `
2. Example usage 2: Run the following with `PYTORCH_MIOPEN_SUGGEST_NHWC=1` on ROCm4.3.
```
#!/usr/bin/env python3
import torch
model = torch.nn.Conv2d(8, 4, 3).cuda().half()
model = model.to(memory_format=torch.channels_last)
input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float32, requires_grad=True)
input = input.to(device="cuda", memory_format=torch.channels_last, dtype=torch.float16)
# should print True for is_contiguous(channels_last), and strides must match NHWC format
print(input.is_contiguous(memory_format=torch.channels_last), input.shape, input.stride() )
out = model(input)
# should print True for is_contiguous(channels_last), and strides must match NHWC format
print("Contiguous channel last :", out.is_contiguous(memory_format=torch.channels_last), " out shape :", out.shape, "out stride :", out.stride() )
```
See https://pytorch.org/tutorials/intermediate/memory_format_tutorial.html for more examples.
cc jeffdaily sunway513 jithunnair-amd ROCmSupport
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63617
Reviewed By: saketh-are
Differential Revision: D30730800
Pulled By: ezyang
fbshipit-source-id: 61906a0f30be8299e6547d312ae6ac91cc7c3238
36 lines
835 B
C++
36 lines
835 B
C++
#pragma once
|
|
|
|
#include <c10/util/Exception.h>
|
|
#include <c10/util/Optional.h>
|
|
#include <cstring>
|
|
|
|
namespace c10 {
|
|
namespace utils {
|
|
// Reads an environment variable and returns
|
|
// - optional<true>, if set equal to "1"
|
|
// - optional<false>, if set equal to "0"
|
|
// - nullopt, otherwise
|
|
//
|
|
// NB:
|
|
// Issues a warning if the value of the environment variable is not 0 or 1.
|
|
inline optional<bool> check_env(const char* name) {
|
|
auto envar = std::getenv(name);
|
|
if (envar) {
|
|
if (strcmp(envar, "0") == 0) {
|
|
return false;
|
|
}
|
|
if (strcmp(envar, "1") == 0) {
|
|
return true;
|
|
}
|
|
TORCH_WARN(
|
|
"Ignoring invalid value for boolean flag ",
|
|
name,
|
|
": ",
|
|
envar,
|
|
"valid values are 0 or 1.");
|
|
}
|
|
return c10::nullopt;
|
|
}
|
|
} // namespace utils
|
|
} // namespace c10
|