mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-07 00:20:20 +01:00
errors out if the evaluator has task id > 0.
PiperOrigin-RevId: 171047652
This commit is contained in:
parent
8c9ef44668
commit
943c6d7af7
|
|
@ -543,6 +543,7 @@ cc_library(
|
||||||
],
|
],
|
||||||
deps = [
|
deps = [
|
||||||
":ir_emission_utils",
|
":ir_emission_utils",
|
||||||
|
":parallel_task_assignment",
|
||||||
":shape_partition",
|
":shape_partition",
|
||||||
"//tensorflow/compiler/xla:types",
|
"//tensorflow/compiler/xla:types",
|
||||||
"//tensorflow/compiler/xla:util",
|
"//tensorflow/compiler/xla:util",
|
||||||
|
|
@ -652,6 +653,18 @@ tf_cc_test(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "parallel_task_assignment",
|
||||||
|
srcs = ["parallel_task_assignment.cc"],
|
||||||
|
hdrs = ["parallel_task_assignment.h"],
|
||||||
|
deps = [
|
||||||
|
":ir_emission_utils",
|
||||||
|
":shape_partition",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo",
|
||||||
|
"//tensorflow/compiler/xla/service:hlo_cost_analysis",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "cpu_options",
|
name = "cpu_options",
|
||||||
srcs = ["cpu_options.cc"],
|
srcs = ["cpu_options.cc"],
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||||
|
|
||||||
#include "tensorflow/compiler/xla/map_util.h"
|
#include "tensorflow/compiler/xla/map_util.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
|
||||||
#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
|
#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
|
@ -109,10 +110,11 @@ StatusOr<bool> ParallelizationPreparation::RunParallelTaskAssignment(
|
||||||
HloModule* module) {
|
HloModule* module) {
|
||||||
VLOG(1) << "RunParallelTaskAssignment max_parallelism_: " << max_parallelism_;
|
VLOG(1) << "RunParallelTaskAssignment max_parallelism_: " << max_parallelism_;
|
||||||
bool changed = false;
|
bool changed = false;
|
||||||
// Run cost analysis on entry computation.
|
// Initialize ParallelTaskAssignment.
|
||||||
HloCostAnalysis cost_analysis(shape_size_);
|
ParallelTaskAssignment parallel_task_assignment(max_parallelism_, shape_size_,
|
||||||
|
module);
|
||||||
|
// Assign parallel tasks to HLOs in entry computation.
|
||||||
HloComputation* computation = module->entry_computation();
|
HloComputation* computation = module->entry_computation();
|
||||||
Status cost_status = computation->root_instruction()->Accept(&cost_analysis);
|
|
||||||
for (auto* instruction : computation->instructions()) {
|
for (auto* instruction : computation->instructions()) {
|
||||||
// Currently, we do not assign parallel tasks to instructions with at least
|
// Currently, we do not assign parallel tasks to instructions with at least
|
||||||
// one of the following properties:
|
// one of the following properties:
|
||||||
|
|
@ -135,8 +137,8 @@ StatusOr<bool> ParallelizationPreparation::RunParallelTaskAssignment(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Calculate target parallel task count in [1, max_parallelism_].
|
// Calculate target parallel task count in [1, max_parallelism_].
|
||||||
const int64 target_parallel_task_count = GetTargetParallelTaskCount(
|
const int64 target_parallel_task_count =
|
||||||
cost_status.ok() ? &cost_analysis : nullptr, instruction);
|
parallel_task_assignment.GetTargetParallelTaskCount(instruction);
|
||||||
if (target_parallel_task_count == 1) {
|
if (target_parallel_task_count == 1) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
@ -159,30 +161,6 @@ StatusOr<bool> ParallelizationPreparation::RunParallelTaskAssignment(
|
||||||
return changed;
|
return changed;
|
||||||
}
|
}
|
||||||
|
|
||||||
int64 ParallelizationPreparation::GetTargetParallelTaskCount(
|
|
||||||
const HloCostAnalysis* cost_analysis, HloInstruction* instruction) {
|
|
||||||
// Default to a simple cost model based on hlo size and typical L2 cache size.
|
|
||||||
// Note that 'cost_analysis' can be 'nullptr' if HloCostAnalysis returns an
|
|
||||||
// error status (likely because HLOs like CustomCall are not yet implemented
|
|
||||||
// in the HloCostAnalysis).
|
|
||||||
int64 instruction_cost = shape_size_(instruction->shape());
|
|
||||||
int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size.
|
|
||||||
if (cost_analysis != nullptr) {
|
|
||||||
// Calculate the instruction cost in cycles.
|
|
||||||
// TODO(29630486) Improve on this linear cost model.
|
|
||||||
// Consider making 'min_cost_per_thread' be a function of the target
|
|
||||||
// bandwidth limit for instructions with low arithmetic complexity.
|
|
||||||
instruction_cost = 1 * cost_analysis->flop_count(*instruction) +
|
|
||||||
2 * cost_analysis->transcendental_count(*instruction) +
|
|
||||||
10 * cost_analysis->bytes_accessed(*instruction);
|
|
||||||
// Minimum per-thread cost is 100us of work on a 2GHz core.
|
|
||||||
min_cost_per_thread = 100000;
|
|
||||||
}
|
|
||||||
// Return target parallel task count in [1, max_parallelism_].
|
|
||||||
return std::min(max_parallelism_,
|
|
||||||
std::max(1LL, instruction_cost / min_cost_per_thread));
|
|
||||||
}
|
|
||||||
|
|
||||||
bool ParallelizationPreparation::OutlineParallelizableInstruction(
|
bool ParallelizationPreparation::OutlineParallelizableInstruction(
|
||||||
HloInstruction* instruction) {
|
HloInstruction* instruction) {
|
||||||
if (instruction->outer_dimension_partitions().empty()) {
|
if (instruction->outer_dimension_partitions().empty()) {
|
||||||
|
|
|
||||||
|
|
@ -55,12 +55,6 @@ class ParallelizationPreparation : public HloPassInterface {
|
||||||
// Returns true on success or error status otherwise.
|
// Returns true on success or error status otherwise.
|
||||||
StatusOr<bool> RunParallelTaskAssignment(HloModule* module);
|
StatusOr<bool> RunParallelTaskAssignment(HloModule* module);
|
||||||
|
|
||||||
// Returns the target parallel task count for 'instruction'.
|
|
||||||
// Utilizes 'cost_analysis' if non-null.
|
|
||||||
// Otherwise defaults to a simple HLO output size-based cost model.
|
|
||||||
int64 GetTargetParallelTaskCount(const HloCostAnalysis* cost_analysis,
|
|
||||||
HloInstruction* instruction);
|
|
||||||
|
|
||||||
// Outlines 'instruction' from entry computation, if it had
|
// Outlines 'instruction' from entry computation, if it had
|
||||||
// been assigned parallel tasks in an earlier pass through the computation.
|
// been assigned parallel tasks in an earlier pass through the computation.
|
||||||
// Returns true if 'instruction' was successfully outlined, false otherwise.
|
// Returns true if 'instruction' was successfully outlined, false otherwise.
|
||||||
|
|
|
||||||
125
tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
Normal file
125
tensorflow/compiler/xla/service/cpu/parallel_task_assignment.cc
Normal file
|
|
@ -0,0 +1,125 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/cpu/parallel_task_assignment.h"
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/cpu/shape_partition.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_computation.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_instruction.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace cpu {
|
||||||
|
|
||||||
|
class SimpleCostModel : public ParallelCostModel {
|
||||||
|
public:
|
||||||
|
SimpleCostModel(const int64 max_parallelism,
|
||||||
|
const HloCostAnalysis::ShapeSizeFunction& shape_size)
|
||||||
|
: max_parallelism_(max_parallelism), shape_size_(shape_size) {}
|
||||||
|
~SimpleCostModel() override {}
|
||||||
|
|
||||||
|
int64 GetParallelTaskCount(HloInstruction* instruction) override {
|
||||||
|
// Simple cost model based on hlo size and typical L2 cache size.
|
||||||
|
const int64 instruction_cost = shape_size_(instruction->shape());
|
||||||
|
const int64 min_cost_per_thread = 256LL << 10; // 256KB L2 Cache size.
|
||||||
|
// Return target parallel task count in [1, max_parallelism_].
|
||||||
|
return std::min(max_parallelism_,
|
||||||
|
std::max(1LL, instruction_cost / min_cost_per_thread));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const int64 max_parallelism_;
|
||||||
|
const HloCostAnalysis::ShapeSizeFunction shape_size_;
|
||||||
|
};
|
||||||
|
|
||||||
|
class DefaultCostModel : public ParallelCostModel {
|
||||||
|
public:
|
||||||
|
DefaultCostModel(const int64 max_parallelism,
|
||||||
|
std::unique_ptr<HloCostAnalysis> cost_analysis)
|
||||||
|
: max_parallelism_(max_parallelism),
|
||||||
|
cost_analysis_(std::move(cost_analysis)) {}
|
||||||
|
~DefaultCostModel() override {}
|
||||||
|
|
||||||
|
int64 GetParallelTaskCount(HloInstruction* instruction) override {
|
||||||
|
// Calculate the instruction cost in cycles.
|
||||||
|
// TODO(29630486) Improve on this linear cost model.
|
||||||
|
// Consider making 'min_cost_per_thread' be a function of the target
|
||||||
|
// bandwidth limit for instructions with low arithmetic complexity.
|
||||||
|
const int64 instruction_cost =
|
||||||
|
1 * cost_analysis_->flop_count(*instruction) +
|
||||||
|
2 * cost_analysis_->transcendental_count(*instruction) +
|
||||||
|
10 * cost_analysis_->bytes_accessed(*instruction);
|
||||||
|
// Minimum per-thread cost is 100us of work on a 2GHz core.
|
||||||
|
const int64 min_cost_per_thread = 100000;
|
||||||
|
// Return target parallel task count in [1, max_parallelism_].
|
||||||
|
return std::min(max_parallelism_,
|
||||||
|
std::max(1LL, instruction_cost / min_cost_per_thread));
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
const int64 max_parallelism_;
|
||||||
|
const std::unique_ptr<HloCostAnalysis> cost_analysis_;
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
|
ParallelTaskAssignment::ParallelTaskAssignment(
|
||||||
|
const int64 max_parallelism,
|
||||||
|
const HloCostAnalysis::ShapeSizeFunction& shape_size,
|
||||||
|
HloModule* module) {
|
||||||
|
VLOG(1) << "ParallelTaskAssignment max_parallelism: " << max_parallelism;
|
||||||
|
// Run cost analysis on 'module'.
|
||||||
|
auto cost_analysis = MakeUnique<HloCostAnalysis>(shape_size);
|
||||||
|
HloComputation* computation = module->entry_computation();
|
||||||
|
Status status = computation->root_instruction()->Accept(cost_analysis.get());
|
||||||
|
if (status.ok()) {
|
||||||
|
// Set default cost model based on 'cost_analysis'.
|
||||||
|
cost_model_.reset(new DefaultCostModel(max_parallelism,
|
||||||
|
std::move(cost_analysis)));
|
||||||
|
} else {
|
||||||
|
// Fall back to a simple cost model based on hlo size and L2 cache size.
|
||||||
|
// Note that HloCostAnalysis can returns an error status (likely because
|
||||||
|
// HLOs like CustomCall are not yet implemented in the HloCostAnalysis).
|
||||||
|
cost_model_.reset(new SimpleCostModel(max_parallelism, shape_size));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int64 ParallelTaskAssignment::GetTargetParallelTaskCount(
|
||||||
|
HloInstruction* instruction) {
|
||||||
|
// Currently, we do not assign parallel tasks to instructions with at least
|
||||||
|
// one of the following properties:
|
||||||
|
// *) Internal threading (library calls to kConv, kDot, and kCustomCall).
|
||||||
|
// *) Emit custom loops (kSelectAndScatter, FusionKind::kTransposeDot).
|
||||||
|
// *) Tuple-shaped.
|
||||||
|
// TODO(b/27458679) Parallelize instructions which are skipped here.
|
||||||
|
if (instruction->opcode() == HloOpcode::kParameter ||
|
||||||
|
instruction->opcode() == HloOpcode::kConstant ||
|
||||||
|
instruction->opcode() == HloOpcode::kCall ||
|
||||||
|
instruction->opcode() == HloOpcode::kCustomCall ||
|
||||||
|
instruction->opcode() == HloOpcode::kSelectAndScatter ||
|
||||||
|
(instruction->opcode() == HloOpcode::kConvolution &&
|
||||||
|
PotentiallyImplementedAsEigenConvolution(*instruction)) ||
|
||||||
|
PotentiallyImplementedAsEigenDot(*instruction) ||
|
||||||
|
(instruction->opcode() == HloOpcode::kFusion &&
|
||||||
|
instruction->fusion_kind() != HloInstruction::FusionKind::kLoop) ||
|
||||||
|
ShapeUtil::IsTuple(instruction->shape())) {
|
||||||
|
return 1;
|
||||||
|
}
|
||||||
|
// Consult 'cost_model_' to compute target parallel task count.
|
||||||
|
return cost_model_->GetParallelTaskCount(instruction);
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace cpu
|
||||||
|
} // namespace xla
|
||||||
|
|
@ -0,0 +1,55 @@
|
||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
|
||||||
|
#define THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
|
||||||
|
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_cost_analysis.h"
|
||||||
|
#include "tensorflow/compiler/xla/service/hlo_module.h"
|
||||||
|
|
||||||
|
namespace xla {
|
||||||
|
namespace cpu {
|
||||||
|
|
||||||
|
// Simple interface for different parallel cost model implementations.
|
||||||
|
class ParallelCostModel {
|
||||||
|
public:
|
||||||
|
virtual ~ParallelCostModel() = default;
|
||||||
|
virtual int64 GetParallelTaskCount(HloInstruction* instruction) = 0;
|
||||||
|
};
|
||||||
|
|
||||||
|
// ParallelTaskAssignment computes parallel task counts for HLOs in 'module'.
|
||||||
|
class ParallelTaskAssignment {
|
||||||
|
public:
|
||||||
|
// 'max_parallelism': the maximum parallel task count per instruction.
|
||||||
|
// 'shape_size': shape size function used by HloCostAnalysis during parallel
|
||||||
|
// task assignment.
|
||||||
|
// 'module': the containing HloModule.
|
||||||
|
ParallelTaskAssignment(
|
||||||
|
const int64 max_parallelism,
|
||||||
|
const HloCostAnalysis::ShapeSizeFunction& shape_size,
|
||||||
|
HloModule* module);
|
||||||
|
~ParallelTaskAssignment() {}
|
||||||
|
|
||||||
|
// Computes and returns the target parallel task count for 'instruction'.
|
||||||
|
int64 GetTargetParallelTaskCount(HloInstruction* instruction);
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unique_ptr<ParallelCostModel> cost_model_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace cpu
|
||||||
|
} // namespace xla
|
||||||
|
|
||||||
|
#endif // THIRD_PARTY_TENSORFLOW_COMPILER_XLA_SERVICE_CPU_PARALLEL_TASK_ASSIGNMENT_H_
|
||||||
|
|
@ -106,16 +106,6 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
|
||||||
for state_element
|
for state_element
|
||||||
in self._lstm_cell.zero_state(batch_size=1, dtype=self.dtype)])
|
in self._lstm_cell.zero_state(batch_size=1, dtype=self.dtype)])
|
||||||
|
|
||||||
def _transform(self, data):
|
|
||||||
"""Normalize data based on input statistics to encourage stable training."""
|
|
||||||
mean, variance = self._input_statistics.overall_feature_moments
|
|
||||||
return (data - mean) / variance
|
|
||||||
|
|
||||||
def _de_transform(self, data):
|
|
||||||
"""Transform data back to the input scale."""
|
|
||||||
mean, variance = self._input_statistics.overall_feature_moments
|
|
||||||
return data * variance + mean
|
|
||||||
|
|
||||||
def _filtering_step(self, current_times, current_values, state, predictions):
|
def _filtering_step(self, current_times, current_values, state, predictions):
|
||||||
"""Update model state based on observations.
|
"""Update model state based on observations.
|
||||||
|
|
||||||
|
|
@ -140,7 +130,10 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
|
||||||
state_from_time, prediction, lstm_state = state
|
state_from_time, prediction, lstm_state = state
|
||||||
with tf.control_dependencies(
|
with tf.control_dependencies(
|
||||||
[tf.assert_equal(current_times, state_from_time)]):
|
[tf.assert_equal(current_times, state_from_time)]):
|
||||||
transformed_values = self._transform(current_values)
|
# Subtract the mean and divide by the variance of the series. Slightly
|
||||||
|
# more efficient if done for a whole window (using the normalize_features
|
||||||
|
# argument to SequentialTimeSeriesModel).
|
||||||
|
transformed_values = self._scale_data(current_values)
|
||||||
# Use mean squared error across features for the loss.
|
# Use mean squared error across features for the loss.
|
||||||
predictions["loss"] = tf.reduce_mean(
|
predictions["loss"] = tf.reduce_mean(
|
||||||
(prediction - transformed_values) ** 2, axis=-1)
|
(prediction - transformed_values) ** 2, axis=-1)
|
||||||
|
|
@ -156,7 +149,7 @@ class _LSTMModel(ts_model.SequentialTimeSeriesModel):
|
||||||
inputs=previous_observation_or_prediction, state=lstm_state)
|
inputs=previous_observation_or_prediction, state=lstm_state)
|
||||||
next_prediction = self._predict_from_lstm_output(lstm_output)
|
next_prediction = self._predict_from_lstm_output(lstm_output)
|
||||||
new_state_tuple = (current_times, next_prediction, new_lstm_state)
|
new_state_tuple = (current_times, next_prediction, new_lstm_state)
|
||||||
return new_state_tuple, {"mean": self._de_transform(next_prediction)}
|
return new_state_tuple, {"mean": self._scale_back_data(next_prediction)}
|
||||||
|
|
||||||
def _imputation_step(self, current_times, state):
|
def _imputation_step(self, current_times, state):
|
||||||
"""Advance model state across a gap."""
|
"""Advance model state across a gap."""
|
||||||
|
|
|
||||||
|
|
@ -89,8 +89,6 @@ class ARModel(model.TimeSeriesModel):
|
||||||
self.hidden_layer_sizes = hidden_layer_sizes
|
self.hidden_layer_sizes = hidden_layer_sizes
|
||||||
self.window_size = self.input_window_size + self.output_window_size
|
self.window_size = self.input_window_size + self.output_window_size
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
self.stats_means = None
|
|
||||||
self.stats_sigmas = None
|
|
||||||
super(ARModel, self).__init__(
|
super(ARModel, self).__init__(
|
||||||
num_features=num_features)
|
num_features=num_features)
|
||||||
assert num_time_buckets > 0
|
assert num_time_buckets > 0
|
||||||
|
|
@ -106,32 +104,6 @@ class ARModel(model.TimeSeriesModel):
|
||||||
assert len(self._periods) or self.input_window_size
|
assert len(self._periods) or self.input_window_size
|
||||||
assert output_window_size > 0
|
assert output_window_size > 0
|
||||||
|
|
||||||
def scale_data(self, data):
|
|
||||||
"""Scale data according to stats."""
|
|
||||||
if self._input_statistics is not None:
|
|
||||||
return (data - self.stats_means) / self.stats_sigmas
|
|
||||||
else:
|
|
||||||
return data
|
|
||||||
|
|
||||||
def scale_back_data(self, data):
|
|
||||||
if self._input_statistics is not None:
|
|
||||||
return (data * self.stats_sigmas) + self.stats_means
|
|
||||||
else:
|
|
||||||
return data
|
|
||||||
|
|
||||||
def scale_back_variance(self, var):
|
|
||||||
if self._input_statistics is not None:
|
|
||||||
return var * self.stats_sigmas * self.stats_sigmas
|
|
||||||
else:
|
|
||||||
return var
|
|
||||||
|
|
||||||
def initialize_graph(self, input_statistics=None):
|
|
||||||
super(ARModel, self).initialize_graph(input_statistics=input_statistics)
|
|
||||||
if self._input_statistics:
|
|
||||||
self.stats_means, variances = (
|
|
||||||
self._input_statistics.overall_feature_moments)
|
|
||||||
self.stats_sigmas = math_ops.sqrt(variances)
|
|
||||||
|
|
||||||
def get_start_state(self):
|
def get_start_state(self):
|
||||||
# State which matches the format we'll return later. Typically this will not
|
# State which matches the format we'll return later. Typically this will not
|
||||||
# be used by the model directly, but the shapes and dtypes should match so
|
# be used by the model directly, but the shapes and dtypes should match so
|
||||||
|
|
@ -388,8 +360,8 @@ class ARModel(model.TimeSeriesModel):
|
||||||
predicted_covariance = array_ops.ones_like(predicted_mean)
|
predicted_covariance = array_ops.ones_like(predicted_mean)
|
||||||
|
|
||||||
# Transform and scale the mean and covariance appropriately.
|
# Transform and scale the mean and covariance appropriately.
|
||||||
predicted_mean = self.scale_back_data(predicted_mean)
|
predicted_mean = self._scale_back_data(predicted_mean)
|
||||||
predicted_covariance = self.scale_back_variance(predicted_covariance)
|
predicted_covariance = self._scale_back_variance(predicted_covariance)
|
||||||
|
|
||||||
return {"mean": predicted_mean,
|
return {"mean": predicted_mean,
|
||||||
"covariance": predicted_covariance}
|
"covariance": predicted_covariance}
|
||||||
|
|
@ -418,7 +390,7 @@ class ARModel(model.TimeSeriesModel):
|
||||||
times_feature=TrainEvalFeatures.TIMES,
|
times_feature=TrainEvalFeatures.TIMES,
|
||||||
window_size=self.window_size,
|
window_size=self.window_size,
|
||||||
times_shape=times.get_shape()))
|
times_shape=times.get_shape()))
|
||||||
values = self.scale_data(values)
|
values = self._scale_data(values)
|
||||||
if self.input_window_size > 0:
|
if self.input_window_size > 0:
|
||||||
input_values = values[:, :self.input_window_size, :]
|
input_values = values[:, :self.input_window_size, :]
|
||||||
else:
|
else:
|
||||||
|
|
@ -435,14 +407,14 @@ class ARModel(model.TimeSeriesModel):
|
||||||
# (observed - predicted) ** 2.
|
# (observed - predicted) ** 2.
|
||||||
# Note that this affects only evaluation; the training loss is unaffected.
|
# Note that this affects only evaluation; the training loss is unaffected.
|
||||||
loss = self.loss_op(
|
loss = self.loss_op(
|
||||||
self.scale_back_data(targets),
|
self._scale_back_data(targets),
|
||||||
{"mean": self.scale_back_data(prediction_ops["mean"])})
|
{"mean": self._scale_back_data(prediction_ops["mean"])})
|
||||||
else:
|
else:
|
||||||
loss = self.loss_op(targets, prediction_ops)
|
loss = self.loss_op(targets, prediction_ops)
|
||||||
|
|
||||||
# Scale back the prediction.
|
# Scale back the prediction.
|
||||||
prediction = self.scale_back_data(prediction)
|
prediction = self._scale_back_data(prediction)
|
||||||
covariance = self.scale_back_variance(covariance)
|
covariance = self._scale_back_variance(covariance)
|
||||||
|
|
||||||
return model.ModelOutputs(
|
return model.ModelOutputs(
|
||||||
loss=loss,
|
loss=loss,
|
||||||
|
|
@ -565,7 +537,7 @@ class ARModel(model.TimeSeriesModel):
|
||||||
new_state_times.set_shape((None, self.input_window_size))
|
new_state_times.set_shape((None, self.input_window_size))
|
||||||
new_state_values = array_ops.concat(
|
new_state_values = array_ops.concat(
|
||||||
[previous_state_values,
|
[previous_state_values,
|
||||||
self.scale_data(values)], axis=1)[:, -self.input_window_size:, :]
|
self._scale_data(values)], axis=1)[:, -self.input_window_size:, :]
|
||||||
new_state_values.set_shape((None, self.input_window_size,
|
new_state_values.set_shape((None, self.input_window_size,
|
||||||
self.num_features))
|
self.num_features))
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
|
|
@ -936,8 +936,7 @@ class InputStatisticsFromMiniBatch(object):
|
||||||
start_time = variable_scope.get_variable(
|
start_time = variable_scope.get_variable(
|
||||||
name="start_time",
|
name="start_time",
|
||||||
dtype=dtypes.int64,
|
dtype=dtypes.int64,
|
||||||
initializer=init_ops.zeros_initializer(),
|
initializer=dtypes.int64.max,
|
||||||
shape=[],
|
|
||||||
trainable=False)
|
trainable=False)
|
||||||
total_observation_count = variable_scope.get_variable(
|
total_observation_count = variable_scope.get_variable(
|
||||||
name="total_observation_count",
|
name="total_observation_count",
|
||||||
|
|
|
||||||
|
|
@ -80,6 +80,8 @@ class TimeSeriesModel(object):
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
self._input_statistics = None
|
self._input_statistics = None
|
||||||
self._graph_initialized = False
|
self._graph_initialized = False
|
||||||
|
self._stats_means = None
|
||||||
|
self._stats_sigmas = None
|
||||||
|
|
||||||
# TODO(allenl): Move more of the generic machinery for generating and
|
# TODO(allenl): Move more of the generic machinery for generating and
|
||||||
# predicting into TimeSeriesModel, and possibly share it between generate()
|
# predicting into TimeSeriesModel, and possibly share it between generate()
|
||||||
|
|
@ -120,6 +122,38 @@ class TimeSeriesModel(object):
|
||||||
"""
|
"""
|
||||||
self._graph_initialized = True
|
self._graph_initialized = True
|
||||||
self._input_statistics = input_statistics
|
self._input_statistics = input_statistics
|
||||||
|
if self._input_statistics:
|
||||||
|
self._stats_means, variances = (
|
||||||
|
self._input_statistics.overall_feature_moments)
|
||||||
|
self._stats_sigmas = math_ops.sqrt(variances)
|
||||||
|
|
||||||
|
def _scale_data(self, data):
|
||||||
|
"""Scale data according to stats (input scale -> model scale)."""
|
||||||
|
if self._input_statistics is not None:
|
||||||
|
return (data - self._stats_means) / self._stats_sigmas
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _scale_variance(self, variance):
|
||||||
|
"""Scale variances according to stats (input scale -> model scale)."""
|
||||||
|
if self._input_statistics is not None:
|
||||||
|
return variance / self._input_statistics.overall_feature_moments.variance
|
||||||
|
else:
|
||||||
|
return variance
|
||||||
|
|
||||||
|
def _scale_back_data(self, data):
|
||||||
|
"""Scale back data according to stats (model scale -> input scale)."""
|
||||||
|
if self._input_statistics is not None:
|
||||||
|
return (data * self._stats_sigmas) + self._stats_means
|
||||||
|
else:
|
||||||
|
return data
|
||||||
|
|
||||||
|
def _scale_back_variance(self, variance):
|
||||||
|
"""Scale back variances according to stats (model scale -> input scale)."""
|
||||||
|
if self._input_statistics is not None:
|
||||||
|
return variance * self._input_statistics.overall_feature_moments.variance
|
||||||
|
else:
|
||||||
|
return variance
|
||||||
|
|
||||||
def _check_graph_initialized(self):
|
def _check_graph_initialized(self):
|
||||||
if not self._graph_initialized:
|
if not self._graph_initialized:
|
||||||
|
|
@ -304,6 +338,7 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
|
||||||
train_output_names,
|
train_output_names,
|
||||||
predict_output_names,
|
predict_output_names,
|
||||||
num_features,
|
num_features,
|
||||||
|
normalize_features=False,
|
||||||
dtype=dtypes.float32,
|
dtype=dtypes.float32,
|
||||||
exogenous_feature_columns=None,
|
exogenous_feature_columns=None,
|
||||||
exogenous_update_condition=None,
|
exogenous_update_condition=None,
|
||||||
|
|
@ -316,6 +351,12 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
|
||||||
predict_output_names: A list of products/predictions returned from
|
predict_output_names: A list of products/predictions returned from
|
||||||
_prediction_step.
|
_prediction_step.
|
||||||
num_features: Number of features for the time series
|
num_features: Number of features for the time series
|
||||||
|
normalize_features: Boolean. If True, `values` are passed normalized to
|
||||||
|
the model (via self._scale_data). Scaling is done for the whole window
|
||||||
|
as a batch, which is slightly more efficient than scaling inside the
|
||||||
|
window loop. The model must then define _scale_back_predictions, which
|
||||||
|
may use _scale_back_data or _scale_back_variance to return predictions
|
||||||
|
to the input scale.
|
||||||
dtype: The floating point datatype to use.
|
dtype: The floating point datatype to use.
|
||||||
exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn
|
exogenous_feature_columns: A list of tf.contrib.layers.FeatureColumn
|
||||||
objects. See `TimeSeriesModel`.
|
objects. See `TimeSeriesModel`.
|
||||||
|
|
@ -344,9 +385,25 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
|
||||||
self._exogenous_update_condition = exogenous_update_condition
|
self._exogenous_update_condition = exogenous_update_condition
|
||||||
self._train_output_names = train_output_names
|
self._train_output_names = train_output_names
|
||||||
self._predict_output_names = predict_output_names
|
self._predict_output_names = predict_output_names
|
||||||
|
self._normalize_features = normalize_features
|
||||||
self._static_unrolling_window_size_threshold = (
|
self._static_unrolling_window_size_threshold = (
|
||||||
static_unrolling_window_size_threshold)
|
static_unrolling_window_size_threshold)
|
||||||
|
|
||||||
|
def _scale_back_predictions(self, predictions):
|
||||||
|
"""Return a window of predictions to input scale.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
predictions: A dictionary mapping from prediction names to Tensors.
|
||||||
|
Returns:
|
||||||
|
A dictionary with values corrected for input normalization (e.g. with
|
||||||
|
self._scale_back_mean and possibly self._scale_back_variance). May be a
|
||||||
|
mutated version of the argument.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"SequentialTimeSeriesModel normalized input data"
|
||||||
|
" (normalize_features=True), but no method was provided to transform "
|
||||||
|
"the predictions back to the input scale.")
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def _filtering_step(self, current_times, current_values, state, predictions):
|
def _filtering_step(self, current_times, current_values, state, predictions):
|
||||||
"""Compute a single-step loss for a batch of data.
|
"""Compute a single-step loss for a batch of data.
|
||||||
|
|
@ -524,6 +581,8 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
|
||||||
self._check_graph_initialized()
|
self._check_graph_initialized()
|
||||||
times = math_ops.cast(features[TrainEvalFeatures.TIMES], dtype=dtypes.int64)
|
times = math_ops.cast(features[TrainEvalFeatures.TIMES], dtype=dtypes.int64)
|
||||||
values = math_ops.cast(features[TrainEvalFeatures.VALUES], dtype=self.dtype)
|
values = math_ops.cast(features[TrainEvalFeatures.VALUES], dtype=self.dtype)
|
||||||
|
if self._normalize_features:
|
||||||
|
values = self._scale_data(values)
|
||||||
exogenous_regressors = self._process_exogenous_features(
|
exogenous_regressors = self._process_exogenous_features(
|
||||||
times=times,
|
times=times,
|
||||||
features={key: value for key, value in features.items()
|
features={key: value for key, value in features.items()
|
||||||
|
|
@ -556,6 +615,8 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
|
||||||
# Since we have window-level additions to the loss, its per-step value is
|
# Since we have window-level additions to the loss, its per-step value is
|
||||||
# misleading, so we avoid returning it.
|
# misleading, so we avoid returning it.
|
||||||
del outputs["loss"]
|
del outputs["loss"]
|
||||||
|
if self._normalize_features:
|
||||||
|
outputs = self._scale_back_predictions(outputs)
|
||||||
return per_observation_loss, state, outputs
|
return per_observation_loss, state, outputs
|
||||||
|
|
||||||
def predict(self, features):
|
def predict(self, features):
|
||||||
|
|
@ -583,6 +644,8 @@ class SequentialTimeSeriesModel(TimeSeriesModel):
|
||||||
times=predict_times, state=start_state,
|
times=predict_times, state=start_state,
|
||||||
state_update_fn=_call_prediction_step,
|
state_update_fn=_call_prediction_step,
|
||||||
outputs=self._predict_output_names)
|
outputs=self._predict_output_names)
|
||||||
|
if self._normalize_features:
|
||||||
|
predictions = self._scale_back_predictions(predictions)
|
||||||
return predictions
|
return predictions
|
||||||
|
|
||||||
class _FakeTensorArray(object):
|
class _FakeTensorArray(object):
|
||||||
|
|
|
||||||
|
|
@ -57,7 +57,9 @@ class AdderStateSpaceModel(state_space_model.StateSpaceModel):
|
||||||
# TODO(allenl): Better support for multivariate series here.
|
# TODO(allenl): Better support for multivariate series here.
|
||||||
initial_value = array_ops.stack([
|
initial_value = array_ops.stack([
|
||||||
math_ops.reduce_mean(
|
math_ops.reduce_mean(
|
||||||
self._input_statistics.series_start_moments.mean), 0.
|
self._scale_data(
|
||||||
|
self._input_statistics.series_start_moments.mean)),
|
||||||
|
0.
|
||||||
])
|
])
|
||||||
return initial_value + variable_scope.get_variable(
|
return initial_value + variable_scope.get_variable(
|
||||||
name="prior_state_mean",
|
name="prior_state_mean",
|
||||||
|
|
|
||||||
|
|
@ -232,6 +232,7 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
||||||
+ filtering_postprocessor_names),
|
+ filtering_postprocessor_names),
|
||||||
predict_output_names=["mean", "covariance"],
|
predict_output_names=["mean", "covariance"],
|
||||||
num_features=configuration.num_features,
|
num_features=configuration.num_features,
|
||||||
|
normalize_features=True,
|
||||||
dtype=configuration.dtype,
|
dtype=configuration.dtype,
|
||||||
exogenous_feature_columns=configuration.exogenous_feature_columns,
|
exogenous_feature_columns=configuration.exogenous_feature_columns,
|
||||||
exogenous_update_condition=configuration.exogenous_update_condition,
|
exogenous_update_condition=configuration.exogenous_update_condition,
|
||||||
|
|
@ -309,15 +310,10 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
||||||
_, _, priors_from_time = state
|
_, _, priors_from_time = state
|
||||||
times = ops.convert_to_tensor(times)
|
times = ops.convert_to_tensor(times)
|
||||||
priors_from_time = ops.convert_to_tensor(priors_from_time)
|
priors_from_time = ops.convert_to_tensor(priors_from_time)
|
||||||
with ops.control_dependencies([
|
|
||||||
control_flow_ops.Assert(
|
|
||||||
math_ops.reduce_all(priors_from_time <= times[:, 0]),
|
|
||||||
[priors_from_time, times[:, 0]],
|
|
||||||
summarize=100)
|
|
||||||
]):
|
|
||||||
times = array_ops.identity(times)
|
|
||||||
intra_batch_gaps = array_ops.reshape(times[:, 1:] - times[:, :-1], [-1])
|
intra_batch_gaps = array_ops.reshape(times[:, 1:] - times[:, :-1], [-1])
|
||||||
starting_gaps = times[:, 0] - priors_from_time
|
# Ignore negative starting gaps, since there will be transient start times
|
||||||
|
# as inputs statistics are computed.
|
||||||
|
starting_gaps = math_ops.maximum(times[:, 0] - priors_from_time, 0)
|
||||||
# Pre-define transition matrices raised to powers (and their sums) for every
|
# Pre-define transition matrices raised to powers (and their sums) for every
|
||||||
# gap in this window. This avoids duplicate computation (for example many
|
# gap in this window. This avoids duplicate computation (for example many
|
||||||
# steps will use the transition matrix raised to the first power) and
|
# steps will use the transition matrix raised to the first power) and
|
||||||
|
|
@ -369,20 +365,15 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
||||||
Imputed model state corresponding to the `state` argument.
|
Imputed model state corresponding to the `state` argument.
|
||||||
"""
|
"""
|
||||||
estimated_state, estimated_state_var, previous_times = state
|
estimated_state, estimated_state_var, previous_times = state
|
||||||
catchup_times = current_times - previous_times
|
# Ignore negative imputation intervals due to transient start time
|
||||||
non_negative_assertion = control_flow_ops.Assert(
|
# estimates.
|
||||||
math_ops.reduce_all(catchup_times >= 0), [
|
catchup_times = math_ops.maximum(current_times - previous_times, 0)
|
||||||
"Negative imputation interval", catchup_times, current_times,
|
transition_matrices, transition_noise_sums = ( # pylint: disable=unbalanced-tuple-unpacking
|
||||||
previous_times
|
self._cached_transition_powers_and_sums(catchup_times))
|
||||||
],
|
estimated_state = self._kalman_filter.predict_state_mean(
|
||||||
summarize=100)
|
estimated_state, transition_matrices)
|
||||||
with ops.control_dependencies([non_negative_assertion]):
|
estimated_state_var = self._kalman_filter.predict_state_var(
|
||||||
transition_matrices, transition_noise_sums = ( # pylint: disable=unbalanced-tuple-unpacking
|
estimated_state_var, transition_matrices, transition_noise_sums)
|
||||||
self._cached_transition_powers_and_sums(catchup_times))
|
|
||||||
estimated_state = self._kalman_filter.predict_state_mean(
|
|
||||||
estimated_state, transition_matrices)
|
|
||||||
estimated_state_var = self._kalman_filter.predict_state_var(
|
|
||||||
estimated_state_var, transition_matrices, transition_noise_sums)
|
|
||||||
return (estimated_state, estimated_state_var,
|
return (estimated_state, estimated_state_var,
|
||||||
previous_times + catchup_times)
|
previous_times + catchup_times)
|
||||||
|
|
||||||
|
|
@ -437,6 +428,13 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
||||||
outputs=predictions)
|
outputs=predictions)
|
||||||
return (filtered_state, predictions)
|
return (filtered_state, predictions)
|
||||||
|
|
||||||
|
def _scale_back_predictions(self, predictions):
|
||||||
|
"""Return a window of predictions to input scale."""
|
||||||
|
predictions["mean"] = self._scale_back_data(predictions["mean"])
|
||||||
|
predictions["covariance"] = self._scale_back_variance(
|
||||||
|
predictions["covariance"])
|
||||||
|
return predictions
|
||||||
|
|
||||||
def _prediction_step(self, current_times, state):
|
def _prediction_step(self, current_times, state):
|
||||||
"""Make a prediction based on `state`.
|
"""Make a prediction based on `state`.
|
||||||
|
|
||||||
|
|
@ -458,7 +456,7 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
||||||
"""
|
"""
|
||||||
estimated_state, estimated_state_var, previous_times = state
|
estimated_state, estimated_state_var, previous_times = state
|
||||||
advanced_to_current_assert = control_flow_ops.Assert(
|
advanced_to_current_assert = control_flow_ops.Assert(
|
||||||
math_ops.reduce_all(math_ops.equal(current_times, previous_times)),
|
math_ops.reduce_all(math_ops.less_equal(current_times, previous_times)),
|
||||||
["Attempted to predict without imputation"])
|
["Attempted to predict without imputation"])
|
||||||
with ops.control_dependencies([advanced_to_current_assert]):
|
with ops.control_dependencies([advanced_to_current_assert]):
|
||||||
observation_model = self.get_broadcasted_observation_model(current_times)
|
observation_model = self.get_broadcasted_observation_model(current_times)
|
||||||
|
|
@ -475,6 +473,9 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
||||||
(self.num_features,)))
|
(self.num_features,)))
|
||||||
predicted_obs_var.set_shape(current_times.get_shape().concatenate(
|
predicted_obs_var.set_shape(current_times.get_shape().concatenate(
|
||||||
(self.num_features, self.num_features)))
|
(self.num_features, self.num_features)))
|
||||||
|
# Not scaled back to input-scale, since this also feeds into the
|
||||||
|
# loss. Instead, predictions are scaled back before being returned to the
|
||||||
|
# user in _scale_back_predictions.
|
||||||
predictions = {
|
predictions = {
|
||||||
"mean": predicted_obs,
|
"mean": predicted_obs,
|
||||||
"covariance": predicted_obs_var}
|
"covariance": predicted_obs_var}
|
||||||
|
|
@ -722,7 +723,8 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
||||||
# Make sure initial latent value uncertainty is at least on the same
|
# Make sure initial latent value uncertainty is at least on the same
|
||||||
# scale as noise in the data.
|
# scale as noise in the data.
|
||||||
covariance_multiplier = math_ops.reduce_max(
|
covariance_multiplier = math_ops.reduce_max(
|
||||||
self._input_statistics.series_start_moments.variance)
|
self._scale_variance(
|
||||||
|
self._input_statistics.series_start_moments.variance))
|
||||||
return base_covariance * gen_math_ops.maximum(
|
return base_covariance * gen_math_ops.maximum(
|
||||||
covariance_multiplier, 1.0)
|
covariance_multiplier, 1.0)
|
||||||
else:
|
else:
|
||||||
|
|
@ -920,7 +922,8 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
||||||
self.get_noise_transform(), dtype=self.dtype)
|
self.get_noise_transform(), dtype=self.dtype)
|
||||||
state_noise_dimension = state_noise_transform.get_shape()[1].value
|
state_noise_dimension = state_noise_transform.get_shape()[1].value
|
||||||
if self._input_statistics is not None:
|
if self._input_statistics is not None:
|
||||||
feature_variance = self._input_statistics.series_start_moments.variance
|
feature_variance = self._scale_variance(
|
||||||
|
self._input_statistics.series_start_moments.variance)
|
||||||
initial_transition_noise_scale = math_ops.log(
|
initial_transition_noise_scale = math_ops.log(
|
||||||
gen_math_ops.maximum(
|
gen_math_ops.maximum(
|
||||||
math_ops.reduce_mean(feature_variance) / math_ops.cast(
|
math_ops.reduce_mean(feature_variance) / math_ops.cast(
|
||||||
|
|
@ -945,7 +948,8 @@ class StateSpaceModel(model.SequentialTimeSeriesModel):
|
||||||
if self._input_statistics is not None:
|
if self._input_statistics is not None:
|
||||||
# Get variance across the first few values in each batch for each
|
# Get variance across the first few values in each batch for each
|
||||||
# feature, for an initial observation noise (over-)estimate.
|
# feature, for an initial observation noise (over-)estimate.
|
||||||
feature_variance = self._input_statistics.series_start_moments.variance
|
feature_variance = self._scale_variance(
|
||||||
|
self._input_statistics.series_start_moments.variance)
|
||||||
else:
|
else:
|
||||||
feature_variance = None
|
feature_variance = None
|
||||||
if feature_variance is not None:
|
if feature_variance is not None:
|
||||||
|
|
|
||||||
|
|
@ -605,6 +605,7 @@ class TimeDependentStateSpaceModel(state_space_model.StateSpaceModel):
|
||||||
super(TimeDependentStateSpaceModel, self).__init__(
|
super(TimeDependentStateSpaceModel, self).__init__(
|
||||||
configuration=state_space_model.StateSpaceModelConfiguration(
|
configuration=state_space_model.StateSpaceModelConfiguration(
|
||||||
use_observation_noise=False,
|
use_observation_noise=False,
|
||||||
|
transition_covariance_initial_log_scale_bias=5.,
|
||||||
static_unrolling_window_size_threshold=
|
static_unrolling_window_size_threshold=
|
||||||
static_unrolling_window_size_threshold))
|
static_unrolling_window_size_threshold))
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -182,7 +182,8 @@ class VARMA(state_space_model.StateSpaceModel):
|
||||||
# modeled as transition noise in VARMA, we set its initial value based on a
|
# modeled as transition noise in VARMA, we set its initial value based on a
|
||||||
# slight over-estimate empirical observation noise.
|
# slight over-estimate empirical observation noise.
|
||||||
if self._input_statistics is not None:
|
if self._input_statistics is not None:
|
||||||
feature_variance = self._input_statistics.series_start_moments.variance
|
feature_variance = self._scale_variance(
|
||||||
|
self._input_statistics.series_start_moments.variance)
|
||||||
initial_transition_noise_scale = math_ops.log(
|
initial_transition_noise_scale = math_ops.log(
|
||||||
math_ops.maximum(
|
math_ops.maximum(
|
||||||
math_ops.reduce_mean(feature_variance), minimum_initial_variance))
|
math_ops.reduce_mean(feature_variance), minimum_initial_variance))
|
||||||
|
|
|
||||||
|
|
@ -39,8 +39,8 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
|
||||||
|
|
||||||
// x is from the feed.
|
// x is from the feed.
|
||||||
const int batch_size = tensor_size < 0 ? 1 : tensor_size;
|
const int batch_size = tensor_size < 0 ? 1 : tensor_size;
|
||||||
Output x =
|
Output x = RandomNormal(s.WithOpName("x").WithDevice("/CPU:0"),
|
||||||
RandomNormal(s.WithOpName("x"), {batch_size, 1}, DataType::DT_FLOAT);
|
{batch_size, 1}, DataType::DT_FLOAT);
|
||||||
|
|
||||||
// Create stages.
|
// Create stages.
|
||||||
std::vector<Output> last_stage;
|
std::vector<Output> last_stage;
|
||||||
|
|
@ -64,16 +64,19 @@ GraphDef CreateGraphDef(int num_stages, int width, int tensor_size,
|
||||||
}
|
}
|
||||||
|
|
||||||
if (insert_queue) {
|
if (insert_queue) {
|
||||||
FIFOQueue queue(s.WithOpName("queue"), {DataType::DT_FLOAT});
|
FIFOQueue queue(s.WithOpName("queue").WithDevice("/CPU:0"),
|
||||||
QueueEnqueue enqueue(s.WithOpName("enqueue"), queue, last_stage);
|
{DataType::DT_FLOAT});
|
||||||
QueueDequeue dequeue(s.WithOpName("dequeue"), queue, {DataType::DT_FLOAT});
|
QueueEnqueue enqueue(s.WithOpName("enqueue").WithDevice("/CPU:0"), queue,
|
||||||
QueueClose cancel(s.WithOpName("cancel"), queue,
|
last_stage);
|
||||||
|
QueueDequeue dequeue(s.WithOpName("dequeue").WithDevice("/CPU:0"), queue,
|
||||||
|
{DataType::DT_FLOAT});
|
||||||
|
QueueClose cancel(s.WithOpName("cancel").WithDevice("/CPU:0"), queue,
|
||||||
QueueClose::CancelPendingEnqueues(true));
|
QueueClose::CancelPendingEnqueues(true));
|
||||||
last_stage = {dequeue[0]};
|
last_stage = {dequeue[0]};
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create output.
|
// Create output.
|
||||||
AddN output(s.WithOpName("y"), last_stage);
|
AddN output(s.WithOpName("y").WithDevice("/CPU:0"), last_stage);
|
||||||
|
|
||||||
GraphDef def;
|
GraphDef def;
|
||||||
TF_CHECK_OK(s.ToGraphDef(&def));
|
TF_CHECK_OK(s.ToGraphDef(&def));
|
||||||
|
|
|
||||||
|
|
@ -438,14 +438,18 @@ def train_and_evaluate(estimator, train_spec, eval_spec):
|
||||||
'`estimator.config` must have task_type set. This usually means '
|
'`estimator.config` must have task_type set. This usually means '
|
||||||
'TF_CONFIG environment is not set correctly.')
|
'TF_CONFIG environment is not set correctly.')
|
||||||
|
|
||||||
# TODO(xiejw): error out if evaluator index is more than 0.
|
|
||||||
|
|
||||||
if config.task_type == 'local':
|
if config.task_type == 'local':
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'`task.type` in TF_CONFIG cannot be `local`. Leaving `cluster` and '
|
'`task.type` in TF_CONFIG cannot be `local`. Leaving `cluster` and '
|
||||||
'`task` properties in TF_CONFIG absent triggers train and evaluate '
|
'`task` properties in TF_CONFIG absent triggers train and evaluate '
|
||||||
'`Estimator` locally (non-distributed).')
|
'`Estimator` locally (non-distributed).')
|
||||||
|
|
||||||
|
if (config.task_type == run_config_lib.TaskType.EVALUATOR and
|
||||||
|
config.task_id > 0):
|
||||||
|
raise ValueError(
|
||||||
|
'For distributed training, there can only be one `evaluator` task '
|
||||||
|
'(with task id 0). Given task id {}'.format(config.task_id))
|
||||||
|
|
||||||
# For task type foo, call executor.run_foo.
|
# For task type foo, call executor.run_foo.
|
||||||
available_tasks = [x for x in dir(executor) if x.startswith('run_')
|
available_tasks = [x for x in dir(executor) if x.startswith('run_')
|
||||||
and x != 'run_local'
|
and x != 'run_local'
|
||||||
|
|
|
||||||
|
|
@ -71,6 +71,8 @@ _INVALID_EMPTY_EVAL_RESULT_ERR = (
|
||||||
_INVALID_EVAL_RESULT_TYPE_ERR = '`Estimator.evaluate` should return dict.'
|
_INVALID_EVAL_RESULT_TYPE_ERR = '`Estimator.evaluate` should return dict.'
|
||||||
_MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR = (
|
_MISSING_GLOBAL_STEP_IN_EVAL_RESULT_ERR = (
|
||||||
'Internal error: `Estimator.evaluate` result should have `global_step`')
|
'Internal error: `Estimator.evaluate` result should have `global_step`')
|
||||||
|
_INVALID_EVAL_TASK_ID_ERR = (
|
||||||
|
'there can only be one `evaluator` task .*with task id 0')
|
||||||
|
|
||||||
_TF_CONFIG_FOR_CHIEF = {
|
_TF_CONFIG_FOR_CHIEF = {
|
||||||
'cluster': {
|
'cluster': {
|
||||||
|
|
@ -128,7 +130,7 @@ _TF_CONFIG_FOR_EVALUATOR = {
|
||||||
},
|
},
|
||||||
'task': {
|
'task': {
|
||||||
'type': run_config_lib.TaskType.EVALUATOR,
|
'type': run_config_lib.TaskType.EVALUATOR,
|
||||||
'index': 1
|
'index': 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -351,6 +353,20 @@ class TrainAndEvaluteTest(test.TestCase):
|
||||||
_TF_CONFIG_FOR_EVALUATOR))
|
_TF_CONFIG_FOR_EVALUATOR))
|
||||||
self.assertEqual(1, mock_executor.call_task['evaluator'])
|
self.assertEqual(1, mock_executor.call_task['evaluator'])
|
||||||
|
|
||||||
|
def test_error_out_if_evaluator_task_id_is_non_zero(self):
|
||||||
|
tf_config = {
|
||||||
|
'cluster': {
|
||||||
|
run_config_lib.TaskType.CHIEF: ['host0:0'],
|
||||||
|
},
|
||||||
|
'task': {
|
||||||
|
'type': run_config_lib.TaskType.EVALUATOR,
|
||||||
|
'index': 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
with self.assertRaisesRegexp(ValueError, _INVALID_EVAL_TASK_ID_ERR):
|
||||||
|
self._test_run_task_in_distributed_training(
|
||||||
|
run_config=_create_run_config_with_cluster_spec(tf_config))
|
||||||
|
|
||||||
def test_run_local(self):
|
def test_run_local(self):
|
||||||
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
|
mock_est = test.mock.Mock(spec=estimator_lib.Estimator)
|
||||||
mock_est.config = run_config_lib.RunConfig()
|
mock_est.config = run_config_lib.RunConfig()
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user