pytorch/torch/csrc/jit/codegen/cuda/interface.cpp
jjsjann123 7b419e8513 [NVFuser] Upstream push 1026 (#87779)
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Codegen changes include:

* codegen improvement:
    i. allow non-root trivial reductions, allow empty/no-op fusion
    ii. fixes vectorization checks and size calculation
    iii. bank conflict handle improvement
    iv. enables transpose scheduler

* misc:
    i. CI tests failure fixes
    ii. cpp tests file clean up
    iii. trivial forwarding supports added in codegen runtime
    iv. added factory methods support in codegen

Commits that's in this PR from the devel branch:

```
7117a7e37ebec372d9e802fdfb8abb7786960f4a patching nvfuser conv cudnn test numerics mismatch (#2048)
65af1a4e7013f070df1ba33701f2d524de79d096 Inserting sync for redundant parallel types is already done at the (#2023)
6ac74d181689c8f135f60bfc1ec139d88941c98c Fix sync map (#2047)
f5bca333355e2c0033523f3402de5b8aac602c00 Bank conflict checker improvements (#2032)
d2ca7e3fd203537946be3f7b435303c60fa7f51e Minor update on cp.async code generation. (#1901)
d36cf61f5570c9c992a748126287c4e7432228e0 Test file cleanup (#2040)
0b8e83f49c2ea9f04a4aad5061c1e7f4268474c6 Allow non-root trivial reductions (#2037)
a2dfe40b27cd3f5c04207596f0a1818fbd5e5439 Fix vectorize size calculation (#2035)
e040676a317fe34ea5875276270c7be88f6eaa56 Use withPredicate to replace setPredicate to maintain Exprs immutable (#2025)
197221b847ad5eb347d7ec1cf2706733aacbf97c removing ci workflow (#2034)
40e2703d00795526e7855860aa00b9ab7160755f Reduction rand like patch (#2031)
bc772661cbdb3b711d8e9854ae9b8b7052e3e4a3 Add utility for checking bank conflict of shared memory (#2029)
ddd1cf7695f3fb172a0e4bcb8e4004573617a037 Add back FusionReductionWithTrivialReduction_CUDA (#2030)
fbd97e5ef15fa0f7573800e6fbb5743463fd9e57 Revert "Cleanup trivial reduction workarounds (#2006)" (#2024)
bca20c1dfb8aa8d881fc7973e7579ce82bc6a894 Cleanup trivial reduction workarounds (#2006)
e4b65850eee1d70084105bb6e1f290651adde23e Trivial forwarding (#1995)
1a0e355b5027ed0df501989194ee8f2be3fdd37a Fix contiguity analysis of predicates to match updated contiguity. (#1991)
a4effa6a5f7066647519dc56e854f4c8a2efd2a7 Enable output allocation cache (#2010)
35440b7953ed8da164a5fb28f87d7fd760ac5e00 Patching bn inference (#2016)
0f9f0b4060dc8ca18dc65779cfd7e0776b6b38e8 Add matmul benchmark (#2007)
45045cd05ea268f510587321dbcc8d7c2977cdab Enable tests previously disabled due to an aliasing bug (#2005)
967aa77d2c8e360c7c01587522eec1c1d377c87e Contiguous indexing for View operations (#1990)
a43cb20f48943595894e345865bc1eabf58a5b48 Make inlining even more modular (#2004)
dc458358c0ac91dfaf4e6655a9b3fc206fc0c897 Test util cleanup (#2003)
3ca21ebe4d213f0070ffdfa4ae5d7f6cb0b8e870 More strict validation (#2000)
a7a7d573310c4707a9f381831d3114210461af01 Fix build problem (#1999)
fc235b064e27921fa9d6dbb9dc7055e5bae1c222 Just fixes comments (#1998)
482386c0509fee6edb2964c5ae72074791f3e43a cleanup (#1997)
4cbe0db6558a82c3097d281eec9c85ad2ea0893a Improve divisible split detection (#1970)
42ccc52bdc18bab0330f4b93ed1399164e2980c9 Minor build fix. (#1996)
fcf8c091f72d46f3055975a35afd06263324ede6 Cleanup of lower_utils.cpp: Isolate out GpuLower usage (#1989)
15f2f6dba8cbf408ec93c344767c1862c30f7ecc Move ConcretizedBroadcastDomains to shared_ptr in GpuLower. (#1988)
8f1c7f52679a3ad6acfd419d28a2f4be4a7d89e2 Minor cleanup lower_unroll.cpp (#1994)
1d9858c80319ca7f0037db7de5f04e47f540d76c Minor cleanup (#1992)
f262d9cab59f41c669f53799c6d4a6b9fc4267eb Add support for uniform RNG (#1986)
eb1dad10c73f855eb1ecb20a8b1f7b6edb0c9ea3 Remove non-const functions, remove GpuLower instance on build, pass in ca_map. (#1987)
634820c5e3586c0fe44132c51179b3155be18072 Add support for some empty fusion (#1981)
eabe8d844ad765ee4973faa4821d451ef71b83c3 Segment self mapping fusions (#1954)
e96aacfd9cf9b3c6d08f120282762489bdf540c8 Enable Transpose operation (#1882)
425dce2777420248e9f08893765b5402644f4161 Add a null scheduler that helps segmenting away no-op schedules (#1835)
306d4a68f127dd1b854b749855e48ba23444ba60 Fix canScheduleCompileTime check of transpose scheduler (#1969)
b1bd32cc1b2ae7bbd44701477bddbcfa6642a9be Minor fix (#1967)
bd93578143c1763c1e00ba613a017f8130a6b989 Enable transpose scheduler (#1927)
b7a206e93b4ac823c791c87f12859cf7af264a4c Move scheduler vectorize utilities into their own file (#1959)
d9420e4ca090489bf210e68e9912bb059b895baf View scheduling (#1928)
c668e13aea0cf21d40f95b48e0163b812712cdf2 Upstream push ci fixes (#1965)
c40202bb40ce955955bb97b12762ef3b6b612997 Fix dump effective bandwidth (#1962)
93505bcbb90a7849bd67090fe5708d867e8909e4 WAR on index mapping when exact and permissive maps differ (#1960)
45e95fd1d3c773ee9b2a21d79624c279d269da9f Allow splitting inner-most ID to create virtual innermost ID in transpose scheduler (#1930)
a3ecb339442131f87842eb56955e4f17c544e99f Improve the comments at the beginning of index_compute.h (#1946)
f7bc3417cc2923a635042cc6cc361b2f344248d6 Remove unused variables (#1955)
df3393adbb5cb0309d091f358cfa98706bd4d313 Some cleanup (#1957)
7d1d7c8724ab5a226fad0f5a80feeac04975a496 TVDomainGuard factory (#1953)
357ba224c0fb41ed3e4e8594d95599c973f4a0ca Fill allocation with nan on tests (#1956)
8eafc54685d406f5ac527bcbacc475fda4492d7a Fix detection of unmappable root domains (#1952)
90a51f282601ba8ebd4c84b9334efd7762a234bc Some indexing cleanups, Add eye support (#1940)
ddc01e4e16428aec92f9c84d698f959b6436a971 Exclude unsupported data types (#1951)
992e17c0688fe690c51b50e81a75803621b7e6aa test the groups the same order as they are merged (#1949)
208262b75d1fed0597a0329d61d57bc8bcd7ff14 Move detection of self mapping IDs to IterDomainGraph from (#1941)
ac4de38c6ee53b366e85fdfe408c3642d32b57df Merge pull request #1945 from csarofeen/master_merge_0828
631094891a96f715d8c9925fb73d41013ca7f2e3 Add full, full_like, zeros, zeros_like, ones, ones_like (#1943)
aab10bce4541204c46b91ff0f0ed9878aec1bfc4 Merge remote-tracking branch 'upstream/viable/strict' into HEAD
4c254c063bb55887b45677e3812357556a7aa80d Fix arange when step is negative (#1942)
89330aa23aa804340b2406ab58899d816e3dc3d2 Tensor factories must set the output shape as its input (#1939)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D40869846](https://our.internmc.facebook.com/intern/diff/D40869846)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/87779
Approved by: https://github.com/davidberard98
2022-11-04 20:04:34 +00:00

930 lines
32 KiB
C++

#include <torch/csrc/jit/codegen/cuda/interface.h>
#include <ATen/core/dispatch/OperatorOptions.h>
#include <ATen/native/NonSymbolicBC.h>
#include <ATen/native/TensorShape.h>
#include <c10/util/CallOnce.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/codegen/cuda/transform_view.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/register_ops_utils.h>
// NOLINTNEXTLINE
C10_DEFINE_bool(
torch_jit_nvfuser_singleton_fusion,
false,
"enable single node fusion for nvfuser");
// NOLINTNEXTLINE
C10_DEFINE_bool(
torch_jit_nvfuser_horizontal_fusion,
true,
"enable horizontal fusion for nvfuser");
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
static std::atomic<bool> cuda_fusion_guard_mode{true};
// There are 3 sources of information on whether to enable nvfuser:
// 1. assigned value from setEnabled() - takes precendence if it has been set
// 2. value from environment variable - only used if setEnabled() is unset
// 3. default value - used if both 1 and 2 are unset.
//
// If 1 or 2 tries to enable nvfuser when it cannot be enabled (e.g. cuda not
// available), then an error will be thrown. The default will not error.
class NVFuserEnabler {
private:
c10::optional<bool> runtime_assigned_fuser_enabled_ = c10::nullopt;
c10::once_flag enabled_check_flag_;
std::mutex mutex_;
public:
static bool nvfuserCanBeEnabled() {
return at::globalContext().hasCUDA() &&
NVFuserPassManager::isRegistered() && getExecutorMode();
}
private:
static void assertFuserCanBeEnabled(bool is_enabled) {
if (!is_enabled) {
return;
}
TORCH_CHECK(
nvfuserCanBeEnabled(),
"Running CUDA fuser is only supported on CUDA builds.");
}
static c10::optional<bool> getFuserEnabledEnvVar() {
static const char* enable_c_str = std::getenv("PYTORCH_JIT_ENABLE_NVFUSER");
if (!enable_c_str) {
return c10::nullopt;
}
std::string enable(enable_c_str);
if (enable == "0" || enable == "OFF") {
return false;
}
return true;
}
static c10::optional<bool> getCachedFuserEnabledEnvVar() {
static c10::optional<bool> default_enabled = getFuserEnabledEnvVar();
return default_enabled;
}
static bool getNNCNotNVFuser() {
static const char* env_c_str =
std::getenv("PYTORCH_JIT_USE_NNC_NOT_NVFUSER");
if (!env_c_str) {
return false;
}
std::string env(env_c_str);
if (env == "1" || env == "ON") {
return true;
}
return false;
}
static bool getCachedNNCNotNVFuser() {
static bool force_disable = getNNCNotNVFuser();
return force_disable;
}
bool isEnabledImpl() {
// 0. opportunity to force disable NVFuser
if (getCachedNNCNotNVFuser()) {
return false;
}
c10::call_once(enabled_check_flag_, [&]() {
// if environment variable is setting the value, we must
if (!runtime_assigned_fuser_enabled_.has_value() &&
getCachedFuserEnabledEnvVar().has_value()) {
assertFuserCanBeEnabled(*getCachedFuserEnabledEnvVar());
}
});
// 1. if user has explicitly assigned fuser value, that value takes
// precedence.
if (runtime_assigned_fuser_enabled_.has_value()) {
return *runtime_assigned_fuser_enabled_;
}
// 2. next precedence is any value assigned by
if (getCachedFuserEnabledEnvVar().has_value()) {
return *getCachedFuserEnabledEnvVar();
}
// 3. default value
#if defined(USE_ROCM) || defined(FBCODE_CAFFE2)
return false;
#else
return nvfuserCanBeEnabled();
#endif
}
public:
bool setEnabled(bool is_enabled) {
std::lock_guard<std::mutex> lock(mutex_);
assertFuserCanBeEnabled(is_enabled);
bool old_value = isEnabledImpl();
runtime_assigned_fuser_enabled_ = is_enabled;
return old_value;
}
bool isEnabled() {
std::lock_guard<std::mutex> lock(mutex_);
return isEnabledImpl();
}
};
static NVFuserEnabler nvfuser_enabler;
bool isEnabled() {
return nvfuser_enabler.isEnabled();
}
bool setEnabled(bool is_enabled) {
return nvfuser_enabler.setEnabled(is_enabled);
}
bool canBeEnabled() {
return nvfuser_enabler.nvfuserCanBeEnabled();
}
bool getSingletonFusion() {
return FLAGS_torch_jit_nvfuser_singleton_fusion;
}
bool setSingletonFusion(bool value) {
bool old_value = FLAGS_torch_jit_nvfuser_singleton_fusion;
FLAGS_torch_jit_nvfuser_singleton_fusion = value;
return old_value;
}
bool getHorizontalFusion() {
return FLAGS_torch_jit_nvfuser_horizontal_fusion;
}
bool setHorizontalFusion(bool value) {
bool old_value = FLAGS_torch_jit_nvfuser_horizontal_fusion;
FLAGS_torch_jit_nvfuser_horizontal_fusion = value;
return old_value;
}
std::atomic<bool>& getCudaFusionGuardMode() {
return cuda_fusion_guard_mode;
}
CudaFuserInterface* getFuserInterface() {
static CudaFuserInterface fuser_interface_;
return &fuser_interface_;
}
void compileFusionGroup(Node* fusion_node) {
TORCH_CHECK(
getFuserInterface()->fn_compile_n != nullptr,
"Running the CUDA fuser requires a CUDA build.");
getFuserInterface()->fn_compile_n(fusion_node);
}
void runFusionGroup(const Node* fusion_node, Stack& stack) {
TORCH_CHECK(
getFuserInterface()->fn_run_n_s != nullptr,
"Running the CUDA fuser requires a CUDA build.");
getFuserInterface()->fn_run_n_s(fusion_node, stack);
}
void fuseGraph(std::shared_ptr<Graph>& graph) {
if (!isEnabled()) {
return;
}
TORCH_CHECK(
getFuserInterface()->fn_fuse_graph != nullptr,
"Running the CUDA fuser requires a CUDA build.");
getFuserInterface()->fn_fuse_graph(graph);
}
bool canFuseNode(const Node* node) {
return getFuserInterface()->fn_can_fuse_n != nullptr &&
getFuserInterface()->fn_can_fuse_n(node);
}
void InsertProfileNodesForCUDAFuser(ProfilingRecord* pr) {
if (getFuserInterface()->fn_insert_profile_inodes) {
getFuserInterface()->fn_insert_profile_inodes(pr);
}
}
bool profileNode(const Node* node) {
return getFuserInterface()->fn_profile_n != nullptr &&
getFuserInterface()->fn_profile_n(node);
}
bool skipNode(const std::string& symbol_str, bool flip) {
return getFuserInterface()->fn_skip_n != nullptr &&
getFuserInterface()->fn_skip_n(symbol_str, flip);
}
AnalyzeViewConstraint getViewConstraint(
const std::vector<int64_t>& original_sizes,
const std::vector<int64_t>& new_sizes) {
if (getFuserInterface()->fn_analyze_view != nullptr) {
return getFuserInterface()->fn_analyze_view(original_sizes, new_sizes);
}
TORCH_INTERNAL_ASSERT(false, "Requires nvFuser which requires CUDA build.");
}
//! [ Note -- type guard logic in CudaFusionGuard ]
//!
//! CudaFusionGuard is used to Guard input tensor to `CudaFusionGroup` so that
//! we would not feed inputs that violates the graph defined in `GraphCache`.
//!
//! see [ Note -- 2 level cache implementation ] for definition of unique
//! computational graph.
//! see [ Note -- CudaFusionGuard implementation] for details on how guard works
//! in profiling executor
//!
//! Type guard logic is used to query whether a runtime input `tensor` compiles
//! with profiled `guard_tensor_type`. `guard_tensor_type` is the observed
//! tensor type during profiling runs.
//!
//! At this moment, we only do single profiling run, so `guard_tensor_type` has
//! static shape / stride / scalarType. *This might be a little confusing as our
//! implementation is actually more relaxed.
//!
//! Things that we check:
//! a. identical rank & scalar type
//! b. stride check:
//! b.1. identical stride order
//! b.2. identical contiguity
//! note that contiguity here is used for tensor collapsing. So
//! extra attention should be paid to contiguity across size-1
//! dimensions.
//! c. size check:
//! c.1 broadcast check:
//! making sure that broadcast semantics are identical. So we want to
//! make sure a given dimension either are both size-1 for `tensor` &
//! `guard_tensor_type`, or are both non-size-1.
//! This is due to the fact that we specialize size-1 dimension as
//! broadcasted dimension while translating PyTorch tensor to Fusion IR.
//! c.1 size-0 check:
//! we don't specialize this on codegen, but we do specialize fusion
//! logic for size-0 on reductoins, hence the check
//!
bool complyWith(
const at::Tensor& tensor,
const c10::TensorTypePtr& guard_tensor_type) {
// guard broadcast semantics, contiguity & stride order;
TORCH_INTERNAL_ASSERT(
guard_tensor_type && guard_tensor_type->dim().has_value());
// check a. if num_dimension check fails or scalar type check fails
if (*guard_tensor_type->dim() != static_cast<size_t>(tensor.ndimension()) ||
(guard_tensor_type->scalarType().has_value() &&
(guard_tensor_type->scalarType().value() != tensor.scalar_type())) ||
(guard_tensor_type->device().has_value() &&
(guard_tensor_type->device().value() != tensor.device())) ||
(guard_tensor_type->requiresGrad().has_value() &&
guard_tensor_type->requiresGrad().value() !=
(tensor.requires_grad() && at::GradMode::is_enabled()))) {
return false;
}
// TODO: should we get symbolic_size instead and check for size
// consistency across tensors as well?
const auto& sizes = guard_tensor_type->sizes();
// see [ Note -- stirde_properties in tensor type ]
const auto& stride_properties = guard_tensor_type->stride_properties();
const auto& t_sizes = tensor.sizes();
const auto& t_strides = tensor.strides();
int inner_dim = -1;
for (const auto j : c10::irange(*guard_tensor_type->dim())) {
// check b. for stride check, we go along dimensions from fastest stride to
// slowest stride
int sorted_index = stride_properties[j]->stride_index_
? static_cast<int>(*stride_properties[j]->stride_index_)
: -1;
// only apply stride check when we have stride_properties
if (sorted_index != -1) {
// check b.1. stride order [current dimension has stride larger
// than its inner dimension(s)], check only applies when both:
// i. already encountered an inner dimension
// ii. not at the fastest dimension
if (j != 0 && inner_dim != -1) {
// we are not looking at dim-j, but dim-sorted_index, which
// is the j-th fastest dim;
// Note: we ignore 0-stride dimension, since eager logic on stride
// indices is ambiguous
if (t_strides[sorted_index] != 0 && t_strides[inner_dim] != 0 &&
t_strides[sorted_index] < t_strides[inner_dim]) {
return false;
}
}
// check b.2. contiguity, we only check when it's marked as
// contiguous.
if (stride_properties[j]->contiguous_ &&
*stride_properties[j]->contiguous_) {
if (j != 0) {
// we use contiguity to collapse dimension, if size == 1, it is
// always collapsible
// computeStrideProps also default to contiguous when stride == 1
if (t_sizes[sorted_index] != 1 && t_strides[sorted_index] != 1) {
TORCH_INTERNAL_ASSERT(
stride_properties[j - 1]->stride_index_.has_value(),
"Counknown index is meaningless");
// TODO: merge this check up
if (t_strides[sorted_index] !=
t_strides[inner_dim] * t_sizes[inner_dim]) {
return false;
}
}
} else {
// TODO: merge this check up
if (t_strides[sorted_index] != 1) {
return false;
}
}
}
// update inner_dim to be current dim. Note that we try to skip update
// when current `t_size[sorted_index] == 1`, because:
// 1. stride comparison on a size-1 dimension is meaningless
// [check b.1]
// 2. contiguity on a size-1 dimension is misleading. For collapsing,
// we should actually look at the next non-size-1 dimension
// [check b.2]
if (inner_dim == -1 || t_sizes[sorted_index] != 1) {
inner_dim = sorted_index;
}
}
// check c.1, we go along semantic ordered dimensions
// check broadcast / size-1:
bool guard_bcast = sizes[j].has_value() && sizes[j].value() == 1;
if (guard_bcast != (t_sizes[j] == 1)) {
return false;
}
// check c.2, check for size-0
bool guard_size_0 = sizes[j].has_value() && sizes[j].value() == 0;
if (guard_size_0 != (t_sizes[j] == 0)) {
return false;
}
}
return true;
}
} // namespace cuda
} // namespace fuser
namespace {
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
RegisterOperators size_eq_guard({
Operator(
//"prim::CudaFusionSizeEq(int[] size, int[] ref) -> bool",
"prim::CudaFusionSizeEq(...) -> bool",
// prim::CudaFusionGuard returns a fresh Boolean type without aliasing.
// if we would ever return refined tensor, which would change aliasing
// analysis, we should update aliasdb pass.
[](const Node* node) -> Operation {
return [](Stack& stack) {
at::ArrayRef<IValue> inputs = last(stack, 2);
drop(stack, 2);
if (!fuser::cuda::getCudaFusionGuardMode()) {
push(stack, IValue(true));
return;
}
// auto inp = inputs[0].toIntList();
TORCH_INTERNAL_ASSERT(
inputs[1].isIntList(), "reference needs to be of int list");
auto ref = inputs[1].toIntList();
auto ret = true;
if (ref.empty()) {
ret = inputs[0].isNone();
} else {
if (inputs[0].isIntList()) {
auto inp = inputs[0].toIntList();
if (inp.size() != ref.size()) {
push(stack, IValue(false));
return;
}
for (const auto i : c10::irange(inp.size())) {
if (((inp[i] == 1) != (ref[i] == 1))) {
ret = false;
break;
}
}
} else {
ret = false;
}
}
push(stack, IValue(ret));
return;
};
},
aliasAnalysisFromSchema()),
});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
RegisterOperators reg_fusion({
Operator(
prim::CudaFusionGroup,
[](const Node* node) -> Operation {
return [node](Stack& stack) {
fuser::cuda::runFusionGroup(node, stack);
};
},
aliasAnalysisSpecialCase()),
});
RegisterOperators reg_guard({
Operator(
"prim::CudaFusionGuard(...) -> bool",
// prim::CudaFusionGuard returns a fresh Boolean type without aliasing.
// if we would ever return refined tensor, which would change aliasing
// analysis, we should update aliasdb pass.
[](const Node* node) -> Operation {
return [node](Stack& stack) {
// TODO: check latency here!!!!
std::vector<TypePtr> types = node->tys(attr::types);
const auto num_inputs = types.size();
at::ArrayRef<IValue> inputs = last(stack, num_inputs);
drop(stack, num_inputs);
if (!fuser::cuda::getCudaFusionGuardMode()) {
push(stack, IValue(true));
return;
}
for (const auto i : c10::irange(num_inputs)) {
const c10::TensorTypePtr& guard_tensor_type =
types[i]->cast<TensorType>();
// TODO: maybe we should just push false and fallback
TORCH_INTERNAL_ASSERT(inputs[i].isTensor());
const at::Tensor& tensor = inputs[i].toTensor();
if (!fuser::cuda::complyWith(tensor, guard_tensor_type)) {
push(stack, IValue(false));
return;
}
}
// TODO: check type and return the right flag
// naively return true;
push(stack, IValue(true));
return;
};
},
aliasAnalysisFromSchema()),
});
// Infer dynamic axis (-1) in view_sizes given tensor_sizes
bool inferViewShape(
c10::List<int64_t> tensor_sizes,
c10::List<int64_t> view_sizes) {
int64_t dynamic_index = -1;
size_t view_size_num_elements = 1;
for (size_t idx = 0; idx < view_sizes.size(); ++idx) {
if (view_sizes[idx] == -1) {
TORCH_INTERNAL_ASSERT(
dynamic_index == -1, "Only one dimension can by inferred.")
dynamic_index = idx;
} else {
TORCH_INTERNAL_ASSERT(view_sizes[idx] > 0);
view_size_num_elements *= view_sizes[idx];
}
}
const size_t kNumElements = std::accumulate(
tensor_sizes.begin(), tensor_sizes.end(), 1, std::multiplies<>());
if (kNumElements % view_size_num_elements != 0) {
return false;
}
if (dynamic_index != -1) {
view_sizes[dynamic_index] = kNumElements / view_size_num_elements;
}
return true;
}
//!
//! CudaFusionViewGuard Example Graph:
//!
//! graph(%self : __torch__.BiasViewRelu,
//! %inputs.1 : Tensor):
//! %2 : int = prim::Constant[value=-1]() # dynamic_bvg.py:50:40
//! %3 : int = prim::Constant[value=1]() # dynamic_bvg.py:50:25
//! %4 : NoneType = prim::Constant()
//! %5 : int[] = prim::Constant[value=[2, 3]]()
//! %6 : int[] = aten::size(%inputs.1) # dynamic_bvg.py:50:25
//! %7 : int[] = aten::slice(%6, %4, %2, %3) # dynamic_bvg.py:50:25
//! %view_shape.1 : int[] = aten::add(%7, %5) # dynamic_bvg.py:50:25
//! %bias : Tensor = prim::GetAttr[name="bias"](%self)
//! %10 : int[] = aten::size(%bias)
//! %11 : int[] = prim::BroadcastSizes(%6, %10)
//! %12 : bool = prim::CudaFusionGuard[types=[...]](%inputs.1, %bias)
//! %13 : int[] = prim::Constant[value=[-1, -1, -1, 6]]()
//! %14 : int[] = prim::Constant[value=[-1, -1, -1, 2, 3]]()
//! %15 : bool = prim::CudaFusionViewGuard(%11, %view_shape.1, %13, %14)
//! %16 : bool[] = prim::ListConstruct(%15, %12)
//! %17 : bool = aten::all(%16)
//! %18 : Tensor = prim::If(%17)
//! block0():
//! %19 : Tensor = prim::CudaFusionGroup_0[cache_id=0](%inputs.1, %bias)
//! -> (%19)
//! block1():
//! %20 : Function = prim::Constant[name="fallback_fn", fallback=1]()
//! %21 : (...) = prim::CallFunction(%20, %inputs.1, %bias, %view_shape.1)
//! %22 : Float(...) = prim::TupleUnpack(%21)
//! -> (%22)
//! return (%18)
//! with prim::CudaFusionGroup_0 = graph(%0 : Float(...),
//! %1 : Float(...)):
//! %2 : int[] = prim::Constant[value=[2, 3, 4, 2, 3]]()
//! %3 : int = prim::Constant[value=1]() # dynamic_bvg.py:50:25
//! %o.1 : Float(...) = aten::add(%0, %1, %3) # dynamic_bvg.py:51:16
//! %5 : Float(...) = prim::view_copy(%o.1, %2)
//! %6 : Float(...) = aten::relu(%5) # dynamic_bvg.py:53:19
//! return (%6)
//!
RegisterOperators view_guard({
Operator(
"prim::CudaFusionViewGuard(...) -> bool",
// prim::CudaFusionViewGuard returns a fresh Boolean type without
// aliasing. if we would ever return refined tensor, which would change
// aliasing analysis, we should update aliasdb pass.
[](const Node* node) -> Operation {
return [](Stack& stack) {
// view_sizes_constraint - Constant List[Int]
at::ArrayRef<IValue> inputs = last(stack, 3);
// tensor_sizes is the runtime size for the self tensor
// tensor_sizes - dynamic size List[Int]
TORCH_INTERNAL_ASSERT(
inputs[0].isIntList(), "tensor_sizes needs to be Int List");
auto tensor_sizes = inputs[0].toIntList();
// profiled_view_sizes is the runtime view size
// profiled_view_sizes - profile_ivalue List[Int]
TORCH_INTERNAL_ASSERT(
inputs[1].isIntList(),
"profiled_view_sizes needs to be Int list");
auto profiled_view_sizes = inputs[1].toIntList();
// tensor_constraints is a constant List[Int]
// used to guard tensor_sizes
TORCH_INTERNAL_ASSERT(
inputs[2].isIntList(),
"tensor constraint needs to be Int List");
auto tensor_constraints = inputs[2].toIntList();
// Drop after gather all input arguments
// If an argument is moved, it is destroyed when dropped from stack
drop(stack, 3);
auto status = inferViewShape(tensor_sizes, profiled_view_sizes);
if (!status) {
push(stack, IValue(false));
return;
}
if (!fuser::cuda::getCudaFusionGuardMode()) {
push(stack, IValue(true));
return;
}
std::vector<int64_t> tensor_sizes_int_vec = tensor_sizes.vec();
std::vector<int64_t> view_sizes_int_vec = tensor_sizes.vec();
std::vector<int64_t> previous_constraints =
tensor_constraints.vec();
auto new_constraints = fuser::cuda::getViewConstraint(
tensor_sizes_int_vec, view_sizes_int_vec);
bool guard_status =
(new_constraints.conglomerateString() == previous_constraints);
push(stack, IValue(guard_status));
return;
};
},
aliasAnalysisFromSchema()),
});
RegisterOperators ivalue_guard({
Operator(
"prim::CudaFusionIvalGuard(...) -> bool",
[](const Node* node) -> Operation {
return [](Stack& stack) {
at::ArrayRef<IValue> inputs = last(stack, 2);
drop(stack, 2);
if (!fuser::cuda::getCudaFusionGuardMode()) {
push(stack, IValue(true));
return;
}
push(stack, inputs[0].equals(inputs[1]));
return;
};
},
aliasAnalysisFromSchema()),
});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
RegisterOperators reg_add_optional({
Operator(
"prim::add_optional(Tensor(a) input, Tensor? bias) -> Tensor(a)",
[](const Node* node) -> Operation {
return [](Stack& stack) {
IValue input, bias;
pop(stack, input, bias);
if (bias.isNone()) {
push(stack, std::move(input));
} else {
push(stack, at::add(input.toTensor(), bias.toTensor(), 1.0));
}
};
},
aliasAnalysisFromSchema()),
});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
RegisterOperators reg_permute_copy({
Operator(
"prim::permute_copy(Tensor(a) self, int[] dims) -> Tensor",
[](const Node* node) -> Operation {
return [node](Stack& stack) {
TORCH_CHECK(
node->s(attr::name) == "CudaFusionGroup",
"permute_copy is only used by nvfuser to identify non-mutating ",
"alias ops, should be restored after fusion pass!");
IValue self, dims;
pop(stack, self, dims);
push(stack, at::native::view(self.toTensor(), dims.toIntVector()));
};
},
aliasAnalysisFromSchema()),
});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
RegisterOperators reg_transpose_copy({
Operator(
"prim::transpose_copy.int(Tensor(a) self, int dim0, int dim1) -> Tensor",
[](const Node* node) -> Operation {
return [node](Stack& stack) {
TORCH_CHECK(
node->s(attr::name) == "CudaFusionGroup",
"transpose_copy is only used by nvfuser to identify non-mutating ",
"alias ops, should be restored after fusion pass!");
IValue self, dim0, dim1;
pop(stack, self, dim0, dim1);
push(
stack,
at::transpose(self.toTensor(), dim0.toInt(), dim1.toInt()));
};
},
aliasAnalysisFromSchema()),
});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
RegisterOperators reg_t_copy({
Operator(
"prim::t_copy(Tensor(a) self) -> Tensor",
[](const Node* node) -> Operation {
return [node](Stack& stack) {
TORCH_CHECK(
node->s(attr::name) == "CudaFusionGroup",
"t_copy is only used by nvfuser to identify non-mutating ",
"alias ops, should be restored after fusion pass!");
IValue self;
pop(stack, self);
push(stack, at::t(self.toTensor()));
};
},
aliasAnalysisFromSchema()),
});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
RegisterOperators reg_view_copy({
Operator(
"prim::view_copy(Tensor self, int[] size) -> Tensor",
[](const Node* node) -> Operation {
return [node](Stack& stack) {
TORCH_CHECK(
node->s(attr::name) == "CudaFusionGroup",
"view_copy is only used by nvfuser to identify non-mutating ",
"alias ops, should be restored after fusion pass!");
IValue self, size;
pop(stack, self, size);
push(stack, at::native::view(self.toTensor(), size.toIntVector()));
};
},
aliasAnalysisFromSchema()),
});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
RegisterOperators reg_flatten_copy({
Operator(
"prim::flatten_copy(Tensor self, int start_dim, int end_dim) -> Tensor",
[](const Node* node) -> Operation {
return [node](Stack& stack) {
TORCH_CHECK(
node->s(attr::name) == "CudaFusionGroup",
"flatten_copy is only used by nvfuser to identify non-mutating ",
"alias ops, should be restored after fusion pass!");
IValue self, start_dim, end_dim;
pop(stack, self, start_dim, end_dim);
push(
stack,
at::native::flatten(
self.toTensor(), start_dim.toInt(), end_dim.toInt()));
};
},
aliasAnalysisFromSchema()),
});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
RegisterOperators reg_reshape_copy({
Operator(
"prim::reshape_copy(Tensor self, int[] shape) -> Tensor",
[](const Node* node) -> Operation {
return [node](Stack& stack) {
TORCH_CHECK(
node->s(attr::name) == "CudaFusionGroup",
"reshape_copy is only used by nvfuser to identify non-mutating ",
"alias ops, should be restored after fusion pass!");
IValue self, shape;
pop(stack, self, shape);
push(
stack,
at::native::reshape(self.toTensor(), shape.toIntVector()));
};
},
aliasAnalysisFromSchema()),
});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
RegisterOperators reg_squeeze_copy({
Operator(
"prim::squeeze_copy(Tensor self) -> Tensor",
[](const Node* node) -> Operation {
return [node](Stack& stack) {
TORCH_CHECK(
node->s(attr::name) == "CudaFusionGroup",
"squeeze_copy is only used by nvfuser to identify non-mutating ",
"alias ops, should be restored after fusion pass!");
IValue self;
pop(stack, self);
push(stack, at::squeeze(self.toTensor()));
};
},
aliasAnalysisFromSchema()),
});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
RegisterOperators reg_squeeze_dim_copy({
Operator(
"prim::squeeze_copy.dim(Tensor self, int dim) -> Tensor",
[](const Node* node) -> Operation {
return [node](Stack& stack) {
TORCH_CHECK(
node->s(attr::name) == "CudaFusionGroup",
"squeeze_dim_copy is only used by nvfuser to identify non-mutating ",
"alias ops, should be restored after fusion pass!");
IValue self, dim;
pop(stack, self, dim);
push(stack, at::squeeze(self.toTensor(), dim.toInt()));
};
},
aliasAnalysisFromSchema()),
});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
RegisterOperators reg_unsqueeze_copy({
Operator(
"prim::unsqueeze_copy(Tensor self, int dim) -> Tensor",
[](const Node* node) -> Operation {
return [node](Stack& stack) {
TORCH_CHECK(
node->s(attr::name) == "CudaFusionGroup",
"unsqueeze_copy is only used by nvfuser to identify non-mutating ",
"alias ops, should be restored after fusion pass!");
IValue self, dim;
pop(stack, self, dim);
push(stack, at::unsqueeze(self.toTensor(), dim.toInt()));
};
},
aliasAnalysisFromSchema()),
});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
RegisterOperators reg_infer_unsqueeze_size({
Operator(
"prim::infer_unsqueeze_size(int[] a, int dim) -> int[]",
[](const Node* node) -> Operation {
return [](Stack& stack) {
auto dim = pop(stack).toInt();
auto size = pop(stack).toIntVector();
if (dim < 0) {
dim = dim + 1 + size.size();
}
auto it = size.begin() + dim;
size.insert(it, 1);
push(stack, IValue(size));
};
},
aliasAnalysisFromSchema()),
});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
RegisterOperators reg_infer_squeeze_dim_size({
Operator(
"prim::infer_squeeze_size.dim(int[] a, int dim) -> int[]",
[](const Node* node) -> Operation {
return [](Stack& stack) {
auto dim = pop(stack).toInt();
auto size = pop(stack).toIntVector();
if (dim < 0) {
dim = dim + size.size();
}
auto it = size.begin() + dim;
if (*it == 1) {
size.erase(it);
}
push(stack, IValue(size));
};
},
aliasAnalysisFromSchema()),
});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
RegisterOperators reg_infer_squeeze_size({
Operator(
"prim::infer_squeeze_size(int[] a) -> int[]",
[](const Node* node) -> Operation {
return [](Stack& stack) {
auto size = pop(stack).toIntVector();
for (auto it = size.begin(); it != size.end(); it++) {
if (*it == 1) {
auto pre = it - 1;
size.erase(it);
it = pre;
}
}
push(stack, IValue(size));
};
},
aliasAnalysisFromSchema()),
});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
RegisterOperators reg_expand_copy({
Operator(
"prim::expand_copy(Tensor self, int[] size, *, bool implicit=False) -> Tensor",
[](const Node* node) -> Operation {
return [node](Stack& stack) {
TORCH_CHECK(
node->s(attr::name) == "CudaFusionGroup",
"expand_copy is only used by nvfuser to identify non-mutating ",
"alias ops, should be restored after fusion pass!");
IValue self, size, implicit;
pop(stack, self, size, implicit);
push(stack, self.toTensor().expand(size.toIntVector()));
};
},
aliasAnalysisFromSchema()),
});
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
RegisterOperators reg_expand_as_copy({
Operator(
"prim::expand_as_copy(Tensor self, Tensor other) -> Tensor",
[](const Node* node) -> Operation {
return [node](Stack& stack) {
TORCH_CHECK(
node->s(attr::name) == "CudaFusionGroup",
"expand_as_copy is only used by nvfuser to identify non-mutating ",
"alias ops, should be restored after fusion pass!");
IValue self, other;
pop(stack, self, other);
push(
stack,
at::native::expand_as(self.toTensor(), other.toTensor()));
};
},
aliasAnalysisFromSchema()),
});
} // namespace
} // namespace jit
} // namespace torch