mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[1/n][Optimus][Auto-AC] Support activation quantization without scaling (#148380)
Summary: We enable the activation quantization in the forward pass, and users can customize the dtype they want to quantize. Test Plan: # unit test ``` buck2 test 'fbcode//mode/dev-nosan' fbcode//caffe2/test/inductor:quantization -- test_activation_quantization_aten ``` Buck UI: https://www.internalfb.com/buck2/776d3911-bb86-4ac8-a527-540cf1510b9d Test UI: https://www.internalfb.com/intern/testinfra/testrun/4785074873051017 Network: Up: 4.3MiB Down: 42MiB (reSessionID-fef7e727-68b1-4645-a519-5652854df38d) Executing actions. Remaining 0/4 6.7s exec time total Command: test. Finished 2 local Time elapsed: 3:11.5s Tests finished: Pass 2. Fail 0. Fatal 0. Skip 0. Build failure 0 # E2E ### how to enable (you can overrite the dtype, if nothing given, the default is fp8) ``` post_grad_fusion_options={ "activation_quantization_aten_pass": {"quant_type": "torch.float8_e5m2"} }, ``` Differential Revision: D70522237 Pull Request resolved: https://github.com/pytorch/pytorch/pull/148380 Approved by: https://github.com/Mingming-Ding, https://github.com/Hahu803
This commit is contained in:
parent
6f6fac6a41
commit
2d25e4d478
|
|
@ -8,7 +8,7 @@ torch.fx._symbolic_trace.Tracer.trace(self, root: Union[torch.nn.modules.module.
|
|||
torch.fx._symbolic_trace.symbolic_trace(root: Union[torch.nn.modules.module.Module, Callable[..., Any]], concrete_args: Optional[Dict[str, Any]] = None) -> torch.fx.graph_module.GraphModule
|
||||
torch.fx._symbolic_trace.wrap(fn_or_name: Union[str, Callable])
|
||||
torch.fx.graph.Graph.__init__(self, owning_module: Optional[GraphModule] = None, tracer_cls: Optional[Type[Tracer]] = None, tracer_extras: Optional[Dict[str, Any]] = None)
|
||||
torch.fx.graph.Graph.call_function(self, the_function: Callable[..., Any], args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
|
||||
torch.fx.graph.Graph.call_function(self, the_function: Callable[..., Any], args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None, name: Optional[str] = None) -> torch.fx.node.Node
|
||||
torch.fx.graph.Graph.call_method(self, method_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
|
||||
torch.fx.graph.Graph.call_module(self, module_name: str, args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
|
||||
torch.fx.graph.Graph.create_node(self, op: str, target: 'Target', args: Optional[Tuple[Argument, ...]] = None, kwargs: Optional[Dict[str, Argument]] = None, name: Optional[str] = None, type_expr: Optional[Any] = None) -> torch.fx.node.Node
|
||||
|
|
|
|||
163
test/inductor/test_quantization.py
Normal file
163
test/inductor/test_quantization.py
Normal file
|
|
@ -0,0 +1,163 @@
|
|||
# Owner(s): ["module: inductor"]
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
|
||||
import torch
|
||||
import torch._inductor
|
||||
import torch._inductor.fx_passes.group_batch_fusion
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._inductor.test_case import run_tests, TestCase
|
||||
from torch.testing._internal.common_utils import IS_LINUX
|
||||
from torch.testing._internal.inductor_utils import GPU_TYPE, HAS_GPU, requires_gpu
|
||||
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class TargetCPModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
def forward(self, x1, x2):
|
||||
relued = torch.relu(x1)
|
||||
tanhed = torch.tanh(relued)
|
||||
tensor = torch.matmul(
|
||||
tanhed,
|
||||
x2,
|
||||
)
|
||||
return tensor
|
||||
|
||||
|
||||
class FeedforwardNN(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.fc1 = torch.nn.Linear(1, 64)
|
||||
self.fc2 = torch.nn.Linear(64, 64)
|
||||
self.fc3 = torch.nn.Linear(64, 64)
|
||||
self.fc4 = torch.nn.Linear(64, 1)
|
||||
|
||||
def forward(self, x):
|
||||
x = torch.relu(self.fc1(x))
|
||||
tanh_x = torch.tanh(x)
|
||||
x = torch.relu(self.fc2(x))
|
||||
x = torch.relu(self.fc3(tanh_x))
|
||||
x = self.fc4(x)
|
||||
return x
|
||||
|
||||
|
||||
class TestQuantization(TestCase):
|
||||
def compare_dict_tensors(self, ref_dict, res_dict, rtol=1e-3, atol=1e-3):
|
||||
if len(set(ref_dict.keys())) != len(set(res_dict.keys())):
|
||||
return False
|
||||
for key1 in ref_dict.keys():
|
||||
key2 = "_orig_mod." + key1
|
||||
assert key2 in res_dict, f"{key1} does not exist in traced module"
|
||||
# if both of them are None, continue
|
||||
if (
|
||||
not isinstance(ref_dict[key1], torch.Tensor)
|
||||
and not isinstance(res_dict[key2], torch.Tensor)
|
||||
and ref_dict[key1] is None
|
||||
and res_dict[key2] is None
|
||||
):
|
||||
log.info(
|
||||
"None found with key1 and value 1: %s, %s, key2 and value2 %s, %s",
|
||||
key1,
|
||||
ref_dict[key1],
|
||||
key2,
|
||||
res_dict[key2],
|
||||
)
|
||||
continue
|
||||
elif not torch.allclose(
|
||||
ref_dict[key1], res_dict[key2], rtol=rtol, atol=atol, equal_nan=True
|
||||
):
|
||||
log.info(
|
||||
"gradient mismatch for eager and compiled modules, with eager: %s and compiled: %s",
|
||||
ref_dict[key1],
|
||||
res_dict[key2],
|
||||
)
|
||||
return False
|
||||
return True
|
||||
|
||||
def compare_pred(self, module, traced, input, rtol=1e-3, atol=1e-3):
|
||||
ref = module(*input)
|
||||
res = traced(*input)
|
||||
self.assertEqual(ref, res, rtol=rtol, atol=atol)
|
||||
|
||||
def compare_parameters(self, module, traced, rtol=1e-3, atol=1e-3):
|
||||
ref_params = dict(module.named_parameters())
|
||||
res_params = dict(traced.named_parameters())
|
||||
self.assertTrue(self.compare_dict_tensors(ref_params, res_params, rtol, atol))
|
||||
|
||||
def compare_gradients(self, module, traced, rtol=1e-3, atol=1e-3):
|
||||
ref_grad = {key: param.grad for key, param in module.named_parameters()}
|
||||
res_grad = {key: param.grad for key, param in traced.named_parameters()}
|
||||
self.assertTrue(
|
||||
self.compare_dict_tensors(ref_grad, res_grad, rtol=rtol, atol=atol)
|
||||
)
|
||||
|
||||
@requires_gpu()
|
||||
@torch._inductor.config.patch(
|
||||
pre_grad_fusion_options={},
|
||||
post_grad_fusion_options={
|
||||
"activation_quantization_aten_pass": {
|
||||
"quant_type": "torch.float8_e5m2",
|
||||
"size_in_mb": 0.0,
|
||||
},
|
||||
},
|
||||
)
|
||||
def test_activation_quantization_aten(self):
|
||||
counters.clear()
|
||||
module = TargetCPModule().to(GPU_TYPE)
|
||||
input = [
|
||||
torch.rand(
|
||||
(16, 10), requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16
|
||||
),
|
||||
torch.rand(
|
||||
(10, 16), requires_grad=True, device=GPU_TYPE, dtype=torch.bfloat16
|
||||
),
|
||||
]
|
||||
traced = torch.compile(module)
|
||||
ref = module(*input)
|
||||
res = traced(*input)
|
||||
self.compare_pred(module, traced, input)
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
self.compare_parameters(module, traced)
|
||||
self.compare_gradients(module, traced)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["activation_quantization_fwd_aten_pass"], 1
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["activation_quantization_bwd_aten_pass"], 1
|
||||
)
|
||||
self.assertTrue(torch.allclose(ref, res))
|
||||
counters.clear()
|
||||
|
||||
module = FeedforwardNN().to(GPU_TYPE)
|
||||
X = np.linspace(-10, 10, 100).reshape(-1, 1).astype(np.float32)
|
||||
input = [
|
||||
torch.from_numpy(X).to(GPU_TYPE),
|
||||
]
|
||||
traced = torch.compile(module)
|
||||
ref = module(*input)
|
||||
res = traced(*input)
|
||||
self.compare_pred(module, traced, input)
|
||||
ref.sum().backward()
|
||||
res.sum().backward()
|
||||
self.compare_parameters(module, traced)
|
||||
self.compare_gradients(module, traced)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["activation_quantization_fwd_aten_pass"], 1
|
||||
)
|
||||
self.assertEqual(
|
||||
counters["inductor"]["activation_quantization_bwd_aten_pass"], 1
|
||||
)
|
||||
self.assertTrue(torch.allclose(ref, res))
|
||||
counters.clear()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if IS_LINUX and HAS_GPU:
|
||||
run_tests()
|
||||
|
|
@ -4588,3 +4588,7 @@ def maybe_disable_inference_mode_for_fake_prop() -> Generator[None, None, None]:
|
|||
yield
|
||||
else:
|
||||
yield
|
||||
|
||||
|
||||
def is_node_meta_valid(node: Optional[torch.fx.Node]) -> bool:
|
||||
return node is None or "example_value" in node.meta or "val" in node.meta
|
||||
|
|
|
|||
|
|
@ -17,9 +17,13 @@ import torch._inductor.inductor_prims
|
|||
import torch.distributed
|
||||
import torch.fx as fx
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._dynamo.utils import counters, is_node_meta_valid
|
||||
from torch._functorch._activation_checkpointing.ac_logging_utils import (
|
||||
create_structured_trace_for_min_cut_info,
|
||||
)
|
||||
from torch._inductor import config as inductor_config
|
||||
from torch._logging import trace_structured
|
||||
from torch._subclasses.fake_tensor import extract_tensor_metadata
|
||||
from torch.fx.experimental._backward_state import BackwardState
|
||||
from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
|
||||
from torch.fx.experimental.sym_node import magic_methods, method_to_operator
|
||||
|
|
@ -296,12 +300,216 @@ def _remove_by_name(saved_values: list[fx.Node], name: str):
|
|||
break
|
||||
|
||||
|
||||
def calculate_tensor_size(tensor: torch.Tensor) -> float:
|
||||
"""
|
||||
Calculate the size of a PyTorch tensor in megabytes (MB).
|
||||
|
||||
Args:
|
||||
tensor (torch.Tensor): Input tensor
|
||||
|
||||
Returns:
|
||||
float: Memory size in MB
|
||||
"""
|
||||
# Get number of elements and size per element
|
||||
num_elements = tensor.numel()
|
||||
element_size = tensor.element_size()
|
||||
|
||||
return (num_elements * element_size) / (1024 * 1024)
|
||||
|
||||
|
||||
def get_allowed_dtypes() -> list[torch.dtype]:
|
||||
allowed_dtypes = torch._inductor.config.post_grad_fusion_options[
|
||||
"activation_quantization_aten_pass"
|
||||
].get("allowed_dtypes", "torch.bfloat16")
|
||||
allowed_dtypes = [
|
||||
getattr(torch, dtype.split(".")[-1]) for dtype in allowed_dtypes.split(";")
|
||||
]
|
||||
return allowed_dtypes
|
||||
|
||||
|
||||
def should_quantize(node: torch.fx.Node) -> bool:
|
||||
allowed_dtypes = get_allowed_dtypes()
|
||||
if not is_node_meta_valid(node) or node.meta["val"].dtype not in allowed_dtypes:
|
||||
return False
|
||||
|
||||
# calculate the size of the node
|
||||
size_in_mb = calculate_tensor_size(node.meta["val"])
|
||||
|
||||
return size_in_mb >= torch._inductor.config.post_grad_fusion_options[
|
||||
"activation_quantization_aten_pass"
|
||||
].get("size_in_mb", 100)
|
||||
|
||||
|
||||
def get_quant_type() -> torch.dtype:
|
||||
quant_type = torch._inductor.config.post_grad_fusion_options[
|
||||
"activation_quantization_aten_pass"
|
||||
].get("quant_type", "torch.float8_e5m2")
|
||||
|
||||
return getattr(torch, quant_type.split(".")[-1])
|
||||
|
||||
|
||||
def quantize_activation_fw(graph: torch.fx.Graph) -> None:
|
||||
output = graph.find_nodes(op="output")[0]
|
||||
fwd_outputs = output.args[0]
|
||||
quant_type = get_quant_type()
|
||||
node_to_quant = dict()
|
||||
for node in fwd_outputs:
|
||||
# check if the activation node is the node saved for quantization
|
||||
if node.meta.get("saved_for_quantization", False):
|
||||
with graph.inserting_after(node):
|
||||
quant_node = graph.call_function(
|
||||
torch.ops.prims.convert_element_type.default,
|
||||
args=(node, quant_type),
|
||||
name="quant_" + str(node.name),
|
||||
)
|
||||
quant_node.meta["val"] = torch.ops.prims.convert_element_type.default(
|
||||
node.meta["val"], quant_type
|
||||
)
|
||||
quant_node.meta["tensor_meta"] = extract_tensor_metadata(
|
||||
quant_node.meta["val"]
|
||||
)
|
||||
node_to_quant[node] = quant_node
|
||||
# only update the return node args, and remain all other users unchanged
|
||||
output_updated_args = tuple(
|
||||
node_to_quant[node] if node in node_to_quant else node for node in fwd_outputs # type: ignore[union-attr]
|
||||
)
|
||||
|
||||
output.update_arg(0, output_updated_args)
|
||||
counters["inductor"]["activation_quantization_fwd_aten_pass"] += 1
|
||||
|
||||
|
||||
def quantize_activation_bw(graph: torch.fx.Graph) -> None:
|
||||
bw_inputs = [node for node in graph.nodes if node.op == "placeholder"]
|
||||
for node in bw_inputs:
|
||||
if is_node_meta_valid(node) and node.meta.get("saved_for_quantization", False):
|
||||
node.meta.pop("saved_for_quantization")
|
||||
dequant_type = node.meta.pop("dequant_type")
|
||||
# dequantize the node
|
||||
with graph.inserting_after(node):
|
||||
dequant_node = graph.call_function(
|
||||
torch.ops.prims.convert_element_type.default,
|
||||
args=(node, dequant_type),
|
||||
name="dequant_" + str(node.name),
|
||||
)
|
||||
dequant_node.meta["val"] = torch.ops.prims.convert_element_type.default(
|
||||
node.meta["val"], dequant_type
|
||||
)
|
||||
dequant_node.meta["tensor_meta"] = extract_tensor_metadata(
|
||||
dequant_node.meta["val"]
|
||||
)
|
||||
# find the users of the node and replace them with the new node except the dequant_node
|
||||
for user in list(node.users.keys()):
|
||||
if user != dequant_node:
|
||||
user.replace_input_with(node, dequant_node)
|
||||
|
||||
counters["inductor"]["activation_quantization_bwd_aten_pass"] += 1
|
||||
|
||||
|
||||
def enable_activation_quantization(
|
||||
saved_values: list[fx.Node],
|
||||
fwd_module: fx.GraphModule,
|
||||
bwd_module: fx.GraphModule,
|
||||
static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None,
|
||||
) -> None:
|
||||
if (
|
||||
inductor_config.post_grad_fusion_options.get(
|
||||
"activation_quantization_aten_pass", None
|
||||
)
|
||||
is None
|
||||
):
|
||||
return
|
||||
|
||||
static_input_names = (
|
||||
[node.name for node in static_lifetime_input_nodes]
|
||||
if static_lifetime_input_nodes
|
||||
else []
|
||||
)
|
||||
saved_values_names = {node.name: node for node in saved_values}
|
||||
fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0]
|
||||
bwd_module_inputs = {
|
||||
node.name: node for node in bwd_module.graph.find_nodes(op="placeholder")
|
||||
}
|
||||
for node in fwd_module_outputs:
|
||||
if node.name in saved_values_names and should_quantize(node):
|
||||
if node.name in static_input_names:
|
||||
log.debug("Skipping quantization of static input %s: ", node.name)
|
||||
continue
|
||||
node.meta["saved_for_quantization"] = True
|
||||
node.meta["dequant_type"] = node.meta["val"].dtype
|
||||
# some of the fwd outputs and bwd inputs are not share the same object
|
||||
bwd_module_inputs[node.name].meta["saved_for_quantization"] = True
|
||||
bwd_module_inputs[node.name].meta["dequant_type"] = node.meta["val"].dtype
|
||||
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "before_activation_quantization_fwd_aten_pass",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: fwd_module.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
),
|
||||
)
|
||||
|
||||
quantize_activation_fw(fwd_module.graph)
|
||||
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "after_activation_quantization_fwd_aten_pass",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: fwd_module.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
),
|
||||
)
|
||||
|
||||
quant_fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0]
|
||||
# update the corresponding bwd_inputs due to the fwd_outputs quantization
|
||||
for fwd_node in quant_fwd_module_outputs:
|
||||
if "quant_" in fwd_node.name:
|
||||
bwd_input = bwd_module_inputs[fwd_node.name.replace("quant_", "")]
|
||||
with bwd_module.graph.inserting_after(bwd_input):
|
||||
quant_bwd_input = bwd_module.graph.placeholder(name=fwd_node.name)
|
||||
dequant_type = bwd_input.meta["dequant_type"]
|
||||
quant_bwd_input.meta.update(fwd_node.meta)
|
||||
quant_bwd_input.meta["saved_for_quantization"] = True
|
||||
quant_bwd_input.meta["dequant_type"] = dequant_type
|
||||
bwd_input.replace_all_uses_with(quant_bwd_input)
|
||||
bwd_module.graph.erase_node(bwd_input)
|
||||
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "before_activation_quantization_bwd_aten_pass",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: bwd_module.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
),
|
||||
)
|
||||
|
||||
quantize_activation_bw(bwd_module.graph)
|
||||
|
||||
trace_structured(
|
||||
"artifact",
|
||||
metadata_fn=lambda: {
|
||||
"name": "after_activation_quantization_bwd_aten_pass",
|
||||
"encoding": "string",
|
||||
},
|
||||
payload_fn=lambda: bwd_module.print_readable(
|
||||
print_output=False, include_stride=True, include_device=True
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _extract_fwd_bwd_modules(
|
||||
joint_module: fx.GraphModule,
|
||||
saved_values: list[fx.Node],
|
||||
saved_sym_nodes: list[fx.Node],
|
||||
*,
|
||||
num_fwd_outputs: int,
|
||||
static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None,
|
||||
) -> tuple[fx.GraphModule, fx.GraphModule]:
|
||||
fwd_outputs, bwd_outputs = _extract_fwd_bwd_outputs(
|
||||
joint_module, num_fwd_outputs=num_fwd_outputs
|
||||
|
|
@ -405,6 +613,9 @@ def _extract_fwd_bwd_modules(
|
|||
|
||||
fwd_module = fx._lazy_graph_module._make_graph_module(joint_module, fwd_graph)
|
||||
bwd_module = fx._lazy_graph_module._make_graph_module(joint_module, bwd_graph)
|
||||
enable_activation_quantization(
|
||||
saved_values, fwd_module, bwd_module, static_lifetime_input_nodes
|
||||
)
|
||||
return fwd_module, bwd_module
|
||||
|
||||
|
||||
|
|
@ -414,6 +625,7 @@ def default_partition(
|
|||
*,
|
||||
num_fwd_outputs,
|
||||
static_lifetime_input_indices: Optional[list[int]] = None,
|
||||
static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None,
|
||||
) -> tuple[fx.GraphModule, fx.GraphModule]:
|
||||
"""
|
||||
Partitions the :attr:`joint_module` in a manner that closely resembles the
|
||||
|
|
@ -440,7 +652,10 @@ def default_partition(
|
|||
"""
|
||||
if has_recomputable_ops(joint_module):
|
||||
return min_cut_rematerialization_partition(
|
||||
joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs
|
||||
joint_module,
|
||||
_joint_inputs,
|
||||
num_fwd_outputs=num_fwd_outputs,
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
)
|
||||
primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
|
||||
fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
|
||||
|
|
@ -495,6 +710,7 @@ def default_partition(
|
|||
saved_values,
|
||||
saved_sym_nodes=saved_sym_nodes,
|
||||
num_fwd_outputs=num_fwd_outputs,
|
||||
static_lifetime_input_nodes=static_lifetime_input_nodes,
|
||||
)
|
||||
|
||||
|
||||
|
|
@ -2085,7 +2301,11 @@ def min_cut_rematerialization_partition(
|
|||
# this case, send our graph over to the default partitioner.
|
||||
if len(node_info.required_bw_nodes) == 0:
|
||||
return default_partition(
|
||||
joint_module, _joint_inputs, num_fwd_outputs=num_fwd_outputs
|
||||
joint_module,
|
||||
_joint_inputs,
|
||||
num_fwd_outputs=num_fwd_outputs,
|
||||
static_lifetime_input_indices=static_lifetime_input_indices,
|
||||
static_lifetime_input_nodes=node_info.static_lifetime_input_nodes,
|
||||
)
|
||||
|
||||
for node in reversed(joint_module.graph.nodes):
|
||||
|
|
@ -2120,6 +2340,7 @@ def min_cut_rematerialization_partition(
|
|||
saved_values,
|
||||
saved_sym_nodes=saved_sym_nodes,
|
||||
num_fwd_outputs=num_fwd_outputs,
|
||||
static_lifetime_input_nodes=node_info.static_lifetime_input_nodes,
|
||||
)
|
||||
if graph_has_recomputable_ops:
|
||||
if graph_has_recomputable_rng_ops:
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ import logging
|
|||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._dynamo.utils import counters, is_node_meta_valid
|
||||
from torch.fx.experimental.symbolic_shapes import statically_known_true
|
||||
|
||||
from .. import config
|
||||
|
|
@ -83,10 +83,6 @@ def should_decompose_mm(mat1, mat2) -> bool:
|
|||
)
|
||||
|
||||
|
||||
def is_node_meta_valid(node: torch.fx.Node):
|
||||
return "val" in node.meta
|
||||
|
||||
|
||||
def print_decompose_pattern(match: Match, inputs: list[torch.fx.Node]):
|
||||
node = match.nodes[-1]
|
||||
log.debug(
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from collections.abc import Iterable, Iterator
|
|||
from typing import Any, Optional
|
||||
|
||||
import torch
|
||||
from torch._dynamo.utils import counters
|
||||
from torch._dynamo.utils import counters, is_node_meta_valid
|
||||
from torch._logging import trace_structured
|
||||
from torch.fx.passes.graph_transform_observer import GraphTransformObserver
|
||||
from torch.utils._ordered_set import OrderedSet
|
||||
|
|
@ -18,6 +18,7 @@ from ..pattern_matcher import (
|
|||
get_arg_value,
|
||||
stable_topological_sort,
|
||||
)
|
||||
from ..utils import OPTIMUS_EXCLUDE_POST_GRAD
|
||||
|
||||
|
||||
try:
|
||||
|
|
@ -574,10 +575,6 @@ class BatchLinearLHSFusion(BatchFusion):
|
|||
counters["inductor"]["batch_linear_lhs"] += 1
|
||||
|
||||
|
||||
def is_node_meta_valid(node: Optional[torch.fx.Node]):
|
||||
return node is None or "example_value" in node.meta or "val" in node.meta
|
||||
|
||||
|
||||
# Poor person's check for if a node in the graph mutates its input.
|
||||
# (the graph is torch IR, so we will see torch fns and python operators)
|
||||
def _is_mutable_node(tgt):
|
||||
|
|
@ -1403,7 +1400,10 @@ def group_batch_fusion_passes(graph: torch.fx.Graph, pre_grad=True):
|
|||
fbgemm_fusion_keys = [
|
||||
x
|
||||
for x in config.post_grad_fusion_options
|
||||
if config.post_grad_fusion_options[x].get("require_fbgemm", False)
|
||||
if (
|
||||
x not in OPTIMUS_EXCLUDE_POST_GRAD
|
||||
and config.post_grad_fusion_options[x].get("require_fbgemm", False)
|
||||
)
|
||||
]
|
||||
fbgemm_fusions = {
|
||||
fusion: config.post_grad_fusion_options[fusion]
|
||||
|
|
|
|||
|
|
@ -46,7 +46,13 @@ from ..pattern_matcher import (
|
|||
register_replacement,
|
||||
stable_topological_sort,
|
||||
)
|
||||
from ..utils import decode_device, get_gpu_type, is_gpu, is_pointwise_use
|
||||
from ..utils import (
|
||||
decode_device,
|
||||
get_gpu_type,
|
||||
is_gpu,
|
||||
is_pointwise_use,
|
||||
OPTIMUS_EXCLUDE_POST_GRAD,
|
||||
)
|
||||
from ..virtualized import V
|
||||
from .b2b_gemm import B2B_GEMM_PASS
|
||||
from .ddp_fusion import fuse_ddp_communication
|
||||
|
|
@ -137,8 +143,8 @@ def post_grad_passes(gm: torch.fx.GraphModule, is_inference: bool):
|
|||
patterns.apply
|
||||
)
|
||||
for pass_name in config.post_grad_fusion_options:
|
||||
# skip all patterns for group batch fusions
|
||||
if pass_name in POST_GRAD_FUSIONS:
|
||||
# skip all patterns for group batch fusions or quantization patterns
|
||||
if pass_name in POST_GRAD_FUSIONS or pass_name in OPTIMUS_EXCLUDE_POST_GRAD:
|
||||
continue
|
||||
pattern_matcher_pass = POST_GRAD_PATTERNS[pass_name]
|
||||
inductor_before_change = save_inductor_dict(
|
||||
|
|
|
|||
|
|
@ -56,6 +56,10 @@ from torch.utils._ordered_set import OrderedSet
|
|||
from torch.utils._pytree import tree_map_only
|
||||
|
||||
|
||||
OPTIMUS_EXCLUDE_POST_GRAD = [
|
||||
"activation_quantization_aten_pass",
|
||||
]
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Sequence, ValuesView
|
||||
|
||||
|
|
|
|||
|
|
@ -1421,6 +1421,7 @@ class Graph:
|
|||
args: Optional[tuple["Argument", ...]] = None,
|
||||
kwargs: Optional[dict[str, "Argument"]] = None,
|
||||
type_expr: Optional[Any] = None,
|
||||
name: Optional[str] = None,
|
||||
) -> Node:
|
||||
"""
|
||||
Insert a ``call_function`` ``Node`` into the ``Graph``. A ``call_function`` node
|
||||
|
|
@ -1441,6 +1442,8 @@ class Graph:
|
|||
type_expr (Optional[Any]): an optional type annotation representing the
|
||||
Python type the output of this node will have.
|
||||
|
||||
name (Optional[str]): The name of the node. If not specified, set to None
|
||||
|
||||
Returns:
|
||||
|
||||
The newly created and inserted ``call_function`` node.
|
||||
|
|
@ -1450,7 +1453,7 @@ class Graph:
|
|||
as :meth:`Graph.create_node`.
|
||||
"""
|
||||
return self.create_node(
|
||||
"call_function", the_function, args, kwargs, type_expr=type_expr
|
||||
"call_function", the_function, args, kwargs, name=name, type_expr=type_expr
|
||||
)
|
||||
|
||||
@compatibility(is_backward_compatible=True)
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user