mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[JIT] Add autocasting to freezing pass & enable autocast pass by default (#74178)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/74178 Autocasting + freezing should reduce model size in some scenarios, since half-precision constants should be smaller than full-precision constants. This also enables the jit autocast pass by default, so `torch._C._jit_set_autocast_mode(True)` doesn't need to be set in order to enable autocasting. Test Plan: Imported from OSS Reviewed By: zou3519, eellison Differential Revision: D34914245 Pulled By: davidberard98 fbshipit-source-id: 301f3669431feabbd695ebbdfc9c17bd1be3b565 (cherry picked from commit 0530cd365ae1f148910100a5c2981e80d04e4883)
This commit is contained in:
parent
f5a9c36d0b
commit
981baadf47
|
|
@ -659,6 +659,55 @@ class TestAutocast(JitTestCase):
|
|||
# isn't enabled
|
||||
self.assertRaises(RuntimeError, lambda: scripted_thing1.forward(x, y))
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
||||
def test_jit_freeze_autocast_basic(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TestModule, self).__init__()
|
||||
|
||||
def forward(self, x, y):
|
||||
with torch.cuda.amp.autocast():
|
||||
return torch.mm(x, y)
|
||||
|
||||
x = torch.rand((3, 4), dtype=torch.float).cuda()
|
||||
y = torch.rand((4, 5), dtype=torch.float).cuda()
|
||||
|
||||
mod = TestModule().eval()
|
||||
|
||||
# sanity check
|
||||
self._test_autocast(mod, "aten::_autocast_to_reduced_precision", x, y)
|
||||
|
||||
frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval())
|
||||
FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(frozen_mod.graph)
|
||||
|
||||
# make sure that the runtime pass doesn't duplicate autocast nodes
|
||||
frozen_mod(x, y)
|
||||
optimized_graph = frozen_mod.graph_for(x, y)
|
||||
FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(optimized_graph)
|
||||
|
||||
@unittest.skipIf(not TEST_CUDA, "No cuda")
|
||||
def test_jit_freeze_autocast_constants(self):
|
||||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super(TestModule, self).__init__()
|
||||
self.x = torch.rand((3, 4), dtype=torch.float).cuda()
|
||||
|
||||
def forward(self, y):
|
||||
with torch.cuda.amp.autocast():
|
||||
return torch.mm(self.x, y)
|
||||
|
||||
y = torch.rand((4, 5), dtype=torch.float).cuda()
|
||||
mod = TestModule().eval()
|
||||
|
||||
frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval())
|
||||
# freezing should pre-cast the constant self.x to remove one autocast call
|
||||
FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(frozen_mod.graph)
|
||||
|
||||
# the runtime autocasting pass will re-insert the second autocast call,
|
||||
# but constant propagation will merge it with the constant that it's casting.
|
||||
frozen_mod(y)
|
||||
optimized_graph = frozen_mod.graph_for(y)
|
||||
FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(optimized_graph)
|
||||
|
||||
if __name__ == "__main__":
|
||||
run_tests()
|
||||
|
|
|
|||
|
|
@ -12,15 +12,14 @@
|
|||
|
||||
#include <stack>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
namespace torch {
|
||||
namespace jit {
|
||||
|
||||
namespace {
|
||||
|
||||
// TODO: Turn on autocast by default. default turned off to avoid tests failures
|
||||
// as we prototype the support
|
||||
bool autocast_enabled = false;
|
||||
bool autocast_enabled = true;
|
||||
|
||||
struct AutocastContext {
|
||||
bool gpu_enabled = false;
|
||||
|
|
@ -149,17 +148,23 @@ void castTensorInputs(
|
|||
const auto graph = node->owningGraph();
|
||||
|
||||
std::unordered_set<Value*> casted_inputs;
|
||||
// need to also keep the inputs in order, otherwise tracing fails
|
||||
// sanity checks because casting ops are inserted in random order
|
||||
std::vector<Value*> casted_inputs_ordered;
|
||||
for (auto input : node->inputs()) {
|
||||
// TODO: update cast_op signature to take dynamic context flags
|
||||
auto input_tensor_type = input->type()->cast<TensorType>();
|
||||
if (input_tensor_type && input->node()->kind() != cast_op) {
|
||||
casted_inputs.insert(input);
|
||||
auto has_inserted = casted_inputs.insert(input);
|
||||
if (has_inserted.second) {
|
||||
casted_inputs_ordered.push_back(input);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
WithInsertPoint insert_point(node);
|
||||
|
||||
for (auto input : casted_inputs) {
|
||||
for (auto input : casted_inputs_ordered) {
|
||||
if (cast_op == aten::_autocast_to_full_precision) {
|
||||
const auto new_input = graph->insert(
|
||||
cast_op,
|
||||
|
|
@ -437,7 +442,9 @@ void handleBlock(Block* block, AutocastContext initial_state) {
|
|||
|
||||
// Banned in autocast, see binary_cross_entropy_banned()
|
||||
case aten::binary_cross_entropy:
|
||||
AT_ERROR("Unsafe to autocast");
|
||||
if (current_state()) {
|
||||
AT_ERROR("Unsafe to autocast");
|
||||
}
|
||||
}
|
||||
|
||||
// process sub-blocks, if any
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@
|
|||
#include <c10/util/irange.h>
|
||||
#include <torch/csrc/jit/api/function_impl.h>
|
||||
#include <torch/csrc/jit/ir/alias_analysis.h>
|
||||
#include <torch/csrc/jit/passes/autocast.h>
|
||||
#include <torch/csrc/jit/passes/clear_profiling.h>
|
||||
#include <torch/csrc/jit/passes/eliminate_no_ops.h>
|
||||
#include <torch/csrc/jit/passes/inliner.h>
|
||||
|
|
@ -101,6 +102,9 @@ class AttributePropagator {
|
|||
ClearProfilingInformation(subgraph);
|
||||
};
|
||||
auto applyOptimizations = [](std::shared_ptr<Graph>& subgraph) {
|
||||
#ifndef C10_MOBILE
|
||||
Autocast(subgraph);
|
||||
#endif
|
||||
runOptimization(
|
||||
subgraph,
|
||||
/* unroll_non_constant_loops? */ false,
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user