Fix formatting issues for onnx

Summary:
These are formatting changes automatically done with `arc f` to deal with issues landing the onnx changes in this stack

{F703786210}

Test Plan: yeah_sandcastle

Reviewed By: malfet

Differential Revision: D34402111

fbshipit-source-id: 06eb352d1e4f8b1439a580148fe1060fb5c9e102
(cherry picked from commit 7bbf29ed8e)
This commit is contained in:
Eli Uriegas 2022-02-22 14:39:41 -08:00 committed by PyTorch MergeBot
parent cc2aad2ef2
commit 4267e6e55e

View File

@ -702,7 +702,8 @@ void SetShapeValueFromListConstructNode(Node* lc_node) {
}
}
std::vector<::c10::ShapeSymbol> Broadcast(const std::vector<::c10::ShapeSymbol> &input_shape_value_0,
std::vector<::c10::ShapeSymbol> Broadcast(
const std::vector<::c10::ShapeSymbol>& input_shape_value_0,
const std::vector<::c10::ShapeSymbol>& input_shape_value_1) {
size_t rank_0 = input_shape_value_0.size();
size_t rank_1 = input_shape_value_1.size();
@ -714,10 +715,8 @@ std::vector<::c10::ShapeSymbol> Broadcast(const std::vector<::c10::ShapeSymbol>
final_shape.emplace_back(::c10::ShapeSymbol::newSymbol());
}
for (auto idx = 0; idx < rank_min; idx++) {
const c10::ShapeSymbol& ss_shape_0 =
input_shape_value_0[rank_0 - 1 - idx];
const c10::ShapeSymbol& ss_shape_1 =
input_shape_value_1[rank_1 - 1 - idx];
const c10::ShapeSymbol& ss_shape_0 = input_shape_value_0[rank_0 - 1 - idx];
const c10::ShapeSymbol& ss_shape_1 = input_shape_value_1[rank_1 - 1 - idx];
bool is_static_0 = ss_shape_0.is_static();
bool is_static_1 = ss_shape_1.is_static();
if (is_static_0 && is_static_1) {
@ -878,14 +877,17 @@ void ProcessMatMulNode(Node* n) {
is_rank_1_1 = true;
}
// Per https://pytorch.org/docs/stable/generated/torch.matmul.html
// the broadcasting logic only applies to the batch dimensions, and not the matrix dimensions
// so we remove the matrix dimensions which are the last 2 dimensions before broadcasting
// the broadcasting logic only applies to the batch dimensions, and not the
// matrix dimensions so we remove the matrix dimensions which are the last 2
// dimensions before broadcasting
auto final_shape = Broadcast(
std::vector<::c10::ShapeSymbol>(input_shape_value_0.begin(), input_shape_value_0.end() - 2),
std::vector<::c10::ShapeSymbol>(input_shape_value_1.begin(), input_shape_value_1.end() - 2)
);
// add the last 2 dimensions back, unless they do not exist in the first place and inserted by this function
// Then apply [n,k]X[k,m]=[n,m], where n=input_shape_value_0[rank_0 - 2], m=input_shape_value_1[rank_1 - 1]
std::vector<::c10::ShapeSymbol>(
input_shape_value_0.begin(), input_shape_value_0.end() - 2),
std::vector<::c10::ShapeSymbol>(
input_shape_value_1.begin(), input_shape_value_1.end() - 2));
// add the last 2 dimensions back, unless they do not exist in the first
// place and inserted by this function Then apply [n,k]X[k,m]=[n,m], where
// n=input_shape_value_0[rank_0 - 2], m=input_shape_value_1[rank_1 - 1]
if (!is_rank_0_1) {
final_shape.emplace_back(input_shape_value_0[rank_0 - 2]);
}