mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 12:20:11 +01:00
Expand set of 64-bit type tests in LocalClientExecuteTest.ShapeBufferToLiteralConversion64bit and factor out into their own test.
PiperOrigin-RevId: 171043047
This commit is contained in:
parent
cc521eb06c
commit
8c9ef44668
|
|
@ -543,7 +543,6 @@ cc_library(
|
|||
],
|
||||
deps = [
|
||||
":ir_emission_utils",
|
||||
":parallel_task_assignment",
|
||||
":shape_partition",
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
|
|
@ -653,18 +652,6 @@ tf_cc_test(
|
|||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "parallel_task_assignment",
|
||||
srcs = ["parallel_task_assignment.cc"],
|
||||
hdrs = ["parallel_task_assignment.h"],
|
||||
deps = [
|
||||
":ir_emission_utils",
|
||||
":shape_partition",
|
||||
"//tensorflow/compiler/xla/service:hlo",
|
||||
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "cpu_options",
|
||||
srcs = ["cpu_options.cc"],
|
||||
|
|
|
|||
|
|
@ -17,7 +17,6 @@ limitations under the License.
|
|||
|
||||
#include "tensorflow/compiler/xla/map_util.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
|
|
@ -110,11 +109,10 @@ StatusOr<bool> ParallelizationPreparation::RunParallelTaskAssignment(
|
|||
HloModule* module) {
|
||||
VLOG(1) << "RunParallelTaskAssignment max_parallelism_: " << max_parallelism_;
|
||||
bool changed = false;
|
||||
// Initialize ParallelTaskAssignment.
|
||||
ParallelTaskAssignment parallel_task_assignment(max_parallelism_, shape_size_,
|
||||
module);
|
||||
// Assign parallel tasks to HLOs in entry computation.
|
||||
// Run cost analysis on entry computation.
|
||||
HloCostAnalysis cost_analysis(shape_size_);
|
||||
HloComputation* computation = module->entry_computation();
|
||||
Status cost_status = computation->root_instruction()->Accept(&cost_analysis);
|
||||
for (auto* instruction : computation->instructions()) {
|
||||
// Currently, we do not assign parallel tasks to instructions with at least
|
||||
// one of the following properties:
|
||||
|
|
@ -137,8 +135,8 @@ StatusOr<bool> ParallelizationPreparation::RunParallelTaskAssignment(
|
|||
}
|
||||
|
||||
// Calculate target parallel task count in [1, max_parallelism_].
|
||||
const int64 target_parallel_task_count =
|
||||
parallel_task_assignment.GetTargetParallelTaskCount(instruction);
|
||||
const int64 target_parallel_task_count = GetTargetParallelTaskCount(
|
||||
cost_status.ok() ? &cost_analysis : nullptr, instruction);
|
||||
if (target_parallel_task_count == 1) {
|
||||
continue;
|
||||
}
|
||||
|
|
@ -161,6 +159,30 @@ StatusOr<bool> ParallelizationPreparation::RunParallelTaskAssignment(
|
|||
return changed;
|
||||
}
|
||||
|
||||
int64 ParallelizationPreparation::GetTargetParallelTaskCount(
|
||||
const HloCostAnalysis* cost_analysis, HloInstruction* instruction) {
|
||||
// Default to a simple cost model based on hlo size and typical L2 cache size.
|
||||
// Note that 'cost_analysis' can be 'nullptr' if HloCostAnalysis returns an
|
||||
// error status (likely because HLOs like CustomCall are not yet implemented
|
||||
// in the HloCostAnalysis).
|
||||
int64 instruction_cost = shape_size_(instruction->shape());
|
||||
int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size.
|
||||
if (cost_analysis != nullptr) {
|
||||
// Calculate the instruction cost in cycles.
|
||||
// TODO(29630486) Improve on this linear cost model.
|
||||
// Consider making 'min_cost_per_thread' be a function of the target
|
||||
// bandwidth limit for instructions with low arithmetic complexity.
|
||||
instruction_cost = 1 * cost_analysis->flop_count(*instruction) +
|
||||
2 * cost_analysis->transcendental_count(*instruction) +
|
||||
10 * cost_analysis->bytes_accessed(*instruction);
|
||||
// Minimum per-thread cost is 100us of work on a 2GHz core.
|
||||
min_cost_per_thread = 100000;
|
||||
}
|
||||
// Return target parallel task count in [1, max_parallelism_].
|
||||
return std::min(max_parallelism_,
|
||||
std::max(1LL, instruction_cost / min_cost_per_thread));
|
||||
}
|
||||
|
||||
bool ParallelizationPreparation::OutlineParallelizableInstruction(
|
||||
HloInstruction* instruction) {
|
||||
if (instruction->outer_dimension_partitions().empty()) {
|
||||
|
|
|
|||
|
|
@ -55,6 +55,12 @@ class ParallelizationPreparation : public HloPassInterface {
|
|||
// Returns true on success or error status otherwise.
|
||||
StatusOr<bool> RunParallelTaskAssignment(HloModule* module);
|
||||
|
||||
// Returns the target parallel task count for 'instruction'.
|
||||
// Utilizes 'cost_analysis' if non-null.
|
||||
// Otherwise defaults to a simple HLO output size-based cost model.
|
||||
int64 GetTargetParallelTaskCount(const HloCostAnalysis* cost_analysis,
|
||||
HloInstruction* instruction);
|
||||
|
||||
// Outlines 'instruction' from entry computation, if it had
|
||||
// been assigned parallel tasks in an earlier pass through the computation.
|
||||
// Returns true if 'instruction' was successfully outlined, false otherwise.
|
||||
|
|
|
|||
|
|
@ -1,125 +0,0 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
|
||||
|
||||
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
||||
#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
|
||||
class SimpleCostModel : public ParallelCostModel {
|
||||
public:
|
||||
SimpleCostModel(const int64 max_parallelism,
|
||||
const HloCostAnalysis::ShapeSizeFunction& shape_size)
|
||||
: max_parallelism_(max_parallelism), shape_size_(shape_size) {}
|
||||
~SimpleCostModel() override {}
|
||||
|
||||
int64 GetParallelTaskCount(HloInstruction* instruction) override {
|
||||
// Simple cost model based on hlo size and typical L2 cache size.
|
||||
const int64 instruction_cost = shape_size_(instruction->shape());
|
||||
const int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size.
|
||||
// Return target parallel task count in [1, max_parallelism_].
|
||||
return std::min(max_parallelism_,
|
||||
std::max(1LL, instruction_cost / min_cost_per_thread));
|
||||
}
|
||||
|
||||
private:
|
||||
const int64 max_parallelism_;
|
||||
const HloCostAnalysis::ShapeSizeFunction shape_size_;
|
||||
};
|
||||
|
||||
class DefaultCostModel : public ParallelCostModel {
|
||||
public:
|
||||
DefaultCostModel(const int64 max_parallelism,
|
||||
std::unique_ptr<HloCostAnalysis> cost_analysis)
|
||||
: max_parallelism_(max_parallelism),
|
||||
cost_analysis_(std::move(cost_analysis)) {}
|
||||
~DefaultCostModel() override {}
|
||||
|
||||
int64 GetParallelTaskCount(HloInstruction* instruction) override {
|
||||
// Calculate the instruction cost in cycles.
|
||||
// TODO(29630486) Improve on this linear cost model.
|
||||
// Consider making 'min_cost_per_thread' be a function of the target
|
||||
// bandwidth limit for instructions with low arithmetic complexity.
|
||||
const int64 instruction_cost =
|
||||
1 * cost_analysis_->flop_count(*instruction) +
|
||||
2 * cost_analysis_->transcendental_count(*instruction) +
|
||||
10 * cost_analysis_->bytes_accessed(*instruction);
|
||||
// Minimum per-thread cost is 100us of work on a 2GHz core.
|
||||
const int64 min_cost_per_thread = 100000;
|
||||
// Return target parallel task count in [1, max_parallelism_].
|
||||
return std::min(max_parallelism_,
|
||||
std::max(1LL, instruction_cost / min_cost_per_thread));
|
||||
}
|
||||
|
||||
private:
|
||||
const int64 max_parallelism_;
|
||||
const std::unique_ptr<HloCostAnalysis> cost_analysis_;
|
||||
};
|
||||
|
||||
|
||||
ParallelTaskAssignment::ParallelTaskAssignment(
|
||||
const int64 max_parallelism,
|
||||
const HloCostAnalysis::ShapeSizeFunction& shape_size,
|
||||
HloModule* module) {
|
||||
VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism;
|
||||
// Run cost analysis on 'module'.
|
||||
auto cost_analysis = MakeUnique<HloCostAnalysis>(shape_size);
|
||||
HloComputation* computation = module->entry_computation();
|
||||
Status status = computation->root_instruction()->Accept(cost_analysis.get());
|
||||
if (status.ok()) {
|
||||
// Set default cost model based on 'cost_analysis'.
|
||||
cost_model_.reset(new DefaultCostModel(max_parallelism,
|
||||
std::move(cost_analysis)));
|
||||
} else {
|
||||
// Fall back to a simple cost model based on hlo size and L2 cache size.
|
||||
// Note that HloCostAnalysis can returns an error status (likely because
|
||||
// HLOs like CustomCall are not yet implemented in the HloCostAnalysis).
|
||||
cost_model_.reset(new SimpleCostModel(max_parallelism, shape_size));
|
||||
}
|
||||
}
|
||||
|
||||
int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
|
||||
HloInstruction* instruction) {
|
||||
// Currently, we do not assign parallel tasks to instructions with at least
|
||||
// one of the following properties:
|
||||
// *) Internal threading (library calls to kConv, kDot, and kCustomCall).
|
||||
// *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot).
|
||||
// *) Tuple-shaped.
|
||||
// TODO(b/27458679) Parallelize instructions which are skipped here.
|
||||
if (instruction->opcode() == HloOpcode::kParameter ||
|
||||
instruction->opcode() == HloOpcode::kConstant ||
|
||||
instruction->opcode() == HloOpcode::kCall ||
|
||||
instruction->opcode() == HloOpcode::kCustomCall ||
|
||||
instruction->opcode() == HloOpcode::kSelectAndScatter ||
|
||||
(instruction->opcode() == HloOpcode::kConvolution &&
|
||||
PotentiallyImplementedAsEigenConvolution(*instruction)) ||
|
||||
PotentiallyImplementedAsEigenDot(*instruction) ||
|
||||
(instruction->opcode() == HloOpcode::kFusion &&
|
||||
instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) ||
|
||||
ShapeUtil::IsTuple(instruction->shape())) {
|
||||
return 1;
|
||||
}
|
||||
// Consult 'cost_model_' to compute target parallel task count.
|
||||
return cost_model_->GetParallelTaskCount(instruction);
|
||||
}
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
|
@ -1,55 +0,0 @@
|
|||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
|
||||
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
|
||||
|
||||
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
|
||||
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||
|
||||
namespace xla {
|
||||
namespace cpu {
|
||||
|
||||
// Simple interface for different parallel cost model implementations.
|
||||
class ParallelCostModel {
|
||||
public:
|
||||
virtual ~ParallelCostModel() = default;
|
||||
virtual int64 GetParallelTaskCount(HloInstruction* instruction) = 0;
|
||||
};
|
||||
|
||||
// ParallelTaskAssignment computes parallel task counts for HLOs in 'module'.
|
||||
class ParallelTaskAssignment {
|
||||
public:
|
||||
// 'max_parallelism': the maximum parallel task count per instruction.
|
||||
// 'shape_size': shape size function used by HloCostAnalysis during parallel
|
||||
// task assignment.
|
||||
// 'module': the containing HloModule.
|
||||
ParallelTaskAssignment(
|
||||
const int64 max_parallelism,
|
||||
const HloCostAnalysis::ShapeSizeFunction& shape_size,
|
||||
HloModule* module);
|
||||
~ParallelTaskAssignment() {}
|
||||
|
||||
// Computes and returns the target parallel task count for 'instruction'.
|
||||
int64 GetTargetParallelTaskCount(HloInstruction* instruction);
|
||||
|
||||
private:
|
||||
std::unique_ptr<ParallelCostModel> cost_model_;
|
||||
};
|
||||
|
||||
} // namespace cpu
|
||||
} // namespace xla
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
|
||||
|
|
@ -106,6 +106,16 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
|
|||
for state_element
|
||||
in self._lstm_cell.zero_state(batch_size=1, dtype=self.dtype)])
|
||||
|
||||
def _transform(self, data):
|
||||
"""Normalize data based on input statistics to encourage stable training."""
|
||||
mean, variance = self._input_statistics.overall_feature_moments
|
||||
return (data - mean) / variance
|
||||
|
||||
def _de_transform(self, data):
|
||||
"""Transform data back to the input scale."""
|
||||
mean, variance = self._input_statistics.overall_feature_moments
|
||||
return data * variance + mean
|
||||
|
||||
def _filtering_step(self, current_times, current_values, state, predictions):
|
||||
"""Update model state based on observations.
|
||||
|
||||
|
|
@ -130,10 +140,7 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
|
|||
state_from_time, prediction, lstm_state = state
|
||||
with tf.control_dependencies(
|
||||
[tf.assert_equal(current_times, state_from_time)]):
|
||||
# Subtract the mean and divide by the variance of the series. Slightly
|
||||
# more efficient if done for a whole window (using the normalize_features
|
||||
# argument to SequentialTimeSeriesModel).
|
||||
transformed_values = self._scale_data(current_values)
|
||||
transformed_values = self._transform(current_values)
|
||||
# Use mean squared error across features for the loss.
|
||||
predictions["loss"] = tf.reduce_mean(
|
||||
(prediction - transformed_values) ** 2, axis=-1)
|
||||
|
|
@ -149,7 +156,7 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
|
|||
inputs=previous_observation_or_prediction, state=lstm_state)
|
||||
next_prediction = self._predict_from_lstm_output(lstm_output)
|
||||
new_state_tuple = (current_times, next_prediction, new_lstm_state)
|
||||
return new_state_tuple, {"mean": self._scale_back_data(next_prediction)}
|
||||
return new_state_tuple, {"mean": self._de_transform(next_prediction)}
|
||||
|
||||
def _imputation_step(self, current_times, state):
|
||||
"""Advance model state across a gap."""
|
||||
|
|
|
|||
|
|
@ -89,6 +89,8 @@ class ARModel(model.TimeSeriesModel):
|
|||
self.hidden_layer_sizes = hidden_layer_sizes
|
||||
self.window_size = self.input_window_size + self.output_window_size
|
||||
self.loss = loss
|
||||
self.stats_means = None
|
||||
self.stats_sigmas = None
|
||||
super(ARModel, self).__init__(
|
||||
num_features=num_features)
|
||||
assert num_time_buckets > 0
|
||||
|
|
@ -104,6 +106,32 @@ class ARModel(model.TimeSeriesModel):
|
|||
assert len(self._periods) or self.input_window_size
|
||||
assert output_window_size > 0
|
||||
|
||||
def scale_data(self, data):
|
||||
"""Scale data according to stats."""
|
||||
if self._input_statistics is not None:
|
||||
return (data - self.stats_means) / self.stats_sigmas
|
||||
else:
|
||||
return data
|
||||
|
||||
def scale_back_data(self, data):
|
||||
if self._input_statistics is not None:
|
||||
return (data * self.stats_sigmas) + self.stats_means
|
||||
else:
|
||||
return data
|
||||
|
||||
def scale_back_variance(self, var):
|
||||
if self._input_statistics is not None:
|
||||
return var * self.stats_sigmas * self.stats_sigmas
|
||||
else:
|
||||
return var
|
||||
|
||||
def initialize_graph(self, input_statistics=None):
|
||||
super(ARModel, self).initialize_graph(input_statistics=input_statistics)
|
||||
if self._input_statistics:
|
||||
self.stats_means, variances = (
|
||||
self._input_statistics.overall_feature_moments)
|
||||
self.stats_sigmas = math_ops.sqrt(variances)
|
||||
|
||||
def get_start_state(self):
|
||||
# State which matches the format we'll return later. Typically this will not
|
||||
# be used by the model directly, but the shapes and dtypes should match so
|
||||
|
|
@ -360,8 +388,8 @@ class ARModel(model.TimeSeriesModel):
|
|||
predicted_covariance = array_ops.ones_like(predicted_mean)
|
||||
|
||||
# Transform and scale the mean and covariance appropriately.
|
||||
predicted_mean = self._scale_back_data(predicted_mean)
|
||||
predicted_covariance = self._scale_back_variance(predicted_covariance)
|
||||
predicted_mean = self.scale_back_data(predicted_mean)
|
||||
predicted_covariance = self.scale_back_variance(predicted_covariance)
|
||||
|
||||
return {"mean": predicted_mean,
|
||||
"covariance": predicted_covariance}
|
||||
|
|
@ -390,7 +418,7 @@ class ARModel(model.TimeSeriesModel):
|
|||
times_feature=TrainEvalFeatures.TIMES,
|
||||
window_size=self.window_size,
|
||||
times_shape=times.get_shape()))
|
||||
values = self._scale_data(values)
|
||||
values = self.scale_data(values)
|
||||
if self.input_window_size > 0:
|
||||
input_values = values[:, :self.input_window_size, :]
|
||||
else:
|
||||
|
|
@ -407,14 +435,14 @@ class ARModel(model.TimeSeriesModel):
|
|||
# (observed - predicted) ** 2.
|
||||
# Note that this affects only evaluation; the training loss is unaffected.
|
||||
loss = self.loss_op(
|
||||
self._scale_back_data(targets),
|
||||
{"mean": self._scale_back_data(prediction_ops["mean"])})
|
||||
self.scale_back_data(targets),
|
||||
{"mean": self.scale_back_data(prediction_ops["mean"])})
|
||||
else:
|
||||
loss = self.loss_op(targets, prediction_ops)
|
||||
|
||||
# Scale back the prediction.
|
||||
prediction = self._scale_back_data(prediction)
|
||||
covariance = self._scale_back_variance(covariance)
|
||||
prediction = self.scale_back_data(prediction)
|
||||
covariance = self.scale_back_variance(covariance)
|
||||
|
||||
return model.ModelOutputs(
|
||||
loss=loss,
|
||||
|
|
@ -537,7 +565,7 @@ class ARModel(model.TimeSeriesModel):
|
|||
new_state_times.set_shape((None, self.input_window_size))
|
||||
new_state_values = array_ops.concat(
|
||||
[previous_state_values,
|
||||
self._scale_data(values)], axis=1)[:, -self.input_window_size:, :]
|
||||
self.scale_data(values)], axis=1)[:, -self.input_window_size:, :]
|
||||
new_state_values.set_shape((None, self.input_window_size,
|
||||
self.num_features))
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -936,7 +936,8 @@ class InputStatisticsFromMiniBatch(object):
|
|||
start_time = variable_scope.get_variable(
|
||||
name="start_time",
|
||||
dtype=dtypes.int64,
|
||||
initializer=dtypes.int64.max,
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
shape=[],
|
||||
trainable=False)
|
||||
total_observation_count = variable_scope.get_variable(
|
||||
name="total_observation_count",
|
||||
|
|
|
|||
|
|
@ -80,8 +80,6 @@ class TimeSeriesModel(object):
|
|||
self.dtype = dtype
|
||||
self._input_statistics = None
|
||||
self._graph_initialized = False
|
||||
self._stats_means = None
|
||||
self._stats_sigmas = None
|
||||
|
||||
# TODO(allenl): Move more of the generic machinery for generating and
|
||||
# predicting into TimeSeriesModel, and possibly share it between generate()
|
||||
|
|
@ -122,38 +120,6 @@ class TimeSeriesModel(object):
|
|||
"""
|
||||
self._graph_initialized = True
|
||||
self._input_statistics = input_statistics
|
||||
if self._input_statistics:
|
||||
self._stats_means, variances = (
|
||||
self._input_statistics.overall_feature_moments)
|
||||
self._stats_sigmas = math_ops.sqrt(variances)
|
||||
|
||||
def _scale_data(self, data):
|
||||
"""Scale data according to stats (input scale -> model scale)."""
|
||||
if self._input_statistics is not None:
|
||||
return (data - self._stats_means) / self._stats_sigmas
|
||||
else:
|
||||
return data
|
||||
|
||||
def _scale_variance(self, variance):
|
||||
"""Scale variances according to stats (input scale -> model scale)."""
|
||||
if self._input_statistics is not None:
|
||||
return variance / self._input_statistics.overall_feature_moments.variance
|
||||
else:
|
||||
return variance
|
||||
|
||||
def _scale_back_data(self, data):
|
||||
"""Scale back data according to stats (model scale -> input scale)."""
|
||||
if self._input_statistics is not None:
|
||||
return (data * self._stats_sigmas) + self._stats_means
|
||||
else:
|
||||
return data
|
||||
|
||||
def _scale_back_variance(self, variance):
|
||||
"""Scale back variances according to stats (model scale -> input scale)."""
|
||||
if self._input_statistics is not None:
|
||||
return variance * self._input_statistics.overall_feature_moments.variance
|
||||
else:
|
||||
return variance
|
||||
|
||||
def _check_graph_initialized(self):
|
||||
if not self._graph_initialized:
|
||||
|
|
@ -338,7 +304,6 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
|
|||
train_output_names,
|
||||
predict_output_names,
|
||||
num_features,
|
||||
normalize_features=False,
|
||||
dtype=dtypes.float32,
|
||||
exogenous_feature_columns=None,
|
||||
exogenous_update_condition=None,
|
||||
|
|
@ -351,12 +316,6 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
|
|||
predict_output_names: A list of products/predictions returned from
|
||||
_prediction_step.
|
||||
num_features: Number of features for the time series
|
||||
normalize_features: Boolean. If True, `values` are passed normalized to
|
||||
the model (via self._scale_data). Scaling is done for the whole window
|
||||
as a batch, which is slightly more efficient than scaling inside the
|
||||
window loop. The model must then define _scale_back_predictions, which
|
||||
may use _scale_back_data or _scale_back_variance to return predictions
|
||||
to the input scale.
|
||||
dtype: The floating point datatype to use.
|
||||
exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn
|
||||
objects. See `TimeSeriesModel`.
|
||||
|
|
@ -385,25 +344,9 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
|
|||
self._exogenous_update_condition = exogenous_update_condition
|
||||
self._train_output_names = train_output_names
|
||||
self._predict_output_names = predict_output_names
|
||||
self._normalize_features = normalize_features
|
||||
self._static_unrolling_window_size_threshold = (
|
||||
static_unrolling_window_size_threshold)
|
||||
|
||||
def _scale_back_predictions(self, predictions):
|
||||
"""Return a window of predictions to input scale.
|
||||
|
||||
Args:
|
||||
predictions: A dictionary mapping from prediction names to Tensors.
|
||||
Returns:
|
||||
A dictionary with values corrected for input normalization (e.g. with
|
||||
self._scale_back_mean and possibly self._scale_back_variance). May be a
|
||||
mutated version of the argument.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
"SequentialTimeSeriesModel normalized input data"
|
||||
" (normalize_features=True), but no method was provided to transform "
|
||||
"the predictions back to the input scale.")
|
||||
|
||||
@abc.abstractmethod
|
||||
def _filtering_step(self, current_times, current_values, state, predictions):
|
||||
"""Compute a single-step loss for a batch of data.
|
||||
|
|
@ -581,8 +524,6 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
|
|||
self._check_graph_initialized()
|
||||
times = math_ops.cast(features[TrainEvalFeatures.TIMES], dtype=dtypes.int64)
|
||||
values = math_ops.cast(features[TrainEvalFeatures.VALUES], dtype=self.dtype)
|
||||
if self._normalize_features:
|
||||
values = self._scale_data(values)
|
||||
exogenous_regressors = self._process_exogenous_features(
|
||||
times=times,
|
||||
features={key: value for key, value in features.items()
|
||||
|
|
@ -615,8 +556,6 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
|
|||
# Since we have window-level additions to the loss, its per-step value is
|
||||
# misleading, so we avoid returning it.
|
||||
del outputs["loss"]
|
||||
if self._normalize_features:
|
||||
outputs = self._scale_back_predictions(outputs)
|
||||
return per_observation_loss, state, outputs
|
||||
|
||||
def predict(self, features):
|
||||
|
|
@ -644,8 +583,6 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
|
|||
times=predict_times, state=start_state,
|
||||
state_update_fn=_call_prediction_step,
|
||||
outputs=self._predict_output_names)
|
||||
if self._normalize_features:
|
||||
predictions = self._scale_back_predictions(predictions)
|
||||
return predictions
|
||||
|
||||
class _FakeTensorArray(object):
|
||||
|
|
|
|||
|
|
@ -57,9 +57,7 @@ class AdderStateSpaceModel(state_space_model.StateSpaceModel):
|
|||
# TODO(allenl): Better support for multivariate series here.
|
||||
initial_value = array_ops.stack([
|
||||
math_ops.reduce_mean(
|
||||
self._scale_data(
|
||||
self._input_statistics.series_start_moments.mean)),
|
||||
0.
|
||||
self._input_statistics.series_start_moments.mean), 0.
|
||||
])
|
||||
return initial_value + variable_scope.get_variable(
|
||||
name="prior_state_mean",
|
||||
|
|
|
|||
|
|
@ -232,7 +232,6 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
|||
+ filtering_postprocessor_names),
|
||||
predict_output_names=["mean", "covariance"],
|
||||
num_features=configuration.num_features,
|
||||
normalize_features=True,
|
||||
dtype=configuration.dtype,
|
||||
exogenous_feature_columns=configuration.exogenous_feature_columns,
|
||||
exogenous_update_condition=configuration.exogenous_update_condition,
|
||||
|
|
@ -310,10 +309,15 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
|||
_, _, priors_from_time = state
|
||||
times = ops.convert_to_tensor(times)
|
||||
priors_from_time = ops.convert_to_tensor(priors_from_time)
|
||||
with ops.control_dependencies([
|
||||
control_flow_ops.Assert(
|
||||
math_ops.reduce_all(priors_from_time <= times[:, 0]),
|
||||
[priors_from_time, times[:, 0]],
|
||||
summarize=100)
|
||||
]):
|
||||
times = array_ops.identity(times)
|
||||
intra_batch_gaps = array_ops.reshape(times[:, 1:] - times[:, :-1], [-1])
|
||||
# Ignore negative starting gaps, since there will be transient start times
|
||||
# as inputs statistics are computed.
|
||||
starting_gaps = math_ops.maximum(times[:, 0] - priors_from_time, 0)
|
||||
starting_gaps = times[:, 0] - priors_from_time
|
||||
# Pre-define transition matrices raised to powers (and their sums) for every
|
||||
# gap in this window. This avoids duplicate computation (for example many
|
||||
# steps will use the transition matrix raised to the first power) and
|
||||
|
|
@ -365,15 +369,20 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
|||
Imputed model state corresponding to the `state` argument.
|
||||
"""
|
||||
estimated_state, estimated_state_var, previous_times = state
|
||||
# Ignore negative imputation intervals due to transient start time
|
||||
# estimates.
|
||||
catchup_times = math_ops.maximum(current_times - previous_times, 0)
|
||||
transition_matrices, transition_noise_sums = ( # pylint: disable=unbalanced-tuple-unpacking
|
||||
self._cached_transition_powers_and_sums(catchup_times))
|
||||
estimated_state = self._kalman_filter.predict_state_mean(
|
||||
estimated_state, transition_matrices)
|
||||
estimated_state_var = self._kalman_filter.predict_state_var(
|
||||
estimated_state_var, transition_matrices, transition_noise_sums)
|
||||
catchup_times = current_times - previous_times
|
||||
non_negative_assertion = control_flow_ops.Assert(
|
||||
math_ops.reduce_all(catchup_times >= 0), [
|
||||
"Negative imputation interval", catchup_times, current_times,
|
||||
previous_times
|
||||
],
|
||||
summarize=100)
|
||||
with ops.control_dependencies([non_negative_assertion]):
|
||||
transition_matrices, transition_noise_sums = ( # pylint: disable=unbalanced-tuple-unpacking
|
||||
self._cached_transition_powers_and_sums(catchup_times))
|
||||
estimated_state = self._kalman_filter.predict_state_mean(
|
||||
estimated_state, transition_matrices)
|
||||
estimated_state_var = self._kalman_filter.predict_state_var(
|
||||
estimated_state_var, transition_matrices, transition_noise_sums)
|
||||
return (estimated_state, estimated_state_var,
|
||||
previous_times + catchup_times)
|
||||
|
||||
|
|
@ -428,13 +437,6 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
|||
outputs=predictions)
|
||||
return (filtered_state, predictions)
|
||||
|
||||
def _scale_back_predictions(self, predictions):
|
||||
"""Return a window of predictions to input scale."""
|
||||
predictions["mean"] = self._scale_back_data(predictions["mean"])
|
||||
predictions["covariance"] = self._scale_back_variance(
|
||||
predictions["covariance"])
|
||||
return predictions
|
||||
|
||||
def _prediction_step(self, current_times, state):
|
||||
"""Make a prediction based on `state`.
|
||||
|
||||
|
|
@ -456,7 +458,7 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
|||
"""
|
||||
estimated_state, estimated_state_var, previous_times = state
|
||||
advanced_to_current_assert = control_flow_ops.Assert(
|
||||
math_ops.reduce_all(math_ops.less_equal(current_times, previous_times)),
|
||||
math_ops.reduce_all(math_ops.equal(current_times, previous_times)),
|
||||
["Attempted to predict without imputation"])
|
||||
with ops.control_dependencies([advanced_to_current_assert]):
|
||||
observation_model = self.get_broadcasted_observation_model(current_times)
|
||||
|
|
@ -473,9 +475,6 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
|||
(self.num_features,)))
|
||||
predicted_obs_var.set_shape(current_times.get_shape().concatenate(
|
||||
(self.num_features, self.num_features)))
|
||||
# Not scaled back to input-scale, since this also feeds into the
|
||||
# loss. Instead, predictions are scaled back before being returned to the
|
||||
# user in _scale_back_predictions.
|
||||
predictions = {
|
||||
"mean": predicted_obs,
|
||||
"covariance": predicted_obs_var}
|
||||
|
|
@ -723,8 +722,7 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
|||
# Make sure initial latent value uncertainty is at least on the same
|
||||
# scale as noise in the data.
|
||||
covariance_multiplier = math_ops.reduce_max(
|
||||
self._scale_variance(
|
||||
self._input_statistics.series_start_moments.variance))
|
||||
self._input_statistics.series_start_moments.variance)
|
||||
return base_covariance * gen_math_ops.maximum(
|
||||
covariance_multiplier, 1.0)
|
||||
else:
|
||||
|
|
@ -922,8 +920,7 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
|||
self.get_noise_transform(), dtype=self.dtype)
|
||||
state_noise_dimension = state_noise_transform.get_shape()[1].value
|
||||
if self._input_statistics is not None:
|
||||
feature_variance = self._scale_variance(
|
||||
self._input_statistics.series_start_moments.variance)
|
||||
feature_variance = self._input_statistics.series_start_moments.variance
|
||||
initial_transition_noise_scale = math_ops.log(
|
||||
gen_math_ops.maximum(
|
||||
math_ops.reduce_mean(feature_variance) / math_ops.cast(
|
||||
|
|
@ -948,8 +945,7 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
|||
if self._input_statistics is not None:
|
||||
# Get variance across the first few values in each batch for each
|
||||
# feature, for an initial observation noise (over-)estimate.
|
||||
feature_variance = self._scale_variance(
|
||||
self._input_statistics.series_start_moments.variance)
|
||||
feature_variance = self._input_statistics.series_start_moments.variance
|
||||
else:
|
||||
feature_variance = None
|
||||
if feature_variance is not None:
|
||||
|
|
|
|||
|
|
@ -605,7 +605,6 @@ class TimeDependentStateSpaceModel(state_space_model.StateSpaceModel):
|
|||
super(TimeDependentStateSpaceModel, self).__init__(
|
||||
configuration=state_space_model.StateSpaceModelConfiguration(
|
||||
use_observation_noise=False,
|
||||
transition_covariance_initial_log_scale_bias=5.,
|
||||
static_unrolling_window_size_threshold=
|
||||
static_unrolling_window_size_threshold))
|
||||
|
||||
|
|
|
|||
|
|
@ -182,8 +182,7 @@ class VARMA(state_space_model.StateSpaceModel):
|
|||
# modeled as transition noise in VARMA, we set its initial value based on a
|
||||
# slight over-estimate empirical observation noise.
|
||||
if self._input_statistics is not None:
|
||||
feature_variance = self._scale_variance(
|
||||
self._input_statistics.series_start_moments.variance)
|
||||
feature_variance = self._input_statistics.series_start_moments.variance
|
||||
initial_transition_noise_scale = math_ops.log(
|
||||
math_ops.maximum(
|
||||
math_ops.reduce_mean(feature_variance), minimum_initial_variance))
|
||||
|
|
|
|||
|
|
@ -39,8 +39,8 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
|
|||
|
||||
// x is from the feed.
|
||||
const int batch_size = tensor_size < 0 ? 1 : tensor_size;
|
||||
Output x = RandomNormal(s.WithOpName("x").WithDevice("/CPU:0"),
|
||||
{batch_size, 1}, DataType::DT_FLOAT);
|
||||
Output x =
|
||||
RandomNormal(s.WithOpName("x"), {batch_size, 1}, DataType::DT_FLOAT);
|
||||
|
||||
// Create stages.
|
||||
std::vector<Output> last_stage;
|
||||
|
|
@ -64,19 +64,16 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
|
|||
}
|
||||
|
||||
if (insert_queue) {
|
||||
FIFOQueue queue(s.WithOpName("queue").WithDevice("/CPU:0"),
|
||||
{DataType::DT_FLOAT});
|
||||
QueueEnqueue enqueue(s.WithOpName("enqueue").WithDevice("/CPU:0"), queue,
|
||||
last_stage);
|
||||
QueueDequeue dequeue(s.WithOpName("dequeue").WithDevice("/CPU:0"), queue,
|
||||
{DataType::DT_FLOAT});
|
||||
QueueClose cancel(s.WithOpName("cancel").WithDevice("/CPU:0"), queue,
|
||||
FIFOQueue queue(s.WithOpName("queue"), {DataType::DT_FLOAT});
|
||||
QueueEnqueue enqueue(s.WithOpName("enqueue"), queue, last_stage);
|
||||
QueueDequeue dequeue(s.WithOpName("dequeue"), queue, {DataType::DT_FLOAT});
|
||||
QueueClose cancel(s.WithOpName("cancel"), queue,
|
||||
QueueClose::CancelPendingEnqueues(true));
|
||||
last_stage = {dequeue[0]};
|
||||
}
|
||||
|
||||
// Create output.
|
||||
AddN output(s.WithOpName("y").WithDevice("/CPU:0"), last_stage);
|
||||
AddN output(s.WithOpName("y"), last_stage);
|
||||
|
||||
GraphDef def;
|
||||
TF_CHECK_OK(s.ToGraphDef(&def));
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user