mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44840 Pull Request resolved: https://github.com/pytorch/pytorch/pull/44762 Move CostInferenceForFCGradient to fc_inference.cc/h to be used in multiple .cc files. Test Plan: CI Reviewed By: qizzzh Differential Revision: D23714877 fbshipit-source-id: d27f33e270a93b0e053f2af592dc4a24e35526cd
29 lines
775 B
C++
29 lines
775 B
C++
#pragma once
|
|
#include "caffe2/core/context.h"
|
|
#include "caffe2/core/operator.h"
|
|
#include "caffe2/utils/conversions.h"
|
|
#include "caffe2/utils/math.h"
|
|
|
|
namespace caffe2 {
|
|
std::vector<TensorShape> FCShapeInference(
|
|
const OperatorDef& def,
|
|
const std::vector<TensorShape>& in,
|
|
bool pretransposed_weight);
|
|
|
|
OpSchema::Cost CostInferenceForFC(
|
|
const OperatorDef& def,
|
|
const std::vector<TensorShape>& in,
|
|
bool pretransposed_weight = false);
|
|
|
|
std::vector<TensorShape> FCGradientShapeInference(
|
|
const OperatorDef& def,
|
|
const std::vector<TensorShape>& in,
|
|
bool pretransposed_weight);
|
|
|
|
OpSchema::Cost CostInferenceForFCGradient(
|
|
const OperatorDef& def,
|
|
const std::vector<TensorShape>& in,
|
|
bool pretransposed_weight);
|
|
|
|
} // namespace caffe2
|