mirror of
https://github.com/zebrajr/tensorflow.git
synced 2025-12-06 00:19:58 +01:00
Add option for build more python tests in Cmake (#11853)
* Ignore Windows built project * Fix deprecated methods in tf.contrib.python * Fix regex match for Windows build in contrib.keras * Fix Regex match for Windows build in session_bundle * * Fix deprecated methods * Fix regex match for Windows * Fix compatibility issue with Python 3.x * Add missing ops into Windows build for test * Enabled more testcases for Windows build * Clean code and fix typo * Add conditional cmake mode for enabling more unit testcase * Add Cmake mode for major Contrib packages * Add supplementary info in RAEDME for new cmake option * * Update tf_tests after testing with TF 1.3 * Clean code and resolve conflicts * Fix unsafe regex matches and format code * Update exclude list after testing with latest master branch * Fix missing module
This commit is contained in:
parent
98f0e1efec
commit
9f81374c30
1
.gitignore
vendored
1
.gitignore
vendored
|
|
@ -13,4 +13,5 @@ node_modules
|
|||
__pycache__
|
||||
*.swp
|
||||
.vscode/
|
||||
cmake_build/
|
||||
.idea/**
|
||||
|
|
|
|||
|
|
@ -29,6 +29,7 @@ option(tensorflow_BUILD_ALL_KERNELS "Build all OpKernels" ON)
|
|||
option(tensorflow_BUILD_CONTRIB_KERNELS "Build OpKernels from tensorflow/contrib/..." ON)
|
||||
option(tensorflow_BUILD_CC_TESTS "Build cc unit tests " OFF)
|
||||
option(tensorflow_BUILD_PYTHON_TESTS "Build python unit tests " OFF)
|
||||
option(tensorflow_BUILD_MORE_PYTHON_TESTS "Build more python unit tests for contrib packages" OFF)
|
||||
option(tensorflow_BUILD_SHARED_LIB "Build TensorFlow as a shared library" OFF)
|
||||
option(tensorflow_OPTIMIZE_FOR_NATIVE_ARCH "Enable compiler optimizations for the native processor architecture (if available)" ON)
|
||||
option(tensorflow_WIN_CPU_SIMD_OPTIONS "Enables CPU SIMD instructions")
|
||||
|
|
|
|||
|
|
@ -241,6 +241,13 @@ Step-by-step Windows build
|
|||
```
|
||||
ctest -C RelWithDebInfo
|
||||
```
|
||||
* `-Dtensorflow_BUILD_MORE_PYTHON_TESTS=(ON|OFF)`. Defaults to `OFF`. This enables python tests on
|
||||
serveral major packages. This option is only valid if this and tensorflow_BUILD_PYTHON_TESTS are both set as `ON`.
|
||||
After building the python wheel, you need to install the new wheel before running the tests.
|
||||
To execute the tests, use
|
||||
```
|
||||
ctest -C RelWithDebInfo
|
||||
```
|
||||
|
||||
4. Invoke MSBuild to build TensorFlow.
|
||||
|
||||
|
|
|
|||
|
|
@ -76,7 +76,9 @@ if(tensorflow_BUILD_CONTRIB_KERNELS)
|
|||
#"${tensorflow_source_dir}/tensorflow/contrib/ffmpeg/encode_audio_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/generate_vocab_remapping_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/load_and_remap_matrix_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/framework/kernels/zero_initializer_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/framework/ops/checkpoint_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/framework/ops/variable_ops.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/layers/kernels/sparse_feature_cross_kernel.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/layers/ops/sparse_feature_cross_op.cc"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/nccl/kernels/nccl_manager.cc"
|
||||
|
|
|
|||
|
|
@ -156,6 +156,21 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
|||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/*_test.py"
|
||||
)
|
||||
|
||||
if (tensorflow_BUILD_MORE_PYTHON_TESTS)
|
||||
# Adding other major packages
|
||||
file(GLOB_RECURSE tf_test_src_py
|
||||
${tf_test_src_py}
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/legacy_seq2seq/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/linalg/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/graph_editor/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/bayesflow/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/framework/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/keras/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/*_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/learn/*_test.py"
|
||||
)
|
||||
endif()
|
||||
|
||||
# exclude the ones we don't want
|
||||
set(tf_test_src_py_exclude
|
||||
# Python source line inspection tests are flaky on Windows (b/36375074).
|
||||
|
|
@ -183,6 +198,9 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
|||
# Loading resources in contrib doesn't seem to work on Windows
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/client/random_forest_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/tensor_forest/python/tensor_forest_test.py"
|
||||
# dask need fix
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/generator_io_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/graph_io_test.py"
|
||||
# Test is flaky on Windows GPU builds (b/38283730).
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/gmm_test.py"
|
||||
)
|
||||
|
|
@ -215,11 +233,8 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
|||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/py_func_test.py"
|
||||
# training tests
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/basic_session_run_hooks_test.py" # Needs tf.contrib fix.
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/evaluation_test.py" # Needs tf.contrib fix.
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/localhost_cluster_performance_test.py" # Needs portpicker.
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/monitored_session_test.py" # Needs tf.contrib fix.
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/quantize_training_test.py" # Needs quantization ops to be included in windows.
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/saver_large_variable_test.py" # Overflow error.
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/supervisor_test.py" # Flaky I/O error on rename.
|
||||
"${tensorflow_source_dir}/tensorflow/python/training/sync_replicas_optimizer_test.py" # Needs portpicker.
|
||||
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/array_ops_test.py" # depends on python/framework/test_ops
|
||||
|
|
@ -233,6 +248,45 @@ if (tensorflow_BUILD_PYTHON_TESTS)
|
|||
"${tensorflow_source_dir}/tensorflow/python/ops/cloud/bigquery_reader_ops_test.py" # No libcurl support
|
||||
# Newly running on Windows since TensorBoard backend move. Fail on Windows and need debug.
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/data/python/kernel_tests/dataset_constructor_op_test.py" # Segfaults on Windows.
|
||||
# Dask.Dataframe bugs on Window Build
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/tests/dataframe/tensorflow_dataframe_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/data_feeder_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/learn_io/io_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/graph_actions_test.py"
|
||||
# Need extra build
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/conditional_distribution_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/conditional_transformed_distribution_test.py"
|
||||
# Windows Path
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/framework/python/ops/checkpoint_ops_test.py" #TODO: Fix path
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/models_test.py"
|
||||
# Related to Windows Multiprocessing https://github.com/fchollet/keras/issues/5071
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/engine/training_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/utils/data_utils_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/callbacks_test.py"
|
||||
# Scipy needed
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/preprocessing/image_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sigmoid_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/binomial_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/chi2_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/geometric_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/inverse_gamma_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/logistic_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mixture_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_diag_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_full_covariance_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/mvn_tril_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/negative_binomial_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/poisson_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/quantized_distribution_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/relaxed_bernoulli_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/relaxed_onehot_categorical_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/transformed_distribution_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/vector_student_t_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/wishart_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/learn/python/learn/estimators/kmeans_test.py"
|
||||
# Failing with TF 1.3 (TODO)
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/estimator_test.py"
|
||||
"${tensorflow_source_dir}/tensorflow/contrib/distributions/python/kernel_tests/bijectors/sinh_arcsinh_test.py"
|
||||
)
|
||||
endif()
|
||||
list(REMOVE_ITEM tf_test_src_py ${tf_test_src_py_exclude})
|
||||
|
|
|
|||
|
|
@ -49,6 +49,7 @@ from tensorflow.contrib.distributions.python.ops.quantized_distribution import *
|
|||
from tensorflow.contrib.distributions.python.ops.relaxed_bernoulli import *
|
||||
from tensorflow.contrib.distributions.python.ops.relaxed_onehot_categorical import *
|
||||
from tensorflow.contrib.distributions.python.ops.sample_stats import *
|
||||
from tensorflow.contrib.distributions.python.ops.test_util import *
|
||||
from tensorflow.contrib.distributions.python.ops.vector_exponential_diag import *
|
||||
from tensorflow.contrib.distributions.python.ops.vector_laplace_diag import *
|
||||
from tensorflow.contrib.distributions.python.ops.wishart import *
|
||||
|
|
|
|||
|
|
@ -562,7 +562,7 @@ def assign_from_checkpoint(model_path, var_list, ignore_missing_vars=False):
|
|||
grouped_vars[ckpt_name].append(var)
|
||||
|
||||
else:
|
||||
for ckpt_name, value in var_list.iteritems():
|
||||
for ckpt_name, value in var_list.items():
|
||||
if isinstance(value, (tuple, list)):
|
||||
grouped_vars[ckpt_name] = value
|
||||
else:
|
||||
|
|
|
|||
|
|
@ -18,6 +18,7 @@ from __future__ import division
|
|||
from __future__ import print_function
|
||||
|
||||
import marshal
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import types as python_types
|
||||
|
|
@ -195,7 +196,10 @@ def func_dump(func):
|
|||
Returns:
|
||||
A tuple `(code, defaults, closure)`.
|
||||
"""
|
||||
code = marshal.dumps(func.__code__).decode('raw_unicode_escape')
|
||||
if os.name == 'nt':
|
||||
code = marshal.dumps(func.__code__).replace(b'\\',b'/').decode('raw_unicode_escape')
|
||||
else:
|
||||
code = marshal.dumps(func.__code__).decode('raw_unicode_escape')
|
||||
defaults = func.__defaults__
|
||||
if func.__closure__:
|
||||
closure = tuple(c.cell_contents for c in func.__closure__)
|
||||
|
|
|
|||
|
|
@ -505,7 +505,7 @@ class EstimatorModelFnTest(test.TestCase):
|
|||
return input_fn_utils.InputFnOps(
|
||||
features, labels, {'examples': serialized_tf_example})
|
||||
|
||||
est.export_savedmodel(est.model_dir + '/export', serving_input_fn)
|
||||
est.export_savedmodel(os.path.join(est.model_dir, 'export'), serving_input_fn)
|
||||
self.assertTrue(self.mock_saver.restore.called)
|
||||
|
||||
|
||||
|
|
@ -955,10 +955,11 @@ class EstimatorTest(test.TestCase):
|
|||
self.assertTrue('input_example_tensor' in graph_ops)
|
||||
self.assertTrue('ParseExample/ParseExample' in graph_ops)
|
||||
self.assertTrue('linear/linear/feature/matmul' in graph_ops)
|
||||
self.assertSameElements(
|
||||
['bogus_lookup', 'feature'],
|
||||
graph.get_collection(
|
||||
constants.COLLECTION_DEF_KEY_FOR_INPUT_FEATURE_KEYS))
|
||||
self.assertItemsEqual(
|
||||
['bogus_lookup', 'feature'],
|
||||
[compat.as_str_any(x) for x in graph.get_collection(
|
||||
constants.COLLECTION_DEF_KEY_FOR_INPUT_FEATURE_KEYS)])
|
||||
|
||||
|
||||
# cleanup
|
||||
gfile.DeleteRecursively(tmpdir)
|
||||
|
|
|
|||
|
|
@ -31,6 +31,7 @@ from tensorflow.contrib.session_bundle import exporter
|
|||
from tensorflow.contrib.session_bundle import manifest_pb2
|
||||
from tensorflow.python.client import session
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.platform import gfile
|
||||
|
|
@ -49,9 +50,8 @@ def _training_input_fn():
|
|||
|
||||
|
||||
class ExportTest(test.TestCase):
|
||||
|
||||
def _get_default_signature(self, export_meta_filename):
|
||||
"""Gets the default signature from the export.meta file."""
|
||||
""" Gets the default signature from the export.meta file. """
|
||||
with session.Session():
|
||||
save = saver.import_meta_graph(export_meta_filename)
|
||||
meta_graph_def = save.export_meta_graph()
|
||||
|
|
@ -68,18 +68,19 @@ class ExportTest(test.TestCase):
|
|||
self.assertTrue(gfile.Exists(export_dir))
|
||||
# Only the written checkpoints are exported.
|
||||
self.assertTrue(
|
||||
saver.checkpoint_exists(export_dir + '00000001/export'),
|
||||
saver.checkpoint_exists(os.path.join(export_dir, '00000001', 'export')),
|
||||
'Exported checkpoint expected but not found: %s' %
|
||||
(export_dir + '00000001/export'))
|
||||
os.path.join(export_dir, '00000001', 'export'))
|
||||
self.assertTrue(
|
||||
saver.checkpoint_exists(export_dir + '00000010/export'),
|
||||
saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')),
|
||||
'Exported checkpoint expected but not found: %s' %
|
||||
(export_dir + '00000010/export'))
|
||||
os.path.join(export_dir, '00000010', 'export'))
|
||||
self.assertEquals(
|
||||
six.b(os.path.join(export_dir, '00000010')),
|
||||
export_monitor.last_export_dir)
|
||||
# Validate the signature
|
||||
signature = self._get_default_signature(export_dir + '00000010/export.meta')
|
||||
signature = self._get_default_signature(
|
||||
os.path.join(export_dir, '00000010', 'export.meta'))
|
||||
self.assertTrue(signature.HasField(expected_signature))
|
||||
|
||||
def testExportMonitor_EstimatorProvidesSignature(self):
|
||||
|
|
@ -88,7 +89,7 @@ class ExportTest(test.TestCase):
|
|||
y = 2 * x + 3
|
||||
cont_features = [feature_column.real_valued_column('', dimension=1)]
|
||||
regressor = learn.LinearRegressor(feature_columns=cont_features)
|
||||
export_dir = tempfile.mkdtemp() + 'export/'
|
||||
export_dir = os.path.join(tempfile.mkdtemp(), 'export')
|
||||
export_monitor = learn.monitors.ExportMonitor(
|
||||
every_n_steps=1, export_dir=export_dir, exports_to_keep=2)
|
||||
regressor.fit(x, y, steps=10, monitors=[export_monitor])
|
||||
|
|
@ -99,7 +100,7 @@ class ExportTest(test.TestCase):
|
|||
x = np.random.rand(1000)
|
||||
y = 2 * x + 3
|
||||
cont_features = [feature_column.real_valued_column('', dimension=1)]
|
||||
export_dir = tempfile.mkdtemp() + 'export/'
|
||||
export_dir = os.path.join(tempfile.mkdtemp(), 'export')
|
||||
export_monitor = learn.monitors.ExportMonitor(
|
||||
every_n_steps=1,
|
||||
export_dir=export_dir,
|
||||
|
|
@ -122,7 +123,7 @@ class ExportTest(test.TestCase):
|
|||
input_feature_key = 'my_example_key'
|
||||
monitor = learn.monitors.ExportMonitor(
|
||||
every_n_steps=1,
|
||||
export_dir=tempfile.mkdtemp() + 'export/',
|
||||
export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
|
||||
input_fn=_serving_input_fn,
|
||||
input_feature_key=input_feature_key,
|
||||
exports_to_keep=2,
|
||||
|
|
@ -140,7 +141,7 @@ class ExportTest(test.TestCase):
|
|||
|
||||
monitor = learn.monitors.ExportMonitor(
|
||||
every_n_steps=1,
|
||||
export_dir=tempfile.mkdtemp() + 'export/',
|
||||
export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
|
||||
input_fn=_serving_input_fn,
|
||||
input_feature_key=input_feature_key,
|
||||
exports_to_keep=2,
|
||||
|
|
@ -165,7 +166,7 @@ class ExportTest(test.TestCase):
|
|||
|
||||
monitor = learn.monitors.ExportMonitor(
|
||||
every_n_steps=1,
|
||||
export_dir=tempfile.mkdtemp() + 'export/',
|
||||
export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
|
||||
input_fn=_serving_input_fn,
|
||||
input_feature_key=input_feature_key,
|
||||
exports_to_keep=2,
|
||||
|
|
@ -187,7 +188,7 @@ class ExportTest(test.TestCase):
|
|||
|
||||
monitor = learn.monitors.ExportMonitor(
|
||||
every_n_steps=1,
|
||||
export_dir=tempfile.mkdtemp() + 'export/',
|
||||
export_dir=os.path.join(tempfile.mkdtemp(), 'export'),
|
||||
input_fn=_serving_input_fn,
|
||||
input_feature_key=input_feature_key,
|
||||
exports_to_keep=2,
|
||||
|
|
@ -210,7 +211,7 @@ class ExportTest(test.TestCase):
|
|||
shape=(1,), minval=0.0, maxval=1000.0)
|
||||
}, None
|
||||
|
||||
export_dir = tempfile.mkdtemp() + 'export/'
|
||||
export_dir = os.path.join(tempfile.mkdtemp(), 'export')
|
||||
monitor = learn.monitors.ExportMonitor(
|
||||
every_n_steps=1,
|
||||
export_dir=export_dir,
|
||||
|
|
@ -235,7 +236,7 @@ class ExportTest(test.TestCase):
|
|||
y = 2 * x + 3
|
||||
cont_features = [feature_column.real_valued_column('', dimension=1)]
|
||||
regressor = learn.LinearRegressor(feature_columns=cont_features)
|
||||
export_dir = tempfile.mkdtemp() + 'export/'
|
||||
export_dir = os.path.join(tempfile.mkdtemp(), 'export')
|
||||
export_monitor = learn.monitors.ExportMonitor(
|
||||
every_n_steps=1,
|
||||
export_dir=export_dir,
|
||||
|
|
@ -244,10 +245,13 @@ class ExportTest(test.TestCase):
|
|||
regressor.fit(x, y, steps=10, monitors=[export_monitor])
|
||||
|
||||
self.assertTrue(gfile.Exists(export_dir))
|
||||
self.assertFalse(saver.checkpoint_exists(export_dir + '00000000/export'))
|
||||
self.assertTrue(saver.checkpoint_exists(export_dir + '00000010/export'))
|
||||
with self.assertRaises(errors.NotFoundError):
|
||||
saver.checkpoint_exists(os.path.join(export_dir, '00000000', 'export'))
|
||||
self.assertTrue(
|
||||
saver.checkpoint_exists(os.path.join(export_dir, '00000010', 'export')))
|
||||
# Validate the signature
|
||||
signature = self._get_default_signature(export_dir + '00000010/export.meta')
|
||||
signature = self._get_default_signature(
|
||||
os.path.join(export_dir, '00000010', 'export.meta'))
|
||||
self.assertTrue(signature.HasField('regression_signature'))
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -33,8 +33,13 @@ from tensorflow.python.util import compat
|
|||
def _create_parser(base_dir):
|
||||
# create a simple parser that pulls the export_version from the directory.
|
||||
def parser(path):
|
||||
match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$",
|
||||
compat.as_str_any(path.path))
|
||||
# Modify the path object for RegEx match for Windows Paths
|
||||
if os.name == 'nt':
|
||||
match = re.match("^" + compat.as_str_any(base_dir).replace('\\','/') + "/(\\d+)$",
|
||||
compat.as_str_any(path.path).replace('\\','/'))
|
||||
else:
|
||||
match = re.match("^" + compat.as_str_any(base_dir) + "/(\\d+)$",
|
||||
compat.as_str_any(path.path))
|
||||
if not match:
|
||||
return None
|
||||
return path._replace(export_version=int(match.group(1)))
|
||||
|
|
@ -48,13 +53,13 @@ class GcTest(test_util.TensorFlowTestCase):
|
|||
paths = [gc.Path("/foo", 8), gc.Path("/foo", 9), gc.Path("/foo", 10)]
|
||||
newest = gc.largest_export_versions(2)
|
||||
n = newest(paths)
|
||||
self.assertEquals(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)])
|
||||
self.assertEqual(n, [gc.Path("/foo", 9), gc.Path("/foo", 10)])
|
||||
|
||||
def testLargestExportVersionsDoesNotDeleteZeroFolder(self):
|
||||
paths = [gc.Path("/foo", 0), gc.Path("/foo", 3)]
|
||||
newest = gc.largest_export_versions(2)
|
||||
n = newest(paths)
|
||||
self.assertEquals(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)])
|
||||
self.assertEqual(n, [gc.Path("/foo", 0), gc.Path("/foo", 3)])
|
||||
|
||||
def testModExportVersion(self):
|
||||
paths = [
|
||||
|
|
@ -62,9 +67,9 @@ class GcTest(test_util.TensorFlowTestCase):
|
|||
gc.Path("/foo", 9)
|
||||
]
|
||||
mod = gc.mod_export_version(2)
|
||||
self.assertEquals(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)])
|
||||
self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 6)])
|
||||
mod = gc.mod_export_version(3)
|
||||
self.assertEquals(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)])
|
||||
self.assertEqual(mod(paths), [gc.Path("/foo", 6), gc.Path("/foo", 9)])
|
||||
|
||||
def testOneOfEveryNExportVersions(self):
|
||||
paths = [
|
||||
|
|
@ -73,7 +78,7 @@ class GcTest(test_util.TensorFlowTestCase):
|
|||
gc.Path("/foo", 8), gc.Path("/foo", 33)
|
||||
]
|
||||
one_of = gc.one_of_every_n_export_versions(3)
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
one_of(paths), [
|
||||
gc.Path("/foo", 3), gc.Path("/foo", 6), gc.Path("/foo", 8),
|
||||
gc.Path("/foo", 33)
|
||||
|
|
@ -84,14 +89,14 @@ class GcTest(test_util.TensorFlowTestCase):
|
|||
# Test that here.
|
||||
paths = [gc.Path("/foo", 0), gc.Path("/foo", 4), gc.Path("/foo", 5)]
|
||||
one_of = gc.one_of_every_n_export_versions(3)
|
||||
self.assertEquals(one_of(paths), [gc.Path("/foo", 0), gc.Path("/foo", 5)])
|
||||
self.assertEqual(one_of(paths), [gc.Path("/foo", 0), gc.Path("/foo", 5)])
|
||||
|
||||
def testUnion(self):
|
||||
paths = []
|
||||
for i in xrange(10):
|
||||
paths.append(gc.Path("/foo", i))
|
||||
f = gc.union(gc.largest_export_versions(3), gc.mod_export_version(3))
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
f(paths), [
|
||||
gc.Path("/foo", 0), gc.Path("/foo", 3), gc.Path("/foo", 6),
|
||||
gc.Path("/foo", 7), gc.Path("/foo", 8), gc.Path("/foo", 9)
|
||||
|
|
@ -103,9 +108,9 @@ class GcTest(test_util.TensorFlowTestCase):
|
|||
gc.Path("/foo", 9)
|
||||
]
|
||||
mod = gc.negation(gc.mod_export_version(2))
|
||||
self.assertEquals(mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)])
|
||||
self.assertEqual(mod(paths), [gc.Path("/foo", 5), gc.Path("/foo", 9)])
|
||||
mod = gc.negation(gc.mod_export_version(3))
|
||||
self.assertEquals(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)])
|
||||
self.assertEqual(mod(paths), [gc.Path("/foo", 4), gc.Path("/foo", 5)])
|
||||
|
||||
def testPathsWithParse(self):
|
||||
base_dir = os.path.join(test.get_temp_dir(), "paths_parse")
|
||||
|
|
@ -115,7 +120,7 @@ class GcTest(test_util.TensorFlowTestCase):
|
|||
# add a base_directory to ignore
|
||||
gfile.MakeDirs(os.path.join(base_dir, "ignore"))
|
||||
|
||||
self.assertEquals(
|
||||
self.assertEqual(
|
||||
gc.get_paths(base_dir, _create_parser(base_dir)),
|
||||
[
|
||||
gc.Path(os.path.join(base_dir, "0"), 0),
|
||||
|
|
|
|||
|
|
@ -301,7 +301,12 @@ class Exporter(object):
|
|||
if exports_to_keep:
|
||||
# create a simple parser that pulls the export_version from the directory.
|
||||
def parser(path):
|
||||
match = re.match("^" + export_dir_base + "/(\\d{8})$", path.path)
|
||||
if os.name == 'nt':
|
||||
match = re.match("^" + export_dir_base.replace('\\','/') + "/(\\d{8})$",
|
||||
path.path.replace('\\','/'))
|
||||
else:
|
||||
match = re.match("^" + export_dir_base + "/(\\d{8})$",
|
||||
path.path)
|
||||
if not match:
|
||||
return None
|
||||
return path._replace(export_version=int(match.group(1)))
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user