[JIT] Add autocasting to freezing pass & enable autocast pass by default (#74178)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/74178

Autocasting + freezing should reduce model size in some scenarios, since half-precision constants should be smaller than full-precision constants. This also enables the jit autocast pass by default, so `torch._C._jit_set_autocast_mode(True)` doesn't need to be set in order to enable autocasting.

Test Plan: Imported from OSS

Reviewed By: zou3519, eellison

Differential Revision: D34914245

Pulled By: davidberard98

fbshipit-source-id: 301f3669431feabbd695ebbdfc9c17bd1be3b565
(cherry picked from commit 0530cd365ae1f148910100a5c2981e80d04e4883)
This commit is contained in:
David Berard 2022-03-23 16:02:16 -07:00 committed by PyTorch MergeBot
parent f5a9c36d0b
commit 981baadf47
3 changed files with 66 additions and 6 deletions

View File

@ -659,6 +659,55 @@ class TestAutocast(JitTestCase):
# isn't enabled
self.assertRaises(RuntimeError, lambda: scripted_thing1.forward(x, y))
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_jit_freeze_autocast_basic(self):
class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
def forward(self, x, y):
with torch.cuda.amp.autocast():
return torch.mm(x, y)
x = torch.rand((3, 4), dtype=torch.float).cuda()
y = torch.rand((4, 5), dtype=torch.float).cuda()
mod = TestModule().eval()
# sanity check
self._test_autocast(mod, "aten::_autocast_to_reduced_precision", x, y)
frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval())
FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(frozen_mod.graph)
# make sure that the runtime pass doesn't duplicate autocast nodes
frozen_mod(x, y)
optimized_graph = frozen_mod.graph_for(x, y)
FileCheck().check_count("aten::_autocast_to_reduced_precision", 2, True).run(optimized_graph)
@unittest.skipIf(not TEST_CUDA, "No cuda")
def test_jit_freeze_autocast_constants(self):
class TestModule(torch.nn.Module):
def __init__(self):
super(TestModule, self).__init__()
self.x = torch.rand((3, 4), dtype=torch.float).cuda()
def forward(self, y):
with torch.cuda.amp.autocast():
return torch.mm(self.x, y)
y = torch.rand((4, 5), dtype=torch.float).cuda()
mod = TestModule().eval()
frozen_mod = torch.jit.freeze(torch.jit.script(mod).eval())
# freezing should pre-cast the constant self.x to remove one autocast call
FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(frozen_mod.graph)
# the runtime autocasting pass will re-insert the second autocast call,
# but constant propagation will merge it with the constant that it's casting.
frozen_mod(y)
optimized_graph = frozen_mod.graph_for(y)
FileCheck().check_count("aten::_autocast_to_reduced_precision", 1, True).run(optimized_graph)
if __name__ == "__main__":
run_tests()

View File

@ -12,15 +12,14 @@
#include <stack>
#include <unordered_set>
#include <vector>
namespace torch {
namespace jit {
namespace {
// TODO: Turn on autocast by default. default turned off to avoid tests failures
// as we prototype the support
bool autocast_enabled = false;
bool autocast_enabled = true;
struct AutocastContext {
bool gpu_enabled = false;
@ -149,17 +148,23 @@ void castTensorInputs(
const auto graph = node->owningGraph();
std::unordered_set<Value*> casted_inputs;
// need to also keep the inputs in order, otherwise tracing fails
// sanity checks because casting ops are inserted in random order
std::vector<Value*> casted_inputs_ordered;
for (auto input : node->inputs()) {
// TODO: update cast_op signature to take dynamic context flags
auto input_tensor_type = input->type()->cast<TensorType>();
if (input_tensor_type && input->node()->kind() != cast_op) {
casted_inputs.insert(input);
auto has_inserted = casted_inputs.insert(input);
if (has_inserted.second) {
casted_inputs_ordered.push_back(input);
}
}
}
WithInsertPoint insert_point(node);
for (auto input : casted_inputs) {
for (auto input : casted_inputs_ordered) {
if (cast_op == aten::_autocast_to_full_precision) {
const auto new_input = graph->insert(
cast_op,
@ -437,7 +442,9 @@ void handleBlock(Block* block, AutocastContext initial_state) {
// Banned in autocast, see binary_cross_entropy_banned()
case aten::binary_cross_entropy:
AT_ERROR("Unsafe to autocast");
if (current_state()) {
AT_ERROR("Unsafe to autocast");
}
}
// process sub-blocks, if any

View File

@ -5,6 +5,7 @@
#include <c10/util/irange.h>
#include <torch/csrc/jit/api/function_impl.h>
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/passes/autocast.h>
#include <torch/csrc/jit/passes/clear_profiling.h>
#include <torch/csrc/jit/passes/eliminate_no_ops.h>
#include <torch/csrc/jit/passes/inliner.h>
@ -101,6 +102,9 @@ class AttributePropagator {
ClearProfilingInformation(subgraph);
};
auto applyOptimizations = [](std::shared_ptr<Graph>& subgraph) {
#ifndef C10_MOBILE
Autocast(subgraph);
#endif
runOptimization(
subgraph,
/* unroll_non_constant_loops? */ false,