#include #include #include #include #include #include namespace torch { namespace optim { namespace detail { OptimizerBase::OptimizerBase(std::vector parameters) : parameters_(std::move(parameters)) {} OptimizerBase::OptimizerBase(const ParameterCursor& cursor) { add_parameters(cursor); } void OptimizerBase::add_parameters(const std::vector& parameters) { parameters_.insert(parameters_.end(), parameters.begin(), parameters.end()); } void OptimizerBase::add_parameters(const ParameterCursor& cursor) { std::vector tensors(cursor.size()); cursor.map(tensors.begin(), [](const Tensor& tensor) { return tensor; }); add_parameters(tensors); } void OptimizerBase::zero_grad() { for (auto& parameter : parameters_) { if (parameter.grad().defined()) { parameter.grad().detach_(); parameter.grad().zero_(); } } } const std::vector& OptimizerBase::parameters() const noexcept { return parameters_; } std::vector& OptimizerBase::parameters() noexcept { return parameters_; } size_t OptimizerBase::size() const noexcept { return parameters_.size(); } void OptimizerBase::save(serialize::OutputArchive& archive) const {} void OptimizerBase::load(serialize::InputArchive& archive) {} } // namespace detail } // namespace optim } // namespace torch