Remove hard-coded NVRTC specific constant from fuser header

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22699

Test Plan: Imported from OSS

Differential Revision: D16192290

Pulled By: bwasti

fbshipit-source-id: 4dccaf3e6e0151e86d35474c36e1ddb7f2afb5cf
This commit is contained in:
Bram Wasti 2019-07-11 13:26:39 -07:00 committed by Facebook Github Bot
parent 513b7a7a06
commit 05d56bd1b6
3 changed files with 35 additions and 21 deletions

View File

@ -332,10 +332,6 @@ std::shared_ptr<FusedKernel> compileKernel(
}
}
// Have checked the limit at graph_fuser. Assert nothing else changing that.
AT_ASSERT((flat_inputs.size() + flat_outputs.size()) <=
fusion_kernel_args_limit);
const bool use_cuda = device.is_cuda();
const std::string name = "kernel_" + std::to_string(next_kernel_id++);
std::string code =

View File

@ -138,6 +138,14 @@ struct GraphFuser {
FusionCallback callback_ = [&](Node* n) { return isFusableDefault(n); };
Symbol kind_ = prim::FusionGroup;
// nvrtc has a limit on the number of arguments allowed in a CUDA kernel.
// The specific limit is a function of constant memory size, amount available
// to pass arguments, and some implementation dependence. Select a safe
// limit here.
// This limit is also applied to other devices in the fuser by default.
// Change with setInputArgLimit
size_t subgraph_arg_limit_ = 128;
GraphFuser(Block* block, std::shared_ptr<Graph> graph)
: block_(block), graph_(std::move(graph)) {}
@ -152,6 +160,10 @@ struct GraphFuser {
callback_(callback),
kind_(kind) {}
void setInputArgLimit(size_t limit) {
subgraph_arg_limit_ = limit;
}
value_list tensorInputs(Node* node) {
return filter(node->inputs(), [](Value* v) {
return v->type()->isSubtypeOf(TensorType::get());
@ -219,7 +231,7 @@ struct GraphFuser {
auto tensors_node = node->namedInput(attr::tensors)->node();
if ((tensors_node->inputs().size() + node->outputs().size()) >
fusion_kernel_args_limit) {
subgraph_arg_limit_) {
return false;
}
if (tensors_node->kind() != prim::ListConstruct)
@ -428,7 +440,7 @@ struct GraphFuser {
if ((consumer->inputs().size() + consumer->outputs().size() +
producer->node()->inputs().size() +
producer->node()->outputs().size()) > fusion_kernel_args_limit) {
producer->node()->outputs().size()) > subgraph_arg_limit_) {
return at::nullopt;
}
@ -1022,7 +1034,7 @@ struct GraphFuser {
// If the number of kernel args could exceed the limit, skip.
if ((before_check->inputs().size() + before_check->outputs().size() +
producer->node()->inputs().size() +
producer->node()->outputs().size()) > fusion_kernel_args_limit) {
producer->node()->outputs().size()) > subgraph_arg_limit_) {
return false;
}
@ -1306,13 +1318,15 @@ void FuseGraph(std::shared_ptr<Graph>& graph) {
void CustomFuseGraph(
std::shared_ptr<Graph>& graph,
std::function<bool(Node*)> fn,
Symbol kind) {
GraphFuser(
Symbol kind,
size_t arg_limit) {
auto g = GraphFuser(
graph->block(),
graph,
[=](Node* n) { return fn(n) || n->kind() == kind; },
kind)
.run();
kind);
g.setInputArgLimit(arg_limit);
g.run();
}
} // namespace jit

View File

@ -5,24 +5,28 @@
namespace torch {
namespace jit {
// nvrtc has a limit on the number of arguments allowed in a CUDA kernel.
// The specific limit is a function of constant memory size, amount available
// to pass arguments, and some implementation dependence. Select a safe
// limit here.
// This limit is also applied to other devices in the fuser, because we
// don't consider a kernel with such a large number of arguments would be
// profitable.
constexpr size_t fusion_kernel_args_limit = 128;
// NB: Be sure to run DCE before fusion, because dead instructions
// can prevent fusion opportunities from being exploited.
// On Windows will noop, NYI
TORCH_API void FuseGraph(std::shared_ptr<Graph>& graph);
// \brief Custom fusion pass using a node-level callback to
// determine the inclusion of nodes in a subgraph.
//
// This helper omits aliased inputs and fusion across control flow
// boundaries.
//
// \arg graph The graph to be modified in-place
// \arg is_fusable A callback run on each fusable node in the graph.
// \arg kind The label given to the resultant fused subgraph
// \arg arg_limit The maximum number of args the resultant fused subgraph
// should have. Note: This will likely develop into a general
// post condition on the fused subgraph.
TORCH_API void CustomFuseGraph(
std::shared_ptr<Graph>& graph,
std::function<bool(Node*)> is_fusable,
Symbol kind);
Symbol kind,
size_t arg_limit=std::numeric_limits<size_t>::max());
TORCH_API bool trackSingleGradSumToSizeToOutputs(
Value* gradSumToSizeOutput,