mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
Fix formatting issues for onnx
Summary:
These are formatting changes automatically done with `arc f` to deal with issues landing the onnx changes in this stack
{F703786210}
Test Plan: yeah_sandcastle
Reviewed By: malfet
Differential Revision: D34402111
fbshipit-source-id: 06eb352d1e4f8b1439a580148fe1060fb5c9e102
(cherry picked from commit 7bbf29ed8e)
This commit is contained in:
parent
cc2aad2ef2
commit
4267e6e55e
|
|
@ -702,7 +702,8 @@ void SetShapeValueFromListConstructNode(Node* lc_node) {
|
|||
}
|
||||
}
|
||||
|
||||
std::vector<::c10::ShapeSymbol> Broadcast(const std::vector<::c10::ShapeSymbol> &input_shape_value_0,
|
||||
std::vector<::c10::ShapeSymbol> Broadcast(
|
||||
const std::vector<::c10::ShapeSymbol>& input_shape_value_0,
|
||||
const std::vector<::c10::ShapeSymbol>& input_shape_value_1) {
|
||||
size_t rank_0 = input_shape_value_0.size();
|
||||
size_t rank_1 = input_shape_value_1.size();
|
||||
|
|
@ -714,10 +715,8 @@ std::vector<::c10::ShapeSymbol> Broadcast(const std::vector<::c10::ShapeSymbol>
|
|||
final_shape.emplace_back(::c10::ShapeSymbol::newSymbol());
|
||||
}
|
||||
for (auto idx = 0; idx < rank_min; idx++) {
|
||||
const c10::ShapeSymbol& ss_shape_0 =
|
||||
input_shape_value_0[rank_0 - 1 - idx];
|
||||
const c10::ShapeSymbol& ss_shape_1 =
|
||||
input_shape_value_1[rank_1 - 1 - idx];
|
||||
const c10::ShapeSymbol& ss_shape_0 = input_shape_value_0[rank_0 - 1 - idx];
|
||||
const c10::ShapeSymbol& ss_shape_1 = input_shape_value_1[rank_1 - 1 - idx];
|
||||
bool is_static_0 = ss_shape_0.is_static();
|
||||
bool is_static_1 = ss_shape_1.is_static();
|
||||
if (is_static_0 && is_static_1) {
|
||||
|
|
@ -878,14 +877,17 @@ void ProcessMatMulNode(Node* n) {
|
|||
is_rank_1_1 = true;
|
||||
}
|
||||
// Per https://pytorch.org/docs/stable/generated/torch.matmul.html
|
||||
// the broadcasting logic only applies to the batch dimensions, and not the matrix dimensions
|
||||
// so we remove the matrix dimensions which are the last 2 dimensions before broadcasting
|
||||
// the broadcasting logic only applies to the batch dimensions, and not the
|
||||
// matrix dimensions so we remove the matrix dimensions which are the last 2
|
||||
// dimensions before broadcasting
|
||||
auto final_shape = Broadcast(
|
||||
std::vector<::c10::ShapeSymbol>(input_shape_value_0.begin(), input_shape_value_0.end() - 2),
|
||||
std::vector<::c10::ShapeSymbol>(input_shape_value_1.begin(), input_shape_value_1.end() - 2)
|
||||
);
|
||||
// add the last 2 dimensions back, unless they do not exist in the first place and inserted by this function
|
||||
// Then apply [n,k]X[k,m]=[n,m], where n=input_shape_value_0[rank_0 - 2], m=input_shape_value_1[rank_1 - 1]
|
||||
std::vector<::c10::ShapeSymbol>(
|
||||
input_shape_value_0.begin(), input_shape_value_0.end() - 2),
|
||||
std::vector<::c10::ShapeSymbol>(
|
||||
input_shape_value_1.begin(), input_shape_value_1.end() - 2));
|
||||
// add the last 2 dimensions back, unless they do not exist in the first
|
||||
// place and inserted by this function Then apply [n,k]X[k,m]=[n,m], where
|
||||
// n=input_shape_value_0[rank_0 - 2], m=input_shape_value_1[rank_1 - 1]
|
||||
if (!is_rank_0_1) {
|
||||
final_shape.emplace_back(input_shape_value_0[rank_0 - 2]);
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user