mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
29 lines
687 B
C++
29 lines
687 B
C++
#pragma once
|
|
|
|
#include "torch/csrc/python_headers.h"
|
|
#include <memory>
|
|
#include <vector>
|
|
#include <ATen/ATen.h>
|
|
|
|
namespace torch {
|
|
|
|
struct TupleParser {
|
|
TupleParser(PyObject* args, int num_args=-1);
|
|
|
|
void parse(bool& x, const std::string& param_name);
|
|
void parse(int& x, const std::string& param_name);
|
|
void parse(double& x, const std::string& param_name);
|
|
void parse(std::vector<int>& x, const std::string& param_name);
|
|
void parse(std::string& x, const std::string& param_name);
|
|
|
|
protected:
|
|
std::runtime_error invalid_type(const std::string& expected, const std::string& param_name);
|
|
PyObject* next_arg();
|
|
|
|
private:
|
|
PyObject* args;
|
|
int idx;
|
|
};
|
|
|
|
} // namespace torch
|