mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
[3/N] Add -Wdeprecated and related fixes (#109698)
This PR follows #108626. Hopefully we can enable the warning in the next PR. Pull Request resolved: https://github.com/pytorch/pytorch/pull/109698 Approved by: https://github.com/Skylion007, https://github.com/ezyang
This commit is contained in:
parent
836ba6430a
commit
c31fcdaa4f
|
|
@ -15,12 +15,14 @@ class VulkanPackedContext {
|
||||||
|
|
||||||
public:
|
public:
|
||||||
VulkanPackedContext() : packed_{c10::AnyType::get()} {}
|
VulkanPackedContext() : packed_{c10::AnyType::get()} {}
|
||||||
|
VulkanPackedContext(const VulkanPackedContext&) = default;
|
||||||
|
VulkanPackedContext(VulkanPackedContext&&) = default;
|
||||||
|
|
||||||
inline const c10::IValue get_val(int64_t i) const {
|
inline const c10::IValue get_val(int64_t i) const {
|
||||||
return packed_.get(i);
|
return packed_.get(i);
|
||||||
}
|
}
|
||||||
|
|
||||||
inline void set_val(int64_t i, c10::IValue val) const {
|
inline void set_val(int64_t i, const c10::IValue& val) const {
|
||||||
return packed_.set(i, val);
|
return packed_.set(i, val);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -62,6 +62,14 @@ namespace impl {
|
||||||
* those uses will be devirtualized.
|
* those uses will be devirtualized.
|
||||||
*/
|
*/
|
||||||
struct C10_API DeviceGuardImplInterface {
|
struct C10_API DeviceGuardImplInterface {
|
||||||
|
DeviceGuardImplInterface() = default;
|
||||||
|
DeviceGuardImplInterface(const DeviceGuardImplInterface&) = default;
|
||||||
|
DeviceGuardImplInterface& operator=(const DeviceGuardImplInterface&) =
|
||||||
|
default;
|
||||||
|
DeviceGuardImplInterface(DeviceGuardImplInterface&&) noexcept = default;
|
||||||
|
DeviceGuardImplInterface& operator=(DeviceGuardImplInterface&&) noexcept =
|
||||||
|
default;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Return the type of device managed by this guard implementation.
|
* Return the type of device managed by this guard implementation.
|
||||||
*/
|
*/
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,11 @@ class TORCH_API IMethod {
|
||||||
using IValueList = std::vector<c10::IValue>;
|
using IValueList = std::vector<c10::IValue>;
|
||||||
using IValueMap = std::unordered_map<std::string, at::IValue>;
|
using IValueMap = std::unordered_map<std::string, at::IValue>;
|
||||||
|
|
||||||
|
IMethod() = default;
|
||||||
|
IMethod(const IMethod&) = default;
|
||||||
|
IMethod& operator=(const IMethod&) = default;
|
||||||
|
IMethod(IMethod&&) noexcept = default;
|
||||||
|
IMethod& operator=(IMethod&&) noexcept = default;
|
||||||
virtual ~IMethod() = default;
|
virtual ~IMethod() = default;
|
||||||
|
|
||||||
virtual c10::IValue operator()(
|
virtual c10::IValue operator()(
|
||||||
|
|
|
||||||
|
|
@ -21,7 +21,7 @@ namespace nn {
|
||||||
/// because then storing a module would always require templatizing it.
|
/// because then storing a module would always require templatizing it.
|
||||||
template <typename Derived>
|
template <typename Derived>
|
||||||
// NOLINTNEXTLINE(bugprone-exception-escape)
|
// NOLINTNEXTLINE(bugprone-exception-escape)
|
||||||
class Cloneable : public virtual Module {
|
class Cloneable : public Module {
|
||||||
public:
|
public:
|
||||||
using Module::Module;
|
using Module::Module;
|
||||||
|
|
||||||
|
|
@ -90,7 +90,7 @@ class Cloneable : public virtual Module {
|
||||||
clone != nullptr,
|
clone != nullptr,
|
||||||
"Attempted to clone submodule, but it is of a "
|
"Attempted to clone submodule, but it is of a "
|
||||||
"different type than the submodule it was to be cloned into");
|
"different type than the submodule it was to be cloned into");
|
||||||
static_cast<Derived&>(*this) = std::move(*clone);
|
static_cast<Derived&>(*this) = *clone;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -81,6 +81,10 @@ class TORCH_API Module : public std::enable_shared_from_this<Module> {
|
||||||
/// The name of the submodule is inferred via RTTI (if possible) the first
|
/// The name of the submodule is inferred via RTTI (if possible) the first
|
||||||
/// time `.name()` is invoked.
|
/// time `.name()` is invoked.
|
||||||
Module();
|
Module();
|
||||||
|
Module(const Module&) = default;
|
||||||
|
Module& operator=(const Module&) = default;
|
||||||
|
Module(Module&&) noexcept = default;
|
||||||
|
Module& operator=(Module&&) noexcept = default;
|
||||||
|
|
||||||
virtual ~Module() = default;
|
virtual ~Module() = default;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -137,7 +137,23 @@ class BatchNormImplBase : public NormImplBase<D, Derived, BatchNormOptions> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Pretty prints the `BatchNorm{1,2,3}d` module into the given `stream`.
|
/// Pretty prints the `BatchNorm{1,2,3}d` module into the given `stream`.
|
||||||
void pretty_print(std::ostream& stream) const override;
|
void pretty_print(std::ostream& stream) const override {
|
||||||
|
stream << std::boolalpha << "torch::nn::BatchNorm" << D << "d("
|
||||||
|
<< this->options.num_features() << ", "
|
||||||
|
<< "eps=" << this->options.eps() << ", "
|
||||||
|
<< "momentum=";
|
||||||
|
|
||||||
|
if (this->options.momentum().has_value()) {
|
||||||
|
stream << this->options.momentum().value();
|
||||||
|
} else {
|
||||||
|
stream << "None";
|
||||||
|
}
|
||||||
|
|
||||||
|
stream << ", "
|
||||||
|
<< "affine=" << this->options.affine() << ", "
|
||||||
|
<< "track_running_stats=" << this->options.track_running_stats()
|
||||||
|
<< ")";
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm1d
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ BatchNorm1d
|
||||||
|
|
|
||||||
|
|
@ -91,6 +91,8 @@ class AnyValue {
|
||||||
struct Placeholder {
|
struct Placeholder {
|
||||||
explicit Placeholder(const std::type_info& type_info_) noexcept
|
explicit Placeholder(const std::type_info& type_info_) noexcept
|
||||||
: type_info(type_info_) {}
|
: type_info(type_info_) {}
|
||||||
|
Placeholder(const Placeholder&) = default;
|
||||||
|
Placeholder(Placeholder&&) = default;
|
||||||
virtual ~Placeholder() = default;
|
virtual ~Placeholder() = default;
|
||||||
virtual std::unique_ptr<Placeholder> clone() const {
|
virtual std::unique_ptr<Placeholder> clone() const {
|
||||||
TORCH_CHECK(false, "clone() should only be called on `AnyValue::Holder`");
|
TORCH_CHECK(false, "clone() should only be called on `AnyValue::Holder`");
|
||||||
|
|
|
||||||
|
|
@ -45,7 +45,15 @@ class InstanceNormImpl
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Pretty prints the `InstanceNorm{1,2,3}d` module into the given `stream`.
|
/// Pretty prints the `InstanceNorm{1,2,3}d` module into the given `stream`.
|
||||||
void pretty_print(std::ostream& stream) const override;
|
void pretty_print(std::ostream& stream) const override {
|
||||||
|
stream << std::boolalpha << "torch::nn::InstanceNorm" << D << "d("
|
||||||
|
<< this->options.num_features() << ", "
|
||||||
|
<< "eps=" << this->options.eps() << ", "
|
||||||
|
<< "momentum=" << this->options.momentum() << ", "
|
||||||
|
<< "affine=" << this->options.affine() << ", "
|
||||||
|
<< "track_running_stats=" << this->options.track_running_stats()
|
||||||
|
<< ")";
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm1d
|
// ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ InstanceNorm1d
|
||||||
|
|
|
||||||
|
|
@ -164,8 +164,6 @@ struct TORCH_API RNNCellOptionsBase {
|
||||||
int64_t hidden_size,
|
int64_t hidden_size,
|
||||||
bool bias,
|
bool bias,
|
||||||
int64_t num_chunks);
|
int64_t num_chunks);
|
||||||
virtual ~RNNCellOptionsBase() = default;
|
|
||||||
|
|
||||||
TORCH_ARG(int64_t, input_size);
|
TORCH_ARG(int64_t, input_size);
|
||||||
TORCH_ARG(int64_t, hidden_size);
|
TORCH_ARG(int64_t, hidden_size);
|
||||||
TORCH_ARG(bool, bias);
|
TORCH_ARG(bool, bias);
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,6 @@ struct TORCH_API AdagradOptions
|
||||||
TORCH_API friend bool operator==(
|
TORCH_API friend bool operator==(
|
||||||
const AdagradOptions& lhs,
|
const AdagradOptions& lhs,
|
||||||
const AdagradOptions& rhs);
|
const AdagradOptions& rhs);
|
||||||
~AdagradOptions() override = default;
|
|
||||||
double get_lr() const override;
|
double get_lr() const override;
|
||||||
void set_lr(const double lr) override;
|
void set_lr(const double lr) override;
|
||||||
};
|
};
|
||||||
|
|
@ -45,12 +44,16 @@ struct TORCH_API AdagradParamState
|
||||||
TORCH_ARG(int64_t, step) = 0;
|
TORCH_ARG(int64_t, step) = 0;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
|
AdagradParamState() = default;
|
||||||
|
AdagradParamState(const AdagradParamState&) = default;
|
||||||
|
AdagradParamState& operator=(const AdagradParamState&) = default;
|
||||||
|
AdagradParamState(AdagradParamState&&) noexcept = default;
|
||||||
|
AdagradParamState& operator=(AdagradParamState&&) noexcept = default;
|
||||||
void serialize(torch::serialize::InputArchive& archive) override;
|
void serialize(torch::serialize::InputArchive& archive) override;
|
||||||
void serialize(torch::serialize::OutputArchive& archive) const override;
|
void serialize(torch::serialize::OutputArchive& archive) const override;
|
||||||
TORCH_API friend bool operator==(
|
TORCH_API friend bool operator==(
|
||||||
const AdagradParamState& lhs,
|
const AdagradParamState& lhs,
|
||||||
const AdagradParamState& rhs);
|
const AdagradParamState& rhs);
|
||||||
~AdagradParamState() override = default;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class TORCH_API Adagrad : public Optimizer {
|
class TORCH_API Adagrad : public Optimizer {
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,6 @@ struct TORCH_API AdamOptions : public OptimizerCloneableOptions<AdamOptions> {
|
||||||
TORCH_API friend bool operator==(
|
TORCH_API friend bool operator==(
|
||||||
const AdamOptions& lhs,
|
const AdamOptions& lhs,
|
||||||
const AdamOptions& rhs);
|
const AdamOptions& rhs);
|
||||||
~AdamOptions() override = default;
|
|
||||||
double get_lr() const override;
|
double get_lr() const override;
|
||||||
void set_lr(const double lr) override;
|
void set_lr(const double lr) override;
|
||||||
};
|
};
|
||||||
|
|
@ -50,7 +49,6 @@ struct TORCH_API AdamParamState
|
||||||
TORCH_API friend bool operator==(
|
TORCH_API friend bool operator==(
|
||||||
const AdamParamState& lhs,
|
const AdamParamState& lhs,
|
||||||
const AdamParamState& rhs);
|
const AdamParamState& rhs);
|
||||||
~AdamParamState() override = default;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class TORCH_API Adam : public Optimizer {
|
class TORCH_API Adam : public Optimizer {
|
||||||
|
|
|
||||||
|
|
@ -32,7 +32,6 @@ struct TORCH_API AdamWOptions : public OptimizerCloneableOptions<AdamWOptions> {
|
||||||
TORCH_API friend bool operator==(
|
TORCH_API friend bool operator==(
|
||||||
const AdamWOptions& lhs,
|
const AdamWOptions& lhs,
|
||||||
const AdamWOptions& rhs);
|
const AdamWOptions& rhs);
|
||||||
~AdamWOptions() override = default;
|
|
||||||
double get_lr() const override;
|
double get_lr() const override;
|
||||||
void set_lr(const double lr) override;
|
void set_lr(const double lr) override;
|
||||||
};
|
};
|
||||||
|
|
@ -50,7 +49,6 @@ struct TORCH_API AdamWParamState
|
||||||
TORCH_API friend bool operator==(
|
TORCH_API friend bool operator==(
|
||||||
const AdamWParamState& lhs,
|
const AdamWParamState& lhs,
|
||||||
const AdamWParamState& rhs);
|
const AdamWParamState& rhs);
|
||||||
~AdamWParamState() override = default;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class TORCH_API AdamW : public Optimizer {
|
class TORCH_API AdamW : public Optimizer {
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,6 @@ struct TORCH_API LBFGSOptions : public OptimizerCloneableOptions<LBFGSOptions> {
|
||||||
TORCH_API friend bool operator==(
|
TORCH_API friend bool operator==(
|
||||||
const LBFGSOptions& lhs,
|
const LBFGSOptions& lhs,
|
||||||
const LBFGSOptions& rhs);
|
const LBFGSOptions& rhs);
|
||||||
~LBFGSOptions() override = default;
|
|
||||||
double get_lr() const override;
|
double get_lr() const override;
|
||||||
void set_lr(const double lr) override;
|
void set_lr(const double lr) override;
|
||||||
};
|
};
|
||||||
|
|
@ -54,7 +53,6 @@ struct TORCH_API LBFGSParamState
|
||||||
TORCH_API friend bool operator==(
|
TORCH_API friend bool operator==(
|
||||||
const LBFGSParamState& lhs,
|
const LBFGSParamState& lhs,
|
||||||
const LBFGSParamState& rhs);
|
const LBFGSParamState& rhs);
|
||||||
~LBFGSParamState() override = default;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class TORCH_API LBFGS : public Optimizer {
|
class TORCH_API LBFGS : public Optimizer {
|
||||||
|
|
|
||||||
|
|
@ -34,6 +34,11 @@ namespace optim {
|
||||||
|
|
||||||
class TORCH_API OptimizerParamState {
|
class TORCH_API OptimizerParamState {
|
||||||
public:
|
public:
|
||||||
|
OptimizerParamState() = default;
|
||||||
|
OptimizerParamState(const OptimizerParamState&) = default;
|
||||||
|
OptimizerParamState& operator=(const OptimizerParamState&) = default;
|
||||||
|
OptimizerParamState(OptimizerParamState&&) noexcept = default;
|
||||||
|
OptimizerParamState& operator=(OptimizerParamState&&) noexcept = default;
|
||||||
virtual std::unique_ptr<OptimizerParamState> clone() const;
|
virtual std::unique_ptr<OptimizerParamState> clone() const;
|
||||||
virtual void serialize(torch::serialize::InputArchive& archive);
|
virtual void serialize(torch::serialize::InputArchive& archive);
|
||||||
virtual void serialize(torch::serialize::OutputArchive& archive) const;
|
virtual void serialize(torch::serialize::OutputArchive& archive) const;
|
||||||
|
|
@ -49,6 +54,11 @@ class OptimizerCloneableParamState : public OptimizerParamState {
|
||||||
|
|
||||||
class TORCH_API OptimizerOptions {
|
class TORCH_API OptimizerOptions {
|
||||||
public:
|
public:
|
||||||
|
OptimizerOptions() = default;
|
||||||
|
OptimizerOptions(const OptimizerOptions&) = default;
|
||||||
|
OptimizerOptions& operator=(const OptimizerOptions&) = default;
|
||||||
|
OptimizerOptions(OptimizerOptions&&) noexcept = default;
|
||||||
|
OptimizerOptions& operator=(OptimizerOptions&&) noexcept = default;
|
||||||
virtual std::unique_ptr<OptimizerOptions> clone() const;
|
virtual std::unique_ptr<OptimizerOptions> clone() const;
|
||||||
virtual void serialize(torch::serialize::InputArchive& archive);
|
virtual void serialize(torch::serialize::InputArchive& archive);
|
||||||
virtual void serialize(torch::serialize::OutputArchive& archive) const;
|
virtual void serialize(torch::serialize::OutputArchive& archive) const;
|
||||||
|
|
|
||||||
|
|
@ -37,7 +37,6 @@ struct TORCH_API RMSpropOptions
|
||||||
TORCH_API friend bool operator==(
|
TORCH_API friend bool operator==(
|
||||||
const RMSpropOptions& lhs,
|
const RMSpropOptions& lhs,
|
||||||
const RMSpropOptions& rhs);
|
const RMSpropOptions& rhs);
|
||||||
~RMSpropOptions() override = default;
|
|
||||||
double get_lr() const override;
|
double get_lr() const override;
|
||||||
void set_lr(const double lr) override;
|
void set_lr(const double lr) override;
|
||||||
};
|
};
|
||||||
|
|
@ -55,7 +54,6 @@ struct TORCH_API RMSpropParamState
|
||||||
TORCH_API friend bool operator==(
|
TORCH_API friend bool operator==(
|
||||||
const RMSpropParamState& lhs,
|
const RMSpropParamState& lhs,
|
||||||
const RMSpropParamState& rhs);
|
const RMSpropParamState& rhs);
|
||||||
~RMSpropParamState() override = default;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class TORCH_API RMSprop : public Optimizer {
|
class TORCH_API RMSprop : public Optimizer {
|
||||||
|
|
|
||||||
|
|
@ -34,7 +34,6 @@ struct TORCH_API SGDOptions : public OptimizerCloneableOptions<SGDOptions> {
|
||||||
TORCH_API friend bool operator==(
|
TORCH_API friend bool operator==(
|
||||||
const SGDOptions& lhs,
|
const SGDOptions& lhs,
|
||||||
const SGDOptions& rhs);
|
const SGDOptions& rhs);
|
||||||
~SGDOptions() override = default;
|
|
||||||
double get_lr() const override;
|
double get_lr() const override;
|
||||||
void set_lr(const double lr) override;
|
void set_lr(const double lr) override;
|
||||||
};
|
};
|
||||||
|
|
@ -49,7 +48,6 @@ struct TORCH_API SGDParamState
|
||||||
TORCH_API friend bool operator==(
|
TORCH_API friend bool operator==(
|
||||||
const SGDParamState& lhs,
|
const SGDParamState& lhs,
|
||||||
const SGDParamState& rhs);
|
const SGDParamState& rhs);
|
||||||
~SGDParamState() override = default;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
class TORCH_API SGD : public Optimizer {
|
class TORCH_API SGD : public Optimizer {
|
||||||
|
|
|
||||||
|
|
@ -432,7 +432,7 @@ void ThresholdImpl::pretty_print(std::ostream& stream) const {
|
||||||
|
|
||||||
MultiheadAttentionImpl::MultiheadAttentionImpl(
|
MultiheadAttentionImpl::MultiheadAttentionImpl(
|
||||||
const MultiheadAttentionOptions& options_)
|
const MultiheadAttentionOptions& options_)
|
||||||
: Module("torch::nn::MultiheadAttention"), options(options_) {
|
: Cloneable("torch::nn::MultiheadAttention"), options(options_) {
|
||||||
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
|
// NOLINTNEXTLINE(clang-analyzer-optin.cplusplus.VirtualCall)
|
||||||
reset();
|
reset();
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -14,25 +14,6 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace nn {
|
namespace nn {
|
||||||
|
|
||||||
template <size_t D, typename Derived>
|
|
||||||
void BatchNormImplBase<D, Derived>::pretty_print(std::ostream& stream) const {
|
|
||||||
stream << std::boolalpha << "torch::nn::BatchNorm" << D << "d("
|
|
||||||
<< this->options.num_features() << ", "
|
|
||||||
<< "eps=" << this->options.eps() << ", "
|
|
||||||
<< "momentum=";
|
|
||||||
|
|
||||||
if (this->options.momentum().has_value()) {
|
|
||||||
stream << this->options.momentum().value();
|
|
||||||
} else {
|
|
||||||
stream << "None";
|
|
||||||
}
|
|
||||||
|
|
||||||
stream << ", "
|
|
||||||
<< "affine=" << this->options.affine() << ", "
|
|
||||||
<< "track_running_stats=" << this->options.track_running_stats()
|
|
||||||
<< ")";
|
|
||||||
}
|
|
||||||
|
|
||||||
void BatchNorm1dImpl::_check_input_dim(const Tensor& input) {
|
void BatchNorm1dImpl::_check_input_dim(const Tensor& input) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
input.dim() == 2 || input.dim() == 3,
|
input.dim() == 2 || input.dim() == 3,
|
||||||
|
|
|
||||||
|
|
@ -4,17 +4,6 @@
|
||||||
namespace torch {
|
namespace torch {
|
||||||
namespace nn {
|
namespace nn {
|
||||||
|
|
||||||
template <size_t D, typename Derived>
|
|
||||||
void InstanceNormImpl<D, Derived>::pretty_print(std::ostream& stream) const {
|
|
||||||
stream << std::boolalpha << "torch::nn::InstanceNorm" << D << "d("
|
|
||||||
<< this->options.num_features() << ", "
|
|
||||||
<< "eps=" << this->options.eps() << ", "
|
|
||||||
<< "momentum=" << this->options.momentum() << ", "
|
|
||||||
<< "affine=" << this->options.affine() << ", "
|
|
||||||
<< "track_running_stats=" << this->options.track_running_stats()
|
|
||||||
<< ")";
|
|
||||||
}
|
|
||||||
|
|
||||||
void InstanceNorm1dImpl::_check_input_dim(const Tensor& input) {
|
void InstanceNorm1dImpl::_check_input_dim(const Tensor& input) {
|
||||||
if (input.dim() != 3 && input.dim() != 2) {
|
if (input.dim() != 3 && input.dim() != 2) {
|
||||||
TORCH_CHECK(
|
TORCH_CHECK(
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,9 @@ class TORCH_API Store : public torch::CustomClassHolder {
|
||||||
explicit Store(const std::chrono::milliseconds& timeout)
|
explicit Store(const std::chrono::milliseconds& timeout)
|
||||||
: timeout_(timeout) {}
|
: timeout_(timeout) {}
|
||||||
|
|
||||||
|
Store(const Store&) = default;
|
||||||
|
Store(Store&&) noexcept = default;
|
||||||
|
|
||||||
~Store() override = default;
|
~Store() override = default;
|
||||||
|
|
||||||
void set(const std::string& key, const std::string& value);
|
void set(const std::string& key, const std::string& value);
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,7 @@ class PYBIND11_EXPORT PyRRef {
|
||||||
// for more explanations.
|
// for more explanations.
|
||||||
explicit PyRRef(const py::object& value, const py::object& type_hint);
|
explicit PyRRef(const py::object& value, const py::object& type_hint);
|
||||||
explicit PyRRef(c10::intrusive_ptr<RRef> rref);
|
explicit PyRRef(c10::intrusive_ptr<RRef> rref);
|
||||||
|
PyRRef(const PyRRef&) = default;
|
||||||
~PyRRef();
|
~PyRRef();
|
||||||
|
|
||||||
bool isOwner() const;
|
bool isOwner() const;
|
||||||
|
|
|
||||||
|
|
@ -142,8 +142,7 @@ std::unique_ptr<RRefUserDelete> RRefUserDelete::fromMessage(
|
||||||
const Message& message) {
|
const Message& message) {
|
||||||
auto pair =
|
auto pair =
|
||||||
ForkMessageBase::fromMessage(message, MessageType::RREF_USER_DELETE);
|
ForkMessageBase::fromMessage(message, MessageType::RREF_USER_DELETE);
|
||||||
return std::make_unique<RRefUserDelete>(
|
return std::make_unique<RRefUserDelete>(pair.first, pair.second);
|
||||||
RRefUserDelete(pair.first, pair.second));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<RemoteRet> RemoteRet::fromMessage(const Message& message) {
|
std::unique_ptr<RemoteRet> RemoteRet::fromMessage(const Message& message) {
|
||||||
|
|
|
||||||
|
|
@ -161,11 +161,9 @@ C10_DEFINE_REGISTRY_WITHOUT_WARNING(
|
||||||
|
|
||||||
const std::string& TensorPipeAgent::guessAddress() {
|
const std::string& TensorPipeAgent::guessAddress() {
|
||||||
static const std::string uvAddress = []() {
|
static const std::string uvAddress = []() {
|
||||||
tensorpipe::Error error;
|
|
||||||
std::string result;
|
|
||||||
char* ifnameEnv = std::getenv(kSocketIfnameEnvVar.c_str());
|
char* ifnameEnv = std::getenv(kSocketIfnameEnvVar.c_str());
|
||||||
if (ifnameEnv != nullptr) {
|
if (ifnameEnv != nullptr) {
|
||||||
std::tie(error, result) =
|
auto [error, result] =
|
||||||
tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv);
|
tensorpipe::transport::uv::lookupAddrForIface(ifnameEnv);
|
||||||
if (error) {
|
if (error) {
|
||||||
LOG(WARNING) << "Failed to look up the IP address for interface "
|
LOG(WARNING) << "Failed to look up the IP address for interface "
|
||||||
|
|
@ -173,15 +171,13 @@ const std::string& TensorPipeAgent::guessAddress() {
|
||||||
<< kDefaultUvAddress;
|
<< kDefaultUvAddress;
|
||||||
return kDefaultUvAddress;
|
return kDefaultUvAddress;
|
||||||
}
|
}
|
||||||
} else {
|
return result;
|
||||||
std::tie(error, result) =
|
}
|
||||||
tensorpipe::transport::uv::lookupAddrForHostname();
|
auto [error, result] = tensorpipe::transport::uv::lookupAddrForHostname();
|
||||||
if (error) {
|
if (error) {
|
||||||
LOG(WARNING) << "Failed to look up the IP address for the hostname ("
|
LOG(WARNING) << "Failed to look up the IP address for the hostname ("
|
||||||
<< error.what() << "), defaulting to "
|
<< error.what() << "), defaulting to " << kDefaultUvAddress;
|
||||||
<< kDefaultUvAddress;
|
return kDefaultUvAddress;
|
||||||
return kDefaultUvAddress;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}();
|
}();
|
||||||
|
|
@ -1226,8 +1222,8 @@ const std::string& TensorPipeAgent::findWorkerURL(
|
||||||
|
|
||||||
void TensorPipeAgent::updateGroupMembership(
|
void TensorPipeAgent::updateGroupMembership(
|
||||||
const WorkerInfo& workerInfo,
|
const WorkerInfo& workerInfo,
|
||||||
const std::vector<c10::Device> devices,
|
const std::vector<c10::Device>& devices,
|
||||||
const std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
|
const std::unordered_map<std::string, DeviceMap>& reverseDeviceMaps,
|
||||||
bool isJoin) {
|
bool isJoin) {
|
||||||
std::string name = workerInfo.name_;
|
std::string name = workerInfo.name_;
|
||||||
worker_id_t id = workerInfo.id_;
|
worker_id_t id = workerInfo.id_;
|
||||||
|
|
|
||||||
|
|
@ -194,8 +194,8 @@ class TORCH_API TensorPipeAgent : public RpcAgent {
|
||||||
std::vector<WorkerInfo> getWorkerInfos() const override;
|
std::vector<WorkerInfo> getWorkerInfos() const override;
|
||||||
void updateGroupMembership(
|
void updateGroupMembership(
|
||||||
const WorkerInfo& workerInfo,
|
const WorkerInfo& workerInfo,
|
||||||
const std::vector<c10::Device> devices,
|
const std::vector<c10::Device>& devices,
|
||||||
const std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
|
const std::unordered_map<std::string, DeviceMap>& reverseDeviceMaps,
|
||||||
bool isJoin);
|
bool isJoin);
|
||||||
|
|
||||||
std::unordered_map<std::string, std::string> getMetrics() override;
|
std::unordered_map<std::string, std::string> getMetrics() override;
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@ class SGDParamState {
|
||||||
static_cast<const SGDParamState&>(*this));
|
static_cast<const SGDParamState&>(*this));
|
||||||
}
|
}
|
||||||
friend bool operator==(const SGDParamState& lhs, const SGDParamState& rhs);
|
friend bool operator==(const SGDParamState& lhs, const SGDParamState& rhs);
|
||||||
~SGDParamState() = default;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct TORCH_API SGDOptions {
|
struct TORCH_API SGDOptions {
|
||||||
|
|
@ -40,7 +39,6 @@ struct TORCH_API SGDOptions {
|
||||||
TORCH_API friend bool operator==(
|
TORCH_API friend bool operator==(
|
||||||
const SGDOptions& lhs,
|
const SGDOptions& lhs,
|
||||||
const SGDOptions& rhs);
|
const SGDOptions& rhs);
|
||||||
~SGDOptions() = default;
|
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Stores parameters in the param_group and stores a pointer to the SGDOptions
|
/// Stores parameters in the param_group and stores a pointer to the SGDOptions
|
||||||
|
|
|
||||||
|
|
@ -1100,7 +1100,6 @@ Code::Code(
|
||||||
remaining_bailout_depth)) {}
|
remaining_bailout_depth)) {}
|
||||||
|
|
||||||
Code::Code(CodeImpl* codeImpl) : pImpl(codeImpl) {}
|
Code::Code(CodeImpl* codeImpl) : pImpl(codeImpl) {}
|
||||||
Code::~Code() = default;
|
|
||||||
|
|
||||||
MobileCode::MobileCode(
|
MobileCode::MobileCode(
|
||||||
const std::shared_ptr<Graph>& graph,
|
const std::shared_ptr<Graph>& graph,
|
||||||
|
|
@ -1117,8 +1116,6 @@ MobileCode::MobileCode(
|
||||||
emit_promoted_ops,
|
emit_promoted_ops,
|
||||||
remaining_bailout_depth)) {}
|
remaining_bailout_depth)) {}
|
||||||
|
|
||||||
MobileCode::~MobileCode() = default;
|
|
||||||
|
|
||||||
const std::vector<GraphExecutor*>& Code::grad_executors() {
|
const std::vector<GraphExecutor*>& Code::grad_executors() {
|
||||||
return pImpl->grad_executors();
|
return pImpl->grad_executors();
|
||||||
}
|
}
|
||||||
|
|
@ -1172,7 +1169,6 @@ InterpreterState::InterpreterState(const Code& code, TaskLauncher taskLauncher)
|
||||||
: pImpl(c10::make_intrusive<InterpreterStateImpl>(
|
: pImpl(c10::make_intrusive<InterpreterStateImpl>(
|
||||||
code,
|
code,
|
||||||
std::move(taskLauncher))) {}
|
std::move(taskLauncher))) {}
|
||||||
InterpreterState::~InterpreterState() = default;
|
|
||||||
|
|
||||||
void InterpreterState::run(Stack& stack) {
|
void InterpreterState::run(Stack& stack) {
|
||||||
static_cast<InterpreterStateImpl*>(pImpl.get())->run(stack);
|
static_cast<InterpreterStateImpl*>(pImpl.get())->run(stack);
|
||||||
|
|
|
||||||
|
|
@ -9,11 +9,6 @@
|
||||||
#include <torch/csrc/Export.h>
|
#include <torch/csrc/Export.h>
|
||||||
#include <torch/csrc/jit/frontend/source_range.h>
|
#include <torch/csrc/jit/frontend/source_range.h>
|
||||||
|
|
||||||
C10_CLANG_DIAGNOSTIC_PUSH()
|
|
||||||
#if C10_CLANG_HAS_WARNING("-Wdeprecated-copy-dtor")
|
|
||||||
C10_CLANG_DIAGNOSTIC_IGNORE("-Wdeprecated-copy-dtor")
|
|
||||||
#endif
|
|
||||||
|
|
||||||
C10_DECLARE_bool(torch_jit_disable_warning_prints);
|
C10_DECLARE_bool(torch_jit_disable_warning_prints);
|
||||||
C10_DECLARE_bool(torch_jit_enable_rethrow_caught_exception);
|
C10_DECLARE_bool(torch_jit_enable_rethrow_caught_exception);
|
||||||
|
|
||||||
|
|
@ -55,7 +50,6 @@ struct TORCH_API Code {
|
||||||
const std::shared_ptr<Graph>& graph,
|
const std::shared_ptr<Graph>& graph,
|
||||||
std::string function_name,
|
std::string function_name,
|
||||||
size_t remaining_bailout_depth = 0);
|
size_t remaining_bailout_depth = 0);
|
||||||
~Code();
|
|
||||||
|
|
||||||
const std::vector<GraphExecutor*>& grad_executors();
|
const std::vector<GraphExecutor*>& grad_executors();
|
||||||
const std::vector<GraphExecutor*>& diff_graph_op_executors();
|
const std::vector<GraphExecutor*>& diff_graph_op_executors();
|
||||||
|
|
@ -89,7 +83,6 @@ struct TORCH_API MobileCode : Code {
|
||||||
bool support_default_args_before_out = true,
|
bool support_default_args_before_out = true,
|
||||||
bool emit_promoted_ops = true,
|
bool emit_promoted_ops = true,
|
||||||
size_t remaining_bailout_depth = 0);
|
size_t remaining_bailout_depth = 0);
|
||||||
~MobileCode();
|
|
||||||
};
|
};
|
||||||
|
|
||||||
struct InterpreterState {
|
struct InterpreterState {
|
||||||
|
|
@ -99,7 +92,6 @@ struct InterpreterState {
|
||||||
TORCH_API void run(Stack& stack);
|
TORCH_API void run(Stack& stack);
|
||||||
TORCH_API c10::intrusive_ptr<Future> runAsync(Stack& stack);
|
TORCH_API c10::intrusive_ptr<Future> runAsync(Stack& stack);
|
||||||
c10::intrusive_ptr<Future> getFuture();
|
c10::intrusive_ptr<Future> getFuture();
|
||||||
TORCH_API ~InterpreterState();
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
InterpreterState(c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl);
|
InterpreterState(c10::intrusive_ptr<c10::intrusive_ptr_target> pImpl);
|
||||||
|
|
@ -127,18 +119,19 @@ struct Suspend : public std::exception {
|
||||||
// through (and only through) the forward pass manually, other
|
// through (and only through) the forward pass manually, other
|
||||||
// thread local settings are propagated with ThreadLocalState
|
// thread local settings are propagated with ThreadLocalState
|
||||||
struct InterpreterContinuation {
|
struct InterpreterContinuation {
|
||||||
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
|
|
||||||
InterpreterContinuation(
|
InterpreterContinuation(
|
||||||
const InterpreterState& state_,
|
InterpreterState state_,
|
||||||
Stack stack_,
|
Stack stack_,
|
||||||
int64_t dist_autograd_context_id = 0,
|
int64_t dist_autograd_context_id = 0,
|
||||||
c10::optional<at::ThreadLocalState> tls_state = c10::nullopt)
|
c10::optional<at::ThreadLocalState> tls_state = c10::nullopt)
|
||||||
: state(state_),
|
: state(std::move(state_)),
|
||||||
stack(std::move(stack_)),
|
stack(std::move(stack_)),
|
||||||
tls_state_(std::move(tls_state)) {
|
tls_state_(std::move(tls_state))
|
||||||
#ifdef USE_DISTRIBUTED
|
#ifdef USE_DISTRIBUTED
|
||||||
dist_autograd_context_id_ = dist_autograd_context_id;
|
,
|
||||||
|
dist_autograd_context_id_(dist_autograd_context_id)
|
||||||
#endif
|
#endif
|
||||||
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
void operator()();
|
void operator()();
|
||||||
|
|
@ -163,5 +156,3 @@ TORCH_API std::vector<StackEntry> currentCallstack();
|
||||||
TORCH_API std::vector<std::string> currentModuleHierarchy();
|
TORCH_API std::vector<std::string> currentModuleHierarchy();
|
||||||
|
|
||||||
} // namespace torch::jit
|
} // namespace torch::jit
|
||||||
|
|
||||||
C10_CLANG_DIAGNOSTIC_POP()
|
|
||||||
|
|
|
||||||
|
|
@ -1677,8 +1677,6 @@ uint64_t PythonPrint::minVersion() const {
|
||||||
return pImpl->min_version_;
|
return pImpl->min_version_;
|
||||||
}
|
}
|
||||||
|
|
||||||
PythonPrint::~PythonPrint() = default;
|
|
||||||
|
|
||||||
static std::vector<IValue> traverseIValueAndGetObjects(IValue ivalue) {
|
static std::vector<IValue> traverseIValueAndGetObjects(IValue ivalue) {
|
||||||
std::vector<IValue> result;
|
std::vector<IValue> result;
|
||||||
std::vector<IValue> stack;
|
std::vector<IValue> stack;
|
||||||
|
|
|
||||||
|
|
@ -42,8 +42,6 @@ struct TORCH_API PythonPrint {
|
||||||
const SourceRangeRecords& ranges() const;
|
const SourceRangeRecords& ranges() const;
|
||||||
uint64_t minVersion() const;
|
uint64_t minVersion() const;
|
||||||
|
|
||||||
~PythonPrint();
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
std::shared_ptr<PythonPrintImpl> pImpl;
|
std::shared_ptr<PythonPrintImpl> pImpl;
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,6 @@ class TORCH_API Reducer {
|
||||||
template <typename RI>
|
template <typename RI>
|
||||||
Reducer(ExprHandle init, RI interaction)
|
Reducer(ExprHandle init, RI interaction)
|
||||||
: init_(init.node()), interaction_(std::move(interaction)) {}
|
: init_(init.node()), interaction_(std::move(interaction)) {}
|
||||||
virtual ~Reducer() = default;
|
|
||||||
|
|
||||||
ExprPtr initializer() const {
|
ExprPtr initializer() const {
|
||||||
return init_;
|
return init_;
|
||||||
|
|
|
||||||
|
|
@ -67,6 +67,8 @@ class TORCH_API LazyTensor : public c10::intrusive_ptr_target {
|
||||||
// used to rely on a LazyTensor obj with a null Data can now rely on a null
|
// used to rely on a LazyTensor obj with a null Data can now rely on a null
|
||||||
// LazyTensorPtr instead.
|
// LazyTensorPtr instead.
|
||||||
LazyTensor() = delete;
|
LazyTensor() = delete;
|
||||||
|
LazyTensor(const LazyTensor&) = default;
|
||||||
|
LazyTensor(LazyTensor&&) noexcept = default;
|
||||||
|
|
||||||
~LazyTensor() override = default;
|
~LazyTensor() override = default;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -49,7 +49,6 @@ struct TORCH_API ExperimentalConfig {
|
||||||
std::vector<std::string> performance_events = {},
|
std::vector<std::string> performance_events = {},
|
||||||
bool enable_cuda_sync_events = false,
|
bool enable_cuda_sync_events = false,
|
||||||
bool adjust_timestamps = false);
|
bool adjust_timestamps = false);
|
||||||
~ExperimentalConfig() = default;
|
|
||||||
explicit operator bool() const;
|
explicit operator bool() const;
|
||||||
|
|
||||||
std::vector<std::string> profiler_metrics;
|
std::vector<std::string> profiler_metrics;
|
||||||
|
|
@ -88,7 +87,6 @@ struct TORCH_API ProfilerConfig {
|
||||||
bool with_flops = false,
|
bool with_flops = false,
|
||||||
bool with_modules = false,
|
bool with_modules = false,
|
||||||
ExperimentalConfig experimental_config = ExperimentalConfig());
|
ExperimentalConfig experimental_config = ExperimentalConfig());
|
||||||
~ProfilerConfig() = default;
|
|
||||||
|
|
||||||
bool disabled() const;
|
bool disabled() const;
|
||||||
bool global() const;
|
bool global() const;
|
||||||
|
|
|
||||||
|
|
@ -6,7 +6,9 @@
|
||||||
|
|
||||||
#include <algorithm>
|
#include <algorithm>
|
||||||
#include <memory>
|
#include <memory>
|
||||||
|
#include <string_view>
|
||||||
#include <unordered_map>
|
#include <unordered_map>
|
||||||
|
#include <unordered_set>
|
||||||
|
|
||||||
namespace torch {
|
namespace torch {
|
||||||
|
|
||||||
|
|
@ -17,12 +19,17 @@ std::string py_typename(PyObject* object) {
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Type {
|
struct Type {
|
||||||
|
Type() = default;
|
||||||
|
Type(const Type&) = default;
|
||||||
|
Type& operator=(const Type&) = default;
|
||||||
|
Type(Type&&) noexcept = default;
|
||||||
|
Type& operator=(Type&&) noexcept = default;
|
||||||
virtual bool is_matching(PyObject* object) = 0;
|
virtual bool is_matching(PyObject* object) = 0;
|
||||||
virtual ~Type() = default;
|
virtual ~Type() = default;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct SimpleType : public Type {
|
struct SimpleType : public Type {
|
||||||
SimpleType(std::string& name) : name(name){};
|
SimpleType(std::string_view name) : name(name){};
|
||||||
|
|
||||||
bool is_matching(PyObject* object) override {
|
bool is_matching(PyObject* object) override {
|
||||||
return py_typename(object) == name;
|
return py_typename(object) == name;
|
||||||
|
|
@ -36,11 +43,10 @@ struct MultiType : public Type {
|
||||||
: types(accepted_types){};
|
: types(accepted_types){};
|
||||||
|
|
||||||
bool is_matching(PyObject* object) override {
|
bool is_matching(PyObject* object) override {
|
||||||
auto it = std::find(types.begin(), types.end(), py_typename(object));
|
return types.find(py_typename(object)) != types.end();
|
||||||
return it != types.end();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::string> types;
|
std::unordered_set<std::string> types;
|
||||||
};
|
};
|
||||||
|
|
||||||
struct NullableType : public Type {
|
struct NullableType : public Type {
|
||||||
|
|
@ -93,8 +99,8 @@ struct SequenceType : public Type {
|
||||||
};
|
};
|
||||||
|
|
||||||
struct Argument {
|
struct Argument {
|
||||||
Argument(std::string name, std::unique_ptr<Type> type)
|
Argument(std::string_view name, std::unique_ptr<Type> type)
|
||||||
: name(std::move(name)), type(std::move(type)){};
|
: name(name), type(std::move(type)){};
|
||||||
|
|
||||||
std::string name;
|
std::string name;
|
||||||
std::unique_ptr<Type> type;
|
std::unique_ptr<Type> type;
|
||||||
|
|
@ -118,13 +124,13 @@ struct Option {
|
||||||
bool has_out;
|
bool has_out;
|
||||||
};
|
};
|
||||||
|
|
||||||
std::vector<std::string> _splitString(
|
std::vector<std::string_view> _splitString(
|
||||||
const std::string& s,
|
std::string_view s,
|
||||||
const std::string& delim) {
|
std::string_view delim) {
|
||||||
std::vector<std::string> tokens;
|
std::vector<std::string_view> tokens;
|
||||||
size_t start = 0;
|
size_t start = 0;
|
||||||
size_t end = 0;
|
size_t end = 0;
|
||||||
while ((end = s.find(delim, start)) != std::string::npos) {
|
while ((end = s.find(delim, start)) != std::string_view::npos) {
|
||||||
tokens.push_back(s.substr(start, end - start));
|
tokens.push_back(s.substr(start, end - start));
|
||||||
start = end + delim.length();
|
start = end + delim.length();
|
||||||
}
|
}
|
||||||
|
|
@ -132,7 +138,7 @@ std::vector<std::string> _splitString(
|
||||||
return tokens;
|
return tokens;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::unique_ptr<Type> _buildType(std::string type_name, bool is_nullable) {
|
std::unique_ptr<Type> _buildType(std::string_view type_name, bool is_nullable) {
|
||||||
std::unique_ptr<Type> result;
|
std::unique_ptr<Type> result;
|
||||||
if (type_name == "float") {
|
if (type_name == "float") {
|
||||||
result = std::make_unique<MultiType>(MultiType{"float", "int", "long"});
|
result = std::make_unique<MultiType>(MultiType{"float", "int", "long"});
|
||||||
|
|
@ -140,14 +146,16 @@ std::unique_ptr<Type> _buildType(std::string type_name, bool is_nullable) {
|
||||||
result = std::make_unique<MultiType>(MultiType{"int", "long"});
|
result = std::make_unique<MultiType>(MultiType{"int", "long"});
|
||||||
} else if (type_name.find("tuple[") == 0) {
|
} else if (type_name.find("tuple[") == 0) {
|
||||||
auto type_list = type_name.substr(6);
|
auto type_list = type_name.substr(6);
|
||||||
type_list.pop_back();
|
type_list.remove_suffix(1);
|
||||||
|
auto sub_string_views = _splitString(type_list, ",");
|
||||||
std::vector<std::unique_ptr<Type>> types;
|
std::vector<std::unique_ptr<Type>> types;
|
||||||
for (auto& type : _splitString(type_list, ","))
|
types.reserve(sub_string_views.size());
|
||||||
|
for (auto& type : sub_string_views)
|
||||||
types.emplace_back(_buildType(type, false));
|
types.emplace_back(_buildType(type, false));
|
||||||
result = std::make_unique<TupleType>(std::move(types));
|
result = std::make_unique<TupleType>(std::move(types));
|
||||||
} else if (type_name.find("sequence[") == 0) {
|
} else if (type_name.find("sequence[") == 0) {
|
||||||
auto subtype = type_name.substr(9);
|
auto subtype = type_name.substr(9);
|
||||||
subtype.pop_back();
|
subtype.remove_suffix(1);
|
||||||
result = std::make_unique<SequenceType>(_buildType(subtype, false));
|
result = std::make_unique<SequenceType>(_buildType(subtype, false));
|
||||||
} else {
|
} else {
|
||||||
result = std::make_unique<SimpleType>(type_name);
|
result = std::make_unique<SimpleType>(type_name);
|
||||||
|
|
@ -194,7 +202,7 @@ std::pair<Option, std::string> _parseOption(
|
||||||
if (arg[type_start_idx] == '[') {
|
if (arg[type_start_idx] == '[') {
|
||||||
is_nullable = true;
|
is_nullable = true;
|
||||||
type_start_idx++;
|
type_start_idx++;
|
||||||
arg.erase(arg.length() - std::string(" or None]").length());
|
arg.remove_suffix(std::string(" or None]").length());
|
||||||
}
|
}
|
||||||
|
|
||||||
auto type_end_idx = arg.find_last_of(' ');
|
auto type_end_idx = arg.find_last_of(' ');
|
||||||
|
|
@ -203,17 +211,15 @@ std::pair<Option, std::string> _parseOption(
|
||||||
// "type ... name" => "type ... name"
|
// "type ... name" => "type ... name"
|
||||||
// ^ ^
|
// ^ ^
|
||||||
auto dots_idx = arg.find("...");
|
auto dots_idx = arg.find("...");
|
||||||
if (dots_idx != std::string::npos)
|
if (dots_idx != std::string_view::npos)
|
||||||
type_end_idx -= 4;
|
type_end_idx -= 4;
|
||||||
|
|
||||||
std::string type_name =
|
auto type_name = arg.substr(type_start_idx, type_end_idx - type_start_idx);
|
||||||
arg.substr(type_start_idx, type_end_idx - type_start_idx);
|
auto name = arg.substr(name_start_idx);
|
||||||
std::string name = arg.substr(name_start_idx);
|
|
||||||
|
|
||||||
arguments.emplace_back(name, _buildType(type_name, is_nullable));
|
arguments.emplace_back(name, _buildType(type_name, is_nullable));
|
||||||
}
|
}
|
||||||
|
|
||||||
bool is_variadic = option_str.find("...") != std::string::npos;
|
bool is_variadic = option_str.find("...") != std::string_view::npos;
|
||||||
return std::pair<Option, std::string>(
|
return std::pair<Option, std::string>(
|
||||||
Option(std::move(arguments), is_variadic, has_out),
|
Option(std::move(arguments), is_variadic, has_out),
|
||||||
std::move(printable_option));
|
std::move(printable_option));
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user