mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Follows #131986 Pull Request resolved: https://github.com/pytorch/pytorch/pull/131996 Approved by: https://github.com/ezyang
82 lines
2.3 KiB
C++
82 lines
2.3 KiB
C++
#pragma once
|
|
#include <ATen/core/ivalue.h>
|
|
#include <torch/csrc/jit/frontend/source_range.h>
|
|
#include <torch/csrc/jit/ir/constants.h>
|
|
#include <torch/csrc/utils/variadic.h>
|
|
|
|
namespace torch::jit {
|
|
|
|
struct Value;
|
|
|
|
/**
|
|
* A value with optional extra name and location information. Used during
|
|
* schema matching to provide extra error information and resolve kwargs.
|
|
*/
|
|
struct NamedValue {
|
|
NamedValue(const SourceRange& loc, const std::string& name, Value* value)
|
|
: loc_(loc), name_(name), value_(value) {}
|
|
NamedValue(const SourceRange& loc, Value* value) : loc_(loc), value_(value) {}
|
|
|
|
/* implicit */ NamedValue(Value* value) : value_(value) {}
|
|
NamedValue(const std::string& name, Value* value)
|
|
: name_(name), value_(value) {}
|
|
|
|
/* implicit */ NamedValue(IValue value) : ivalue_(std::move(value)) {}
|
|
|
|
NamedValue(const std::string& name, IValue value)
|
|
: name_(name), ivalue_(std::move(value)) {}
|
|
|
|
template <
|
|
typename T,
|
|
typename = std::enable_if_t<
|
|
(!std::is_same_v<std::decay_t<T>, NamedValue> &&
|
|
!std::is_same_v<std::decay_t<T>, Value*> &&
|
|
!std::is_same_v<std::decay_t<T>, IValue>)>>
|
|
// NOLINTNEXTLINE(bugprone-forwarding-reference-overload)
|
|
NamedValue(T&& t) : NamedValue(IValue(std::forward<T>(t))) {}
|
|
|
|
template <
|
|
typename T,
|
|
typename = std::enable_if_t<
|
|
(!std::is_same_v<std::decay_t<T>, Value*> &&
|
|
!std::is_same_v<std::decay_t<T>, IValue>)>>
|
|
NamedValue(const std::string& name, T&& t)
|
|
: NamedValue(name, IValue(std::forward<T>(t))) {}
|
|
|
|
SourceRange locOr(const SourceRange& backup_location) const {
|
|
if (!loc_)
|
|
return backup_location;
|
|
return loc();
|
|
}
|
|
|
|
// note: this will insert a constant node into the graph at the current
|
|
// insert point if this NamedValue is actually a constant
|
|
Value* value(Graph& g) const {
|
|
if (!value_)
|
|
return insertConstant(
|
|
g, ivalue_); // use insertConstant to remove need to include ir.h here
|
|
return value_;
|
|
}
|
|
|
|
const std::string& name() const {
|
|
AT_ASSERT(name_);
|
|
return *name_;
|
|
}
|
|
|
|
const SourceRange& loc() const {
|
|
AT_ASSERT(loc_);
|
|
return *loc_;
|
|
}
|
|
|
|
at::TypePtr type() const;
|
|
|
|
private:
|
|
std::optional<SourceRange> loc_;
|
|
std::optional<std::string> name_;
|
|
Value* value_{nullptr};
|
|
// only valid if value_ == nullptr;
|
|
IValue ivalue_;
|
|
};
|
|
|
|
} // namespace torch::jit
|