mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Remove hard-coded NVRTC specific constant from fuser header
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/22699 Test Plan: Imported from OSS Differential Revision: D16192290 Pulled By: bwasti fbshipit-source-id: 4dccaf3e6e0151e86d35474c36e1ddb7f2afb5cf
This commit is contained in:
parent
513b7a7a06
commit
05d56bd1b6
|
|
@ -332,10 +332,6 @@ std::shared_ptr<FusedKernel> compileKernel(
|
|||
}
|
||||
}
|
||||
|
||||
// Have checked the limit at graph_fuser. Assert nothing else changing that.
|
||||
AT_ASSERT((flat_inputs.size() + flat_outputs.size()) <=
|
||||
fusion_kernel_args_limit);
|
||||
|
||||
const bool use_cuda = device.is_cuda();
|
||||
const std::string name = "kernel_" + std::to_string(next_kernel_id++);
|
||||
std::string code =
|
||||
|
|
|
|||
|
|
@ -138,6 +138,14 @@ struct GraphFuser {
|
|||
FusionCallback callback_ = [&](Node* n) { return isFusableDefault(n); };
|
||||
Symbol kind_ = prim::FusionGroup;
|
||||
|
||||
// nvrtc has a limit on the number of arguments allowed in a CUDA kernel.
|
||||
// The specific limit is a function of constant memory size, amount available
|
||||
// to pass arguments, and some implementation dependence. Select a safe
|
||||
// limit here.
|
||||
// This limit is also applied to other devices in the fuser by default.
|
||||
// Change with setInputArgLimit
|
||||
size_t subgraph_arg_limit_ = 128;
|
||||
|
||||
GraphFuser(Block* block, std::shared_ptr<Graph> graph)
|
||||
: block_(block), graph_(std::move(graph)) {}
|
||||
|
||||
|
|
@ -152,6 +160,10 @@ struct GraphFuser {
|
|||
callback_(callback),
|
||||
kind_(kind) {}
|
||||
|
||||
void setInputArgLimit(size_t limit) {
|
||||
subgraph_arg_limit_ = limit;
|
||||
}
|
||||
|
||||
value_list tensorInputs(Node* node) {
|
||||
return filter(node->inputs(), [](Value* v) {
|
||||
return v->type()->isSubtypeOf(TensorType::get());
|
||||
|
|
@ -219,7 +231,7 @@ struct GraphFuser {
|
|||
|
||||
auto tensors_node = node->namedInput(attr::tensors)->node();
|
||||
if ((tensors_node->inputs().size() + node->outputs().size()) >
|
||||
fusion_kernel_args_limit) {
|
||||
subgraph_arg_limit_) {
|
||||
return false;
|
||||
}
|
||||
if (tensors_node->kind() != prim::ListConstruct)
|
||||
|
|
@ -428,7 +440,7 @@ struct GraphFuser {
|
|||
|
||||
if ((consumer->inputs().size() + consumer->outputs().size() +
|
||||
producer->node()->inputs().size() +
|
||||
producer->node()->outputs().size()) > fusion_kernel_args_limit) {
|
||||
producer->node()->outputs().size()) > subgraph_arg_limit_) {
|
||||
return at::nullopt;
|
||||
}
|
||||
|
||||
|
|
@ -1022,7 +1034,7 @@ struct GraphFuser {
|
|||
// If the number of kernel args could exceed the limit, skip.
|
||||
if ((before_check->inputs().size() + before_check->outputs().size() +
|
||||
producer->node()->inputs().size() +
|
||||
producer->node()->outputs().size()) > fusion_kernel_args_limit) {
|
||||
producer->node()->outputs().size()) > subgraph_arg_limit_) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -1306,13 +1318,15 @@ void FuseGraph(std::shared_ptr<Graph>& graph) {
|
|||
void CustomFuseGraph(
|
||||
std::shared_ptr<Graph>& graph,
|
||||
std::function<bool(Node*)> fn,
|
||||
Symbol kind) {
|
||||
GraphFuser(
|
||||
Symbol kind,
|
||||
size_t arg_limit) {
|
||||
auto g = GraphFuser(
|
||||
graph->block(),
|
||||
graph,
|
||||
[=](Node* n) { return fn(n) || n->kind() == kind; },
|
||||
kind)
|
||||
.run();
|
||||
kind);
|
||||
g.setInputArgLimit(arg_limit);
|
||||
g.run();
|
||||
}
|
||||
|
||||
} // namespace jit
|
||||
|
|
|
|||
|
|
@ -5,24 +5,28 @@
|
|||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
// nvrtc has a limit on the number of arguments allowed in a CUDA kernel.
|
||||
// The specific limit is a function of constant memory size, amount available
|
||||
// to pass arguments, and some implementation dependence. Select a safe
|
||||
// limit here.
|
||||
// This limit is also applied to other devices in the fuser, because we
|
||||
// don't consider a kernel with such a large number of arguments would be
|
||||
// profitable.
|
||||
constexpr size_t fusion_kernel_args_limit = 128;
|
||||
|
||||
// NB: Be sure to run DCE before fusion, because dead instructions
|
||||
// can prevent fusion opportunities from being exploited.
|
||||
// On Windows will noop, NYI
|
||||
TORCH_API void FuseGraph(std::shared_ptr<Graph>& graph);
|
||||
|
||||
// \brief Custom fusion pass using a node-level callback to
|
||||
// determine the inclusion of nodes in a subgraph.
|
||||
//
|
||||
// This helper omits aliased inputs and fusion across control flow
|
||||
// boundaries.
|
||||
//
|
||||
// \arg graph The graph to be modified in-place
|
||||
// \arg is_fusable A callback run on each fusable node in the graph.
|
||||
// \arg kind The label given to the resultant fused subgraph
|
||||
// \arg arg_limit The maximum number of args the resultant fused subgraph
|
||||
// should have. Note: This will likely develop into a general
|
||||
// post condition on the fused subgraph.
|
||||
TORCH_API void CustomFuseGraph(
|
||||
std::shared_ptr<Graph>& graph,
|
||||
std::function<bool(Node*)> is_fusable,
|
||||
Symbol kind);
|
||||
Symbol kind,
|
||||
size_t arg_limit=std::numeric_limits<size_t>::max());
|
||||
|
||||
TORCH_API bool trackSingleGradSumToSizeToOutputs(
|
||||
Value* gradSumToSizeOutput,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user