mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Revert "[Inductor][CPU] Fuse SmoothQuant int8 linear pattern (#139595)"
This reverts commit d72a308e77.
Reverted https://github.com/pytorch/pytorch/pull/139595 on behalf of https://github.com/ZainRizvi due to Sorry but the newly added tests in test_mkldnn_pattern_matcher.py fail internally. See D65661038 for more details ([comment](https://github.com/pytorch/pytorch/pull/139595#issuecomment-2465797016))
This commit is contained in:
parent
80d0356b11
commit
a7724518c0
|
|
@ -932,8 +932,8 @@ static at::Tensor linear_int8_with_onednn_weight(
|
||||||
c10::string_view& unary_post_op_algorithm) {
|
c10::string_view& unary_post_op_algorithm) {
|
||||||
using ideep::tensor;
|
using ideep::tensor;
|
||||||
const int64_t dim = input.dim();
|
const int64_t dim = input.dim();
|
||||||
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte || input.scalar_type() == c10::ScalarType::Char,
|
TORCH_CHECK(input.scalar_type() == c10::ScalarType::Byte,
|
||||||
"qlinear with mkldnn tensor: data type of input should be uint8 or int8 (unsigned char or char).");
|
"qlinear with mkldnn tensor: data type of input should be uint8 (unsigned char).");
|
||||||
TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char,
|
TORCH_CHECK(onednn_weight.scalar_type() == c10::ScalarType::Char,
|
||||||
"qlinear with mkldnn tensor: data type of weight should be int8 (char).");
|
"qlinear with mkldnn tensor: data type of weight should be int8 (char).");
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
|
|
@ -1022,8 +1022,7 @@ static at::Tensor linear_int8_with_onednn_weight(
|
||||||
empty_tensor;
|
empty_tensor;
|
||||||
|
|
||||||
// Create onednn primitive
|
// Create onednn primitive
|
||||||
auto src_dtype = input.scalar_type() == c10::kByte ? ideep::data_type::u8 : ideep::data_type::s8;
|
auto src_desc = tensor::desc(src_dims, ideep::data_type::u8, ideep::format_tag::any);
|
||||||
auto src_desc = tensor::desc(src_dims, src_dtype, ideep::format_tag::any);
|
|
||||||
auto weights_desc = packed_weight.get_desc();
|
auto weights_desc = packed_weight.get_desc();
|
||||||
auto dst_dtype = dst.get_data_type();
|
auto dst_dtype = dst.get_data_type();
|
||||||
auto dst_desc = tensor::desc(dst_dims, dst_dtype, ideep::format_tag::any);
|
auto dst_desc = tensor::desc(dst_dims, dst_dtype, ideep::format_tag::any);
|
||||||
|
|
@ -1120,14 +1119,12 @@ namespace at::native {
|
||||||
torch::List<std::optional<at::Scalar>> post_op_args,
|
torch::List<std::optional<at::Scalar>> post_op_args,
|
||||||
c10::string_view post_op_algorithm) {
|
c10::string_view post_op_algorithm) {
|
||||||
#if AT_MKLDNN_ENABLED()
|
#if AT_MKLDNN_ENABLED()
|
||||||
// act_zero_point.numel() == 0 for symmetric quantization
|
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1,
|
||||||
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
|
"onednn int8 linear: act scale/zp size should be 1");
|
||||||
"onednn int8 linear: act scale/zp size should be 1/<=1");
|
|
||||||
static std::optional<at::Tensor> other = std::nullopt;
|
static std::optional<at::Tensor> other = std::nullopt;
|
||||||
static const c10::string_view binary_post_op = "none";
|
static const c10::string_view binary_post_op = "none";
|
||||||
int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
|
|
||||||
return linear_int8_with_onednn_weight(
|
return linear_int8_with_onednn_weight(
|
||||||
act, act_scale.item().toDouble(), act_zp,
|
act, act_scale.item().toDouble(), act_zero_point.item().toLong(),
|
||||||
onednn_weight, weight_scales, weight_zero_points,
|
onednn_weight, weight_scales, weight_zero_points,
|
||||||
bias, output_scale, output_zero_point, output_dtype,
|
bias, output_scale, output_zero_point, output_dtype,
|
||||||
other, /*other scale*/1.0, /*other zp*/0,
|
other, /*other scale*/1.0, /*other zp*/0,
|
||||||
|
|
@ -1158,12 +1155,10 @@ namespace at::native {
|
||||||
torch::List<std::optional<at::Scalar>> unary_post_op_args,
|
torch::List<std::optional<at::Scalar>> unary_post_op_args,
|
||||||
c10::string_view unary_post_op_algorithm) {
|
c10::string_view unary_post_op_algorithm) {
|
||||||
#if AT_MKLDNN_ENABLED()
|
#if AT_MKLDNN_ENABLED()
|
||||||
// act_zero_point.numel() == 0 for symmetric quantization
|
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() == 1,
|
||||||
TORCH_CHECK(act_scale.numel() == 1 && act_zero_point.numel() <= 1,
|
"onednn int8 linear: act scale/zp size should be 1");
|
||||||
"onednn int8 linear: act scale/zp size should be 1/<=1");
|
|
||||||
int64_t act_zp = act_zero_point.numel() == 1 ? act_zero_point.item().toLong() : 0;
|
|
||||||
return linear_int8_with_onednn_weight(
|
return linear_int8_with_onednn_weight(
|
||||||
act, act_scale.item().toDouble(), act_zp,
|
act, act_scale.item().toDouble(), act_zero_point.item().toLong(),
|
||||||
onednn_weight, weight_scales, weight_zero_points,
|
onednn_weight, weight_scales, weight_zero_points,
|
||||||
bias, output_scale, output_zero_point, output_dtype,
|
bias, output_scale, output_zero_point, output_dtype,
|
||||||
other, other_scale, other_zero_point,
|
other, other_scale, other_zero_point,
|
||||||
|
|
|
||||||
|
|
@ -145,7 +145,6 @@ class TestPatternMatcherBase(TestCase):
|
||||||
dtype=None,
|
dtype=None,
|
||||||
is_dynamic=False,
|
is_dynamic=False,
|
||||||
quantizer=None,
|
quantizer=None,
|
||||||
compile_options={}, # noqa: B006
|
|
||||||
):
|
):
|
||||||
counters.clear()
|
counters.clear()
|
||||||
torch._dynamo.reset()
|
torch._dynamo.reset()
|
||||||
|
|
@ -189,7 +188,7 @@ class TestPatternMatcherBase(TestCase):
|
||||||
with torch.no_grad(), maybe_autocast:
|
with torch.no_grad(), maybe_autocast:
|
||||||
clone_inputs = self._clone_inputs(inputs)
|
clone_inputs = self._clone_inputs(inputs)
|
||||||
expected = mod(*inputs)
|
expected = mod(*inputs)
|
||||||
actual = torch.compile(mod, **compile_options)(*clone_inputs)
|
actual = torch.compile(mod)(*clone_inputs)
|
||||||
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
|
torch.testing.assert_close(actual, expected, atol=atol, rtol=rtol)
|
||||||
if matcher_count is not None:
|
if matcher_count is not None:
|
||||||
self.assertEqual(
|
self.assertEqual(
|
||||||
|
|
@ -2825,94 +2824,6 @@ class TestPatternMatcher(TestPatternMatcherBase):
|
||||||
rtol=0.07,
|
rtol=0.07,
|
||||||
)
|
)
|
||||||
|
|
||||||
@skipIfNoDynamoSupport
|
|
||||||
@skipIfNoONEDNN
|
|
||||||
def test_smooth_quant_with_int_mm(self):
|
|
||||||
r"""
|
|
||||||
This testcase check if we can match the SmoothQuant int8 linear pattern from Torchao.
|
|
||||||
The pattern is:
|
|
||||||
(no bias) reshape -> _int_mm -> convert_element_type -> (expand -> mul) -> mul -> reshape
|
|
||||||
or
|
|
||||||
(with bias) pattern_no_bias -> add -> reshape -> reshape
|
|
||||||
"""
|
|
||||||
M = 16
|
|
||||||
in_feature = 32
|
|
||||||
out_feature = 64
|
|
||||||
q_min, q_max = -32, 31
|
|
||||||
|
|
||||||
class Mod(torch.nn.Module):
|
|
||||||
def __init__(
|
|
||||||
self, dtype: torch.dtype, has_bias: bool, per_channel_quant: bool
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
self.dtype = dtype
|
|
||||||
self.has_bias = has_bias
|
|
||||||
self.b = torch.randint(
|
|
||||||
q_min, q_max, [in_feature, out_feature], dtype=torch.int8
|
|
||||||
)
|
|
||||||
self.per_channel_quant = per_channel_quant
|
|
||||||
a_scale_per_tensor = torch.rand([1], dtype=dtype) * 0.01 + 0.01
|
|
||||||
a_scale_per_channel = torch.rand([M, 1], dtype=dtype) * 0.01 + 0.01
|
|
||||||
self.a_scale = (
|
|
||||||
a_scale_per_channel
|
|
||||||
if self.per_channel_quant
|
|
||||||
else a_scale_per_tensor
|
|
||||||
)
|
|
||||||
self.b_scale = torch.rand([out_feature]) * 0.01 + 0.01
|
|
||||||
self.b_scale = self.b_scale.to(dtype)
|
|
||||||
self.bias = torch.rand([out_feature], dtype=dtype) if has_bias else None
|
|
||||||
|
|
||||||
def forward(self, a):
|
|
||||||
out_shape = a.shape[:-1] + (self.b.size(-1),)
|
|
||||||
a_reshaped = a.reshape(-1, a.size(-1))
|
|
||||||
c = torch._int_mm(a_reshaped, self.b)
|
|
||||||
c = c.to(self.dtype)
|
|
||||||
c_shape = c.shape
|
|
||||||
a_scale = self.a_scale.expand(c.shape)
|
|
||||||
c = c * a_scale
|
|
||||||
c = c * self.b_scale
|
|
||||||
if self.has_bias:
|
|
||||||
c = c.reshape([1, *list(c_shape)])
|
|
||||||
c = c + self.bias
|
|
||||||
c = c.reshape(c_shape)
|
|
||||||
c = c.reshape(out_shape)
|
|
||||||
return c
|
|
||||||
|
|
||||||
has_bias_list = [True, False]
|
|
||||||
dype_list = (
|
|
||||||
[torch.float, torch.bfloat16]
|
|
||||||
if torch.ops.mkldnn._is_mkldnn_bf16_supported()
|
|
||||||
else [torch.float]
|
|
||||||
)
|
|
||||||
per_channel_list = [True, False]
|
|
||||||
dynamic_list = [True, False]
|
|
||||||
for has_bias, dtype, per_channel_quant, dynamic in itertools.product(
|
|
||||||
has_bias_list, dype_list, per_channel_list, dynamic_list
|
|
||||||
):
|
|
||||||
mod = Mod(dtype, has_bias, per_channel_quant).eval()
|
|
||||||
a = torch.randint(q_min, q_max, [1, M, in_feature], dtype=torch.int8)
|
|
||||||
|
|
||||||
def matcher_check_fn():
|
|
||||||
self.assertEqual(
|
|
||||||
counters["inductor"]["qlinear_weight_prepack_matcher_count"], 1
|
|
||||||
)
|
|
||||||
if dynamic:
|
|
||||||
nodes_count = 10 if has_bias else 7
|
|
||||||
else:
|
|
||||||
nodes_count = 7 if has_bias else 6
|
|
||||||
self.assertEqual(
|
|
||||||
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"],
|
|
||||||
nodes_count,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._test_common(
|
|
||||||
mod,
|
|
||||||
(a,),
|
|
||||||
matcher_check_fn=matcher_check_fn,
|
|
||||||
check_autocast=dtype,
|
|
||||||
compile_options={"dynamic": dynamic},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
|
@dynamo_config.patch({"dynamic_shapes": True, "assume_static_by_default": False})
|
||||||
class TestDynamicPatternMatcher(TestPatternMatcherBase):
|
class TestDynamicPatternMatcher(TestPatternMatcherBase):
|
||||||
|
|
|
||||||
|
|
@ -2529,226 +2529,6 @@ def _register_qlinear_weight_prepack():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _register_smooth_quant_int_mm_pattern():
|
|
||||||
"""
|
|
||||||
The pattern is:
|
|
||||||
(no bias) reshape -> _int_mm -> convert_element_type -> (expand ->) mul -> mul -> reshape
|
|
||||||
or
|
|
||||||
(with bias) pattern_no_bias -> add (-> reshape -> reshape)
|
|
||||||
"""
|
|
||||||
|
|
||||||
# When torch.compile'ing with dynamic=True, the expand node and the two tailing reshape nodes exist
|
|
||||||
# When torch.compile'ing with dynamic=False, they don't exist
|
|
||||||
def get_pattern_no_bias(expand_a_scale: bool):
|
|
||||||
return CallFunction(
|
|
||||||
aten.reshape.default,
|
|
||||||
CallFunction(
|
|
||||||
aten.mul.Tensor,
|
|
||||||
CallFunction(
|
|
||||||
aten.mul.Tensor,
|
|
||||||
CallFunction(
|
|
||||||
prims.convert_element_type.default,
|
|
||||||
CallFunction(
|
|
||||||
aten._int_mm.default,
|
|
||||||
CallFunction(
|
|
||||||
aten.reshape.default,
|
|
||||||
KeywordArg("a"),
|
|
||||||
KeywordArg("in_shape"),
|
|
||||||
),
|
|
||||||
KeywordArg("b"),
|
|
||||||
),
|
|
||||||
KeywordArg("dtype"),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
CallFunction(
|
|
||||||
aten.expand.default,
|
|
||||||
KeywordArg("x_scale"),
|
|
||||||
Arg(),
|
|
||||||
)
|
|
||||||
if expand_a_scale
|
|
||||||
else KeywordArg("x_scale")
|
|
||||||
),
|
|
||||||
),
|
|
||||||
KeywordArg("w_scale"),
|
|
||||||
),
|
|
||||||
KeywordArg("out_shape_no_bias"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# for torch.compile(dynamic=False)
|
|
||||||
pattern_no_bias_1 = get_pattern_no_bias(expand_a_scale=False)
|
|
||||||
pattern_with_bias_1 = CallFunction(
|
|
||||||
aten.add.Tensor,
|
|
||||||
pattern_no_bias_1,
|
|
||||||
KeywordArg("bias"),
|
|
||||||
)
|
|
||||||
# for torch.compile(dynamic=True)
|
|
||||||
pattern_no_bias_2 = get_pattern_no_bias(expand_a_scale=True)
|
|
||||||
pattern_with_bias_2 = CallFunction(
|
|
||||||
aten.reshape.default,
|
|
||||||
CallFunction(
|
|
||||||
aten.reshape.default,
|
|
||||||
CallFunction(
|
|
||||||
aten.add.Tensor,
|
|
||||||
pattern_no_bias_2,
|
|
||||||
KeywordArg("bias"),
|
|
||||||
),
|
|
||||||
Arg(),
|
|
||||||
),
|
|
||||||
KeywordArg("out_shape_with_bias"),
|
|
||||||
)
|
|
||||||
|
|
||||||
def _validate_pattern(match: Match):
|
|
||||||
if len(match.nodes) not in [6, 7, 10]:
|
|
||||||
return False
|
|
||||||
if len(match.nodes) == 10:
|
|
||||||
# Check the two tailing reshape nodes can be fused
|
|
||||||
if match.nodes[9].args[1] != match.nodes[6].args[1]:
|
|
||||||
return False
|
|
||||||
if len(match.nodes) == 10 or (
|
|
||||||
len(match.nodes) == 7 and match.nodes[6].target is aten.add.Tensor
|
|
||||||
):
|
|
||||||
bias_idx = 7 if len(match.nodes) == 10 else 6
|
|
||||||
# Check bias shape
|
|
||||||
bias_node = match.nodes[bias_idx].args[1]
|
|
||||||
if not isinstance(bias_node, torch.fx.node.Node):
|
|
||||||
return False
|
|
||||||
if len(bias_node.meta.get("tensor_meta").shape) != 1: # type: ignore[union-attr]
|
|
||||||
return False
|
|
||||||
return True
|
|
||||||
|
|
||||||
pattern_to_pass_number = {
|
|
||||||
pattern_no_bias_2: 0,
|
|
||||||
pattern_with_bias_2: 0,
|
|
||||||
pattern_no_bias_1: 1,
|
|
||||||
pattern_with_bias_1: 1,
|
|
||||||
}
|
|
||||||
for pattern, pass_number in pattern_to_pass_number.items():
|
|
||||||
|
|
||||||
@register_freezing_graph_pattern(
|
|
||||||
pattern,
|
|
||||||
extra_check=_validate_pattern,
|
|
||||||
pass_number=pass_number,
|
|
||||||
)
|
|
||||||
def _int_mm_weight_prepack(match: Match, *args, **kwargs):
|
|
||||||
bias = kwargs.get("bias", None)
|
|
||||||
x = kwargs["a"]
|
|
||||||
weight = kwargs["b"]
|
|
||||||
dtype = kwargs["dtype"]
|
|
||||||
x_scale = kwargs["x_scale"]
|
|
||||||
w_scale = kwargs["w_scale"]
|
|
||||||
x_shape = x.meta.get("tensor_meta").shape
|
|
||||||
if has_free_symbols(x_shape):
|
|
||||||
# For dynamic shape case, we can't get activation shape ahead of runtime.
|
|
||||||
x_shape = None
|
|
||||||
|
|
||||||
out_node = match.output_node()
|
|
||||||
with match.graph.inserting_before(out_node):
|
|
||||||
transpose_node = match.graph.call_function(
|
|
||||||
aten.permute.default, args=(weight, [1, 0])
|
|
||||||
)
|
|
||||||
contig_node = match.graph.call_function(
|
|
||||||
aten.contiguous.default, args=(transpose_node,)
|
|
||||||
)
|
|
||||||
packed_weight_inputs = (
|
|
||||||
contig_node,
|
|
||||||
x_shape,
|
|
||||||
)
|
|
||||||
packed_weight_op = torch.ops.onednn.qlinear_prepack
|
|
||||||
prepack_weight_node = match.graph.call_function(
|
|
||||||
packed_weight_op, args=packed_weight_inputs
|
|
||||||
)
|
|
||||||
|
|
||||||
dummy_zp = match.graph.call_function(aten.empty, args=([0],))
|
|
||||||
w_scale = match.graph.call_function(
|
|
||||||
prims.convert_element_type.default, args=(w_scale, torch.float32)
|
|
||||||
)
|
|
||||||
|
|
||||||
x_scale_shape = x_scale.meta.get("tensor_meta").shape
|
|
||||||
x_scale_is_scalar = False
|
|
||||||
if not has_free_symbols(x_scale_shape):
|
|
||||||
prod = 1
|
|
||||||
for d in x_scale_shape:
|
|
||||||
prod *= d
|
|
||||||
x_scale_is_scalar = prod == 1
|
|
||||||
|
|
||||||
new_args: Tuple[Any, ...]
|
|
||||||
if x_scale_is_scalar:
|
|
||||||
# in this case, we can call onednn.qlinear directly
|
|
||||||
new_args = (
|
|
||||||
x,
|
|
||||||
x_scale,
|
|
||||||
dummy_zp, # x_zp
|
|
||||||
prepack_weight_node,
|
|
||||||
w_scale,
|
|
||||||
dummy_zp, # w_zp
|
|
||||||
bias,
|
|
||||||
1.0, # output_scale
|
|
||||||
0, # output_zero_point
|
|
||||||
dtype, # output_dtype
|
|
||||||
"none", # post op name
|
|
||||||
[], # post op args
|
|
||||||
"", # post op algorithm
|
|
||||||
)
|
|
||||||
new_linear_node = match.graph.call_function(
|
|
||||||
torch.ops.onednn.qlinear_pointwise.tensor, args=new_args
|
|
||||||
)
|
|
||||||
out_node.replace_all_uses_with(new_linear_node)
|
|
||||||
new_linear_node.meta.update(out_node.meta)
|
|
||||||
else:
|
|
||||||
# onednn.qlinear does not support per-channel quantization of x
|
|
||||||
# so in this case, we have to apply x scale and add bias ourselves after qlinear
|
|
||||||
x_reshaped = match.graph.call_function(
|
|
||||||
aten.reshape.default, args=(x, kwargs["in_shape"])
|
|
||||||
)
|
|
||||||
new_args = (
|
|
||||||
x_reshaped,
|
|
||||||
1.0, # x_scale
|
|
||||||
0, # x_zp
|
|
||||||
prepack_weight_node,
|
|
||||||
w_scale,
|
|
||||||
dummy_zp, # w_zp
|
|
||||||
None, # bias
|
|
||||||
1.0, # output_scale
|
|
||||||
0, # output_zero_point
|
|
||||||
dtype, # output_dtype
|
|
||||||
"none", # post op name
|
|
||||||
[], # post op args
|
|
||||||
"", # post op algorithm
|
|
||||||
)
|
|
||||||
new_linear_node = match.graph.call_function(
|
|
||||||
torch.ops.onednn.qlinear_pointwise, args=new_args
|
|
||||||
)
|
|
||||||
# apply x scale
|
|
||||||
new_out_node = match.graph.call_function(
|
|
||||||
aten.mul.Tensor, args=(new_linear_node, x_scale)
|
|
||||||
)
|
|
||||||
# Add bias and reshape
|
|
||||||
out_shape = kwargs.get(
|
|
||||||
"out_shape_with_bias", kwargs["out_shape_no_bias"]
|
|
||||||
)
|
|
||||||
if bias is not None:
|
|
||||||
new_out_node = match.graph.call_function(
|
|
||||||
aten.add.Tensor, args=(new_out_node, bias)
|
|
||||||
)
|
|
||||||
new_out_node = match.graph.call_function(
|
|
||||||
aten.reshape.default,
|
|
||||||
args=(new_out_node, out_shape),
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
new_out_node = match.graph.call_function(
|
|
||||||
aten.reshape.default,
|
|
||||||
args=(new_out_node, out_shape),
|
|
||||||
)
|
|
||||||
out_node.replace_all_uses_with(new_out_node)
|
|
||||||
new_out_node.meta.update(out_node.meta)
|
|
||||||
for node in reversed(match.nodes):
|
|
||||||
match.graph.erase_node(node)
|
|
||||||
counters["inductor"]["qlinear_weight_prepack_matcher_count"] += 1
|
|
||||||
counters["inductor"]["qlinear_weight_prepack_matcher_nodes"] += len(
|
|
||||||
match.nodes
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@functools.lru_cache(None)
|
@functools.lru_cache(None)
|
||||||
def _register_quantization_weight_pack_pass():
|
def _register_quantization_weight_pack_pass():
|
||||||
# Step 1: Dequant promotion for int8-mixed-fp32/bf16
|
# Step 1: Dequant promotion for int8-mixed-fp32/bf16
|
||||||
|
|
@ -2760,9 +2540,6 @@ def _register_quantization_weight_pack_pass():
|
||||||
# Step 3: QLinear weight prepack
|
# Step 3: QLinear weight prepack
|
||||||
_register_qlinear_weight_prepack()
|
_register_qlinear_weight_prepack()
|
||||||
|
|
||||||
# Step 4: weight prepack for SmoothQuant from Torchao
|
|
||||||
_register_smooth_quant_int_mm_pattern()
|
|
||||||
|
|
||||||
|
|
||||||
def quant_lift_up(graph_module: torch.fx.GraphModule):
|
def quant_lift_up(graph_module: torch.fx.GraphModule):
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user