Expand set of 64-bit type tests in LocalClientExecuteTest.ShapeBufferToLiteralConversion64bit and factor out into their own test.

PiperOrigin-RevId: 171043047
This commit is contained in:
Mark Heffernan 2017-10-04 12:05:26 -07:00 committed by TensorFlower Gardener
parent cc521eb06c
commit 8c9ef44668
14 changed files with 120 additions and 323 deletions

View File

@ -543,7 +543,6 @@ cc_library(
],
deps = [
":ir_emission_utils",
":parallel_task_assignment",
":shape_partition",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
@ -653,18 +652,6 @@ tf_cc_test(
],
)
cc_library(
name = "parallel_task_assignment",
srcs = ["parallel_task_assignment.cc"],
hdrs = ["parallel_task_assignment.h"],
deps = [
":ir_emission_utils",
":shape_partition",
"//tensorflow/compiler/xla/service:hlo",
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
],
)
cc_library(
name = "cpu_options",
srcs = ["cpu_options.cc"],

View File

@ -17,7 +17,6 @@ limitations under the License.
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
@ -110,11 +109,10 @@ StatusOr<bool> ParallelizationPreparation::RunParallelTaskAssignment(
HloModule* module) {
VLOG(1) << "RunParallelTaskAssignment max_parallelism_: " << max_parallelism_;
bool changed = false;
// Initialize ParallelTaskAssignment.
ParallelTaskAssignment parallel_task_assignment(max_parallelism_, shape_size_,
module);
// Assign parallel tasks to HLOs in entry computation.
// Run cost analysis on entry computation.
HloCostAnalysis cost_analysis(shape_size_);
HloComputation* computation = module->entry_computation();
Status cost_status = computation->root_instruction()->Accept(&cost_analysis);
for (auto* instruction : computation->instructions()) {
// Currently, we do not assign parallel tasks to instructions with at least
// one of the following properties:
@ -137,8 +135,8 @@ StatusOr<bool> ParallelizationPreparation::RunParallelTaskAssignment(
}
// Calculate target parallel task count in [1, max_parallelism_].
const int64 target_parallel_task_count =
parallel_task_assignment.GetTargetParallelTaskCount(instruction);
const int64 target_parallel_task_count = GetTargetParallelTaskCount(
cost_status.ok() ? &cost_analysis : nullptr, instruction);
if (target_parallel_task_count == 1) {
continue;
}
@ -161,6 +159,30 @@ StatusOr<bool> ParallelizationPreparation::RunParallelTaskAssignment(
return changed;
}
int64 ParallelizationPreparation::GetTargetParallelTaskCount(
const HloCostAnalysis* cost_analysis, HloInstruction* instruction) {
// Default to a simple cost model based on hlo size and typical L2 cache size.
// Note that 'cost_analysis' can be 'nullptr' if HloCostAnalysis returns an
// error status (likely because HLOs like CustomCall are not yet implemented
// in the HloCostAnalysis).
int64 instruction_cost = shape_size_(instruction->shape());
int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size.
if (cost_analysis != nullptr) {
// Calculate the instruction cost in cycles.
// TODO(29630486) Improve on this linear cost model.
// Consider making 'min_cost_per_thread' be a function of the target
// bandwidth limit for instructions with low arithmetic complexity.
instruction_cost = 1 * cost_analysis->flop_count(*instruction) +
2 * cost_analysis->transcendental_count(*instruction) +
10 * cost_analysis->bytes_accessed(*instruction);
// Minimum per-thread cost is 100us of work on a 2GHz core.
min_cost_per_thread = 100000;
}
// Return target parallel task count in [1, max_parallelism_].
return std::min(max_parallelism_,
std::max(1LL, instruction_cost / min_cost_per_thread));
}
bool ParallelizationPreparation::OutlineParallelizableInstruction(
HloInstruction* instruction) {
if (instruction->outer_dimension_partitions().empty()) {

View File

@ -55,6 +55,12 @@ class ParallelizationPreparation : public HloPassInterface {
// Returns true on success or error status otherwise.
StatusOr<bool> RunParallelTaskAssignment(HloModule* module);
// Returns the target parallel task count for 'instruction'.
// Utilizes 'cost_analysis' if non-null.
// Otherwise defaults to a simple HLO output size-based cost model.
int64 GetTargetParallelTaskCount(const HloCostAnalysis* cost_analysis,
HloInstruction* instruction);
// Outlines 'instruction' from entry computation, if it had
// been assigned parallel tasks in an earlier pass through the computation.
// Returns true if 'instruction' was successfully outlined, false otherwise.

View File

@ -1,125 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
#include "tensorflow/compiler/xla/service/hlo_computation.h"
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
namespace xla {
namespace cpu {
class SimpleCostModel : public ParallelCostModel {
public:
SimpleCostModel(const int64 max_parallelism,
const HloCostAnalysis::ShapeSizeFunction& shape_size)
: max_parallelism_(max_parallelism), shape_size_(shape_size) {}
~SimpleCostModel() override {}
int64 GetParallelTaskCount(HloInstruction* instruction) override {
// Simple cost model based on hlo size and typical L2 cache size.
const int64 instruction_cost = shape_size_(instruction->shape());
const int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size.
// Return target parallel task count in [1, max_parallelism_].
return std::min(max_parallelism_,
std::max(1LL, instruction_cost / min_cost_per_thread));
}
private:
const int64 max_parallelism_;
const HloCostAnalysis::ShapeSizeFunction shape_size_;
};
class DefaultCostModel : public ParallelCostModel {
public:
DefaultCostModel(const int64 max_parallelism,
std::unique_ptr<HloCostAnalysis> cost_analysis)
: max_parallelism_(max_parallelism),
cost_analysis_(std::move(cost_analysis)) {}
~DefaultCostModel() override {}
int64 GetParallelTaskCount(HloInstruction* instruction) override {
// Calculate the instruction cost in cycles.
// TODO(29630486) Improve on this linear cost model.
// Consider making 'min_cost_per_thread' be a function of the target
// bandwidth limit for instructions with low arithmetic complexity.
const int64 instruction_cost =
1 * cost_analysis_->flop_count(*instruction) +
2 * cost_analysis_->transcendental_count(*instruction) +
10 * cost_analysis_->bytes_accessed(*instruction);
// Minimum per-thread cost is 100us of work on a 2GHz core.
const int64 min_cost_per_thread = 100000;
// Return target parallel task count in [1, max_parallelism_].
return std::min(max_parallelism_,
std::max(1LL, instruction_cost / min_cost_per_thread));
}
private:
const int64 max_parallelism_;
const std::unique_ptr<HloCostAnalysis> cost_analysis_;
};
ParallelTaskAssignment::ParallelTaskAssignment(
const int64 max_parallelism,
const HloCostAnalysis::ShapeSizeFunction& shape_size,
HloModule* module) {
VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism;
// Run cost analysis on 'module'.
auto cost_analysis = MakeUnique<HloCostAnalysis>(shape_size);
HloComputation* computation = module->entry_computation();
Status status = computation->root_instruction()->Accept(cost_analysis.get());
if (status.ok()) {
// Set default cost model based on 'cost_analysis'.
cost_model_.reset(new DefaultCostModel(max_parallelism,
std::move(cost_analysis)));
} else {
// Fall back to a simple cost model based on hlo size and L2 cache size.
// Note that HloCostAnalysis can returns an error status (likely because
// HLOs like CustomCall are not yet implemented in the HloCostAnalysis).
cost_model_.reset(new SimpleCostModel(max_parallelism, shape_size));
}
}
int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
HloInstruction* instruction) {
// Currently, we do not assign parallel tasks to instructions with at least
// one of the following properties:
// *) Internal threading (library calls to kConv, kDot, and kCustomCall).
// *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot).
// *) Tuple-shaped.
// TODO(b/27458679) Parallelize instructions which are skipped here.
if (instruction->opcode() == HloOpcode::kParameter ||
instruction->opcode() == HloOpcode::kConstant ||
instruction->opcode() == HloOpcode::kCall ||
instruction->opcode() == HloOpcode::kCustomCall ||
instruction->opcode() == HloOpcode::kSelectAndScatter ||
(instruction->opcode() == HloOpcode::kConvolution &&
PotentiallyImplementedAsEigenConvolution(*instruction)) ||
PotentiallyImplementedAsEigenDot(*instruction) ||
(instruction->opcode() == HloOpcode::kFusion &&
instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) ||
ShapeUtil::IsTuple(instruction->shape())) {
return 1;
}
// Consult 'cost_model_' to compute target parallel task count.
return cost_model_->GetParallelTaskCount(instruction);
}
} // namespace cpu
} // namespace xla

View File

@ -1,55 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
#include "tensorflow/compiler/xla/service/hlo_module.h"
namespace xla {
namespace cpu {
// Simple interface for different parallel cost model implementations.
class ParallelCostModel {
public:
virtual ~ParallelCostModel() = default;
virtual int64 GetParallelTaskCount(HloInstruction* instruction) = 0;
};
// ParallelTaskAssignment computes parallel task counts for HLOs in 'module'.
class ParallelTaskAssignment {
public:
// 'max_parallelism': the maximum parallel task count per instruction.
// 'shape_size': shape size function used by HloCostAnalysis during parallel
// task assignment.
// 'module': the containing HloModule.
ParallelTaskAssignment(
const int64 max_parallelism,
const HloCostAnalysis::ShapeSizeFunction& shape_size,
HloModule* module);
~ParallelTaskAssignment() {}
// Computes and returns the target parallel task count for 'instruction'.
int64 GetTargetParallelTaskCount(HloInstruction* instruction);
private:
std::unique_ptr<ParallelCostModel> cost_model_;
};
} // namespace cpu
} // namespace xla
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_

View File

@ -106,6 +106,16 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
for state_element
in self._lstm_cell.zero_state(batch_size=1, dtype=self.dtype)])
def _transform(self, data):
"""Normalize data based on input statistics to encourage stable training."""
mean, variance = self._input_statistics.overall_feature_moments
return (data - mean) / variance
def _de_transform(self, data):
"""Transform data back to the input scale."""
mean, variance = self._input_statistics.overall_feature_moments
return data * variance + mean
def _filtering_step(self, current_times, current_values, state, predictions):
"""Update model state based on observations.
@ -130,10 +140,7 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
state_from_time, prediction, lstm_state = state
with tf.control_dependencies(
[tf.assert_equal(current_times, state_from_time)]):
# Subtract the mean and divide by the variance of the series. Slightly
# more efficient if done for a whole window (using the normalize_features
# argument to SequentialTimeSeriesModel).
transformed_values = self._scale_data(current_values)
transformed_values = self._transform(current_values)
# Use mean squared error across features for the loss.
predictions["loss"] = tf.reduce_mean(
(prediction - transformed_values) ** 2, axis=-1)
@ -149,7 +156,7 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
inputs=previous_observation_or_prediction, state=lstm_state)
next_prediction = self._predict_from_lstm_output(lstm_output)
new_state_tuple = (current_times, next_prediction, new_lstm_state)
return new_state_tuple, {"mean": self._scale_back_data(next_prediction)}
return new_state_tuple, {"mean": self._de_transform(next_prediction)}
def _imputation_step(self, current_times, state):
"""Advance model state across a gap."""

View File

@ -89,6 +89,8 @@ class ARModel(model.TimeSeriesModel):
self.hidden_layer_sizes = hidden_layer_sizes
self.window_size = self.input_window_size + self.output_window_size
self.loss = loss
self.stats_means = None
self.stats_sigmas = None
super(ARModel, self).__init__(
num_features=num_features)
assert num_time_buckets > 0
@ -104,6 +106,32 @@ class ARModel(model.TimeSeriesModel):
assert len(self._periods) or self.input_window_size
assert output_window_size > 0
def scale_data(self, data):
"""Scale data according to stats."""
if self._input_statistics is not None:
return (data - self.stats_means) / self.stats_sigmas
else:
return data
def scale_back_data(self, data):
if self._input_statistics is not None:
return (data * self.stats_sigmas) + self.stats_means
else:
return data
def scale_back_variance(self, var):
if self._input_statistics is not None:
return var * self.stats_sigmas * self.stats_sigmas
else:
return var
def initialize_graph(self, input_statistics=None):
super(ARModel, self).initialize_graph(input_statistics=input_statistics)
if self._input_statistics:
self.stats_means, variances = (
self._input_statistics.overall_feature_moments)
self.stats_sigmas = math_ops.sqrt(variances)
def get_start_state(self):
# State which matches the format we'll return later. Typically this will not
# be used by the model directly, but the shapes and dtypes should match so
@ -360,8 +388,8 @@ class ARModel(model.TimeSeriesModel):
predicted_covariance = array_ops.ones_like(predicted_mean)
# Transform and scale the mean and covariance appropriately.
predicted_mean = self._scale_back_data(predicted_mean)
predicted_covariance = self._scale_back_variance(predicted_covariance)
predicted_mean = self.scale_back_data(predicted_mean)
predicted_covariance = self.scale_back_variance(predicted_covariance)
return {"mean": predicted_mean,
"covariance": predicted_covariance}
@ -390,7 +418,7 @@ class ARModel(model.TimeSeriesModel):
times_feature=TrainEvalFeatures.TIMES,
window_size=self.window_size,
times_shape=times.get_shape()))
values = self._scale_data(values)
values = self.scale_data(values)
if self.input_window_size > 0:
input_values = values[:, :self.input_window_size, :]
else:
@ -407,14 +435,14 @@ class ARModel(model.TimeSeriesModel):
# (observed - predicted) ** 2.
# Note that this affects only evaluation; the training loss is unaffected.
loss = self.loss_op(
self._scale_back_data(targets),
{"mean": self._scale_back_data(prediction_ops["mean"])})
self.scale_back_data(targets),
{"mean": self.scale_back_data(prediction_ops["mean"])})
else:
loss = self.loss_op(targets, prediction_ops)
# Scale back the prediction.
prediction = self._scale_back_data(prediction)
covariance = self._scale_back_variance(covariance)
prediction = self.scale_back_data(prediction)
covariance = self.scale_back_variance(covariance)
return model.ModelOutputs(
loss=loss,
@ -537,7 +565,7 @@ class ARModel(model.TimeSeriesModel):
new_state_times.set_shape((None, self.input_window_size))
new_state_values = array_ops.concat(
[previous_state_values,
self._scale_data(values)], axis=1)[:, -self.input_window_size:, :]
self.scale_data(values)], axis=1)[:, -self.input_window_size:, :]
new_state_values.set_shape((None, self.input_window_size,
self.num_features))
else:

View File

@ -936,7 +936,8 @@ class InputStatisticsFromMiniBatch(object):
start_time = variable_scope.get_variable(
name="start_time",
dtype=dtypes.int64,
initializer=dtypes.int64.max,
initializer=init_ops.zeros_initializer(),
shape=[],
trainable=False)
total_observation_count = variable_scope.get_variable(
name="total_observation_count",

View File

@ -80,8 +80,6 @@ class TimeSeriesModel(object):
self.dtype = dtype
self._input_statistics = None
self._graph_initialized = False
self._stats_means = None
self._stats_sigmas = None
# TODO(allenl): Move more of the generic machinery for generating and
# predicting into TimeSeriesModel, and possibly share it between generate()
@ -122,38 +120,6 @@ class TimeSeriesModel(object):
"""
self._graph_initialized = True
self._input_statistics = input_statistics
if self._input_statistics:
self._stats_means, variances = (
self._input_statistics.overall_feature_moments)
self._stats_sigmas = math_ops.sqrt(variances)
def _scale_data(self, data):
"""Scale data according to stats (input scale -> model scale)."""
if self._input_statistics is not None:
return (data - self._stats_means) / self._stats_sigmas
else:
return data
def _scale_variance(self, variance):
"""Scale variances according to stats (input scale -> model scale)."""
if self._input_statistics is not None:
return variance / self._input_statistics.overall_feature_moments.variance
else:
return variance
def _scale_back_data(self, data):
"""Scale back data according to stats (model scale -> input scale)."""
if self._input_statistics is not None:
return (data * self._stats_sigmas) + self._stats_means
else:
return data
def _scale_back_variance(self, variance):
"""Scale back variances according to stats (model scale -> input scale)."""
if self._input_statistics is not None:
return variance * self._input_statistics.overall_feature_moments.variance
else:
return variance
def _check_graph_initialized(self):
if not self._graph_initialized:
@ -338,7 +304,6 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
train_output_names,
predict_output_names,
num_features,
normalize_features=False,
dtype=dtypes.float32,
exogenous_feature_columns=None,
exogenous_update_condition=None,
@ -351,12 +316,6 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
predict_output_names: A list of products/predictions returned from
_prediction_step.
num_features: Number of features for the time series
normalize_features: Boolean. If True, `values` are passed normalized to
the model (via self._scale_data). Scaling is done for the whole window
as a batch, which is slightly more efficient than scaling inside the
window loop. The model must then define _scale_back_predictions, which
may use _scale_back_data or _scale_back_variance to return predictions
to the input scale.
dtype: The floating point datatype to use.
exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn
objects. See `TimeSeriesModel`.
@ -385,25 +344,9 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
self._exogenous_update_condition = exogenous_update_condition
self._train_output_names = train_output_names
self._predict_output_names = predict_output_names
self._normalize_features = normalize_features
self._static_unrolling_window_size_threshold = (
static_unrolling_window_size_threshold)
def _scale_back_predictions(self, predictions):
"""Return a window of predictions to input scale.
Args:
predictions: A dictionary mapping from prediction names to Tensors.
Returns:
A dictionary with values corrected for input normalization (e.g. with
self._scale_back_mean and possibly self._scale_back_variance). May be a
mutated version of the argument.
"""
raise NotImplementedError(
"SequentialTimeSeriesModel normalized input data"
" (normalize_features=True), but no method was provided to transform "
"the predictions back to the input scale.")
@abc.abstractmethod
def _filtering_step(self, current_times, current_values, state, predictions):
"""Compute a single-step loss for a batch of data.
@ -581,8 +524,6 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
self._check_graph_initialized()
times = math_ops.cast(features[TrainEvalFeatures.TIMES], dtype=dtypes.int64)
values = math_ops.cast(features[TrainEvalFeatures.VALUES], dtype=self.dtype)
if self._normalize_features:
values = self._scale_data(values)
exogenous_regressors = self._process_exogenous_features(
times=times,
features={key: value for key, value in features.items()
@ -615,8 +556,6 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
# Since we have window-level additions to the loss, its per-step value is
# misleading, so we avoid returning it.
del outputs["loss"]
if self._normalize_features:
outputs = self._scale_back_predictions(outputs)
return per_observation_loss, state, outputs
def predict(self, features):
@ -644,8 +583,6 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
times=predict_times, state=start_state,
state_update_fn=_call_prediction_step,
outputs=self._predict_output_names)
if self._normalize_features:
predictions = self._scale_back_predictions(predictions)
return predictions
class _FakeTensorArray(object):

View File

@ -57,9 +57,7 @@ class AdderStateSpaceModel(state_space_model.StateSpaceModel):
# TODO(allenl): Better support for multivariate series here.
initial_value = array_ops.stack([
math_ops.reduce_mean(
self._scale_data(
self._input_statistics.series_start_moments.mean)),
0.
self._input_statistics.series_start_moments.mean), 0.
])
return initial_value + variable_scope.get_variable(
name="prior_state_mean",

View File

@ -232,7 +232,6 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
+ filtering_postprocessor_names),
predict_output_names=["mean", "covariance"],
num_features=configuration.num_features,
normalize_features=True,
dtype=configuration.dtype,
exogenous_feature_columns=configuration.exogenous_feature_columns,
exogenous_update_condition=configuration.exogenous_update_condition,
@ -310,10 +309,15 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
_, _, priors_from_time = state
times = ops.convert_to_tensor(times)
priors_from_time = ops.convert_to_tensor(priors_from_time)
with ops.control_dependencies([
control_flow_ops.Assert(
math_ops.reduce_all(priors_from_time <= times[:, 0]),
[priors_from_time, times[:, 0]],
summarize=100)
]):
times = array_ops.identity(times)
intra_batch_gaps = array_ops.reshape(times[:, 1:] - times[:, :-1], [-1])
# Ignore negative starting gaps, since there will be transient start times
# as inputs statistics are computed.
starting_gaps = math_ops.maximum(times[:, 0] - priors_from_time, 0)
starting_gaps = times[:, 0] - priors_from_time
# Pre-define transition matrices raised to powers (and their sums) for every
# gap in this window. This avoids duplicate computation (for example many
# steps will use the transition matrix raised to the first power) and
@ -365,15 +369,20 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
Imputed model state corresponding to the `state` argument.
"""
estimated_state, estimated_state_var, previous_times = state
# Ignore negative imputation intervals due to transient start time
# estimates.
catchup_times = math_ops.maximum(current_times - previous_times, 0)
transition_matrices, transition_noise_sums = ( # pylint: disable=unbalanced-tuple-unpacking
self._cached_transition_powers_and_sums(catchup_times))
estimated_state = self._kalman_filter.predict_state_mean(
estimated_state, transition_matrices)
estimated_state_var = self._kalman_filter.predict_state_var(
estimated_state_var, transition_matrices, transition_noise_sums)
catchup_times = current_times - previous_times
non_negative_assertion = control_flow_ops.Assert(
math_ops.reduce_all(catchup_times >= 0), [
"Negative imputation interval", catchup_times, current_times,
previous_times
],
summarize=100)
with ops.control_dependencies([non_negative_assertion]):
transition_matrices, transition_noise_sums = ( # pylint: disable=unbalanced-tuple-unpacking
self._cached_transition_powers_and_sums(catchup_times))
estimated_state = self._kalman_filter.predict_state_mean(
estimated_state, transition_matrices)
estimated_state_var = self._kalman_filter.predict_state_var(
estimated_state_var, transition_matrices, transition_noise_sums)
return (estimated_state, estimated_state_var,
previous_times + catchup_times)
@ -428,13 +437,6 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
outputs=predictions)
return (filtered_state, predictions)
def _scale_back_predictions(self, predictions):
"""Return a window of predictions to input scale."""
predictions["mean"] = self._scale_back_data(predictions["mean"])
predictions["covariance"] = self._scale_back_variance(
predictions["covariance"])
return predictions
def _prediction_step(self, current_times, state):
"""Make a prediction based on `state`.
@ -456,7 +458,7 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
"""
estimated_state, estimated_state_var, previous_times = state
advanced_to_current_assert = control_flow_ops.Assert(
math_ops.reduce_all(math_ops.less_equal(current_times, previous_times)),
math_ops.reduce_all(math_ops.equal(current_times, previous_times)),
["Attempted to predict without imputation"])
with ops.control_dependencies([advanced_to_current_assert]):
observation_model = self.get_broadcasted_observation_model(current_times)
@ -473,9 +475,6 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
(self.num_features,)))
predicted_obs_var.set_shape(current_times.get_shape().concatenate(
(self.num_features, self.num_features)))
# Not scaled back to input-scale, since this also feeds into the
# loss. Instead, predictions are scaled back before being returned to the
# user in _scale_back_predictions.
predictions = {
"mean": predicted_obs,
"covariance": predicted_obs_var}
@ -723,8 +722,7 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
# Make sure initial latent value uncertainty is at least on the same
# scale as noise in the data.
covariance_multiplier = math_ops.reduce_max(
self._scale_variance(
self._input_statistics.series_start_moments.variance))
self._input_statistics.series_start_moments.variance)
return base_covariance * gen_math_ops.maximum(
covariance_multiplier, 1.0)
else:
@ -922,8 +920,7 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
self.get_noise_transform(), dtype=self.dtype)
state_noise_dimension = state_noise_transform.get_shape()[1].value
if self._input_statistics is not None:
feature_variance = self._scale_variance(
self._input_statistics.series_start_moments.variance)
feature_variance = self._input_statistics.series_start_moments.variance
initial_transition_noise_scale = math_ops.log(
gen_math_ops.maximum(
math_ops.reduce_mean(feature_variance) / math_ops.cast(
@ -948,8 +945,7 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
if self._input_statistics is not None:
# Get variance across the first few values in each batch for each
# feature, for an initial observation noise (over-)estimate.
feature_variance = self._scale_variance(
self._input_statistics.series_start_moments.variance)
feature_variance = self._input_statistics.series_start_moments.variance
else:
feature_variance = None
if feature_variance is not None:

View File

@ -605,7 +605,6 @@ class TimeDependentStateSpaceModel(state_space_model.StateSpaceModel):
super(TimeDependentStateSpaceModel, self).__init__(
configuration=state_space_model.StateSpaceModelConfiguration(
use_observation_noise=False,
transition_covariance_initial_log_scale_bias=5.,
static_unrolling_window_size_threshold=
static_unrolling_window_size_threshold))

View File

@ -182,8 +182,7 @@ class VARMA(state_space_model.StateSpaceModel):
# modeled as transition noise in VARMA, we set its initial value based on a
# slight over-estimate empirical observation noise.
if self._input_statistics is not None:
feature_variance = self._scale_variance(
self._input_statistics.series_start_moments.variance)
feature_variance = self._input_statistics.series_start_moments.variance
initial_transition_noise_scale = math_ops.log(
math_ops.maximum(
math_ops.reduce_mean(feature_variance), minimum_initial_variance))

View File

@ -39,8 +39,8 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
// x is from the feed.
const int batch_size = tensor_size < 0 ? 1 : tensor_size;
Output x = RandomNormal(s.WithOpName("x").WithDevice("/CPU:0"),
{batch_size, 1}, DataType::DT_FLOAT);
Output x =
RandomNormal(s.WithOpName("x"), {batch_size, 1}, DataType::DT_FLOAT);
// Create stages.
std::vector<Output> last_stage;
@ -64,19 +64,16 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
}
if (insert_queue) {
FIFOQueue queue(s.WithOpName("queue").WithDevice("/CPU:0"),
{DataType::DT_FLOAT});
QueueEnqueue enqueue(s.WithOpName("enqueue").WithDevice("/CPU:0"), queue,
last_stage);
QueueDequeue dequeue(s.WithOpName("dequeue").WithDevice("/CPU:0"), queue,
{DataType::DT_FLOAT});
QueueClose cancel(s.WithOpName("cancel").WithDevice("/CPU:0"), queue,
FIFOQueue queue(s.WithOpName("queue"), {DataType::DT_FLOAT});
QueueEnqueue enqueue(s.WithOpName("enqueue"), queue, last_stage);
QueueDequeue dequeue(s.WithOpName("dequeue"), queue, {DataType::DT_FLOAT});
QueueClose cancel(s.WithOpName("cancel"), queue,
QueueClose::CancelPendingEnqueues(true));
last_stage = {dequeue[0]};
}
// Create output.
AddN output(s.WithOpName("y").WithDevice("/CPU:0"), last_stage);
AddN output(s.WithOpName("y"), last_stage);
GraphDef def;
TF_CHECK_OK(s.ToGraphDef(&def));