pytorch/torch/csrc/utils/schema_info.h
Elias Ellison 97b2dff600 Add Initial Support For Fake Tensor Constant Tracking (#84387)
Adds support for constant tensor tracking within FakeTensors. Copy-pasta'ing from `proxy_tensor.py` why this is useful:
```
# In some circumstances, we will be tracing in a situation where a tensor
# is *statically* known to be a constant (currently, this only happens if
# you run torch.tensor; deterministic factory functions like torch.arange
# don't get this treatment).  When the tensor in question is small, it's
# helpful to due constant propagation in case we call item() (in which
# case we can return the constant value that is known, rather than give
# an error.)
```

This PR only attempts to add support for the tracing scenarios where we run each operation linearly - aot autograd, torchdynamo. It does not yet handle how constant tensors should be handled as part of the persistent fx graph. Additionally, it does not yet attempt to de-duplicate or interact with ProxyMode's only constant tensor handling.

Edit: plan is to rely on functionalization for fx graph
Pull Request resolved: https://github.com/pytorch/pytorch/pull/84387
Approved by: https://github.com/ezyang
2022-09-02 02:43:04 +00:00

118 lines
3.7 KiB
C++

#pragma once
#include <torch/csrc/jit/frontend/function_schema_parser.h>
#include <unordered_set>
namespace torch {
namespace utils {
using SchemaSpecialCasePair =
std::pair<c10::FunctionSchema, std::unordered_set<std::string>>;
/**
* class SchemaInfo
*
* FunctionSchema wrapper that publicizes argument value specific operator
* behavior (mutation, aliasing, special cases, etc...)
*/
struct TORCH_API SchemaInfo {
public:
explicit SchemaInfo(const c10::FunctionSchema& schema)
: schema_(std::move(schema)),
alias_maps_current_(false),
has_init_(false) {}
explicit SchemaInfo(const char* signature)
: schema_(torch::jit::parseSchema(signature)),
alias_maps_current_(false),
has_init_(false) {}
bool is_mutable();
bool is_mutable(const c10::SchemaArgument& argument);
bool is_mutable(c10::string_view name);
bool has_argument(c10::string_view name);
bool is_nondeterministic() const;
// Returns whether lhs and rhs may alias directly.
// This does not account for cases where lhs or rhs are a container that
// may contain elements that alias the other argument.
// Besides the checks already included in FunctionSchema::may_alias, this
// method also accounts special aliasing cases causes by aliasing argument
// values supplied from addArgumentValue.
bool may_alias(
const c10::SchemaArgument& lhs,
const c10::SchemaArgument& rhs);
// Returns whether lhs and rhs may alias directly or whether lhs/rhs are a
// container that may contain elements that alias the other argument. Besides
// the checks already included in FunctionSchema::may_contain_alias, this
// method also accounts for special aliasing cases causes by aliasing argument
// values supplied from addArgumentValue. bidirectional = false only returns
// whether lhs may contain an alias of rhs while bidirectional = true returns
// both directions.
bool may_contain_alias(
const c10::SchemaArgument& lhs,
const c10::SchemaArgument& rhs,
bool bidirectional = true);
void addArgumentValue(const std::string& name, const at::IValue& value);
void addArgumentValues(
const std::vector<c10::optional<at::IValue>>& value_list);
void addArgumentValues(
const std::unordered_map<std::string, at::IValue>& values);
bool hasInputArgumentNamed(const std::string& name) const;
private:
// This function enforces more conservative results when the TORCH_WARN is
// triggered from above due to duplicates in an argument list
void ensureConservativity(
const std::unordered_set<at::Symbol>& duplicates,
const std::vector<c10::Argument>& arguments_list,
c10::SchemaArgType type);
void initSchemaInfo();
void generateAliasMaps();
bool mayContainAliasImpl(
const c10::SchemaArgument& lhs,
const c10::SchemaArgument& rhs);
static std::vector<c10::FunctionSchema> getNonDeterministicOps();
static std::vector<SchemaSpecialCasePair> getTrainingOps();
const std::unordered_set<c10::SchemaArgument>& wildcardSet();
const std::unordered_set<c10::SchemaArgument>& containerSet();
// Set of all wildcard arguments
std::unordered_set<c10::SchemaArgument> wildcard_set_;
// Set of all container arguments
std::unordered_set<c10::SchemaArgument> container_set_;
// Map of argument IValues
std::unordered_map<std::string, at::IValue> value_map_;
// Alias map of inputs with each other
std::vector<std::unordered_set<size_t>> input_alias_map_;
// Alias map of outputs to inputs
std::vector<std::unordered_set<size_t>> output_alias_map_;
const c10::FunctionSchema schema_;
bool alias_maps_current_;
bool has_init_;
};
} // namespace utils
} // namespace torch