commit 2ed1077a8313c783da4098bdde6c69cde3290329 Author: Yangqing Jia Date: Thu Jun 25 16:26:01 2015 -0700 A clean init for Caffe2, removing my earlier hacky commits. diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000000..d7470e5a39a --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +.DS_Store +*.pyc +gen*/ diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000000..cfa718ac64a --- /dev/null +++ b/LICENSE @@ -0,0 +1,30 @@ +Copyright (c) 2015 Yangqing Jia +All Rights Reserved. + +== LICENSE == + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +== DECLARATION == + +Some parts of the caffe2 code is derived from the original Caffe code, which is +created by Yangqing Jia and is now a BSD-licensed open-source project. The Caffe +license is attached as LICENSE.caffe. diff --git a/LICENSE.caffe b/LICENSE.caffe new file mode 100644 index 00000000000..94611a7b31c --- /dev/null +++ b/LICENSE.caffe @@ -0,0 +1,46 @@ +*** begin Caffe license *** +COPYRIGHT + +All contributions by the University of California: +Copyright (c) 2014, The Regents of the University of California (Regents) +All rights reserved. + +All other contributions: +Copyright (c) 2014, the respective contributors +All rights reserved. + +Caffe uses a shared copyright model: each contributor holds copyright over +their contributions to Caffe. The project versioning records all such +contribution and copyright details. If a contributor wants to further mark +their specific copyright on a particular contribution, they should indicate +their copyright solely in the commit message of the change when it is +committed. + +LICENSE + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +CONTRIBUTION AGREEMENT + +By contributing to the BVLC/caffe repository through pull-request, comment, +or otherwise, the contributor releases their content to the +license and copyright terms herein. +*** end Caffe license *** \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 00000000000..4f0f5a6ecfe --- /dev/null +++ b/Makefile @@ -0,0 +1,21 @@ +# This makefile does nothing but delegating the actual compilation to build.py. + +all: + @python brewery.py build + +clean: + @python brewery.py clean + +reallyclean: + @python brewery.py reallyclean + +test: + @python brewery.py test + +lint: + @find caffe2 -type f -exec python cpplint.py {} \; + +linecount: + @cloc --read-lang-def=caffe.cloc caffe2 pycaffe2 || \ + echo "Cloc is not available on the machine. You can install cloc with " && \ + echo " sudo apt-get install cloc" \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 00000000000..8ab2d9f1846 --- /dev/null +++ b/README.md @@ -0,0 +1,16 @@ +If you are not Yangqing and you don't know what this repository is, you may have +stumbled upon it with some links or forked repositories in the wild. Please, let +me know since I want to make the visibility of this library as small as possible +for now. + +Yangqing +(me@daggerfs.com) + +# Caffe2 + +Caffe2 is a deep learning framework made with expression, speed, and modularity in mind. It is an experimental refactoring of Caffe. + +## License and Citation + +Caffe2 is released under the [BSD 2-Clause license](https://github.com/Yangqing/caffe2/blob/master/LICENSE). + diff --git a/brewery.py b/brewery.py new file mode 100644 index 00000000000..9b48431822b --- /dev/null +++ b/brewery.py @@ -0,0 +1,661 @@ + +import cPickle as pickle +from collections import defaultdict +import multiprocessing +import glob +import hashlib +import os +import shlex +import shutil +import subprocess +import sys +import tempfile +import traceback + +from build_env import Env + +class Colors(object): + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + +def BuildDebug(message, *args): + # Note(Yangqing): if you want to know detailed message about the build, + # uncomment the following line. + print Colors.OKBLUE + 'DEBUG:', message % args, Colors.ENDC + return + +def BuildLog(message, *args): + print Colors.OKGREEN + 'LOG:', message % args, Colors.ENDC + +def BuildWarning(message, *args): + print Colors.WARNING + 'WARNING:', message % args, Colors.ENDC + +def BuildFatal(message, *args): + print Colors.FAIL + 'FATAL:', message % args, Colors.ENDC + print Colors.FAIL + 'Build exiting.' + Colors.ENDC + Brewery.Finalize() + sys.exit(1) + +def BuildFatalIf(command, message, *args): + if command: + BuildFatal(message, *args) + +_single_command_env = os.environ +if 'PYTHONPATH' not in _single_command_env: + _single_command_env['PYTHONPATH'] = '' +_single_command_env['PYTHONPATH'] = ( + Env.GENDIR + ':' + _single_command_env['PYTHONPATH']) + +def RunSingleCommand(command): + BuildDebug(command) + try: + proc = subprocess.Popen(shlex.split(command), stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, env=_single_command_env) + stdout, _ = proc.communicate() + if proc.returncode: + print stdout + return proc.returncode + except: # all exceptions caught here. + e = sys.exc_info()[0] + return str(e) + +def Glob(patterns): + """Globs all files with the given patterns, relative to the path of the BREW + file.""" + files = [] + if type(patterns) is str: + patterns = [patterns] + for pattern in patterns: + full_pattern = os.path.join(Brewery.CWD, pattern) + files += glob.glob(full_pattern) + prefix_len = len(Brewery.CWD) + 1 + return [f[prefix_len:] for f in files if os.path.isfile(f)] + +def RectifyFileName(name): + """Rectifies a build file name to its absolute name.""" + if name.startswith("//"): + # Simply replace the "//" with the root folder. + out_name = name[2:] + else: + # Add the current working directory. + out_name = os.path.join(Brewery.CWD, name) + # check if the name exists. + BuildFatalIf(not os.path.exists(out_name), 'Cannot find file %s' % out_name) + return out_name + +def RectifyFileNames(names): + return [RectifyFileName(n) for n in sorted(names)] + +def RectifyTarget(name): + """Rectifies a build target name.""" + if name.startswith("//"): + return name + elif name.startswith(":"): + return Brewery.TARGET_PREFIX + name + else: + if Brewery.TARGET_PREFIX == '//': + return Brewery.TARGET_PREFIX + name + return Brewery.TARGET_PREFIX + ":" + name + +def RectifyTargets(names): + return [RectifyTarget(n) for n in sorted(names)] + +def MakeGenDirs(rectified_srcs): + for src in rectified_srcs: + dst = os.path.join(Env.GENDIR, src) + try: + os.makedirs(os.path.dirname(dst)) + except OSError as e: + pass + +def CopyToGenDir(rectified_srcs): + MakeGenDirs(rectified_srcs) + for src in rectified_srcs: + shutil.copyfile(src, GenFilename(src)) + +def GenFilename(name, new_ext=None, original_ext=None): + if new_ext: + if original_ext: + new_name = name[:name.rfind(original_ext)] + new_ext + else: + new_name = name[:name.rfind('.') + 1] + new_ext + else: + new_name = name + return os.path.join(Env.GENDIR, new_name) + +def MergeOrderedObjs(dep_lists): + added = set() + output = [] + for dep_list in dep_lists: + for item in dep_list[::-1]: + if item not in added: + added.add(item) + output.insert(0, item) + return output + +class Brewery(object): + # Targets store the dictionary from the target name to the build objects. + _targets = dict() + # Success stores whether a target is successfully built. + _success = defaultdict(bool) + # deps_map is a dictionary mapping each target to its dependents. + _deps_map = dict() + # signature_map is the map that stores the signatures for build targets. + _signatures = defaultdict(str) + _signature_filename = 'brewery.signature' + # Pool is the compute pool that one can use to run a list of commands in + # parallel. + Pool = multiprocessing.Pool(Env.CPUS) + #Pool = multiprocessing.Pool(1) + CWD = '' + TARGET_PREFIX = '//' + TMPDIR = '' + + def __init__(self): + """Brewery is a singleton and should not be instantiated.""" + raise NotImplementedError( + 'Build system error: there shall only be one brewery.') + + @classmethod + def InitBrewery(cls): + """Initializes the brewery, e.g. loads the signatures currently built.""" + try: + os.makedirs(Env.GENDIR) + except OSError as e: + pass + cls.TMPDIR = tempfile.mkdtemp() + if os.path.exists(os.path.join(Env.GENDIR, cls._signature_filename)): + BuildDebug('Loading the signature file.') + cls._signatures = pickle.load( + open(os.path.join(Env.GENDIR, cls._signature_filename))) + cls.FindAndParseBuildFiles() + + @classmethod + def Finalize(cls): + """Finalizes the brew process.""" + if os.path.exists(Env.GENDIR): + BuildDebug('Saving the signature file.') + pickle.dump(cls._signatures, + open(os.path.join(Env.GENDIR, cls._signature_filename), 'w')) + else: + BuildDebug('No gendir present. Exiting.') + shutil.rmtree(cls.TMPDIR) + + @classmethod + def Get(cls, name): + return cls._targets[name] + + @classmethod + def FindAndParseBuildFiles(cls): + """Find and parse all the BREW files in the subfolders.""" + build_files = [os.path.join(d[2:], f) + for (d, _, files) in os.walk('.') if not d.startswith(Env.GENDIR) + for f in files if f.endswith('BREW')] + for build_file in build_files: + # Set the current working directory of the environment, and parse the build + # file. + BuildDebug("Parsing %s" % build_file) + cls.SetCwd(os.path.dirname(build_file)) + execfile(build_file) + cls.SetCwd('') + return + + @classmethod + def SetCwd(cls, cwd): + if cwd and not os.path.isdir(cwd): + # cwd should either be empty, or is a directory. + raise RuntimeError('Setting an invalid cwd: %s' % cwd) + cls.CWD = cwd + cls.TARGET_PREFIX = '//' + cwd + + @classmethod + def RunInParallel(cls, commands): + if any(cls.Pool.map(RunSingleCommand, commands)): + BuildWarning('Command failed.') + return False + else: + return True + + @classmethod + def Register(cls, name, target): + BuildFatalIf(name in cls._targets, + "%s already in build target.", name) + BuildDebug("Registered build target %s, deps %s", name, str(target.deps)) + cls._targets[name] = target + cls._deps_map[name] = target.deps + + @classmethod + def _GetExecutionChain(cls, targets): + """Gets the execution chain.""" + # First, verify all dependencies. + for t in cls._targets: + for d in cls._deps_map[t]: + BuildFatalIf(d not in cls._targets, + "Dependency %s for target %s does not exist.", d, t) + if len(targets) == 0: + targets = cls._targets + else: + # Get all targets that we need to build. + seen_targets = set(targets) + idx = 0 + while idx < len(targets): + for d in cls._deps_map[targets[idx]]: + if d not in seen_targets: + seen_targets.add(d) + targets.append(d) + idx += 1 + # Now, create a topological order. + inverse_deps_map = defaultdict(list) + # Get the graph of all targets + for t in targets: + for d in cls._deps_map[t]: + inverse_deps_map[d].append(t) + deps_count = dict((t, len(cls._deps_map[t])) for t in targets) + #BuildDebug("deps count: %s", str(deps_count)) + frontier = set(t for t in deps_count if deps_count[t] == 0) + build_order = [] + while frontier: + current = frontier.pop() + #BuildDebug("processing %s", current) + build_order.append(current) + for t in inverse_deps_map[current]: + deps_count[t] -= 1 + if deps_count[t] == 0: + #BuildDebug('Add to frontier: %s', t) + frontier.add(t) + # If this does not cover all targets, the graph is not a DAG. + BuildFatalIf(len(build_order) != len(targets), + "There are cycles in the dependency graph!") + BuildDebug('Build order: %s', str(build_order)) + return build_order + + @classmethod + def Signature(cls, target): + # Returns the builtsignature of the current target. + return cls._signatures[target] + + @classmethod + def Success(cls, target): + return cls._success[target] + + @classmethod + def ClearSignature(cls, including_third_party=False): + if including_third_party: + cls._signatures = defaultdict(str) + else: + keys = cls._signatures.keys() + for k in keys: + if not k.startswith('//third_party'): + del cls._signatures[k] + + @classmethod + def Build(cls, targets): + """Build all the targets, using their topological order.""" + BuildDebug("Start building.") + build_order = cls._GetExecutionChain(targets) + for t in build_order: + BuildLog("Building %s", t) + cls._success[t], changed, new_signature = ( + cls._targets[t].SetUpAndBuild(cls._signatures[t])) + if cls._success[t]: + cls._signatures[t] = new_signature + # Finally, print a summary of the build results. + succeeded = [key for key in cls._success if cls._success[key]] + BuildDebug("Successfully built %d targets." % len(succeeded)) + #for key in cls._success: + # if cls._success[key]: + # BuildDebug(key) + failed = [key for key in cls._success if not cls._success[key]] + if len(failed) > 0: + BuildWarning("Failed to build:") + for key in failed: + BuildWarning(key) + + @classmethod + def Draw(cls): + import pydot + graph = pydot.Dot("brewery", rankdir="LR") + nodes = {} + node_style = {'shape': 'box', 'color': '#0F9D58', 'style': 'filled', + 'fontcolor': '#FFFFFF'} + for target_name in cls._targets: + nodes[target_name] = pydot.Node('"' + target_name + '"', **node_style) + graph.add_node(nodes[target_name]) + for target_name in cls._deps_map: + for dep_name in cls._deps_map[target_name]: + graph.add_edge(pydot.Edge(nodes[dep_name], nodes[target_name])) + graph.write(graph.get_name() + '.dot', format='raw') + with open(graph.get_name() + '.pdf', 'w') as fid: + subprocess.call(['dot', '-Tpdf', graph.get_name() + '.dot'], stdout=fid) + +class BuildTarget(object): + """A build target that can be executed with the Build() function.""" + def __init__(self, name, srcs, other_files=[], deps=[]): + self.name = RectifyTarget(name) + self.srcs = RectifyFileNames(srcs) + self.files = sorted(self.srcs + other_files) + self.deps = sorted(RectifyTargets(deps)) + self.command_groups = [] + Brewery.Register(self.name, self) + + def GetSignature(self): + """Generate the signature of the build object.""" + src_digest = ''.join([hashlib.sha256(open(f, 'rb').read()).hexdigest() + for f in self.files]) + dep_digest = ''.join([Brewery.Signature(d) for d in self.deps]) + return hashlib.sha256(src_digest + dep_digest).hexdigest() + + def SetUpAndBuild(self, built_signature): + self.SetUp() + signature = self.GetSignature() + if not all(Brewery.Success(d) for d in self.deps): + BuildWarning("Not all dependencies have succeeded. Skipping build.") + return False, True, signature + if signature != built_signature: + success = self.Build() + return success, True, signature + return True, False, signature + + def SetUp(self): + """Set up the build object's variables. + + This will always run even if the target has already been built. Anything + that further dependencies will need should be implemented here. + + If your target just emits a set of shell commands, in SetUp() you can set + self.command_groups and use the default Build function, which basically + sends the command groups to a execution pool. + """ + BuildFatal('Not implemented.') + + def Build(self): + """Builds the target.""" + success = True + for command_group in self.command_groups: + success &= Brewery.RunInParallel(command_group) + if not success: + return False + return True + +class proto_library(BuildTarget): + """Builds a protobuffer library. + + A protobuffer library builds a set of protobuffer source files to its cc and + python source files, as well as the static library named "libname.a". + """ + def __init__(self, name, srcs, deps=[]): + BuildTarget.__init__(self, name, srcs, deps=deps) + + def SetUp(self): + MakeGenDirs(self.srcs) + # proto_library depends on protoc, so it would need to add that to the + # includes folder. + pbcc_files = [GenFilename(filename, 'pb.cc') for filename in self.srcs] + pbo_files = [GenFilename(filename, 'pb.o') for filename in self.srcs] + proto_commands = [ + ' '.join([Env.PROTOC_BINARY, '-I.', '--cpp_out', Env.GENDIR, + '--python_out', Env.GENDIR, filename]) + for filename in self.srcs] + cpp_commands = [ + ' '.join([Env.CC, Env.CFLAGS, Env.INCLUDES, '-c', pbcc, '-o', pbo]) + for pbcc, pbo in zip(pbcc_files, pbo_files)] + self.cc_obj_files = pbo_files + self.cc_obj_files += MergeOrderedObjs( + [Brewery.Get(dep).cc_obj_files for dep in self.deps]) + self.command_groups = [proto_commands, cpp_commands] + + +class cc_target(BuildTarget): + def __init__(self, name, srcs, hdrs=[], deps=[], cflags=[], external_libs=[], + build_binary=False, is_test=False, whole_archive=False, + shared=False): + self.hdrs = RectifyFileNames(hdrs) + self.cflags = cflags + self.external_libs = [ + '-l' + s if not s.startswith('-') else s for s in external_libs] + self.build_binary = build_binary + self.is_test = is_test + self.whole_archive = whole_archive + self.shared = shared + BuildTarget.__init__(self, name, srcs, self.hdrs, deps=deps) + + def OutputName(self, is_library=False, is_shared=False): + name_split = self.name.split(':') + if is_library: + if is_shared: + return os.path.join( + Env.GENDIR, name_split[0][2:], + 'lib' + name_split[1] + Env.SHARED_LIB_EXT) + else: + return os.path.join( + Env.GENDIR, name_split[0][2:], 'lib' + name_split[1] + '.a') + else: + return os.path.join(Env.GENDIR, name_split[0][2:], name_split[1]) + + def SetUp(self): + MakeGenDirs(self.srcs) + CopyToGenDir(self.hdrs) + obj_files = [GenFilename(src, 'o') for src in self.srcs] + cpp_commands = [ + ' '.join([Env.CC, Env.CFLAGS, Env.INCLUDES, ' '.join(self.cflags), + '-c', src, '-o', obj]) + for src, obj in zip(self.srcs, obj_files)] + archive_file = self.OutputName(is_library=True) + # Create the archive + link_commands = [ + ' '.join([Env.LINK_STATIC, archive_file] + obj_files)] + if self.whole_archive: + archive_file = Env.WHOLE_ARCHIVE_TEMPLATE % archive_file + self.cc_obj_files = MergeOrderedObjs( + [Brewery.Get(dep).cc_obj_files for dep in self.deps] + + [self.external_libs]) + self.cc_obj_files.insert(0, archive_file) + if self.build_binary: + link_binary_commands = [ + ' '.join([Env.LINK_BINARY, self.OutputName()] + self.cc_obj_files + + [Env.LINKFLAGS])] + self.command_groups = [cpp_commands, link_commands, link_binary_commands] + elif self.shared: + link_shared_commands = [' '.join( + [Env.LINK_SHARED, self.OutputName(is_library=True, is_shared=True)] + + obj_files + self.cc_obj_files[1:] + [Env.LINKFLAGS])] + self.command_groups = [cpp_commands, link_commands, link_shared_commands] + else: + self.command_groups = [cpp_commands, link_commands] + if self.is_test: + # Add test command + self.command_groups.append([ + ' '.join([self.OutputName(), '--caffe_test_root', + os.path.abspath(Env.GENDIR), + '--gtest_filter=-*.LARGE_*'])]) + + +def cc_library(*args, **kwargs): + return cc_target(*args, **kwargs) + +def cc_binary(*args, **kwargs): + return cc_target(*args, build_binary=True, **kwargs) + +def cc_test(*args, **kwargs): + if 'cflags' not in kwargs: + kwargs['cflags'] = [] + kwargs['cflags'].append("-DGTEST_USE_OWN_TR1_TUPLE=1") + return cc_target( + *args, build_binary=True, is_test=True, whole_archive=True, **kwargs) + + +class cuda_library(BuildTarget): + def __init__(self, name, srcs, hdrs=[], deps=[], cflags=[], + whole_archive=False): + self.hdrs = RectifyFileNames(hdrs) + self.cflags = cflags + self.whole_archive = whole_archive + BuildTarget.__init__(self, name, srcs, self.hdrs, deps=deps) + + def OutputName(self, is_library=False): + name_split = self.name.split(':') + if is_library: + return os.path.join( + Env.GENDIR, name_split[0][2:], 'lib' + name_split[1] + '.a') + else: + return os.path.join(Env.GENDIR, name_split[0][2:], name_split[1]) + + def SetUp(self): + MakeGenDirs(self.srcs) + CopyToGenDir(self.hdrs) + obj_files = [GenFilename(src, 'cuo') for src in self.srcs] + cpp_commands = [ + ' '.join([Env.NVCC, Env.NVCC_CFLAGS, Env.INCLUDES, + ' '.join(self.cflags), '-c', src, '-o', obj]) + for src, obj in zip(self.srcs, obj_files)] + archive_file = self.OutputName(is_library=True) + # Create the archive + link_commands = [ + ' '.join([Env.LINK_STATIC, archive_file] + + obj_files)] + if self.whole_archive: + archive_file = Env.WHOLE_ARCHIVE_TEMPLATE % archive_file + self.cc_obj_files = MergeOrderedObjs( + [Brewery.Get(dep).cc_obj_files for dep in self.deps]) + # We will need to add nvidia link targets as well + self.cc_obj_files.append(Env.NVCC_LINKS) + self.cc_obj_files.insert(0, archive_file) + self.command_groups = [cpp_commands, link_commands] + + +class filegroup(BuildTarget): + def __init__(self, name, srcs, deps=[]): + self.cc_obj_files = [] + BuildTarget.__init__(self, name, srcs, deps=deps) + + def SetUp(self): + CopyToGenDir(self.srcs) + +def py_library(*args, **kwargs): + return filegroup(*args, **kwargs) + +def cc_headers(*args, **kwargs): + return filegroup(*args, **kwargs) + +class py_test(BuildTarget): + def __init__(self, name, srcs, deps=[]): + self.cc_obj_files = [] + BuildTarget.__init__(self, name, srcs, deps=deps) + + def SetUp(self): + CopyToGenDir(self.srcs) + if len(self.srcs) > 1: + raise RuntimeError('py_test should only take one python source file.') + # Add test command + self.command_groups = [ + ['python %s' % GenFilename(self.srcs[0])]] + + +class cc_thirdparty_target(BuildTarget): + """thirdparty_target should only be used in third_party to build things with + a pre-defined script. Note that this will also set the following values: + cc_includes: the include folder needed for compiling dependent targets. + cc_obj_files: the object files produced by the target. + + When building, this script will copy all stuff to a temporary directory, so + that the original source tree is not affected. + """ + def __init__(self, name, srcs, commands, cc_obj_files, deps=[]): + self.cwd = Brewery.CWD + self.build_dir = os.path.join(Brewery.TMPDIR, Brewery.CWD) + self.commands = [ + 'SRCDIR=%s' % self.build_dir, + 'DSTDIR=%s' % os.path.join(os.path.abspath(Env.GENDIR), "third_party"), + 'CPUS=%d' % Env.CPUS, + 'cd %s' % self.build_dir, + ] + commands + self.cc_obj_files = [ + os.path.join(Env.GENDIR, "third_party", f) + for f in cc_obj_files if not f.startswith('-l')] + [ + f for f in cc_obj_files if f.startswith('-l')] + BuildTarget.__init__(self, name, srcs, deps=deps) + + def SetUp(self): + self.cc_obj_files += MergeOrderedObjs( + [Brewery.Get(dep).cc_obj_files for dep in self.deps]) + + def Build(self): + # First, copy all things to the temp directory + shutil.copytree(self.cwd, self.build_dir) + BuildDebug("script: %s" % str(self.commands)) + + proc = subprocess.Popen(' && '.join(self.commands), stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, shell=True) + stdout, _ = proc.communicate() + if proc.returncode: + BuildWarning("Script failed.") + print stdout + return False + return True + +class shell_script(BuildTarget): + """Shell scripts are directly run to generate data files. It is run from the + root of the gendir. + """ + def __init__(self, name, srcs, commands, deps=[]): + self.cwd = Brewery.CWD + self.commands = [ + 'GENDIR=%s' % os.path.abspath(Env.GENDIR), + 'CWD=%s' % self.cwd, + 'cd %s' % os.path.abspath(Env.GENDIR), + ] + commands + BuildTarget.__init__(self, name, srcs, deps=deps) + + def SetUp(self): + """A shell script should produce no cc_obj_files. This is here just so that + a cc object can use shell_script as a data dependency. + """ + CopyToGenDir(self.srcs) + self.cc_obj_files = [] + + def Build(self): + BuildDebug("script: %s" % str(self.commands)) + proc = subprocess.Popen(' && '.join(self.commands), stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, shell=True) + stdout, _ = proc.communicate() + if proc.returncode: + BuildWarning("Script failed.") + print stdout + return False + return True + +################################################################################ +# Below are functions during the main entry. +################################################################################ + +def main(argv): + """The main entry of the build script.""" + BuildLog('Welcome to Caffe2. Running command: %s' % str(argv)) + Brewery.InitBrewery() + if len(sys.argv) > 1: + if sys.argv[1] == 'clean': + for folder in ['caffe2', 'pycaffe2']: + os.system('rm -rf ' + os.path.join(Env.GENDIR, folder)) + Brewery.ClearSignature() + elif sys.argv[1] == 'reallyclean': + os.system('rm -rf ' + Env.GENDIR) + BuildLog('Finished cleaning.') + elif sys.argv[1] == 'build': + # Build all targets. + targets = sys.argv[2:] + Brewery.Build(targets) + elif sys.argv[1] == 'draw': + # Draws the dependency graph. + Brewery.Draw() + else: + BuildFatal('Unknown command: %s' % sys.argv[1]) + else: + BuildLog('Finished parsing all build files without error.') + Brewery.Finalize() + +if __name__ == "__main__": + main(sys.argv) diff --git a/build_env.py b/build_env.py new file mode 100644 index 00000000000..e997cfac5c5 --- /dev/null +++ b/build_env.py @@ -0,0 +1,156 @@ +""" build_env defines the general environment that we use to build. +""" + +import multiprocessing +import os +import subprocess +import sys + +def _GetSubprocessOutput(commands): + try: + proc = subprocess.Popen(commands, stdout=subprocess.PIPE) + out, err = proc.communicate() + except OSError as err: + print 'Cannot run command', commands, '. Return empty output.' + return '' + return out.strip() + +def _GetCompilerType(CC): + # determine compiler type. + _COMPILER_VERSION_STR = _GetSubprocessOutput([CC, '--version']) + if 'clang' in _COMPILER_VERSION_STR: + return 'clang' + elif ('g++' in _COMPILER_VERSION_STR or + 'Free Software Foundation' in _COMPILER_VERSION_STR): + return 'g++' + else: + raise RuntimeError('Cannot determine C++ compiler type.') + + +class Env(object): + """Env is the class that stores all the build variables.""" + # Define the compile binary commands. + CC = 'c++' + MPICC = 'mpic++' + LINK_BINARY = CC + ' -o' + LINK_SHARED = CC + ' -shared -o' + LINK_STATIC = 'ar rcs' + # Protobuf constants + PROTOC_BINARY = "protoc" + + if sys.platform == 'darwin': + # For some reason, python on mac still recognizes the .so extensions... + # So we will use .so here still. + SHARED_LIB_EXT = '.so' + elif sys.platform.startswith('linux'): + SHARED_LIB_EXT = '.so' + else: + raise RuntimeError('Unknown system platform.') + + COMPILER_TYPE = _GetCompilerType(CC) + + #determine mpi include and mpi link flags. + MPI_INCLUDES = _GetSubprocessOutput([MPICC, '--showme:incdirs']).split(' ') + MPI_LIBDIRS = _GetSubprocessOutput([MPICC, '--showme:libdirs']).split(' ') + MPI_LIBS = _GetSubprocessOutput([MPICC, '--showme:libs']).split(' ') + if len(MPI_INCLUDES) == 1 and MPI_INCLUDES[0] == '': + print ('MPI not found, so some libraries and binaries that use MPI will ' + 'not compile correctly. If you would like to use those, you can ' + 'install MPI on your machine. The easiest way to install on ubuntu ' + 'is via apt-get, and on mac via homebrew.') + # Set all values above to empty lists, so at least others will compile. + MPI_INCLUDES = [] + MPI_LIBDIRS = [] + MPI_LIBS = [] + + # Determine the CUDA directory. + if os.path.exists('/usr/local/cuda'): + CUDA_DIR = '/usr/local/cuda' + else: + raise RuntimeError('Cannot find Cuda directory.') + NVCC = os.path.join(CUDA_DIR, 'bin', 'nvcc') + NVCC_INCLUDES = [os.path.join(CUDA_DIR, 'include')] + + # Determine the NVCC link flags. + if COMPILER_TYPE == 'clang': + NVCC_LINKS = ('-rpath %s -L%s' + % (os.path.join(CUDA_DIR, 'lib'), os.path.join(CUDA_DIR, 'lib'))) + elif COMPILER_TYPE == 'g++': + NVCC_LINKS = ('-Wl,-rpath=%s -L%s' + % (os.path.join(CUDA_DIR, 'lib64'), os.path.join(CUDA_DIR, 'lib64'))) + else: + raise RuntimeError('Unknown compiler type to set nvcc link flags.') + NVCC_LINKS += ' -l' + ' -l'.join([ + 'cublas_static', 'curand_static', 'cuda', 'cudart_static', 'culibos']) + if sys.platform.startswith('linux'): + NVCC_LINKS += ' -l' + ' -l'.join(['rt', 'dl']) + + # NVCC C flags. + NVCC_CFLAGS = ' '.join([ + # add cflags here. + '-Xcompiler -fPIC', + '-O2', + '-std=c++11', + '-gencode=arch=compute_30,code=sm_30', + ]) + + # Determine how the compiler deals with whole archives. + if COMPILER_TYPE == 'clang': + WHOLE_ARCHIVE_TEMPLATE = '-Wl,-force_load,%s' + elif COMPILER_TYPE == 'g++': + WHOLE_ARCHIVE_TEMPLATE = '-Wl,--whole-archive %s -Wl,--no-whole-archive' + else: + raise RuntimeError('Unknown compiler type to set whole-archive template.') + + # General cflags that should be added in all cc arguments. + CFLAGS = ' '.join([ + # add cflags here. + '-fPIC', + '-DPIC', + #'-O0', + '-O2', + #'-pg', + '-DNDEBUG', + '-msse', + '-mavx', + '-ffast-math', + '-std=c++11', + '-W', + '-Wall', + '-Wno-unused-parameter', + '-Wno-sign-compare', + #'-Wno-c++11-extensions', + ]) + + GENDIR = 'gen' + # General include folders. + INCLUDES = NVCC_INCLUDES + MPI_INCLUDES + [ + GENDIR, + os.path.join(GENDIR, 'third_party'), + os.path.join(GENDIR, 'third_party/include'), + '/usr/local/include', + ] + INCLUDES = ' '.join(['-I' + s for s in INCLUDES]) + # Python + INCLUDES += ' ' + _GetSubprocessOutput(['python-config', '--includes']) + # General lib folders. + LIBDIRS = MPI_LIBDIRS + [ + '/usr/local/lib', + ] + LIBDIRS = ' '.join(['-L' + s for s in LIBDIRS]) + # General link flags for binary targets + LIBS = [] + LIBS = ' '.join(['-l' + s for s in LIBS]) + LINKFLAGS = ' '.join([ + # Add link flags here + '-pthread', + #'-pg', + ]) + ' ' + LIBDIRS + ' ' + LIBS + PYTHON_LIBS = [_GetSubprocessOutput(['python-config', '--ldflags'])] + + CPUS = multiprocessing.cpu_count() + + def __init__(self): + """ENV is a singleton and should not be instantiated.""" + raise NotImplementedError( + 'Build system error: ENV should not be instantiated.') diff --git a/caffe.cloc b/caffe.cloc new file mode 100644 index 00000000000..a36ab619113 --- /dev/null +++ b/caffe.cloc @@ -0,0 +1,53 @@ +Bourne Shell + filter remove_matches ^\s*# + filter remove_inline #.*$ + extension sh + script_exe sh +C + filter remove_matches ^\s*// + filter call_regexp_common C + filter remove_inline //.*$ + extension c + extension ec + extension pgc +C++ + filter remove_matches ^\s*// + filter remove_inline //.*$ + filter call_regexp_common C + extension C + extension cc + extension cpp + extension cxx + extension pcc +C/C++ Header + filter remove_matches ^\s*// + filter call_regexp_common C + filter remove_inline //.*$ + extension H + extension h + extension hh + extension hpp +CUDA + filter remove_matches ^\s*// + filter remove_inline //.*$ + filter call_regexp_common C + extension cu +Python + filter remove_matches ^\s*# + filter docstring_to_C + filter call_regexp_common C + filter remove_inline #.*$ + extension py +make + filter remove_matches ^\s*# + filter remove_inline #.*$ + extension Gnumakefile + extension Makefile + extension am + extension gnumakefile + extension makefile + filename Gnumakefile + filename Makefile + filename gnumakefile + filename makefile + script_exe make diff --git a/caffe/BREW b/caffe/BREW new file mode 100644 index 00000000000..9a7cd79aa09 --- /dev/null +++ b/caffe/BREW @@ -0,0 +1,4 @@ +filegroup( + name = "caffe_python", + srcs = ["__init__.py"], +) \ No newline at end of file diff --git a/caffe/__init__.py b/caffe/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/caffe/proto/BREW b/caffe/proto/BREW new file mode 100644 index 00000000000..eac4e2faf77 --- /dev/null +++ b/caffe/proto/BREW @@ -0,0 +1,17 @@ +# Build file for the old caffe protocol buffers. + +proto_library( + name = 'caffe_proto', + srcs = ['caffe.proto'], + deps = [ + "//third_party/google:protobuf", + ] +) + +filegroup( + name = "caffe_proto_py", + srcs = ["__init__.py"], + deps = [ + "//caffe:caffe_python", + ] +) \ No newline at end of file diff --git a/caffe/proto/__init__.py b/caffe/proto/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/caffe/proto/caffe.proto b/caffe/proto/caffe.proto new file mode 100644 index 00000000000..5b21cf20028 --- /dev/null +++ b/caffe/proto/caffe.proto @@ -0,0 +1,967 @@ +syntax = "proto2"; + +package caffe; + +// Specifies the shape (dimensions) of a Blob. +message BlobShape { + repeated int64 dim = 1 [packed = true]; +} + +message BlobProto { + optional BlobShape shape = 7; + repeated float data = 5 [packed = true]; + repeated float diff = 6 [packed = true]; + + // 4D dimensions -- deprecated. Use "shape" instead. + optional int32 num = 1 [default = 0]; + optional int32 channels = 2 [default = 0]; + optional int32 height = 3 [default = 0]; + optional int32 width = 4 [default = 0]; +} + +// The BlobProtoVector is simply a way to pass multiple blobproto instances +// around. +message BlobProtoVector { + repeated BlobProto blobs = 1; +} + +message Datum { + optional int32 channels = 1; + optional int32 height = 2; + optional int32 width = 3; + // the actual image data, in bytes + optional bytes data = 4; + optional int32 label = 5; + // Optionally, the datum could also hold float data. + repeated float float_data = 6; + // If true data contains an encoded image that need to be decoded + optional bool encoded = 7 [default = false]; +} + +message FillerParameter { + // The filler type. + optional string type = 1 [default = 'constant']; + optional float value = 2 [default = 0]; // the value in constant filler + optional float min = 3 [default = 0]; // the min value in uniform filler + optional float max = 4 [default = 1]; // the max value in uniform filler + optional float mean = 5 [default = 0]; // the mean value in Gaussian filler + optional float std = 6 [default = 1]; // the std value in Gaussian filler + // The expected number of non-zero output weights for a given input in + // Gaussian filler -- the default -1 means don't perform sparsification. + optional int32 sparse = 7 [default = -1]; +} + +message NetParameter { + optional string name = 1; // consider giving the network a name + // The input blobs to the network. + repeated string input = 3; + // The shape of the input blobs. + repeated BlobShape input_shape = 8; + + // 4D input dimensions -- deprecated. Use "shape" instead. + // If specified, for each input blob there should be four + // values specifying the num, channels, height and width of the input blob. + // Thus, there should be a total of (4 * #input) numbers. + repeated int32 input_dim = 4; + + // Whether the network will force every layer to carry out backward operation. + // If set False, then whether to carry out backward is determined + // automatically according to the net structure and learning rates. + optional bool force_backward = 5 [default = false]; + // The current "state" of the network, including the phase, level, and stage. + // Some layers may be included/excluded depending on this state and the states + // specified in the layers' include and exclude fields. + optional NetState state = 6; + + // Print debugging information about results while running Net::Forward, + // Net::Backward, and Net::Update. + optional bool debug_info = 7 [default = false]; + + // The layers that make up the net. Each of their configurations, including + // connectivity and behavior, is specified as a LayerParameter. + repeated LayerParameter layer = 100; // ID 100 so layers are printed last. + + // DEPRECATED: use 'layer' instead. + repeated V1LayerParameter layers = 2; +} + +// NOTE +// Update the next available ID when you add a new SolverParameter field. +// +// SolverParameter next available ID: 36 (last added: clip_gradients) +message SolverParameter { + ////////////////////////////////////////////////////////////////////////////// + // Specifying the train and test networks + // + // Exactly one train net must be specified using one of the following fields: + // train_net_param, train_net, net_param, net + // One or more test nets may be specified using any of the following fields: + // test_net_param, test_net, net_param, net + // If more than one test net field is specified (e.g., both net and + // test_net are specified), they will be evaluated in the field order given + // above: (1) test_net_param, (2) test_net, (3) net_param/net. + // A test_iter must be specified for each test_net. + // A test_level and/or a test_stage may also be specified for each test_net. + ////////////////////////////////////////////////////////////////////////////// + + // Proto filename for the train net, possibly combined with one or more + // test nets. + optional string net = 24; + // Inline train net param, possibly combined with one or more test nets. + optional NetParameter net_param = 25; + + optional string train_net = 1; // Proto filename for the train net. + repeated string test_net = 2; // Proto filenames for the test nets. + optional NetParameter train_net_param = 21; // Inline train net params. + repeated NetParameter test_net_param = 22; // Inline test net params. + + // The states for the train/test nets. Must be unspecified or + // specified once per net. + // + // By default, all states will have solver = true; + // train_state will have phase = TRAIN, + // and all test_state's will have phase = TEST. + // Other defaults are set according to the NetState defaults. + optional NetState train_state = 26; + repeated NetState test_state = 27; + + // The number of iterations for each test net. + repeated int32 test_iter = 3; + + // The number of iterations between two testing phases. + optional int32 test_interval = 4 [default = 0]; + optional bool test_compute_loss = 19 [default = false]; + // If true, run an initial test pass before the first iteration, + // ensuring memory availability and printing the starting value of the loss. + optional bool test_initialization = 32 [default = true]; + optional float base_lr = 5; // The base learning rate + // the number of iterations between displaying info. If display = 0, no info + // will be displayed. + optional int32 display = 6; + // Display the loss averaged over the last average_loss iterations + optional int32 average_loss = 33 [default = 1]; + optional int32 max_iter = 7; // the maximum number of iterations + optional string lr_policy = 8; // The learning rate decay policy. + optional float gamma = 9; // The parameter to compute the learning rate. + optional float power = 10; // The parameter to compute the learning rate. + optional float momentum = 11; // The momentum value. + optional float weight_decay = 12; // The weight decay. + // regularization types supported: L1 and L2 + // controlled by weight_decay + optional string regularization_type = 29 [default = "L2"]; + // the stepsize for learning rate policy "step" + optional int32 stepsize = 13; + // the stepsize for learning rate policy "multistep" + repeated int32 stepvalue = 34; + + // Set clip_gradients to >= 0 to clip parameter gradients to that L2 norm, + // whenever their actual L2 norm is larger. + optional float clip_gradients = 35 [default = -1]; + + optional int32 snapshot = 14 [default = 0]; // The snapshot interval + optional string snapshot_prefix = 15; // The prefix for the snapshot. + // whether to snapshot diff in the results or not. Snapshotting diff will help + // debugging but the final protocol buffer size will be much larger. + optional bool snapshot_diff = 16 [default = false]; + // the mode solver will use: 0 for CPU and 1 for GPU. Use GPU in default. + enum SolverMode { + CPU = 0; + GPU = 1; + } + optional SolverMode solver_mode = 17 [default = GPU]; + // the device_id will that be used in GPU mode. Use device_id = 0 in default. + optional int32 device_id = 18 [default = 0]; + // If non-negative, the seed with which the Solver will initialize the Caffe + // random number generator -- useful for reproducible results. Otherwise, + // (and by default) initialize using a seed derived from the system clock. + optional int64 random_seed = 20 [default = -1]; + + // Solver type + enum SolverType { + SGD = 0; + NESTEROV = 1; + ADAGRAD = 2; + } + optional SolverType solver_type = 30 [default = SGD]; + // numerical stability for AdaGrad + optional float delta = 31 [default = 1e-8]; + + // If true, print information about the state of the net that may help with + // debugging learning problems. + optional bool debug_info = 23 [default = false]; + + // If false, don't save a snapshot after training finishes. + optional bool snapshot_after_train = 28 [default = true]; +} + +// A message that stores the solver snapshots +message SolverState { + optional int32 iter = 1; // The current iteration + optional string learned_net = 2; // The file that stores the learned net. + repeated BlobProto history = 3; // The history for sgd solvers + optional int32 current_step = 4 [default = 0]; // The current step for learning rate +} + +enum Phase { + TRAIN = 0; + TEST = 1; +} + +message NetState { + optional Phase phase = 1 [default = TEST]; + optional int32 level = 2 [default = 0]; + repeated string stage = 3; +} + +message NetStateRule { + // Set phase to require the NetState have a particular phase (TRAIN or TEST) + // to meet this rule. + optional Phase phase = 1; + + // Set the minimum and/or maximum levels in which the layer should be used. + // Leave undefined to meet the rule regardless of level. + optional int32 min_level = 2; + optional int32 max_level = 3; + + // Customizable sets of stages to include or exclude. + // The net must have ALL of the specified stages and NONE of the specified + // "not_stage"s to meet the rule. + // (Use multiple NetStateRules to specify conjunctions of stages.) + repeated string stage = 4; + repeated string not_stage = 5; +} + +// Specifies training parameters (multipliers on global learning constants, +// and the name and other settings used for weight sharing). +message ParamSpec { + // The names of the parameter blobs -- useful for sharing parameters among + // layers, but never required otherwise. To share a parameter between two + // layers, give it a (non-empty) name. + optional string name = 1; + + // Whether to require shared weights to have the same shape, or just the same + // count -- defaults to STRICT if unspecified. + optional DimCheckMode share_mode = 2; + enum DimCheckMode { + // STRICT (default) requires that num, channels, height, width each match. + STRICT = 0; + // PERMISSIVE requires only the count (num*channels*height*width) to match. + PERMISSIVE = 1; + } + + // The multiplier on the global learning rate for this parameter. + optional float lr_mult = 3 [default = 1.0]; + + // The multiplier on the global weight decay for this parameter. + optional float decay_mult = 4 [default = 1.0]; +} + +// NOTE +// Update the next available ID when you add a new LayerParameter field. +// +// LayerParameter next available layer-specific ID: 132 (last added: prelu_param) +message LayerParameter { + optional string name = 1; // the layer name + optional string type = 2; // the layer type + repeated string bottom = 3; // the name of each bottom blob + repeated string top = 4; // the name of each top blob + + // The train / test phase for computation. + optional Phase phase = 10; + + // The amount of weight to assign each top blob in the objective. + // Each layer assigns a default value, usually of either 0 or 1, + // to each top blob. + repeated float loss_weight = 5; + + // Specifies training parameters (multipliers on global learning constants, + // and the name and other settings used for weight sharing). + repeated ParamSpec param = 6; + + // The blobs containing the numeric parameters of the layer. + repeated BlobProto blobs = 7; + + // Rules controlling whether and when a layer is included in the network, + // based on the current NetState. You may specify a non-zero number of rules + // to include OR exclude, but not both. If no include or exclude rules are + // specified, the layer is always included. If the current NetState meets + // ANY (i.e., one or more) of the specified rules, the layer is + // included/excluded. + repeated NetStateRule include = 8; + repeated NetStateRule exclude = 9; + + // Parameters for data pre-processing. + optional TransformationParameter transform_param = 100; + + // Parameters shared by loss layers. + optional LossParameter loss_param = 101; + + // Layer type-specific parameters. + // + // Note: certain layers may have more than one computational engine + // for their implementation. These layers include an Engine type and + // engine parameter for selecting the implementation. + // The default for the engine is set by the ENGINE switch at compile-time. + optional AccuracyParameter accuracy_param = 102; + optional ArgMaxParameter argmax_param = 103; + optional ConcatParameter concat_param = 104; + optional ContrastiveLossParameter contrastive_loss_param = 105; + optional ConvolutionParameter convolution_param = 106; + optional DataParameter data_param = 107; + optional DropoutParameter dropout_param = 108; + optional DummyDataParameter dummy_data_param = 109; + optional EltwiseParameter eltwise_param = 110; + optional ExpParameter exp_param = 111; + optional HDF5DataParameter hdf5_data_param = 112; + optional HDF5OutputParameter hdf5_output_param = 113; + optional HingeLossParameter hinge_loss_param = 114; + optional ImageDataParameter image_data_param = 115; + optional InfogainLossParameter infogain_loss_param = 116; + optional InnerProductParameter inner_product_param = 117; + optional LRNParameter lrn_param = 118; + optional MemoryDataParameter memory_data_param = 119; + optional MVNParameter mvn_param = 120; + optional PoolingParameter pooling_param = 121; + optional PowerParameter power_param = 122; + optional PReLUParameter prelu_param = 131; + optional PythonParameter python_param = 130; + optional ReLUParameter relu_param = 123; + optional SigmoidParameter sigmoid_param = 124; + optional SoftmaxParameter softmax_param = 125; + optional SliceParameter slice_param = 126; + optional TanHParameter tanh_param = 127; + optional ThresholdParameter threshold_param = 128; + optional WindowDataParameter window_data_param = 129; +} + +// Message that stores parameters used to apply transformation +// to the data layer's data +message TransformationParameter { + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 1 [default = 1]; + // Specify if we want to randomly mirror data. + optional bool mirror = 2 [default = false]; + // Specify if we would like to randomly crop an image. + optional uint32 crop_size = 3 [default = 0]; + // mean_file and mean_value cannot be specified at the same time + optional string mean_file = 4; + // if specified can be repeated once (would substract it from all the channels) + // or can be repeated the same number of times as channels + // (would subtract them from the corresponding channel) + repeated float mean_value = 5; +} + +// Message that stores parameters shared by loss layers +message LossParameter { + // If specified, ignore instances with the given label. + optional int32 ignore_label = 1; + // If true, normalize each batch across all instances (including spatial + // dimesions, but not ignored instances); else, divide by batch size only. + optional bool normalize = 2 [default = true]; +} + +// Message that stores parameters used by AccuracyLayer +message AccuracyParameter { + // When computing accuracy, count as correct by comparing the true label to + // the top k scoring classes. By default, only compare to the top scoring + // class (i.e. argmax). + optional uint32 top_k = 1 [default = 1]; + + // The "label" axis of the prediction blob, whose argmax corresponds to the + // predicted label -- may be negative to index from the end (e.g., -1 for the + // last axis). For example, if axis == 1 and the predictions are + // (N x C x H x W), the label blob is expected to contain N*H*W ground truth + // labels with integer values in {0, 1, ..., C-1}. + optional int32 axis = 2 [default = 1]; + + // If specified, ignore instances with the given label. + optional int32 ignore_label = 3; +} + +// Message that stores parameters used by ArgMaxLayer +message ArgMaxParameter { + // If true produce pairs (argmax, maxval) + optional bool out_max_val = 1 [default = false]; + optional uint32 top_k = 2 [default = 1]; +} + +// Message that stores parameters used by ConcatLayer +message ConcatParameter { + // The axis along which to concatenate -- may be negative to index from the + // end (e.g., -1 for the last axis). Other axes must have the + // same dimension for all the bottom blobs. + // By default, ConcatLayer concatenates blobs along the "channels" axis (1). + optional int32 axis = 2 [default = 1]; + + // DEPRECATED: alias for "axis" -- does not support negative indexing. + optional uint32 concat_dim = 1 [default = 1]; +} + +// Message that stores parameters used by ContrastiveLossLayer +message ContrastiveLossParameter { + //margin for dissimilar pair + optional float margin = 1 [default = 1.0]; +} + +// Message that stores parameters used by ConvolutionLayer +message ConvolutionParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + optional bool bias_term = 2 [default = true]; // whether to have bias terms + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pad = 3 [default = 0]; // The padding size (equal in Y, X) + optional uint32 pad_h = 9 [default = 0]; // The padding height + optional uint32 pad_w = 10 [default = 0]; // The padding width + optional uint32 kernel_size = 4; // The kernel size (square) + optional uint32 kernel_h = 11; // The kernel height + optional uint32 kernel_w = 12; // The kernel width + optional uint32 group = 5 [default = 1]; // The group size for group conv + optional uint32 stride = 6 [default = 1]; // The stride (equal in Y, X) + optional uint32 stride_h = 13; // The stride height + optional uint32 stride_w = 14; // The stride width + optional FillerParameter weight_filler = 7; // The filler for the weight + optional FillerParameter bias_filler = 8; // The filler for the bias + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 15 [default = DEFAULT]; +} + +// Message that stores parameters used by DataLayer +message DataParameter { + enum DB { + LEVELDB = 0; + LMDB = 1; + } + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 4; + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + optional uint32 rand_skip = 7 [default = 0]; + optional DB backend = 8 [default = LEVELDB]; + // DEPRECATED. See TransformationParameter. For data pre-processing, we can do + // simple scaling and subtracting the data mean, if provided. Note that the + // mean subtraction is always carried out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // DEPRECATED. See TransformationParameter. Specify if we would like to randomly + // crop an image. + optional uint32 crop_size = 5 [default = 0]; + // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror + // data. + optional bool mirror = 6 [default = false]; + // Force the encoded image to have 3 color channels + optional bool force_encoded_color = 9 [default = false]; +} + +// Message that stores parameters used by DropoutLayer +message DropoutParameter { + optional float dropout_ratio = 1 [default = 0.5]; // dropout ratio +} + +// Message that stores parameters used by DummyDataLayer. +// DummyDataLayer fills any number of arbitrarily shaped blobs with random +// (or constant) data generated by "Fillers" (see "message FillerParameter"). +message DummyDataParameter { + // This layer produces N >= 1 top blobs. DummyDataParameter must specify 1 or N + // shape fields, and 0, 1 or N data_fillers. + // + // If 0 data_fillers are specified, ConstantFiller with a value of 0 is used. + // If 1 data_filler is specified, it is applied to all top blobs. If N are + // specified, the ith is applied to the ith top blob. + repeated FillerParameter data_filler = 1; + repeated BlobShape shape = 6; + + // 4D dimensions -- deprecated. Use "shape" instead. + repeated uint32 num = 2; + repeated uint32 channels = 3; + repeated uint32 height = 4; + repeated uint32 width = 5; +} + +// Message that stores parameters used by EltwiseLayer +message EltwiseParameter { + enum EltwiseOp { + PROD = 0; + SUM = 1; + MAX = 2; + } + optional EltwiseOp operation = 1 [default = SUM]; // element-wise operation + repeated float coeff = 2; // blob-wise coefficient for SUM operation + + // Whether to use an asymptotically slower (for >2 inputs) but stabler method + // of computing the gradient for the PROD operation. (No effect for SUM op.) + optional bool stable_prod_grad = 3 [default = true]; +} + +// Message that stores parameters used by ExpLayer +message ExpParameter { + // ExpLayer computes outputs y = base ^ (shift + scale * x), for base > 0. + // Or if base is set to the default (-1), base is set to e, + // so y = exp(shift + scale * x). + optional float base = 1 [default = -1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +// Message that stores parameters used by HDF5DataLayer +message HDF5DataParameter { + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 2; + + // Specify whether to shuffle the data. + // If shuffle == true, the ordering of the HDF5 files is shuffled, + // and the ordering of data within any given HDF5 file is shuffled, + // but data between different files are not interleaved; all of a file's + // data are output (in a random order) before moving onto another file. + optional bool shuffle = 3 [default = false]; +} + +// Message that stores parameters used by HDF5OutputLayer +message HDF5OutputParameter { + optional string file_name = 1; +} + +message HingeLossParameter { + enum Norm { + L1 = 1; + L2 = 2; + } + // Specify the Norm to use L1 or L2 + optional Norm norm = 1 [default = L1]; +} + +// Message that stores parameters used by ImageDataLayer +message ImageDataParameter { + // Specify the data source. + optional string source = 1; + // Specify the batch size. + optional uint32 batch_size = 4; + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + optional uint32 rand_skip = 7 [default = 0]; + // Whether or not ImageLayer should shuffle the list of files at every epoch. + optional bool shuffle = 8 [default = false]; + // It will also resize images if new_height or new_width are not zero. + optional uint32 new_height = 9 [default = 0]; + optional uint32 new_width = 10 [default = 0]; + // Specify if the images are color or gray + optional bool is_color = 11 [default = true]; + // DEPRECATED. See TransformationParameter. For data pre-processing, we can do + // simple scaling and subtracting the data mean, if provided. Note that the + // mean subtraction is always carried out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // DEPRECATED. See TransformationParameter. Specify if we would like to randomly + // crop an image. + optional uint32 crop_size = 5 [default = 0]; + // DEPRECATED. See TransformationParameter. Specify if we want to randomly mirror + // data. + optional bool mirror = 6 [default = false]; + optional string root_folder = 12 [default = ""]; +} + +// Message that stores parameters InfogainLossLayer +message InfogainLossParameter { + // Specify the infogain matrix source. + optional string source = 1; +} + +// Message that stores parameters used by InnerProductLayer +message InnerProductParameter { + optional uint32 num_output = 1; // The number of outputs for the layer + optional bool bias_term = 2 [default = true]; // whether to have bias terms + optional FillerParameter weight_filler = 3; // The filler for the weight + optional FillerParameter bias_filler = 4; // The filler for the bias + + // The first axis to be lumped into a single inner product computation; + // all preceding axes are retained in the output. + // May be negative to index from the end (e.g., -1 for the last axis). + optional int32 axis = 5 [default = 1]; +} + +// Message that stores parameters used by LRNLayer +message LRNParameter { + optional uint32 local_size = 1 [default = 5]; + optional float alpha = 2 [default = 1.]; + optional float beta = 3 [default = 0.75]; + enum NormRegion { + ACROSS_CHANNELS = 0; + WITHIN_CHANNEL = 1; + } + optional NormRegion norm_region = 4 [default = ACROSS_CHANNELS]; + optional float k = 5 [default = 1.]; +} + +// Message that stores parameters used by MemoryDataLayer +message MemoryDataParameter { + optional uint32 batch_size = 1; + optional uint32 channels = 2; + optional uint32 height = 3; + optional uint32 width = 4; +} + +// Message that stores parameters used by MVNLayer +message MVNParameter { + // This parameter can be set to false to normalize mean only + optional bool normalize_variance = 1 [default = true]; + + // This parameter can be set to true to perform DNN-like MVN + optional bool across_channels = 2 [default = false]; +} + +// Message that stores parameters used by PoolingLayer +message PoolingParameter { + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional PoolMethod pool = 1 [default = MAX]; // The pooling method + // Pad, kernel size, and stride are all given as a single value for equal + // dimensions in height and width or as Y, X pairs. + optional uint32 pad = 4 [default = 0]; // The padding size (equal in Y, X) + optional uint32 pad_h = 9 [default = 0]; // The padding height + optional uint32 pad_w = 10 [default = 0]; // The padding width + optional uint32 kernel_size = 2; // The kernel size (square) + optional uint32 kernel_h = 5; // The kernel height + optional uint32 kernel_w = 6; // The kernel width + optional uint32 stride = 3 [default = 1]; // The stride (equal in Y, X) + optional uint32 stride_h = 7; // The stride height + optional uint32 stride_w = 8; // The stride width + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 11 [default = DEFAULT]; + // If global_pooling then it will pool over the size of the bottom by doing + // kernel_h = bottom->height and kernel_w = bottom->width + optional bool global_pooling = 12 [default = false]; +} + +// Message that stores parameters used by PowerLayer +message PowerParameter { + // PowerLayer computes outputs y = (shift + scale * x) ^ power. + optional float power = 1 [default = 1.0]; + optional float scale = 2 [default = 1.0]; + optional float shift = 3 [default = 0.0]; +} + +// Message that stores parameters used by PythonLayer +message PythonParameter { + optional string module = 1; + optional string layer = 2; +} + +// Message that stores parameters used by ReLULayer +message ReLUParameter { + // Allow non-zero slope for negative inputs to speed up optimization + // Described in: + // Maas, A. L., Hannun, A. Y., & Ng, A. Y. (2013). Rectifier nonlinearities + // improve neural network acoustic models. In ICML Workshop on Deep Learning + // for Audio, Speech, and Language Processing. + optional float negative_slope = 1 [default = 0]; + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 2 [default = DEFAULT]; +} + +// Message that stores parameters used by SigmoidLayer +message SigmoidParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; +} + +// Message that stores parameters used by SliceLayer +message SliceParameter { + // The axis along which to slice -- may be negative to index from the end + // (e.g., -1 for the last axis). + // By default, SliceLayer concatenates blobs along the "channels" axis (1). + optional int32 axis = 3 [default = 1]; + repeated uint32 slice_point = 2; + + // DEPRECATED: alias for "axis" -- does not support negative indexing. + optional uint32 slice_dim = 1 [default = 1]; +} + +// Message that stores parameters used by SoftmaxLayer, SoftmaxWithLossLayer +message SoftmaxParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; + + // The axis along which to perform the softmax -- may be negative to index + // from the end (e.g., -1 for the last axis). + // Any other axes will be evaluated as independent softmaxes. + optional int32 axis = 2 [default = 1]; +} + +// Message that stores parameters used by TanHLayer +message TanHParameter { + enum Engine { + DEFAULT = 0; + CAFFE = 1; + CUDNN = 2; + } + optional Engine engine = 1 [default = DEFAULT]; +} + +// Message that stores parameters used by ThresholdLayer +message ThresholdParameter { + optional float threshold = 1 [default = 0]; // Strictly positive values +} + +// Message that stores parameters used by WindowDataLayer +message WindowDataParameter { + // Specify the data source. + optional string source = 1; + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 2 [default = 1]; + optional string mean_file = 3; + // Specify the batch size. + optional uint32 batch_size = 4; + // Specify if we would like to randomly crop an image. + optional uint32 crop_size = 5 [default = 0]; + // Specify if we want to randomly mirror data. + optional bool mirror = 6 [default = false]; + // Foreground (object) overlap threshold + optional float fg_threshold = 7 [default = 0.5]; + // Background (non-object) overlap threshold + optional float bg_threshold = 8 [default = 0.5]; + // Fraction of batch that should be foreground objects + optional float fg_fraction = 9 [default = 0.25]; + // Amount of contextual padding to add around a window + // (used only by the window_data_layer) + optional uint32 context_pad = 10 [default = 0]; + // Mode for cropping out a detection window + // warp: cropped window is warped to a fixed size and aspect ratio + // square: the tightest square around the window is cropped + optional string crop_mode = 11 [default = "warp"]; + // cache_images: will load all images in memory for faster access + optional bool cache_images = 12 [default = false]; + // append root_folder to locate images + optional string root_folder = 13 [default = ""]; +} + +// DEPRECATED: use LayerParameter. +message V1LayerParameter { + repeated string bottom = 2; + repeated string top = 3; + optional string name = 4; + repeated NetStateRule include = 32; + repeated NetStateRule exclude = 33; + enum LayerType { + NONE = 0; + ABSVAL = 35; + ACCURACY = 1; + ARGMAX = 30; + BNLL = 2; + CONCAT = 3; + CONTRASTIVE_LOSS = 37; + CONVOLUTION = 4; + DATA = 5; + DECONVOLUTION = 39; + DROPOUT = 6; + DUMMY_DATA = 32; + EUCLIDEAN_LOSS = 7; + ELTWISE = 25; + EXP = 38; + FLATTEN = 8; + HDF5_DATA = 9; + HDF5_OUTPUT = 10; + HINGE_LOSS = 28; + IM2COL = 11; + IMAGE_DATA = 12; + INFOGAIN_LOSS = 13; + INNER_PRODUCT = 14; + LRN = 15; + MEMORY_DATA = 29; + MULTINOMIAL_LOGISTIC_LOSS = 16; + MVN = 34; + POOLING = 17; + POWER = 26; + RELU = 18; + SIGMOID = 19; + SIGMOID_CROSS_ENTROPY_LOSS = 27; + SILENCE = 36; + SOFTMAX = 20; + SOFTMAX_LOSS = 21; + SPLIT = 22; + SLICE = 33; + TANH = 23; + WINDOW_DATA = 24; + THRESHOLD = 31; + } + optional LayerType type = 5; + repeated BlobProto blobs = 6; + repeated string param = 1001; + repeated DimCheckMode blob_share_mode = 1002; + enum DimCheckMode { + STRICT = 0; + PERMISSIVE = 1; + } + repeated float blobs_lr = 7; + repeated float weight_decay = 8; + repeated float loss_weight = 35; + optional AccuracyParameter accuracy_param = 27; + optional ArgMaxParameter argmax_param = 23; + optional ConcatParameter concat_param = 9; + optional ContrastiveLossParameter contrastive_loss_param = 40; + optional ConvolutionParameter convolution_param = 10; + optional DataParameter data_param = 11; + optional DropoutParameter dropout_param = 12; + optional DummyDataParameter dummy_data_param = 26; + optional EltwiseParameter eltwise_param = 24; + optional ExpParameter exp_param = 41; + optional HDF5DataParameter hdf5_data_param = 13; + optional HDF5OutputParameter hdf5_output_param = 14; + optional HingeLossParameter hinge_loss_param = 29; + optional ImageDataParameter image_data_param = 15; + optional InfogainLossParameter infogain_loss_param = 16; + optional InnerProductParameter inner_product_param = 17; + optional LRNParameter lrn_param = 18; + optional MemoryDataParameter memory_data_param = 22; + optional MVNParameter mvn_param = 34; + optional PoolingParameter pooling_param = 19; + optional PowerParameter power_param = 21; + optional ReLUParameter relu_param = 30; + optional SigmoidParameter sigmoid_param = 38; + optional SoftmaxParameter softmax_param = 39; + optional SliceParameter slice_param = 31; + optional TanHParameter tanh_param = 37; + optional ThresholdParameter threshold_param = 25; + optional WindowDataParameter window_data_param = 20; + optional TransformationParameter transform_param = 36; + optional LossParameter loss_param = 42; + optional V0LayerParameter layer = 1; +} + +// DEPRECATED: V0LayerParameter is the old way of specifying layer parameters +// in Caffe. We keep this message type around for legacy support. +message V0LayerParameter { + optional string name = 1; // the layer name + optional string type = 2; // the string to specify the layer type + + // Parameters to specify layers with inner products. + optional uint32 num_output = 3; // The number of outputs for the layer + optional bool biasterm = 4 [default = true]; // whether to have bias terms + optional FillerParameter weight_filler = 5; // The filler for the weight + optional FillerParameter bias_filler = 6; // The filler for the bias + + optional uint32 pad = 7 [default = 0]; // The padding size + optional uint32 kernelsize = 8; // The kernel size + optional uint32 group = 9 [default = 1]; // The group size for group conv + optional uint32 stride = 10 [default = 1]; // The stride + enum PoolMethod { + MAX = 0; + AVE = 1; + STOCHASTIC = 2; + } + optional PoolMethod pool = 11 [default = MAX]; // The pooling method + optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio + + optional uint32 local_size = 13 [default = 5]; // for local response norm + optional float alpha = 14 [default = 1.]; // for local response norm + optional float beta = 15 [default = 0.75]; // for local response norm + optional float k = 22 [default = 1.]; + + // For data layers, specify the data source + optional string source = 16; + // For data pre-processing, we can do simple scaling and subtracting the + // data mean, if provided. Note that the mean subtraction is always carried + // out before scaling. + optional float scale = 17 [default = 1]; + optional string meanfile = 18; + // For data layers, specify the batch size. + optional uint32 batchsize = 19; + // For data layers, specify if we would like to randomly crop an image. + optional uint32 cropsize = 20 [default = 0]; + // For data layers, specify if we want to randomly mirror data. + optional bool mirror = 21 [default = false]; + + // The blobs containing the numeric parameters of the layer + repeated BlobProto blobs = 50; + // The ratio that is multiplied on the global learning rate. If you want to + // set the learning ratio for one blob, you need to set it for all blobs. + repeated float blobs_lr = 51; + // The weight decay that is multiplied on the global weight decay. + repeated float weight_decay = 52; + + // The rand_skip variable is for the data layer to skip a few data points + // to avoid all asynchronous sgd clients to start at the same point. The skip + // point would be set as rand_skip * rand(0,1). Note that rand_skip should not + // be larger than the number of keys in the database. + optional uint32 rand_skip = 53 [default = 0]; + + // Fields related to detection (det_*) + // foreground (object) overlap threshold + optional float det_fg_threshold = 54 [default = 0.5]; + // background (non-object) overlap threshold + optional float det_bg_threshold = 55 [default = 0.5]; + // Fraction of batch that should be foreground objects + optional float det_fg_fraction = 56 [default = 0.25]; + + // optional bool OBSOLETE_can_clobber = 57 [default = true]; + + // Amount of contextual padding to add around a window + // (used only by the window_data_layer) + optional uint32 det_context_pad = 58 [default = 0]; + + // Mode for cropping out a detection window + // warp: cropped window is warped to a fixed size and aspect ratio + // square: the tightest square around the window is cropped + optional string det_crop_mode = 59 [default = "warp"]; + + // For ReshapeLayer, one needs to specify the new dimensions. + optional int32 new_num = 60 [default = 0]; + optional int32 new_channels = 61 [default = 0]; + optional int32 new_height = 62 [default = 0]; + optional int32 new_width = 63 [default = 0]; + + // Whether or not ImageLayer should shuffle the list of files at every epoch. + // It will also resize images if new_height or new_width are not zero. + optional bool shuffle_images = 64 [default = false]; + + // For ConcatLayer, one needs to specify the dimension for concatenation, and + // the other dimensions must be the same for all the bottom blobs. + // By default it will concatenate blobs along the channels dimension. + optional uint32 concat_dim = 65 [default = 1]; + + optional HDF5OutputParameter hdf5_output_param = 1001; +} + +// Message that stores parameters used by PReLULayer +message PReLUParameter { + // Parametric ReLU described in K. He et al, Delving Deep into Rectifiers: + // Surpassing Human-Level Performance on ImageNet Classification, 2015. + + // Initial value of a_i. Default is a_i=0.25 for all i. + optional FillerParameter filler = 1; + // Whether or not slope paramters are shared across channels. + optional bool channel_shared = 2 [default = false]; +} diff --git a/caffe2/BREW b/caffe2/BREW new file mode 100644 index 00000000000..10323bff909 --- /dev/null +++ b/caffe2/BREW @@ -0,0 +1,4 @@ +filegroup( + name = "caffe2_python", + srcs = ["__init__.py"], +) \ No newline at end of file diff --git a/caffe2/__init__.py b/caffe2/__init__.py new file mode 100644 index 00000000000..eb2c8353acc --- /dev/null +++ b/caffe2/__init__.py @@ -0,0 +1,5 @@ +""" +Caffe2: A General Tool for Neural Networks. +""" + +__author__ = 'Yangqing Jia' diff --git a/caffe2/binaries/BREW b/caffe2/binaries/BREW new file mode 100644 index 00000000000..c4daa5c174c --- /dev/null +++ b/caffe2/binaries/BREW @@ -0,0 +1,204 @@ +cc_binary( + name = "convert_db", + srcs = [ + "convert_db.cc", + ], + deps = [ + "//caffe2/db:db", + "//third_party/gflags:gflags", + "//third_party/glog:glog", + ], +) + +cc_binary( + name = "make_cifar_db", + srcs = [ + "make_cifar_db.cc", + ], + deps = [ + "//caffe2/db:db", + "//caffe2/proto:caffe2_proto", + "//third_party/gflags:gflags", + "//third_party/glog:glog", + ], +) + +cc_binary( + name = "make_image_db", + srcs = [ + "make_image_db.cc", + ], + deps = [ + "//caffe2/db:db", + "//caffe2/proto:caffe2_proto", + "//third_party/gflags:gflags", + "//third_party/glog:glog", + ], + external_libs = [ + "opencv_core", + "opencv_highgui", + "opencv_imgproc", + ], +) + +cc_binary( + name = "convert_encoded_to_raw_leveldb", + srcs = [ + "convert_encoded_to_raw_leveldb.cc", + ], + deps = [ + "//caffe2/core:core", + "//caffe2/proto:caffe2_proto", + "//third_party/leveldb:leveldb", + "//third_party/gflags:gflags", + "//third_party/glog:glog", + ], + external_libs = [ + "opencv_core", + "opencv_highgui", + "opencv_imgproc", + ], +) + + +cc_binary( + name = "make_mnist_db", + srcs = [ + "make_mnist_db.cc", + ], + deps = [ + "//caffe2/db:db", + "//caffe2/proto:caffe2_proto", + "//third_party/gflags:gflags", + "//third_party/glog:glog", + ], +) + +cc_binary( + name = "print_registered_core_operators", + srcs = [ + "print_registered_core_operators.cc", + ], + deps = [ + "//caffe2/core:core", + "//caffe2/db:db", + "//caffe2/image:image_ops", + "//caffe2/image:image_ops_gpu", + "//caffe2/operators:core_ops", + "//caffe2/operators:core_ops_gpu", + ], +) + +cc_binary( + name = "run_client", + srcs = [ + "run_client.cc", + ], + deps = [ + "//caffe2/core:core", + "//caffe2/db:db", + "//caffe2/image:image_ops", + "//caffe2/image:image_ops_gpu", + "//caffe2/operators:core_ops", + "//caffe2/operators:core_ops_gpu", + "//caffe2/utils:proto_utils", + "//third_party/gflags:gflags", + "//third_party/glog:glog", + ], +) + +# run_client_minimal is the binary that links in the operators that have no +# external dependencies at all. +cc_binary( + name = "run_client_minimal", + srcs = [ + "run_client.cc", + ], + deps = [ + "//caffe2/core:core", + "//caffe2/operators:core_ops", + "//caffe2/operators:core_ops_gpu", + "//caffe2/utils:proto_utils", + "//third_party/gflags:gflags", + "//third_party/glog:glog", + ], +) + + +cc_binary( + name = "run_plan", + srcs = [ + "run_plan.cc", + ], + deps = [ + "//caffe2/core:core", + "//caffe2/db:db", + "//caffe2/image:image_ops", + "//caffe2/image:image_ops_gpu", + "//caffe2/operators:core_ops", + "//caffe2/operators:core_ops_gpu", + "//caffe2/utils:proto_utils", + "//third_party/gflags:gflags", + "//third_party/glog:glog", + ], +) + +# run_plan_minimal is the binary that links in the operators that have no +# external dependencies at all. +cc_binary( + name = "run_plan_minimal", + srcs = [ + "run_plan.cc", + ], + deps = [ + "//caffe2/core:core", + "//caffe2/operators:core_ops", + "//caffe2/operators:core_ops_gpu", + "//caffe2/utils:proto_utils", + "//third_party/gflags:gflags", + "//third_party/glog:glog", + ], +) + + +cc_binary( + name = "run_plan_mpi", + srcs = [ + "run_plan_mpi.cc", + ], + deps = [ + "//caffe2/core:core", + "//caffe2/db:db", + "//caffe2/image:image_ops", + "//caffe2/image:image_ops_gpu", + "//caffe2/mpi:mpi_ops", + "//caffe2/operators:core_ops", + "//caffe2/operators:core_ops_gpu", + "//caffe2/utils:proto_utils", + "//third_party/gflags:gflags", + "//third_party/glog:glog", + ], +) + +cc_binary( + name = "inspect_gpus", + srcs = [ + "inspect_gpus.cc", + ], + deps = [ + "//caffe2/core:core_gpu", + "//third_party/glog:glog", + ], +) + +cc_binary( + name = "split_db", + srcs = [ + "split_db.cc", + ], + deps = [ + "//caffe2/db:db", + "//third_party/gflags:gflags", + "//third_party/glog:glog", + ], +) \ No newline at end of file diff --git a/caffe2/binaries/convert_db.cc b/caffe2/binaries/convert_db.cc new file mode 100644 index 00000000000..401943090a7 --- /dev/null +++ b/caffe2/binaries/convert_db.cc @@ -0,0 +1,38 @@ +#include "caffe2/core/db.h" +#include "caffe2/proto/caffe2.pb.h" +#include "gflags/gflags.h" +#include "glog/logging.h" + +DEFINE_string(input_db, "", "The input db."); +DEFINE_string(input_db_type, "", "The input db type."); +DEFINE_string(output_db, "", "The output db."); +DEFINE_string(output_db_type, "", "The output db type."); +DEFINE_int32(batch_size, 1000, "The write batch size."); + +using caffe2::db::Cursor; +using caffe2::db::DB; +using caffe2::db::Transaction; + +int main(int argc, char** argv) { + google::InitGoogleLogging(argv[0]); + google::SetUsageMessage( + "This script converts databases between different formats."); + google::ParseCommandLineFlags(&argc, &argv, true); + + std::unique_ptr in_db(caffe2::db::CreateDB( + FLAGS_input_db_type, FLAGS_input_db, caffe2::db::READ)); + std::unique_ptr out_db(caffe2::db::CreateDB( + FLAGS_output_db_type, FLAGS_output_db, caffe2::db::NEW)); + std::unique_ptr cursor(in_db->NewCursor()); + std::unique_ptr transaction(out_db->NewTransaction()); + int count = 0; + for (; cursor->Valid(); cursor->Next()) { + transaction->Put(cursor->key(), cursor->value()); + if (++count % FLAGS_batch_size == 0) { + transaction->Commit(); + LOG(INFO) << "Converted " << count << " items so far."; + } + } + LOG(INFO) << "A total of " << count << " items processed."; + return 0; +} diff --git a/caffe2/binaries/convert_encoded_to_raw_leveldb.cc b/caffe2/binaries/convert_encoded_to_raw_leveldb.cc new file mode 100644 index 00000000000..54607f15157 --- /dev/null +++ b/caffe2/binaries/convert_encoded_to_raw_leveldb.cc @@ -0,0 +1,139 @@ +// This script converts an image dataset to leveldb. +// +// FLAGS_input_folder is the root folder that holds all the images, and +// FLAGS_list_file should be a list of files as well as their labels, in the +// format as +// subfolder1/file1.JPEG 7 +// .... + +#include + +#include +#include // NOLINT(readability/streams) +#include +#include + +#include "caffe2/proto/caffe2.pb.h" +#include "gflags/gflags.h" +#include "glog/logging.h" +#include "leveldb/db.h" +#include "leveldb/write_batch.h" + +DEFINE_string(input_db_name, "", "The input image file name."); +DEFINE_string(output_db_name, "", "The output training leveldb name."); +DEFINE_bool(color, true, "If set, load images in color."); +DEFINE_int32(scale, 256, + "If FLAGS_raw is set, scale all the images' shorter edge to the given " + "value."); +DEFINE_bool(warp, false, "If warp is set, warp the images to square."); + + +namespace caffe2 { + +using std::string; +using std::unique_ptr; + +void ConvertToRawDataset( + const string& input_db_name, const string& output_db_name) { + // input leveldb + std::unique_ptr input_db; + LOG(INFO) << "Opening input leveldb " << input_db_name; + { + leveldb::Options options; + options.create_if_missing = false; + leveldb::DB* db_temp; + leveldb::Status status = leveldb::DB::Open( + options, input_db_name, &db_temp); + CHECK(status.ok()) << "Failed to open leveldb " << input_db_name << "."; + input_db.reset(db_temp); + } + + // output leveldb + std::unique_ptr output_db; + std::unique_ptr batch; + LOG(INFO) << "Opening leveldb " << output_db_name; + { + leveldb::Options options; + options.error_if_exists = true; + options.create_if_missing = true; + options.write_buffer_size = 268435456; + leveldb::DB* db_temp; + leveldb::Status status = leveldb::DB::Open( + options, output_db_name, &db_temp); + CHECK(status.ok()) << "Failed to open leveldb " << output_db_name + << ". Is it already existing?"; + output_db.reset(db_temp); + } + batch.reset(new leveldb::WriteBatch()); + + TensorProtos input_protos; + TensorProtos output_protos; + TensorProto* data = output_protos.add_protos(); + TensorProto* label = output_protos.add_protos(); + data->set_data_type(TensorProto::BYTE); + data->add_dims(0); + data->add_dims(0); + if (FLAGS_color) { + data->add_dims(3); + } + string value; + + unique_ptr iter; + iter.reset(input_db->NewIterator(leveldb::ReadOptions())); + iter->SeekToFirst(); + int count = 0; + for (; iter->Valid(); iter->Next()) { + CHECK(input_protos.ParseFromString(iter->value().ToString())); + label->CopyFrom(input_protos.protos(1)); + const string& encoded_image = input_protos.protos(0).string_data(0); + int encoded_size = encoded_image.size(); + cv::Mat img = cv::imdecode( + cv::Mat(1, &encoded_size, CV_8UC1, + const_cast(encoded_image.data())), + FLAGS_color ? CV_LOAD_IMAGE_COLOR : CV_LOAD_IMAGE_GRAYSCALE); + cv::Mat resized_img; + int scaled_width, scaled_height; + if (FLAGS_warp) { + scaled_width = FLAGS_scale; + scaled_height = FLAGS_scale; + } else if (img.rows > img.cols) { + scaled_width = FLAGS_scale; + scaled_height = static_cast(img.rows) * FLAGS_scale / img.cols; + } else { + scaled_height = FLAGS_scale; + scaled_width = static_cast(img.cols) * FLAGS_scale / img.rows; + } + cv::resize(img, resized_img, cv::Size(scaled_width, scaled_height), 0, 0, + cv::INTER_LINEAR); + data->set_dims(0, scaled_height); + data->set_dims(1, scaled_width); + DCHECK(resized_img.isContinuous()); + data->set_byte_data(resized_img.ptr(), + scaled_height * scaled_width * (FLAGS_color ? 3 : 1)); + output_protos.SerializeToString(&value); + // Put in db + batch->Put(iter->key(), value); + if (++count % 1000 == 0) { + output_db->Write(leveldb::WriteOptions(), batch.get()); + batch.reset(new leveldb::WriteBatch()); + LOG(INFO) << "Processed " << count << " files."; + } + } + // write the last batch + if (count % 1000 != 0) { + output_db->Write(leveldb::WriteOptions(), batch.get()); + } + LOG(INFO) << "Processed a total of " << count << " files."; +} + +} // namespace caffe2 + + +int main(int argc, char** argv) { + google::InitGoogleLogging(argv[0]); + google::SetUsageMessage("Converts an image dataset to a leveldb."); + google::ParseCommandLineFlags(&argc, &argv, true); + caffe2::ConvertToRawDataset( + FLAGS_input_db_name, FLAGS_output_db_name); + return 0; +} diff --git a/caffe2/binaries/inspect_gpus.cc b/caffe2/binaries/inspect_gpus.cc new file mode 100644 index 00000000000..0141bf11086 --- /dev/null +++ b/caffe2/binaries/inspect_gpus.cc @@ -0,0 +1,30 @@ +#include +#include + +#include + +#include "caffe2/core/common_gpu.h" +#include "glog/logging.h" + +int main(int argc, char** argv) { + google::InitGoogleLogging(argv[0]); + + int gpu_count; + CUDA_CHECK(cudaGetDeviceCount(&gpu_count)); + for (int i = 0; i < gpu_count; ++i) { + LOG(INFO) << "Querying device ID = " << i; + caffe2::DeviceQuery(i); + } + + std::stringstream sstream; + // Find topology + int can_access; + for (int i = 0; i < gpu_count; ++i) { + for (int j = 0; j < gpu_count; ++j) { + CUDA_CHECK(cudaDeviceCanAccessPeer(&can_access, i, j)); + sstream << ((i == j || can_access) ? "+" : "-") << " "; + } + sstream << std::endl; + } + LOG(INFO) << "Access pattern: " << std::endl << sstream.str(); +} diff --git a/caffe2/binaries/make_cifar_db.cc b/caffe2/binaries/make_cifar_db.cc new file mode 100644 index 00000000000..85a0aa33944 --- /dev/null +++ b/caffe2/binaries/make_cifar_db.cc @@ -0,0 +1,146 @@ +// +// This script converts the CIFAR dataset to the leveldb format used +// by caffe to perform classification. +// Usage: +// convert_cifar_data input_folder output_db_file +// The CIFAR dataset could be downloaded at +// http://www.cs.toronto.edu/~kriz/cifar.html + +#include // NOLINT(readability/streams) +#include +#include + +#include "caffe2/core/common.h" +#include "caffe2/core/db.h" +#include "caffe2/proto/caffe2.pb.h" +#include "gflags/gflags.h" +#include "glog/logging.h" + +DEFINE_string(input_folder, "", "The input image file name."); +DEFINE_string(output_train_db_name, "", "The output training leveldb name."); +DEFINE_string(output_test_db_name, "", "The output testing leveldb name."); +DEFINE_string(db, "leveldb", "The db type."); +DEFINE_bool(is_cifar100, false, + "If set, convert cifar100. Otherwise do cifar10."); +DEFINE_bool(channel_first, false, + "If set, write the data as channel-first (CHW order) as the old " + "Caffe does."); + +namespace caffe2 { + +using std::stringstream; + +const int kCIFARSize = 32; +const int kCIFARImageNBytes = kCIFARSize * kCIFARSize * 3; +const int kCIFAR10BatchSize = 10000; +const int kCIFAR10TestDataSize = 10000; +const int kCIFAR10TrainBatches = 5; + +const int kCIFAR100TrainDataSize = 50000; +const int kCIFAR100TestDataSize = 10000; + +void ReadImage(std::ifstream* file, int* label, char* buffer) { + char label_char; + if (FLAGS_is_cifar100) { + // Skip the coarse label. + file->read(&label_char, 1); + } + file->read(&label_char, 1); + *label = label_char; + if (FLAGS_channel_first) { + file->read(buffer, kCIFARImageNBytes); + } else { + // Yes, there are better ways to do it, like in-place swap... but I am too + // lazy so let's just write it in a memory-wasteful way. + static char channel_first_storage[kCIFARImageNBytes]; + file->read(channel_first_storage, kCIFARImageNBytes); + for (int c = 0; c < 3; ++c) { + for (int i = 0; i < kCIFARSize * kCIFARSize; ++i) { + buffer[i * 3 + c] = + channel_first_storage[c * kCIFARSize * kCIFARSize + i]; + } + } + } + return; +} + +void WriteToDB(const string& filename, const int num_items, + const int& offset, db::DB* db) { + TensorProtos protos; + TensorProto* data = protos.add_protos(); + TensorProto* label = protos.add_protos(); + data->set_data_type(TensorProto::BYTE); + if (FLAGS_channel_first) { + data->add_dims(1); + data->add_dims(3); + data->add_dims(kCIFARSize); + data->add_dims(kCIFARSize); + } else { + data->add_dims(1); + data->add_dims(kCIFARSize); + data->add_dims(kCIFARSize); + data->add_dims(3); + } + label->set_data_type(TensorProto::INT32); + label->add_dims(1); + label->add_int32_data(0); + + LOG(INFO) << "Converting file " << filename; + std::ifstream data_file(filename.c_str(), + std::ios::in | std::ios::binary); + CHECK(data_file) << "Unable to open file " << filename; + char str_buffer[kCIFARImageNBytes]; + int label_value; + string serialized_protos; + std::unique_ptr transaction(db->NewTransaction()); + for (int itemid = 0; itemid < num_items; ++itemid) { + ReadImage(&data_file, &label_value, str_buffer); + data->set_byte_data(str_buffer, kCIFARImageNBytes); + label->set_int32_data(0, label_value); + protos.SerializeToString(&serialized_protos); + snprintf(str_buffer, kCIFARImageNBytes, "%05d", + offset + itemid); + transaction->Put(string(str_buffer), serialized_protos); + } +} + +void ConvertCIFAR() { + std::unique_ptr train_db( + db::CreateDB(FLAGS_db, FLAGS_output_train_db_name, db::NEW)); + std::unique_ptr test_db( + db::CreateDB(FLAGS_db, FLAGS_output_test_db_name, db::NEW)); + + if (!FLAGS_is_cifar100) { + // This is cifar 10. + for (int fileid = 0; fileid < kCIFAR10TrainBatches; ++fileid) { + stringstream train_file; + train_file << FLAGS_input_folder << "/data_batch_" << fileid + 1 + << ".bin"; + WriteToDB(train_file.str(), kCIFAR10BatchSize, + fileid * kCIFAR10BatchSize, train_db.get()); + } + stringstream test_file; + test_file << FLAGS_input_folder << "/test_batch.bin"; + WriteToDB(test_file.str(), kCIFAR10TestDataSize, 0, test_db.get()); + } else { + // This is cifar 100. + stringstream train_file; + train_file << FLAGS_input_folder << "/train.bin"; + WriteToDB(train_file.str(), kCIFAR100TrainDataSize, 0, train_db.get()); + stringstream test_file; + test_file << FLAGS_input_folder << "/test.bin"; + WriteToDB(test_file.str(), kCIFAR100TestDataSize, 0, test_db.get()); + } +} + +} // namespace caffe2 + +int main(int argc, char** argv) { + google::InitGoogleLogging(argv[0]); + google::SetUsageMessage( + "This script converts the CIFAR dataset to the db format used " + "by caffe to perform classification."); + google::ParseCommandLineFlags(&argc, &argv, true); + caffe2::ConvertCIFAR(); + return 0; +} diff --git a/caffe2/binaries/make_image_db.cc b/caffe2/binaries/make_image_db.cc new file mode 100644 index 00000000000..9bb8abf8614 --- /dev/null +++ b/caffe2/binaries/make_image_db.cc @@ -0,0 +1,146 @@ +// This script converts an image dataset to a database. +// +// FLAGS_input_folder is the root folder that holds all the images, and +// FLAGS_list_file should be a list of files as well as their labels, in the +// format as +// subfolder1/file1.JPEG 7 +// .... + +#include + +#include +#include // NOLINT(readability/streams) +#include +#include + +#include "caffe2/core/common.h" +#include "caffe2/core/db.h" +#include "caffe2/proto/caffe2.pb.h" +#include "gflags/gflags.h" +#include "glog/logging.h" + +DEFINE_bool(shuffle, false, + "Randomly shuffle the order of images and their labels"); +DEFINE_string(input_folder, "", "The input image file name."); +DEFINE_string(list_file, "", "The text file containing the list of images."); +DEFINE_string(output_db_name, "", "The output training leveldb name."); +DEFINE_string(db, "leveldb", "The db type."); +DEFINE_bool(raw, false, + "If set, we pre-read the images and store the raw buffer."); +DEFINE_bool(color, true, "If set, load images in color."); +DEFINE_int32(scale, 256, + "If FLAGS_raw is set, scale all the images' shorter edge to the given " + "value."); +DEFINE_bool(warp, false, "If warp is set, warp the images to square."); + + +namespace caffe2 { + +void ConvertImageDataset( + const string& input_folder, const string& list_filename, + const string& output_db_name, const bool shuffle) { + std::ifstream list_file(list_filename); + std::vector > lines; + std::string filename; + int file_label; + while (list_file >> filename >> file_label) { + lines.push_back(std::make_pair(filename, file_label)); + } + if (FLAGS_shuffle) { + // randomly shuffle data + LOG(INFO) << "Shuffling data"; + std::shuffle(lines.begin(), lines.end(), + std::default_random_engine(1701)); + } + LOG(INFO) << "A total of " << lines.size() << " images."; + + + LOG(INFO) << "Opening db " << output_db_name; + std::unique_ptr db(db::CreateDB(FLAGS_db, output_db_name, db::NEW)); + std::unique_ptr transaction(db->NewTransaction()); + + TensorProtos protos; + TensorProto* data = protos.add_protos(); + TensorProto* label = protos.add_protos(); + if (FLAGS_raw) { + data->set_data_type(TensorProto::BYTE); + data->add_dims(0); + data->add_dims(0); + if (FLAGS_color) { + data->add_dims(3); + } + } else { + data->set_data_type(TensorProto::STRING); + data->add_dims(1); + data->add_string_data(""); + } + label->set_data_type(TensorProto::INT32); + label->add_dims(1); + label->add_int32_data(0); + const int kMaxKeyLength = 256; + char key_cstr[kMaxKeyLength]; + string value; + int count = 0; + + for (int item_id = 0; item_id < lines.size(); ++item_id) { + // First, set label. + label->set_int32_data(0, lines[item_id].second); + if (!FLAGS_raw) { + // Second, read images. + std::ifstream image_file_stream(input_folder + lines[item_id].first); + data->mutable_string_data(0)->assign( + (std::istreambuf_iterator(image_file_stream)), + std::istreambuf_iterator()); + } else { + // Need to do some opencv magic. + cv::Mat img = cv::imread( + input_folder + lines[item_id].first, + FLAGS_color ? CV_LOAD_IMAGE_COLOR : CV_LOAD_IMAGE_GRAYSCALE); + // Do resizing. + cv::Mat resized_img; + int scaled_width, scaled_height; + if (FLAGS_warp) { + scaled_width = FLAGS_scale; + scaled_height = FLAGS_scale; + } else if (img.rows > img.cols) { + scaled_width = FLAGS_scale; + scaled_height = static_cast(img.rows) * FLAGS_scale / img.cols; + } else { + scaled_height = FLAGS_scale; + scaled_width = static_cast(img.cols) * FLAGS_scale / img.rows; + } + cv::resize(img, resized_img, cv::Size(scaled_width, scaled_height), 0, 0, + cv::INTER_LINEAR); + data->set_dims(0, scaled_height); + data->set_dims(1, scaled_width); + DCHECK(resized_img.isContinuous()); + data->set_byte_data( + resized_img.ptr(), + scaled_height * scaled_width * (FLAGS_color ? 3 : 1)); + } + snprintf(key_cstr, kMaxKeyLength, "%08d_%s", item_id, + lines[item_id].first.c_str()); + protos.SerializeToString(&value); + // Put in db + transaction->Put(string(key_cstr), value); + if (++count % 1000 == 0) { + // Commit the current writes. + transaction->Commit(); + LOG(INFO) << "Processed " << count << " files."; + } + } + LOG(INFO) << "Processed a total of " << count << " files."; +} + +} // namespace caffe2 + + +int main(int argc, char** argv) { + google::InitGoogleLogging(argv[0]); + google::SetUsageMessage("Converts an image dataset to a db."); + google::ParseCommandLineFlags(&argc, &argv, true); + caffe2::ConvertImageDataset( + FLAGS_input_folder, FLAGS_list_file, + FLAGS_output_db_name, FLAGS_shuffle); + return 0; +} diff --git a/caffe2/binaries/make_mnist_db.cc b/caffe2/binaries/make_mnist_db.cc new file mode 100644 index 00000000000..d25ce0a1fb3 --- /dev/null +++ b/caffe2/binaries/make_mnist_db.cc @@ -0,0 +1,123 @@ +// This script converts the MNIST dataset to leveldb. +// The MNIST dataset could be downloaded at +// http://yann.lecun.com/exdb/mnist/ + +#include // NOLINT(readability/streams) +#include + +#include "caffe2/core/common.h" +#include "caffe2/core/db.h" +#include "caffe2/proto/caffe2.pb.h" +#include "gflags/gflags.h" +#include "glog/logging.h" + +DEFINE_string(image_file, "", "The input image file name."); +DEFINE_string(label_file, "", "The label file name."); +DEFINE_string(output_file, "", "The output db name."); +DEFINE_string(db, "leveldb", "The db type."); +DEFINE_int32(data_limit, -1, + "If set, only output this number of data points."); +DEFINE_bool(channel_first, false, + "If set, write the data as channel-first (CHW order) as the old " + "Caffe does."); + +namespace caffe2 { +uint32_t swap_endian(uint32_t val) { + val = ((val << 8) & 0xFF00FF00) | ((val >> 8) & 0xFF00FF); + return (val << 16) | (val >> 16); +} + +void convert_dataset(const char* image_filename, const char* label_filename, + const char* db_path, const int data_limit) { + // Open files + std::ifstream image_file(image_filename, std::ios::in | std::ios::binary); + std::ifstream label_file(label_filename, std::ios::in | std::ios::binary); + CHECK(image_file) << "Unable to open file " << image_filename; + CHECK(label_file) << "Unable to open file " << label_filename; + // Read the magic and the meta data + uint32_t magic; + uint32_t num_items; + uint32_t num_labels; + uint32_t rows; + uint32_t cols; + + image_file.read(reinterpret_cast(&magic), 4); + magic = swap_endian(magic); + CHECK_EQ(magic, 2051) << "Incorrect image file magic."; + label_file.read(reinterpret_cast(&magic), 4); + magic = swap_endian(magic); + CHECK_EQ(magic, 2049) << "Incorrect label file magic."; + image_file.read(reinterpret_cast(&num_items), 4); + num_items = swap_endian(num_items); + label_file.read(reinterpret_cast(&num_labels), 4); + num_labels = swap_endian(num_labels); + CHECK_EQ(num_items, num_labels); + image_file.read(reinterpret_cast(&rows), 4); + rows = swap_endian(rows); + image_file.read(reinterpret_cast(&cols), 4); + cols = swap_endian(cols); + + // leveldb + std::unique_ptr mnist_db(db::CreateDB(FLAGS_db, db_path, db::NEW)); + std::unique_ptr transaction(mnist_db->NewTransaction()); + // Storing to db + char label_value; + std::vector pixels(rows * cols); + int count = 0; + const int kMaxKeyLength = 10; + char key_cstr[kMaxKeyLength]; + string value; + + TensorProtos protos; + TensorProto* data = protos.add_protos(); + TensorProto* label = protos.add_protos(); + data->set_data_type(TensorProto::BYTE); + if (FLAGS_channel_first) { + data->add_dims(1); + data->add_dims(1); + data->add_dims(rows); + data->add_dims(cols); + } else { + data->add_dims(1); + data->add_dims(rows); + data->add_dims(cols); + data->add_dims(1); + } + label->set_data_type(TensorProto::INT32); + label->add_dims(1); + label->add_int32_data(0); + + LOG(INFO) << "A total of " << num_items << " items."; + LOG(INFO) << "Rows: " << rows << " Cols: " << cols; + for (int item_id = 0; item_id < num_items; ++item_id) { + image_file.read(pixels.data(), rows * cols); + label_file.read(&label_value, 1); + for (int i = 0; i < rows * cols; ++i) { + data->set_byte_data(pixels.data(), rows * cols); + } + label->set_int32_data(0, static_cast(label_value)); + snprintf(key_cstr, kMaxKeyLength, "%08d", item_id); + protos.SerializeToString(&value); + string keystr(key_cstr); + + // Put in db + transaction->Put(keystr, value); + if (++count % 1000 == 0) { + transaction->Commit(); + } + if (data_limit > 0 && count == data_limit) { + LOG(INFO) << "Reached data limit of " << data_limit << ", stop."; + break; + } + } +} +} // namespace caffe2 + +int main(int argc, char** argv) { + google::InitGoogleLogging(argv[0]); + google::SetUsageMessage("Converts the raw mnist dataset to a leveldb."); + google::ParseCommandLineFlags(&argc, &argv, true); + caffe2::convert_dataset(FLAGS_image_file.c_str(), FLAGS_label_file.c_str(), + FLAGS_output_file.c_str(), FLAGS_data_limit); + return 0; +} diff --git a/caffe2/binaries/print_registered_core_operators.cc b/caffe2/binaries/print_registered_core_operators.cc new file mode 100644 index 00000000000..4638b08f687 --- /dev/null +++ b/caffe2/binaries/print_registered_core_operators.cc @@ -0,0 +1,11 @@ +#include + +#include "caffe2/core/operator.h" + +int main(int argc, char** argv) { + google::InitGoogleLogging(argv[0]); + std::cout << "CPU operator registry:" << std::endl; + caffe2::CPUOperatorRegistry()->TEST_PrintRegisteredNames(); + std::cout << "CUDA operator registry:" << std::endl; + caffe2::CUDAOperatorRegistry()->TEST_PrintRegisteredNames(); +} diff --git a/caffe2/binaries/run_client.cc b/caffe2/binaries/run_client.cc new file mode 100644 index 00000000000..9626ffcd94c --- /dev/null +++ b/caffe2/binaries/run_client.cc @@ -0,0 +1,54 @@ +#include +#include + +#include "caffe2/core/client.h" +#include "gflags/gflags.h" +#include "glog/logging.h" + +DEFINE_string(client_file, "", "The given path to the client protobuffer."); +DEFINE_string(output_file, "", "The output file."); +DEFINE_int32(input_size, 0, "The input size."); +DEFINE_int32(iter, 0, "The number of iterations for timing."); +DEFINE_string(input_file, "", + "The input file containing a list of float numbers."); + +int main(int argc, char** argv) { + google::InitGoogleLogging(argv[0]); + google::SetUsageMessage("Runs a given client."); + google::ParseCommandLineFlags(&argc, &argv, true); + LOG(INFO) << "Loading client file: " << FLAGS_client_file; + caffe2::Client client(FLAGS_client_file); + std::vector input; + if (FLAGS_input_file.size()) { + std::ifstream infile; + infile.open(FLAGS_input_file, std::ios::in); + float value; + while (infile >> value) { + input.push_back(value); + } + } else { + input.resize(FLAGS_input_size); + } + LOG(INFO) << "An input of " << input.size() << " values."; + std::vector output; + CHECK(client.Run(input, &output)); + clock_t start = clock(); + for (int i = 0; i < FLAGS_iter; ++i) { + CHECK(client.Run(input, &output)); + } + LOG(INFO) << "Timing: "<< FLAGS_iter << " iters took " + << static_cast(clock() - start) / CLOCKS_PER_SEC + << " seconds."; + LOG(INFO) << "Output: " << output.size() << " dims."; + if (FLAGS_output_file.size()) { + std::ofstream outfile; + outfile.open(FLAGS_output_file, std::ios::out | std::ios::trunc); + for (int i = 0; i < output.size(); ++i) { + outfile << output[i] << std::endl; + } + outfile.close(); + } + // This is to allow us to use memory leak checks. + google::ShutDownCommandLineFlags(); + return 0; +} diff --git a/caffe2/binaries/run_plan.cc b/caffe2/binaries/run_plan.cc new file mode 100644 index 00000000000..a94e583fad0 --- /dev/null +++ b/caffe2/binaries/run_plan.cc @@ -0,0 +1,23 @@ +#include "caffe2/core/operator.h" +#include "caffe2/proto/caffe2.pb.h" +#include "caffe2/utils/proto_utils.h" +#include "gflags/gflags.h" +#include "glog/logging.h" + +DEFINE_string(plan, "", "The given path to the plan protobuffer."); + +int main(int argc, char** argv) { + google::InitGoogleLogging(argv[0]); + google::SetUsageMessage("Runs a given plan."); + google::ParseCommandLineFlags(&argc, &argv, true); + LOG(INFO) << "Loading plan: " << FLAGS_plan; + caffe2::PlanDef plan_def; + CHECK(ReadProtoFromFile(FLAGS_plan, &plan_def)); + std::unique_ptr workspace(new caffe2::Workspace()); + workspace->RunPlan(plan_def); + + // This is to allow us to use memory leak checks. + google::protobuf::ShutdownProtobufLibrary(); + google::ShutDownCommandLineFlags(); + return 0; +} diff --git a/caffe2/binaries/run_plan_mpi.cc b/caffe2/binaries/run_plan_mpi.cc new file mode 100644 index 00000000000..954c77e40db --- /dev/null +++ b/caffe2/binaries/run_plan_mpi.cc @@ -0,0 +1,27 @@ +#include + +#include "caffe2/core/operator.h" +#include "caffe2/proto/caffe2.pb.h" +#include "caffe2/utils/proto_utils.h" +#include "gflags/gflags.h" +#include "glog/logging.h" + +DEFINE_string(plan, "", "The given path to the plan protobuffer."); + +int main(int argc, char** argv) { + MPI_Init(&argc, &argv); + google::InitGoogleLogging(argv[0]); + google::SetUsageMessage("Runs a given plan."); + google::ParseCommandLineFlags(&argc, &argv, true); + LOG(INFO) << "Loading plan: " << FLAGS_plan; + caffe2::PlanDef plan_def; + CHECK(ReadProtoFromFile(FLAGS_plan, &plan_def)); + std::unique_ptr workspace(new caffe2::Workspace()); + workspace->RunPlan(plan_def); + + // This is to allow us to use memory leak checks. + google::protobuf::ShutdownProtobufLibrary(); + google::ShutDownCommandLineFlags(); + MPI_Finalize(); + return 0; +} diff --git a/caffe2/binaries/split_db.cc b/caffe2/binaries/split_db.cc new file mode 100644 index 00000000000..a992e76a604 --- /dev/null +++ b/caffe2/binaries/split_db.cc @@ -0,0 +1,52 @@ +#include +#include + +#include "caffe2/core/db.h" +#include "caffe2/proto/caffe2.pb.h" +#include "gflags/gflags.h" +#include "glog/logging.h" + +DEFINE_string(input_db, "", "The input db."); +DEFINE_int32(splits, 0, "The number of splits."); +DEFINE_string(db_type, "", "The db type."); +DEFINE_int32(batch_size, 1000, "The write batch size."); + +using caffe2::db::Cursor; +using caffe2::db::DB; +using caffe2::db::Transaction; + +int main(int argc, char** argv) { + google::InitGoogleLogging(argv[0]); + google::SetUsageMessage( + "This script converts databases between different formats."); + google::ParseCommandLineFlags(&argc, &argv, true); + + std::unique_ptr in_db(caffe2::db::CreateDB( + FLAGS_db_type, FLAGS_input_db, caffe2::db::READ)); + std::unique_ptr cursor(in_db->NewCursor()); + + CHECK_GT(FLAGS_splits, 0) << "Must specify the number of splits."; + std::vector > out_dbs; + std::vector > transactions; + for (int i = 0; i < FLAGS_splits; ++i) { + out_dbs.push_back( + std::unique_ptr(caffe2::db::CreateDB( + FLAGS_db_type, FLAGS_input_db + "_split_" + std::to_string(i), + caffe2::db::NEW))); + transactions.push_back( + std::unique_ptr(out_dbs[i]->NewTransaction())); + } + + int count = 0; + for (; cursor->Valid(); cursor->Next()) { + transactions[count % FLAGS_splits]->Put(cursor->key(), cursor->value()); + if (++count % FLAGS_batch_size == 0) { + for (int i = 0; i < FLAGS_splits; ++i) { + transactions[i]->Commit(); + } + LOG(INFO) << "Splitted " << count << " items so far."; + } + } + LOG(INFO) << "A total of " << count << " items processed."; + return 0; +} diff --git a/caffe2/core/BREW b/caffe2/core/BREW new file mode 100644 index 00000000000..1646ba2c049 --- /dev/null +++ b/caffe2/core/BREW @@ -0,0 +1,94 @@ +cc_library( + name = "core", + srcs = [ + "client.cc", + "db.cc", + "minidb.cc", + "net.cc", + "operator.cc", + "typeid.cc", + "workspace.cc", + ], + hdrs = [ + "blob.h", + "client.h", + "common.h", + "context.h", + "db.h", + "net.h", + "operator.h", + "registry.h", + "typeid.h", + "types.h", + "workspace.h" + ], + deps = [ + "//caffe2/proto:caffe2_proto", + "//caffe2/utils:proto_utils", + "//caffe2/utils:simple_queue", + "//third_party/glog:glog", + ], + whole_archive = True, +) + +cuda_library( + name = "core_gpu", + srcs = [ + "common_gpu.cc", + ], + hdrs = [ + "common_gpu.h", + "context_gpu.h", + ], + deps = [ + ":core", + ] +) + +cc_headers( + name = "core_cudnn", + srcs = [ + "common_cudnn.h", + ], + deps = [ + "//third_party/cudnn:cudnn", + ], +) + +cc_test( + name = "core_test", + srcs = [ + "blob_test.cc", + "context_test.cc", + "operator_test.cc", + "parallel_net_test.cc", + "workspace_test.cc" + ], + deps = [ + ":core", + "//gtest:gtest", + "//gtest:gtest_main", + ], +) + +cc_test( + name = "core_test_gpu", + srcs = [ + "blob_test_gpu.cc", + ], + deps = [ + ":core_gpu", + "//gtest:gtest", + "//gtest:gtest_main", + ], +) + +cc_test( + name = "registry_test", + srcs = ["registry_test.cc"], + deps = [ + ":core", + "//gtest:gtest", + "//gtest:gtest_main", + ], +) diff --git a/caffe2/core/blob.h b/caffe2/core/blob.h new file mode 100644 index 00000000000..4ff67aad7e3 --- /dev/null +++ b/caffe2/core/blob.h @@ -0,0 +1,209 @@ +#ifndef CAFFE2_CORE_BLOB_H_ +#define CAFFE2_CORE_BLOB_H_ + +#include +#include + +#include "caffe2/core/common.h" +#include "caffe2/core/context.h" +#include "caffe2/core/typeid.h" +#include "caffe2/proto/caffe2.pb.h" +#include "glog/logging.h" + +namespace caffe2 { + +namespace internal { +// Destroy is a templated function that allows us to memorize the type of the +// pointer we are storing in a void*. +template +void Destroy(void* pointer) { + delete static_cast(pointer); +} +} // namespace internal + +// Blob is a general container that hosts a pointer as well as checking its +// type, and takes charge of deleting it when the blob is deallocated. A blob +// could contain ANYTHING, although the most common case is to contain a Tensor. +class Blob { + public: + typedef void (*DestroyCall)(void *); + + Blob() : id_(internal::gUnknownType), pointer_(nullptr) {} + + ~Blob() { Reset(); } + + template + inline bool IsType() const { return internal::IsTypeId(id_); } + inline string TypeName() const { return internal::TypeName(id_); } + template + const T& Get() const { + CHECK(IsType()) << "wrong type for the Blob instance. Expected " + << internal::TypeName() << " got " + << internal::TypeName(id_); + return *static_cast(pointer_); + } + + template + T* GetMutable() { + if (!IsType()) { + VLOG(1) << "Create new mutable object " << internal::TypeName(); + if (pointer_) destroy_(pointer_); + // If we are not of the right type, create a new instance. + pointer_ = static_cast(new T()); + destroy_ = &internal::Destroy; + } + id_ = internal::GetTypeId(); + return static_cast(pointer_); + } + + inline void Reset() { + if (pointer_) { + destroy_(pointer_); + pointer_ = nullptr; + } + } + + private: + internal::TypeId id_; + void* pointer_; + DestroyCall destroy_; + + DISABLE_COPY_AND_ASSIGN(Blob); +}; + + +template +class Tensor { + public: + Tensor() : ndim_(0), size_(0), data_(nullptr), + own_data_(true), data_source_(nullptr) {} + + // Creates a tensor. The actual data allocation is going to be carried out + // till the first time mutable_data() is called, so there is no overhead of + // creating multiple tensors just as placeholders (although I haven't got a + // clear idea where such cases would happen). + explicit Tensor(const vector& dims) + : data_(nullptr), own_data_(true), data_source_(nullptr) { + Reshape(dims); + } + + template + Tensor(const Tensor& src, Context* context) + : data_(nullptr), own_data_(true), data_source_(nullptr) { + Reshape(src.dims()); + context->template Copy( + mutable_data(), src.data(), src.size()); + } + + // Creates a tensor, and fills its contents with the given values. We need to + // have a context passed in as the copy function is device dependent. + Tensor(const vector& dims, vector values, Context* context) + : data_(nullptr), own_data_(true), data_source_(nullptr) { + Reshape(dims); + CHECK_EQ(values.size(), size_); + context->template Copy( + mutable_data(), values.data(), values.size()); + } + + // Special case of above: create a tensor of shape 1, and the given value. + Tensor(const dtype& value, Context* context) + : data_(nullptr), own_data_(true), data_source_(nullptr) { + Reshape(std::vector(1, 1)); + context->template Copy( + mutable_data(), &value, 1); + } + + virtual ~Tensor() { + Free(); + } + + void Reshape(const vector& dims) { + CHECK_GT(dims.size(), 0); + dims_ = dims; + ndim_ = dims_.size(); + // Calculate the size. + int new_size = 1; + for (int d : dims_) { + CHECK_GT(d, 0); + new_size *= d; + } + // If the size changes, we will call Free(). The next data() call will + // re-allocate the memory. + if (data_ && size_ != new_size) { + Free(); + } + size_ = new_size; + } + + template + inline void ReshapeLike(const Tensor& src_tensor) { + Reshape(src_tensor.dims()); + } + + void ShareData(const Tensor& src) { + // To share data, the sizes must be equal. + CHECK_EQ(src.size_, size_) + << "Size mismatch - did you call reshape before sharing the data?"; + if (data_) Free(); + own_data_ = false; + data_source_ = &src; + } + + inline int ndim() const { return ndim_; } + inline int size() const { return size_; } + inline const vector& dims() const { return dims_; } + inline int dim(const int i) const { + CHECK_LT(i, ndim_) << "Exceeding ndim limit " << ndim_; + CHECK_GE(i, 0) << "Cannot have negative index"; + return dims_[i]; + } + + const dtype* data() const { + if (own_data_) { + CHECK_NOTNULL(data_); + return data_; + } else { + CHECK_NOTNULL(data_source_); + CHECK_EQ(data_source_->size_, size_) << "Source data size has changed."; + CHECK_NOTNULL(data_source_->data()); + return data_source_->data(); + } + } + + dtype* mutable_data() { + CHECK(own_data_) << "Cannot call mutable_data() from a shared tensor."; + CHECK_GT(size_, 0) << "Cannot call mutable_data on a size 0 tensor."; + if (!data_) Allocate(); + CHECK_NOTNULL(data_); + return data_; + } + + void Allocate() { + CHECK(data_ == nullptr); + CHECK_GT(size_, 0); + data_ = static_cast(Context::New(size_ * sizeof(dtype))); + } + + void Free() { + if (own_data_) { + if (data_) { + Context::Delete(data_); + } + } + own_data_ = true; + data_ = nullptr; + } + + protected: + int ndim_; + vector dims_; + int size_; + dtype* data_; + bool own_data_; + const Tensor* data_source_; + + DISABLE_COPY_AND_ASSIGN(Tensor); +}; + +} // namespace caffe2 +#endif // CAFFE2_CORE_BLOB_H_ diff --git a/caffe2/core/blob_test.cc b/caffe2/core/blob_test.cc new file mode 100644 index 00000000000..21ff921aea1 --- /dev/null +++ b/caffe2/core/blob_test.cc @@ -0,0 +1,186 @@ +#include + +#include "caffe2/core/blob.h" +#include "caffe2/core/common.h" +#include "caffe2/core/context.h" +#include "caffe2/proto/caffe2.pb.h" +#include "gtest/gtest.h" + +namespace caffe2 { + +using namespace internal; // NOLINT + +class Foo {}; +class Bar {}; + +TEST(BlobTest, TypeId) { + TypeId int_id = GetTypeId(); + TypeId float_id = GetTypeId(); + TypeId foo_id = GetTypeId(); + TypeId bar_id = GetTypeId(); + EXPECT_NE(int_id, float_id); + EXPECT_NE(float_id, foo_id); + EXPECT_NE(foo_id, bar_id); + EXPECT_TRUE(IsTypeId(int_id)); + EXPECT_TRUE(IsTypeId(float_id)); + EXPECT_TRUE(IsTypeId(foo_id)); + EXPECT_TRUE(IsTypeId(bar_id)); + EXPECT_FALSE(IsTypeId(float_id)); + EXPECT_FALSE(IsTypeId(foo_id)); + EXPECT_FALSE(IsTypeId(int_id)); + EXPECT_FALSE(IsTypeId(bar_id)); +} + +TEST(BlobTest, Blob) { + Blob blob; + + int* int_unused UNUSED_VARIABLE = blob.GetMutable(); + EXPECT_TRUE(blob.IsType()); + EXPECT_FALSE(blob.IsType()); + + Foo* foo_unused UNUSED_VARIABLE = blob.GetMutable(); + EXPECT_TRUE(blob.IsType()); + EXPECT_FALSE(blob.IsType()); +} + +TEST(BlobDeathTest, BlobUninitialized) { + Blob blob; + ASSERT_DEATH(blob.Get(), ".*wrong type for the Blob instance.*"); +} + +TEST(BlobDeathTest, BlobWrongType) { + Blob blob; + Foo* foo_unused UNUSED_VARIABLE = blob.GetMutable(); + EXPECT_TRUE(blob.IsType()); + EXPECT_FALSE(blob.IsType()); + // When not null, we should only call with the right type. + EXPECT_NE(&blob.Get(), nullptr); + ASSERT_DEATH(blob.Get(), ".*wrong type for the Blob instance.*"); +} + +template class TensorCPUTest : public ::testing::Test {}; +template class TensorCPUDeathTest : public ::testing::Test {}; +typedef ::testing::Types TensorTypes; +TYPED_TEST_CASE(TensorCPUTest, TensorTypes); +TYPED_TEST_CASE(TensorCPUDeathTest, TensorTypes); + +TYPED_TEST(TensorCPUTest, TensorInitializedEmpty) { + Tensor tensor; + EXPECT_EQ(tensor.ndim(), 0); + vector dims(3); + dims[0] = 2; + dims[1] = 3; + dims[2] = 5; + tensor.Reshape(dims); + EXPECT_EQ(tensor.ndim(), 3); + EXPECT_EQ(tensor.dim(0), 2); + EXPECT_EQ(tensor.dim(1), 3); + EXPECT_EQ(tensor.dim(2), 5); + EXPECT_EQ(tensor.size(), 2 * 3 * 5); + EXPECT_TRUE(tensor.mutable_data() != nullptr); + EXPECT_TRUE(tensor.data() != nullptr); +} + +TYPED_TEST(TensorCPUTest, TensorInitializedNonEmpty) { + vector dims(3); + dims[0] = 2; + dims[1] = 3; + dims[2] = 5; + Tensor tensor(dims); + EXPECT_EQ(tensor.ndim(), 3); + EXPECT_EQ(tensor.dim(0), 2); + EXPECT_EQ(tensor.dim(1), 3); + EXPECT_EQ(tensor.dim(2), 5); + EXPECT_TRUE(tensor.mutable_data() != nullptr); + EXPECT_TRUE(tensor.data() != nullptr); + dims[0] = 7; + dims[1] = 11; + dims[2] = 13; + dims.push_back(17); + tensor.Reshape(dims); + EXPECT_EQ(tensor.ndim(), 4); + EXPECT_EQ(tensor.dim(0), 7); + EXPECT_EQ(tensor.dim(1), 11); + EXPECT_EQ(tensor.dim(2), 13); + EXPECT_EQ(tensor.dim(3), 17); + EXPECT_TRUE(tensor.mutable_data() != nullptr); + EXPECT_TRUE(tensor.data() != nullptr); +} + +TYPED_TEST(TensorCPUTest, TensorShareData) { + vector dims(3); + dims[0] = 2; + dims[1] = 3; + dims[2] = 5; + Tensor tensor(dims); + Tensor other_tensor(dims); + other_tensor.ShareData(tensor); + EXPECT_TRUE(tensor.mutable_data() != nullptr); + EXPECT_TRUE(tensor.data() != nullptr); + EXPECT_TRUE(other_tensor.data() != nullptr); + EXPECT_EQ(tensor.data(), other_tensor.data()); + // Set one value, check the other + for (int i = 0; i < tensor.size(); ++i) { + tensor.mutable_data()[i] = i; + EXPECT_EQ(other_tensor.data()[i], i); + } +} + +TYPED_TEST(TensorCPUTest, TensorShareDataCanUseDifferentShapes) { + vector dims(3); + dims[0] = 2; + dims[1] = 3; + dims[2] = 5; + vector alternate_dims(1); + alternate_dims[0] = 2 * 3 * 5; + Tensor tensor(dims); + Tensor other_tensor(alternate_dims); + other_tensor.ShareData(tensor); + EXPECT_EQ(other_tensor.ndim(), 1); + EXPECT_EQ(other_tensor.dim(0), alternate_dims[0]); + EXPECT_TRUE(tensor.mutable_data() != nullptr); + EXPECT_TRUE(tensor.data() != nullptr); + EXPECT_TRUE(other_tensor.data() != nullptr); + EXPECT_EQ(tensor.data(), other_tensor.data()); + // Set one value, check the other + for (int i = 0; i < tensor.size(); ++i) { + tensor.mutable_data()[i] = i; + EXPECT_EQ(other_tensor.data()[i], i); + } +} + +TYPED_TEST(TensorCPUDeathTest, ShareDataCannotInitializeDataFromSharedTensor) { + vector dims(3); + dims[0] = 2; + dims[1] = 3; + dims[2] = 5; + Tensor tensor(dims); + Tensor other_tensor(dims); + other_tensor.ShareData(tensor); + ASSERT_DEATH(other_tensor.mutable_data(), ""); +} + +TYPED_TEST(TensorCPUDeathTest, CannotDoReshapewithAlias) { + vector dims(3); + dims[0] = 2; + dims[1] = 3; + dims[2] = 5; + Tensor tensor(dims); + Tensor other_tensor(dims); + other_tensor.ShareData(tensor); + dims[0] = 7; + tensor.Reshape(dims); + EXPECT_TRUE(tensor.mutable_data() != nullptr); + ASSERT_DEATH(other_tensor.data(), ".*Source data size has changed..*"); +} + +TYPED_TEST(TensorCPUDeathTest, CannotAccessDataWhenEmpty) { + Tensor tensor; + EXPECT_EQ(tensor.ndim(), 0); + ASSERT_DEATH(tensor.data(), ".*Check failed: 'data_' Must be non NULL.*"); +} + + +} // namespace caffe2 + + diff --git a/caffe2/core/blob_test_gpu.cc b/caffe2/core/blob_test_gpu.cc new file mode 100644 index 00000000000..5d99640d352 --- /dev/null +++ b/caffe2/core/blob_test_gpu.cc @@ -0,0 +1,109 @@ +#include // NOLINT + +#include "caffe2/core/blob.h" +#include "caffe2/core/common_gpu.h" +#include "caffe2/core/context_gpu.h" +#include "caffe2/proto/caffe2.pb.h" +#include "gtest/gtest.h" + +namespace caffe2 { + +template class TensorGPUTest : public ::testing::Test {}; +template class TensorGPUDeathTest : public ::testing::Test {}; +typedef ::testing::Types TensorTypes; +TYPED_TEST_CASE(TensorGPUTest, TensorTypes); +TYPED_TEST_CASE(TensorGPUDeathTest, TensorTypes); + +TYPED_TEST(TensorGPUTest, TensorInitializedEmpty) { + Tensor tensor; + EXPECT_EQ(tensor.ndim(), 0); + vector dims(3); + dims[0] = 2; + dims[1] = 3; + dims[2] = 5; + tensor.Reshape(dims); + EXPECT_EQ(tensor.ndim(), 3); + EXPECT_EQ(tensor.dim(0), 2); + EXPECT_EQ(tensor.dim(1), 3); + EXPECT_EQ(tensor.dim(2), 5); + EXPECT_TRUE(tensor.mutable_data() != nullptr); + EXPECT_TRUE(tensor.data() != nullptr); +} + +TYPED_TEST(TensorGPUTest, TensorInitializedNonEmpty) { + vector dims(3); + dims[0] = 2; + dims[1] = 3; + dims[2] = 5; + Tensor tensor(dims); + EXPECT_EQ(tensor.ndim(), 3); + EXPECT_EQ(tensor.dim(0), 2); + EXPECT_EQ(tensor.dim(1), 3); + EXPECT_EQ(tensor.dim(2), 5); + EXPECT_TRUE(tensor.mutable_data() != nullptr); + EXPECT_TRUE(tensor.data() != nullptr); + dims[0] = 7; + dims[1] = 11; + dims[2] = 13; + dims.push_back(17); + tensor.Reshape(dims); + EXPECT_EQ(tensor.ndim(), 4); + EXPECT_EQ(tensor.dim(0), 7); + EXPECT_EQ(tensor.dim(1), 11); + EXPECT_EQ(tensor.dim(2), 13); + EXPECT_EQ(tensor.dim(3), 17); + EXPECT_TRUE(tensor.mutable_data() != nullptr); + EXPECT_TRUE(tensor.data() != nullptr); +} + +TYPED_TEST(TensorGPUTest, TensorShareData) { + vector dims(3); + dims[0] = 2; + dims[1] = 3; + dims[2] = 5; + Tensor tensor(dims); + Tensor other_tensor(dims); + other_tensor.ShareData(tensor); + EXPECT_TRUE(tensor.mutable_data() != nullptr); + EXPECT_TRUE(tensor.data() != nullptr); + EXPECT_TRUE(other_tensor.data() != nullptr); + EXPECT_EQ(tensor.data(), other_tensor.data()); +} + +TYPED_TEST(TensorGPUDeathTest, ShareDataCannotInitializeDataFromSharedTensor) { + ::testing::FLAGS_gtest_death_test_style = "threadsafe"; + vector dims(3); + dims[0] = 2; + dims[1] = 3; + dims[2] = 5; + Tensor tensor(dims); + Tensor other_tensor(dims); + other_tensor.ShareData(tensor); + ASSERT_DEATH(other_tensor.mutable_data(), ""); +} + +TYPED_TEST(TensorGPUDeathTest, CannotDoReshapewithAlias) { + ::testing::FLAGS_gtest_death_test_style = "threadsafe"; + vector dims(3); + dims[0] = 2; + dims[1] = 3; + dims[2] = 5; + Tensor tensor(dims); + Tensor other_tensor(dims); + other_tensor.ShareData(tensor); + dims[0] = 7; + tensor.Reshape(dims); + EXPECT_TRUE(tensor.mutable_data() != nullptr); + ASSERT_DEATH(other_tensor.data(), "Source data size has changed."); +} + +TYPED_TEST(TensorGPUDeathTest, CannotAccessDataWhenEmpty) { + ::testing::FLAGS_gtest_death_test_style = "threadsafe"; + Tensor tensor; + EXPECT_EQ(tensor.ndim(), 0); + ASSERT_DEATH(tensor.data(), "Check failed: 'data_' Must be non NULL"); +} + +} // namespace caffe2 + + diff --git a/caffe2/core/client.cc b/caffe2/core/client.cc new file mode 100644 index 00000000000..aba42eb0cbb --- /dev/null +++ b/caffe2/core/client.cc @@ -0,0 +1,40 @@ +#include "caffe2/core/client.h" +#include "caffe2/core/net.h" +#include "caffe2/core/workspace.h" +#include "caffe2/utils/proto_utils.h" +#include "caffe2/proto/caffe2.pb.h" + +namespace caffe2 { + +Client::Client(const string& client_def_name) : workspace_(new Workspace()) { + SimpleClientDef client_def; + CHECK(ReadProtoFromFile(client_def_name, &client_def)); + workspace_->RunNetOnce(client_def.init_net()); + client_def.mutable_main_net()->set_name("main"); + CHECK(workspace_->CreateNet(client_def.main_net())); + input_blob_ = workspace_->GetBlob(client_def.input()); + output_blob_ = workspace_->GetBlob(client_def.output()); + CHECK(input_blob_ != nullptr); + CHECK(output_blob_ != nullptr); +} + +Client::~Client() { + delete workspace_; +} + +bool Client::Run(const vector& input, vector* output) { + Tensor* input_tensor = + input_blob_->GetMutable >(); + CHECK_EQ(input_tensor->size(), input.size()); + memcpy(input_tensor->mutable_data(), input.data(), + input.size() * sizeof(float)); + workspace_->RunNet("main"); + const Tensor& output_tensor = + output_blob_->Get >(); + output->resize(output_tensor.size()); + memcpy(output->data(), output_tensor.data(), output->size() * sizeof(float)); + return true; +} + +} // namespace caffe2 + diff --git a/caffe2/core/client.h b/caffe2/core/client.h new file mode 100644 index 00000000000..d2e33e6f04e --- /dev/null +++ b/caffe2/core/client.h @@ -0,0 +1,41 @@ +// Client is a very thin wrapper over a Caffe2 interface, allowing us to do +// a very primitive caffe network call without the need of revealing all +// the header files inside Caffe2. Also, what we are going to deal with is +// always float inputs and float outputs, and the input and output shapes +// should be fixed. This is minimal and is only used by Yangqing to deal +// with quick demo cases. + +#ifndef CAFFE2_CORE_CLIENT_H_ +#define CAFFE2_CORE_CLIENT_H_ + +#include +#include + +namespace caffe2 { + +// Forward declaration of a Caffe workspace. +class Blob; +class Workspace; + +// Workspace is a class that holds all the blobs in this run and also runs +// the operators. +class Client { + public: + explicit Client(const std::string& client_def_name); + ~Client(); + + // TODO(Yangqing): Figure out how we can deal with different types of + // inputs. + bool Run(const std::vector& input, std::vector* output); + + private: + // TODO(Yangqing): Are we really going to share workspaces? If not, let's + // remove this unnecessity. + Workspace* workspace_; + Blob* input_blob_; + Blob* output_blob_; +}; + +} // namespace caffe2 + +#endif // CAFFE2_CORE_CLIENT_H_ diff --git a/caffe2/core/common.h b/caffe2/core/common.h new file mode 100644 index 00000000000..6bc895d06b8 --- /dev/null +++ b/caffe2/core/common.h @@ -0,0 +1,42 @@ +#ifndef CAFFE2_CORE_COMMON_H_ +#define CAFFE2_CORE_COMMON_H_ + +#include +#include +#include +#include + +namespace caffe2 { + +using std::string; +using std::unique_ptr; +// Note(Yangqing): NVCC does not play well with unordered_map on some platforms, +// forcing us to use std::map instead of unordered_map. This may affect speed +// in some cases, but in most of the computation code we do not access map very +// often, so it should be fine for us. I am putting a CaffeMap alias so we can +// change it more easily if things work out for unordered_map down the road. +template +using CaffeMap = std::map; +// using CaffeMap = std::unordered_map; +using std::vector; + +// Just in order to mark things as not implemented. Do not use in final code. +#define NOT_IMPLEMENTED LOG(FATAL) << "Not Implemented." + +// suppress an unused variable. +#define UNUSED_VARIABLE __attribute__((unused)) + +// Disable the copy and assignment operator for a class. Note that this will +// disable the usage of the class in std containers. +#define DISABLE_COPY_AND_ASSIGN(classname) \ +private: \ + classname(const classname&); \ + classname& operator=(const classname&) + + +inline string GetGradientName(const string& name) { + return name + ".grad"; +} + +} // namespace caffe2 +#endif // CAFFE2_CORE_COMMON_H_ diff --git a/caffe2/core/common_cudnn.h b/caffe2/core/common_cudnn.h new file mode 100644 index 00000000000..21c6c96cf12 --- /dev/null +++ b/caffe2/core/common_cudnn.h @@ -0,0 +1,162 @@ +#ifndef CAFFE2_CORE_COMMON_CUDNN_H_ +#define CAFFE2_CORE_COMMON_CUDNN_H_ + +#include "caffe2/core/common_gpu.h" +#include "caffe2/core/context.h" +#include "caffe2/core/context_gpu.h" +#include "caffe2/core/types.h" +#include "caffe2/proto/caffe2.pb.h" +#include "cudnn.h" +#include "glog/logging.h" + +namespace caffe2 { + +namespace internal { +inline const char* cudnnGetErrorString(cudnnStatus_t status) { + switch (status) { + case CUDNN_STATUS_SUCCESS: + return "CUDNN_STATUS_SUCCESS"; + case CUDNN_STATUS_NOT_INITIALIZED: + return "CUDNN_STATUS_NOT_INITIALIZED"; + case CUDNN_STATUS_ALLOC_FAILED: + return "CUDNN_STATUS_ALLOC_FAILED"; + case CUDNN_STATUS_BAD_PARAM: + return "CUDNN_STATUS_BAD_PARAM"; + case CUDNN_STATUS_INTERNAL_ERROR: + return "CUDNN_STATUS_INTERNAL_ERROR"; + case CUDNN_STATUS_INVALID_VALUE: + return "CUDNN_STATUS_INVALID_VALUE"; + case CUDNN_STATUS_ARCH_MISMATCH: + return "CUDNN_STATUS_ARCH_MISMATCH"; + case CUDNN_STATUS_MAPPING_ERROR: + return "CUDNN_STATUS_MAPPING_ERROR"; + case CUDNN_STATUS_EXECUTION_FAILED: + return "CUDNN_STATUS_EXECUTION_FAILED"; + case CUDNN_STATUS_NOT_SUPPORTED: + return "CUDNN_STATUS_NOT_SUPPORTED"; + case CUDNN_STATUS_LICENSE_ERROR: + return "CUDNN_STATUS_LICENSE_ERROR"; + } +} +} // namespace internal + +#define CUDNN_CHECK(condition) \ + do { \ + cudnnStatus_t status = condition; \ + CHECK_EQ(status, CUDNN_STATUS_SUCCESS) << " " \ + << "Error at: " << __FILE__ << ":" << __LINE__ << ": " \ + << ::caffe2::internal::cudnnGetErrorString(status); \ + } while (0) + + +template class cudnnTypeWrapper; +template<> class cudnnTypeWrapper { + public: + static const cudnnDataType_t type = CUDNN_DATA_FLOAT; +}; +template<> class cudnnTypeWrapper { + public: + static const cudnnDataType_t type = CUDNN_DATA_DOUBLE; +}; + +inline cudnnTensorFormat_t GetCudnnTensorFormat(const StorageOrder& order) { + switch (order) { + case StorageOrder::NHWC: + return CUDNN_TENSOR_NHWC; + case StorageOrder::NCHW: + return CUDNN_TENSOR_NCHW; + default: + LOG(FATAL) << "Unknown cudnn equivalent for order: " << order; + } + // Just to suppress compiler warnings + return CUDNN_TENSOR_NCHW; +} + +// cudnnDescriptorMeta is the placeholder that wraps around a +// cudnnTensorDescriptor_t, allowing us to do descriptor change as-needed. +class cudnnDescriptorMeta { + public: + cudnnDescriptorMeta() { + CUDNN_CHECK(cudnnCreateTensorDescriptor(&desc_)); + } + cudnnDescriptorMeta(const cudnnDescriptorMeta& src) { + CUDNN_CHECK(cudnnCreateTensorDescriptor(&desc_)); + CHECK_NOTNULL(Descriptor(src.format_, src.type_, src.dims_, nullptr)); + } + ~cudnnDescriptorMeta() { + CUDNN_CHECK(cudnnDestroyTensorDescriptor(desc_)); + } + + inline cudnnTensorDescriptor_t Descriptor( + const cudnnTensorFormat_t format, const cudnnDataType_t type, + const vector& dims, bool* changed) { + if (type_ == type && format_ == format && dims_ == dims) { + // if not changed, simply return the current descriptor. + if (changed) *changed = false; + return desc_; + } + CHECK_EQ(dims.size(), 4) + << "Currently only 4-dimensional descriptor supported."; + format_ = format; + type_ = type; + dims_ = dims; + CUDNN_CHECK(cudnnSetTensor4dDescriptor( + desc_, format, type, dims_[0], + (format == CUDNN_TENSOR_NCHW? dims_[1] : dims_[3]), + (format == CUDNN_TENSOR_NCHW? dims_[2] : dims_[1]), + (format == CUDNN_TENSOR_NCHW? dims_[3] : dims_[2]))); + if (changed) *changed = true; + return desc_; + } + + private: + cudnnTensorDescriptor_t desc_; + cudnnTensorFormat_t format_; + cudnnDataType_t type_; + vector dims_; + cudnnDescriptorMeta& operator=(const cudnnDescriptorMeta&); +}; + +class CuDNNWrapper { + public: + // The default cuda context constructor. + explicit CuDNNWrapper(CUDAContext* context) + : cuda_context_(context), cudnn_handle_(nullptr) {} + + virtual ~CuDNNWrapper() { + if (cudnn_handle_) { + CUDNN_CHECK(cudnnDestroy(cudnn_handle_)); + } + } + + cudnnHandle_t& cudnn_handle() { + if (!cudnn_handle_) { + CUDNN_CHECK(cudnnCreate(&cudnn_handle_)); + CUDNN_CHECK(cudnnSetStream( + cudnn_handle_, cuda_context_->cuda_stream())); + } + return cudnn_handle_; + } + + void cudnnSetNumTensorDescriptors(int n) { + cudnn_tensor_descriptors_.resize(n); + } + + template + inline cudnnTensorDescriptor_t cudnnGetTensor4dDesc( + const int index, const cudnnTensorFormat_t cudnn_format, + const vector& dims, bool* changed) { + return cudnn_tensor_descriptors_.at(index).Descriptor( + cudnn_format, cudnnTypeWrapper::type, dims, changed); + } + + protected: + // Pointer to an external cuda context that the cudnn wrapper will use. + CUDAContext* cuda_context_; + cudnnHandle_t cudnn_handle_; + std::vector cudnn_tensor_descriptors_; +}; + +} // namespace caffe2 + +#endif // CAFFE2_CORE_COMMON_CUDNN_H_ diff --git a/caffe2/core/common_gpu.cc b/caffe2/core/common_gpu.cc new file mode 100644 index 00000000000..bddffc6a434 --- /dev/null +++ b/caffe2/core/common_gpu.cc @@ -0,0 +1,113 @@ +#include + +#include "caffe2/core/common_gpu.h" + +namespace caffe2 { + +namespace { +int gDefaultGPUID = 0; +} + +void SetDefaultGPUID(const int deviceid) { gDefaultGPUID = deviceid; } +int GetDefaultGPUID() { return gDefaultGPUID; } + +void DeviceQuery(const int device) { + cudaDeviceProp prop; + CUDA_CHECK(cudaGetDeviceProperties(&prop, device)); + std::stringstream ss; + ss << std::endl; + ss << "Device id: " << device << std::endl; + ss << "Major revision number: " << prop.major << std::endl; + ss << "Minor revision number: " << prop.minor << std::endl; + ss << "Name: " << prop.name << std::endl; + ss << "Total global memory: " << prop.totalGlobalMem << std::endl; + ss << "Total shared memory per block: " << prop.sharedMemPerBlock + << std::endl; + ss << "Total registers per block: " << prop.regsPerBlock << std::endl; + ss << "Warp size: " << prop.warpSize << std::endl; + ss << "Maximum memory pitch: " << prop.memPitch << std::endl; + ss << "Maximum threads per block: " << prop.maxThreadsPerBlock + << std::endl; + ss << "Maximum dimension of block: " + << prop.maxThreadsDim[0] << ", " << prop.maxThreadsDim[1] << ", " + << prop.maxThreadsDim[2] << std::endl; + ss << "Maximum dimension of grid: " + << prop.maxGridSize[0] << ", " << prop.maxGridSize[1] << ", " + << prop.maxGridSize[2] << std::endl; + ss << "Clock rate: " << prop.clockRate << std::endl; + ss << "Total constant memory: " << prop.totalConstMem << std::endl; + ss << "Texture alignment: " << prop.textureAlignment << std::endl; + ss << "Concurrent copy and execution: " + << (prop.deviceOverlap ? "Yes" : "No") << std::endl; + ss << "Number of multiprocessors: " << prop.multiProcessorCount + << std::endl; + ss << "Kernel execution timeout: " + << (prop.kernelExecTimeoutEnabled ? "Yes" : "No") << std::endl; + LOG(INFO) << ss.str(); + return; +} + +namespace internal { + +const char* cublasGetErrorString(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; +#if CUDA_VERSION >= 6000 + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; +#if CUDA_VERSION >= 6050 + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; +#endif // CUDA_VERSION >= 6050 +#endif // CUDA_VERSION >= 6000 + } +} + +const char* curandGetErrorString(curandStatus_t error) { + switch (error) { + case CURAND_STATUS_SUCCESS: + return "CURAND_STATUS_SUCCESS"; + case CURAND_STATUS_VERSION_MISMATCH: + return "CURAND_STATUS_VERSION_MISMATCH"; + case CURAND_STATUS_NOT_INITIALIZED: + return "CURAND_STATUS_NOT_INITIALIZED"; + case CURAND_STATUS_ALLOCATION_FAILED: + return "CURAND_STATUS_ALLOCATION_FAILED"; + case CURAND_STATUS_TYPE_ERROR: + return "CURAND_STATUS_TYPE_ERROR"; + case CURAND_STATUS_OUT_OF_RANGE: + return "CURAND_STATUS_OUT_OF_RANGE"; + case CURAND_STATUS_LENGTH_NOT_MULTIPLE: + return "CURAND_STATUS_LENGTH_NOT_MULTIPLE"; + case CURAND_STATUS_DOUBLE_PRECISION_REQUIRED: + return "CURAND_STATUS_DOUBLE_PRECISION_REQUIRED"; + case CURAND_STATUS_LAUNCH_FAILURE: + return "CURAND_STATUS_LAUNCH_FAILURE"; + case CURAND_STATUS_PREEXISTING_FAILURE: + return "CURAND_STATUS_PREEXISTING_FAILURE"; + case CURAND_STATUS_INITIALIZATION_FAILED: + return "CURAND_STATUS_INITIALIZATION_FAILED"; + case CURAND_STATUS_ARCH_MISMATCH: + return "CURAND_STATUS_ARCH_MISMATCH"; + case CURAND_STATUS_INTERNAL_ERROR: + return "CURAND_STATUS_INTERNAL_ERROR"; + } +} + +} // namespace internal +} // namespace caffe2 diff --git a/caffe2/core/common_gpu.h b/caffe2/core/common_gpu.h new file mode 100644 index 00000000000..459bcdbf056 --- /dev/null +++ b/caffe2/core/common_gpu.h @@ -0,0 +1,68 @@ +#ifndef CAFFE2_CORE_COMMON_GPU_H_ +#define CAFFE2_CORE_COMMON_GPU_H_ + +#include +#include +#include +#include +#include // cuda driver types +// #include +// #include + +#include "glog/logging.h" +#include "caffe2/core/common.h" + +namespace caffe2 { + +// Sets and gets the default GPU id. If the function is not called, we will use +// GPU 0 ast he default gpu id. If there is an operator that says it runs on the +// GPU but did not specify which GPU, this default gpuid is going to be used. +void SetDefaultGPUID(const int deviceid); +int GetDefaultGPUID(); +void DeviceQuery(const int deviceid); + +namespace internal { +const char* cublasGetErrorString(cublasStatus_t error); +const char* curandGetErrorString(curandStatus_t error); +} // namespace internal + +// CUDA: various checks for different function calls. +#define CUDA_CHECK(condition) \ + do { \ + cudaError_t error = condition; \ + CHECK_EQ(error, cudaSuccess) \ + << "Error at: " << __FILE__ << ":" << __LINE__ << ": " \ + << cudaGetErrorString(error); \ + } while (0) + +#define CUBLAS_CHECK(condition) \ + do { \ + cublasStatus_t status = condition; \ + CHECK_EQ(status, CUBLAS_STATUS_SUCCESS) \ + << "Error at: " << __FILE__ << ":" << __LINE__ << ": " \ + << ::caffe2::internal::cublasGetErrorString(status); \ + } while (0) + +#define CURAND_CHECK(condition) \ + do { \ + curandStatus_t status = condition; \ + CHECK_EQ(status, CURAND_STATUS_SUCCESS) \ + << "Error at: " << __FILE__ << ":" << __LINE__ << ": " \ + << ::caffe2::internal::curandGetErrorString(status); \ + } while (0) + +#define CUDA_1D_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +// TODO(Yangqing): Yuck. Figure out a better way? +const int CAFFE_CUDA_NUM_THREADS = 1024; + +// CUDA: number of blocks for threads. +inline int CAFFE_GET_BLOCKS(const int N) { + return (N + CAFFE_CUDA_NUM_THREADS - 1) / CAFFE_CUDA_NUM_THREADS; +} + +} // namespace caffe2 +#endif // CAFFE2_CORE_COMMON_GPU_H_ diff --git a/caffe2/core/context.h b/caffe2/core/context.h new file mode 100644 index 00000000000..e097ea5faff --- /dev/null +++ b/caffe2/core/context.h @@ -0,0 +1,53 @@ +#ifndef CAFFE2_CORE_CONTEXT_H_ +#define CAFFE2_CORE_CONTEXT_H_ + +#include + +#include "caffe2/proto/caffe2.pb.h" +#include "glog/logging.h" + +namespace caffe2 { + +class CPUContext { + public: + CPUContext() : random_generator_(0) {} + explicit CPUContext(const DeviceOption& device_option) + : random_generator_(device_option.random_seed()) { + DCHECK_EQ(device_option.device_type(), CPU); + } + virtual ~CPUContext() {} + inline void SwitchToDevice() {} + inline bool FinishDeviceComputation() { return true; } + + inline std::mt19937& RandGenerator() { return random_generator_; } + + static void* New(size_t nbytes) { + void* data = new char[nbytes]; + memset(data, 0, nbytes); + return data; + } + static void Delete(void* data) { delete[] static_cast(data); } + + // Two copy functions that deals with cross-device copies. + template + inline void Memcpy(void* dst, const void* src, size_t nbytes); + template + inline void Copy(T* dst, const T* src, int n) { + Memcpy(static_cast(dst), + static_cast(src), + n * sizeof(T)); + } + + protected: + std::mt19937 random_generator_; +}; + +template<> +inline void CPUContext::Memcpy( + void* dst, const void* src, size_t nbytes) { + memcpy(dst, src, nbytes); +} + +} // namespace caffe2 + +#endif // CAFFE2_CORE_CONTEXT_H_ diff --git a/caffe2/core/context_gpu.h b/caffe2/core/context_gpu.h new file mode 100644 index 00000000000..8cefae18a03 --- /dev/null +++ b/caffe2/core/context_gpu.h @@ -0,0 +1,143 @@ +#ifndef CAFFE2_CORE_CONTEXT_GPU_H_ +#define CAFFE2_CORE_CONTEXT_GPU_H_ + +#include "caffe2/core/common_gpu.h" +#include "caffe2/core/context.h" +#include "caffe2/core/types.h" +#include "caffe2/proto/caffe2.pb.h" +#include "glog/logging.h" + +namespace caffe2 { + +class CUDAContext { + public: + // The default cuda context constructor. + CUDAContext() + : cuda_stream_(nullptr), cublas_handle_(nullptr), + random_seed_(1701), curand_generator_(nullptr) { + cuda_gpu_id_ = GetDefaultGPUID(); + CUDA_CHECK(cudaSetDevice(cuda_gpu_id_)); + CUDA_CHECK(cudaStreamCreate(&cuda_stream_)); + } + + explicit CUDAContext(const DeviceOption& option) + : cuda_stream_(nullptr), cublas_handle_(nullptr), + random_seed_(option.random_seed()), curand_generator_(nullptr) { + DCHECK_EQ(option.device_type(), CUDA); + cuda_gpu_id_ = option.has_cuda_gpu_id() ? + option.cuda_gpu_id() : GetDefaultGPUID(); + CUDA_CHECK(cudaSetDevice(cuda_gpu_id_)); + CUDA_CHECK(cudaStreamCreate(&cuda_stream_)); + } + + virtual ~CUDAContext() { + if (curand_generator_) { + CURAND_CHECK(curandDestroyGenerator(curand_generator_)); + } + if (cublas_handle_) { + CUBLAS_CHECK(cublasDestroy(cublas_handle_)); + } + if (cuda_stream_) { + CUDA_CHECK(cudaStreamDestroy(cuda_stream_)); + } + } + + inline void SwitchToDevice() { + CUDA_CHECK(cudaSetDevice(cuda_gpu_id_)); + } + + inline bool FinishDeviceComputation() { + cudaError_t error = cudaStreamSynchronize(cuda_stream_); + if (error != cudaSuccess) { + LOG(ERROR) << cudaGetErrorString(error); + return false; + } + error = cudaPeekAtLastError(); + if (error != cudaSuccess) { + LOG(ERROR) << cudaGetErrorString(error); + return false; + } + return true; + } + + int cuda_gpu_id() { return cuda_gpu_id_; } + + inline cudaStream_t& cuda_stream() { return cuda_stream_; } + + cublasHandle_t& cublas_handle() { + if (!cublas_handle_) { + CUBLAS_CHECK(cublasCreate(&cublas_handle_)); + CUBLAS_CHECK(cublasSetPointerMode( + cublas_handle_, CUBLAS_POINTER_MODE_DEVICE)); + CUBLAS_CHECK(cublasSetStream(cublas_handle_, cuda_stream_)); + } + return cublas_handle_; + } + + curandGenerator_t& curand_generator() { + if (!curand_generator_) { + CURAND_CHECK(curandCreateGenerator( + &curand_generator_, CURAND_RNG_PSEUDO_DEFAULT)); + CURAND_CHECK(curandSetPseudoRandomGeneratorSeed( + curand_generator_, random_seed_)); + CURAND_CHECK(curandSetStream(curand_generator_, cuda_stream_)); + } + return curand_generator_; + } + + static void* New(size_t nbytes) { + void* dev_ptr; + CUDA_CHECK(cudaMalloc(&dev_ptr, nbytes)); + CUDA_CHECK(cudaMemset(dev_ptr, 0, nbytes)); + return dev_ptr; + } + + static void Delete(void* data) { + cudaError_t error = cudaFree(data); + // For some reason, in Python runtime we sometimes delete a data pointer + // after the cuda runtime exits - this is odd but is probably caused by + // a static workspace that pycaffe2 uses, and the destruction got entangled + // in some race condition. Anyway, since cuda runtime is exiting anyway, we + // will not need to worry about memory leak, so we basically ignore it. + // This is definitely not ideal but works for now. + if (error != cudaSuccess && error != cudaErrorCudartUnloading) { + LOG(FATAL) << "Error at: " << __FILE__ << ":" << __LINE__ << ": " + << cudaGetErrorString(error); + } + } + + template + inline void Copy(void* dst, const void* src, size_t nbytes) { + CUDA_CHECK(cudaMemcpyAsync( + dst, src, nbytes, cudaMemcpyDefault, cuda_stream_)); + // TODO(Yangqing): do we want to synchronize inside copy? + CUDA_CHECK(cudaStreamSynchronize(cuda_stream_)); + } + + template + inline void Copy(T* dst, const T* src, int n) { + Copy(static_cast(dst), + static_cast(src), + n * sizeof(T)); + } + + protected: + int cuda_gpu_id_; + cudaStream_t cuda_stream_; + cublasHandle_t cublas_handle_; + int random_seed_; + curandGenerator_t curand_generator_; +}; + +// For the CPU context, we also allow a (probably expensive) function +// to copy the data from a cuda context. +template<> +inline void CPUContext::Memcpy( + void* dst, const void* src, size_t nbytes) { + CUDAContext context; + context.Copy(dst, src, nbytes); +} + +} // namespace caffe2 + +#endif // CAFFE2_CORE_CONTEXT_GPU_H_ diff --git a/caffe2/core/context_test.cc b/caffe2/core/context_test.cc new file mode 100644 index 00000000000..100680df031 --- /dev/null +++ b/caffe2/core/context_test.cc @@ -0,0 +1,45 @@ +#include + +#include "caffe2/proto/caffe2.pb.h" +#include "caffe2/core/context.h" +#include "gtest/gtest.h" + +namespace caffe2 { + +// This is a test that make sure the random number generator works as expected, +// with a specific seed that generates specific responses. I think it should +// be the same number across platforms since we use mt19937 explicitly. +TEST(CPUContextTest, TestRandomNumberGenerator) { + DeviceOption option; + option.set_random_seed(1701); + CPUContext context(option); + std::uniform_int_distribution dist(0, 100); + /* + // These numbers are manually verified off-line. + EXPECT_EQ(dist(context.RandGenerator()), 46); + EXPECT_EQ(dist(context.RandGenerator()), 4); + EXPECT_EQ(dist(context.RandGenerator()), 94); + EXPECT_EQ(dist(context.RandGenerator()), 26); + EXPECT_EQ(dist(context.RandGenerator()), 67); + */ +} + +TEST(CPUContextTest, TestAllocDealloc) { + float* data = static_cast(CPUContext::New(10 * sizeof(float))); + EXPECT_NE(data, nullptr); + float* dst_data = static_cast(CPUContext::New(10 * sizeof(float))); + EXPECT_NE(dst_data, nullptr); + for (int i = 0; i < 10; ++i) { + data[i] = i; + } + DeviceOption option; + CPUContext context(option); + context.Copy(dst_data, data, 10); + for (int i = 0; i < 10; ++i) { + EXPECT_FLOAT_EQ(dst_data[i], i); + } + CPUContext::Delete(data); + CPUContext::Delete(dst_data); +} + +} // namespace caffe2 diff --git a/caffe2/core/db.cc b/caffe2/core/db.cc new file mode 100644 index 00000000000..d282ded8e92 --- /dev/null +++ b/caffe2/core/db.cc @@ -0,0 +1,9 @@ +#include "caffe2/core/db.h" + +namespace caffe2 { +namespace db { + +DEFINE_REGISTRY(Caffe2DBRegistry, DB, const string&, Mode); + +} // namespacd db +} // namespace caffe2 diff --git a/caffe2/core/db.h b/caffe2/core/db.h new file mode 100644 index 00000000000..e3de4ca65f9 --- /dev/null +++ b/caffe2/core/db.h @@ -0,0 +1,62 @@ +#ifndef CAFFE2_CORE_DB_H_ +#define CAFFE2_CORE_DB_H_ + +#include "caffe2/core/registry.h" + +namespace caffe2 { +namespace db { + +enum Mode { READ, WRITE, NEW }; + +class Cursor { + public: + Cursor() { } + virtual ~Cursor() { } + virtual void SeekToFirst() = 0; + virtual void Next() = 0; + virtual string key() = 0; + virtual string value() = 0; + virtual bool Valid() = 0; + + DISABLE_COPY_AND_ASSIGN(Cursor); +}; + +class Transaction { + public: + Transaction() { } + virtual ~Transaction() { } + virtual void Put(const string& key, const string& value) = 0; + virtual void Commit() = 0; + + DISABLE_COPY_AND_ASSIGN(Transaction); +}; + +class DB { + public: + DB(const string& source, Mode mode) : mode_(mode) { + // This constructor does nothing. The actual opening should be done in the + // derived constructors. + } + virtual ~DB() { } + virtual void Close() = 0; + virtual Cursor* NewCursor() = 0; + virtual Transaction* NewTransaction() = 0; + + protected: + Mode mode_; + + DISABLE_COPY_AND_ASSIGN(DB); +}; + +DECLARE_REGISTRY(Caffe2DBRegistry, DB, const string&, Mode); +#define REGISTER_CAFFE2_DB(name, ...) \ + REGISTER_CLASS(Caffe2DBRegistry, name, __VA_ARGS__) + +inline DB* CreateDB(const string& db_type, const string& source, Mode mode) { + return Caffe2DBRegistry()->Create(db_type, source, mode); +} + +} // namespace db +} // namespace caffe2 + +#endif // CAFFE2_CORE_DB_H_ diff --git a/caffe2/core/minidb.cc b/caffe2/core/minidb.cc new file mode 100644 index 00000000000..3577fc92d33 --- /dev/null +++ b/caffe2/core/minidb.cc @@ -0,0 +1,134 @@ +#include +#include + +#include "caffe2/core/db.h" +#include "glog/logging.h" + +namespace caffe2 { +namespace db { + +class MiniDBCursor : public Cursor { + public: + explicit MiniDBCursor(FILE* f, std::mutex* mutex) + : file_(f), lock_(*mutex) {} + ~MiniDBCursor() {} + + void SeekToFirst() override { + fseek(file_, 0, SEEK_SET); + CHECK(!feof(file_)) << "Hmm, empty file?"; + // Read the first item. + valid_ = true; + Next(); + } + + void Next() override { + if (fread(&key_len_, sizeof(int), 1, file_) == 0) { + // Reaching EOF. + valid_ = false; + return; + } + CHECK_EQ(fread(&value_len_, sizeof(int), 1, file_), 1); + CHECK_GT(key_len_, 0); + CHECK_GT(value_len_, 0); + if (key_len_ > key_.size()) { + key_.resize(key_len_); + } + if (value_len_ > value_.size()) { + value_.resize(value_len_); + } + CHECK_EQ(fread(key_.data(), sizeof(char), key_len_, file_), key_len_); + CHECK_EQ(fread(value_.data(), sizeof(char), value_len_, file_), value_len_); + } + + string key() override { + CHECK(valid_) << "Invalid position!"; + return string(key_.data(), key_len_); + } + + string value() override { + CHECK(valid_) << "Invalid position!"; + return string(value_.data(), value_len_); + } + + bool Valid() override { return valid_; } + + private: + FILE* file_; + std::lock_guard lock_; + bool valid_; + int key_len_; + vector key_; + int value_len_; + vector value_; +}; + +class MiniDBTransaction : public Transaction { + public: + explicit MiniDBTransaction(FILE* f, std::mutex* mutex) + : file_(f), lock_(*mutex) {} + ~MiniDBTransaction() { Commit(); } + + void Put(const string& key, const string& value) override { + int key_len = key.size(); + int value_len = value.size(); + CHECK_EQ(fwrite(&key_len, sizeof(int), 1, file_), 1); + CHECK_EQ(fwrite(&value_len, sizeof(int), 1, file_), 1); + CHECK_EQ(fwrite(key.c_str(), sizeof(char), key_len, file_), key_len); + CHECK_EQ(fwrite(value.c_str(), sizeof(char), value_len, file_), value_len); + } + + void Commit() override { + CHECK_EQ(fflush(file_), 0); + } + + private: + FILE* file_; + std::lock_guard lock_; + + DISABLE_COPY_AND_ASSIGN(MiniDBTransaction); +}; + +class MiniDB : public DB { + public: + MiniDB(const string& source, Mode mode) : DB(source, mode), file_(nullptr) { + switch (mode) { + case NEW: + file_ = fopen(source.c_str(), "wb"); + break; + case WRITE: + file_ = fopen(source.c_str(), "ab"); + fseek(file_, 0, SEEK_END); + break; + case READ: + file_ = fopen(source.c_str(), "rb"); + break; + } + CHECK(file_) << "Cannot open file: " << source; + LOG(INFO) << "Opened MiniDB " << source; + } + ~MiniDB() { Close(); } + + void Close() override { fclose(file_); } + + Cursor* NewCursor() override { + CHECK_EQ(this->mode_, READ); + return new MiniDBCursor(file_, &file_access_mutex_); + } + + Transaction* NewTransaction() override { + CHECK(this->mode_ == NEW || this->mode_ == WRITE); + return new MiniDBTransaction(file_, &file_access_mutex_); + } + + private: + FILE* file_; + // access mutex makes sure we don't have multiple cursors/transactions + // reading the same file. + std::mutex file_access_mutex_; +}; + +REGISTER_CAFFE2_DB(MiniDB, MiniDB); +REGISTER_CAFFE2_DB(minidb, MiniDB); + +} // namespace db +} // namespace caffe2 diff --git a/caffe2/core/net.cc b/caffe2/core/net.cc new file mode 100644 index 00000000000..127d8ee2b8e --- /dev/null +++ b/caffe2/core/net.cc @@ -0,0 +1,191 @@ +#include "caffe2/core/net.h" +#include "caffe2/core/operator.h" +#include "caffe2/proto/caffe2.pb.h" + +namespace caffe2 { + +NetBase* CreateNet(const NetDef& net_def, Workspace* ws) { + if (!net_def.has_net_type() || net_def.net_type() == "simple") { + VLOG(1) << "Creating simple net."; + return new SimpleNet(net_def, ws); + } else if (net_def.net_type() == "parallel") { + VLOG(1) << "Creating parallel net."; + return new ParallelNet(net_def, ws); + } else { + LOG(ERROR) << "Unknown net type: " << net_def.net_type(); + return nullptr; + } + // Just to suppress compiler warning + return nullptr; +} + +SimpleNet::SimpleNet(const NetDef& net_def, Workspace* ws) + : NetBase(net_def, ws) { + // Initialize the operators + for (const OperatorDef& operator_def : net_def.operators()) { + VLOG(1) << "Creating operator " << operator_def.name() + << ":" << operator_def.type(); + if (!operator_def.has_device_option()) { + operators_.emplace_back( + CreateOperator(operator_def, net_def.device_option(), ws)); + } else { + operators_.emplace_back(CreateOperator(operator_def, ws)); + } + } +} + +bool SimpleNet::Verify() { + for (auto& op : operators_) { + VLOG(1) << "Verifying operator " << op->def().name() + << "(" << op->def().type() << ")."; + if (op.get() == nullptr || !op->Verify()) { + return false; + } + } + return true; +} + +bool SimpleNet::Run() { + VLOG(1) << "Running net."; + for (const auto& op : operators_) { + VLOG(1) << "Running operator " << op->def().name() + << "(" << op->def().type() << ")."; + // TODO(Yangqing): convert this sequential run to event-based. + if (!op->Run()) return false; + } + return true; +} + +ParallelNet::ParallelNet(const NetDef& net_def, Workspace* ws) + : NetBase(net_def, ws), operator_nodes_(net_def.operators_size()) { + // Blob creator allows us to track which operator created which blob. + std::map blob_creator; + // Initialize the operators + for (int idx = 0; idx < net_def.operators_size(); ++idx) { + const OperatorDef& op_def = net_def.operators(idx); + VLOG(1) << "Creating operator #" << idx << ": " + << op_def.name() << ":" << op_def.type(); + if (!op_def.has_device_option()) { + operator_nodes_[idx].operator_.reset( + CreateOperator(op_def, net_def.device_option(), ws)); + } else { + operator_nodes_[idx].operator_.reset(CreateOperator(op_def, ws)); + } + // Check the inputs, and set up parents if necessary. + for (const string& input : op_def.inputs()) { + if (blob_creator.count(input) == 0) { + VLOG(1) << "Input " << input << " not produced by this net. " + << "Assuming it is pre-existing."; + } else { + int parent = blob_creator[input]; + VLOG(1) << "op dependency: " << parent << "->" << idx; + operator_nodes_[idx].parents_.push_back(parent); + operator_nodes_[parent].children_.push_back(idx); + } + } + for (const string& output : op_def.outputs()) { + if (blob_creator.count(output) != 0) { + LOG(WARNING) << "Output " << output << " produced again. " + << "Such operation is not strictly tested. " + << "Use at your own risk."; + } + blob_creator[output] = idx; + } + } + // Figure out the initial frontier - this is the one we will feed into the job + // queue to start a run. + for (int idx = 0; idx < operator_nodes_.size(); ++idx) { + if (operator_nodes_[idx].parents_.size() == 0) { + initial_frontier_.push_back(idx); + } + } + // Finally, start the workers. + CHECK_GT(net_def.num_workers(), 0) << "Must specify the number of workers."; + for (int i = 0; i < net_def.num_workers(); ++i) { + VLOG(1) << "Start worker #" << i; + workers_.push_back(std::thread(&ParallelNet::WorkerFunction, this)); + } +} + +ParallelNet::~ParallelNet() { + // Safely join all the workers before exiting. + job_queue_.NoMoreJobs(); + VLOG(1) << "Joining workers."; + for (auto& worker : workers_) { + worker.join(); + } +} + +bool ParallelNet::Verify() { + for (auto& op_node : operator_nodes_) { + auto& op = op_node.operator_; + VLOG(1) << "Verifying operator " << op->def().name() + << "(" << op->def().type() << ")."; + if (op.get() == nullptr || !op->Verify()) { + return false; + } + } + return true; +} + +bool ParallelNet::Run() { + VLOG(1) << "Running parallel net."; + // First, set up job queue. + remaining_ops_ = operator_nodes_.size(); + success_ = true; + // TODO(jiayq): Start all worker threads. + // Initialize the runtime parent count. + for (auto& node : operator_nodes_) { + node.runtime_parent_count_ = node.parents_.size(); + } + // Kickstart the job queue. + for (auto& value : initial_frontier_) { + job_queue_.Push(value); + } + std::unique_lock mutex_lock(remaining_ops_mutex_); + while (remaining_ops_ > 0) { + VLOG(2) << "Remaining ops to run: " << remaining_ops_; + cv_.wait(mutex_lock); + } + VLOG(2) << "All ops finished running."; + // If the above while loop finished, we know that the current run finished. + return success_; +} + +void ParallelNet::WorkerFunction() { + // WorkerFunctions() is an infinite loop until there are no more jobs to run. + while (true) { + int idx; + // If there is no more jobs - meaning that the ParallelNet is destructing - + // we will exit safely. + if (!job_queue_.Pop(&idx)) { + return; + } + VLOG(1) << "Running operator #" << idx << " " + << operator_nodes_[idx].operator_->def().name() + << "(" << operator_nodes_[idx].operator_->def().type() << ")."; + bool this_success = operator_nodes_[idx].operator_->Run(); + for (int child : operator_nodes_[idx].children_) { + int count = --operator_nodes_[child].runtime_parent_count_; + // The count should never be smaller than zero. + DCHECK_GE(count, 0) + << "Found runtime parent count smaller than zero for " + << "operator node " + << operator_nodes_[child].operator_->def().name() + << "(" << operator_nodes_[child].operator_->def().type() << ")."; + if (count == 0) { + VLOG(2) << "Pushing operator #" << child << " to queue."; + job_queue_.Push(child); + } + } + // Notify that the processed op is incremented by one. + std::unique_lock mutex_lock(remaining_ops_mutex_); + --remaining_ops_; + success_ &= this_success; + DCHECK_GE(remaining_ops_, 0); + cv_.notify_one(); + VLOG(2) << "Finished executing operator #" << idx; + } +} + +} // namespace caffe2 diff --git a/caffe2/core/net.h b/caffe2/core/net.h new file mode 100644 index 00000000000..d06410919bd --- /dev/null +++ b/caffe2/core/net.h @@ -0,0 +1,90 @@ +#ifndef CAFFE2_CORE_NET_H_ +#define CAFFE2_CORE_NET_H_ + +#include +#include +#include +#include // NOLINT +#include +#include + +#include "caffe2/core/blob.h" +#include "caffe2/core/common.h" +#include "caffe2/core/registry.h" +#include "caffe2/core/workspace.h" +#include "caffe2/proto/caffe2.pb.h" +#include "caffe2/utils/simple_queue.h" + +namespace caffe2 { + +class OperatorBase; + +// Net is a thin struct that owns all the operators together with the operator +// contexts. +class NetBase { + public: + NetBase(const NetDef& net_def, Workspace* ws) {} + virtual ~NetBase() {} + virtual bool Verify() = 0; + virtual bool Run() = 0; + + DISABLE_COPY_AND_ASSIGN(NetBase); +}; + +// Essentially, we won't expect too many Net instances, so we will simply +// have a function that produces different net implementations. If needed we can +// switch to a registration pattern later. +NetBase* CreateNet(const NetDef& net_def, Workspace* ws); + +// This is the very basic structure you need to run a network - all it +// does is simply to run everything in sequence. If you want more fancy control +// such as a DAG-like execution, check out other better net implementations. +class SimpleNet final : public NetBase { + public: + SimpleNet(const NetDef& net_def, Workspace* ws); + bool Verify() override; + bool Run() override; + + protected: + vector > operators_; + + DISABLE_COPY_AND_ASSIGN(SimpleNet); +}; + +namespace internal { +struct OperatorNode { + unique_ptr operator_; + vector children_; + vector parents_; + std::atomic runtime_parent_count_; +}; +} + +class ParallelNet final : public NetBase { + public: + ParallelNet(const NetDef& net_def, Workspace* ws); + ~ParallelNet(); + bool Verify() override; + bool Run() override; + // WorkerFunction() is a function wrapper to allow us to run worker threads. + // It checks out one ready-to-run operator from the job queue, runs it, + // notifies all its children, and for any children that is ready, enqueues + // it to the job queue. + void WorkerFunction(); + + protected: + vector operator_nodes_; + vector initial_frontier_; + SimpleQueue job_queue_; + std::vector workers_; + int remaining_ops_; + bool success_; + std::mutex remaining_ops_mutex_; + std::condition_variable cv_; + + DISABLE_COPY_AND_ASSIGN(ParallelNet); +}; + +} // namespace caffe2 + +#endif // CAFFE2_CORE_NET_H_ diff --git a/caffe2/core/operator.cc b/caffe2/core/operator.cc new file mode 100644 index 00000000000..a9e2b45a73d --- /dev/null +++ b/caffe2/core/operator.cc @@ -0,0 +1,121 @@ +#include +#include + +#include "caffe2/core/net.h" +#include "caffe2/core/operator.h" +#include "caffe2/core/workspace.h" +#include "caffe2/proto/caffe2.pb.h" + +namespace caffe2 { + +// TODO(Yangqing): move all the checks to a less fatal check mechanism. +OperatorBase::OperatorBase(const OperatorDef& operator_def, Workspace* ws) + : operator_def_(operator_def) { + for (auto& arg : operator_def.args()) { + CHECK_GT(arg.name().size(), 0) << "Argument must have a name."; + CHECK_EQ(arg_map_.count(arg.name()), 0) << "Duplicated argument name."; + arg_map_[arg.name()] = &arg; + } + for (const string& input_str : operator_def_.inputs()) { + inputs_.push_back(CHECK_NOTNULL(ws->GetBlob(input_str))); + } + for (const string& output_str : operator_def_.outputs()) { + outputs_.push_back(CHECK_NOTNULL(ws->CreateBlob(output_str))); + } +} + +// Parameter getters. You can use these to get the arguments that you want. +// We need to deal with the fact that we cannot really template into +// protocol buffers... yuck. +#define INSTANTIATE_GET_SINGLE_ARGUMENT(dtype, fieldname) \ +template <> \ +dtype OperatorBase::GetSingleArgument( \ + const string& name, const dtype& default_value) { \ + if (arg_map_.count(name) == 0) { \ + DVLOG(1) << "Using default parameter value " << default_value; \ + return default_value; \ + } \ + CHECK(arg_map_[name]->has_##fieldname()) \ + << "Argument does not have the right field: expected " \ + << #fieldname; \ + return arg_map_[name]->fieldname(); \ +} + +INSTANTIATE_GET_SINGLE_ARGUMENT(float, f) +INSTANTIATE_GET_SINGLE_ARGUMENT(int, i) +INSTANTIATE_GET_SINGLE_ARGUMENT(string, s) +// Undefine the argument just to be safe. +#undef INSTANTIATE_GET_SINGLE_ARGUMENT + +#define INSTANTIATE_GET_REPEATED_ARGUMENT(dtype, fieldname) \ +template <> \ +vector OperatorBase::GetRepeatedArgument( \ + const string& name) { \ + if (arg_map_.count(name) == 0) { \ + return vector(); \ + } \ + vector values; \ + CHECK(arg_map_[name]->fieldname##_size()) \ + << "Argument does not have the right field: expected " \ + << #fieldname; \ + for (const auto& v : arg_map_[name]->fieldname()) values.push_back(v); \ + return values; \ +} + +INSTANTIATE_GET_REPEATED_ARGUMENT(float, floats) +INSTANTIATE_GET_REPEATED_ARGUMENT(int, ints) +INSTANTIATE_GET_REPEATED_ARGUMENT(string, strings) +#undef INSTANTIATE_GET_REPEATED_ARGUMENT + +bool OperatorBase::Verify() { + // Check Blob counts. + if (operator_def_.inputs_size() < MinInput() || + operator_def_.inputs_size() > MaxInput()) { + LOG(ERROR) << "Input size " << operator_def_.inputs_size() + << " not in range [min=" << MinInput() << ", max=" + << MaxInput() << "]."; + LOG(ERROR) << "Error at operator " << operator_def_.name() << ":" + << operator_def_.type(); + return false; + } + if (operator_def_.outputs_size() < MinOutput() || + operator_def_.outputs_size() > MaxOutput()) { + LOG(ERROR) << "Output size " << operator_def_.outputs_size() + << " not in range [min=" << MinOutput() << ", max=" + << MaxOutput() << "]."; + LOG(ERROR) << "Error at operator " << operator_def_.name() << ":" + << operator_def_.type(); + return false; + } + return true; +} + +OperatorBase* CreateOperator(const OperatorDef& operator_def, + const DeviceOption& device_option, + Workspace* ws) { + const string& key = operator_def.type(); + switch (operator_def.device_option().device_type()) { + case CPU: + VLOG(1) << "Creating CPU operator " << key; + return CPUOperatorRegistry()->Create(key, operator_def, ws); + case CUDA: + VLOG(1) << "Creating CUDA operator " << key; + // In Cuda, if we have cudnn, we will prefer to use cudnn first. + if (CUDNNOperatorRegistry()->Has(key)) { + VLOG(1) << "Using CuDNN implementation."; + return CUDNNOperatorRegistry()->Create(key, operator_def, ws); + } + return CUDAOperatorRegistry()->Create(key, operator_def, ws); + } + // Just to suppress some compiler error + return nullptr; +} + +DEFINE_REGISTRY(CPUOperatorRegistry, OperatorBase, + const OperatorDef&, Workspace*); +DEFINE_REGISTRY(CUDAOperatorRegistry, OperatorBase, + const OperatorDef&, Workspace*); +DEFINE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, + const OperatorDef&, Workspace*); + +} // namespace caffe2 diff --git a/caffe2/core/operator.h b/caffe2/core/operator.h new file mode 100644 index 00000000000..9d60c63d38f --- /dev/null +++ b/caffe2/core/operator.h @@ -0,0 +1,233 @@ +#ifndef CAFFE2_CORE_OPERATOR_H_ +#define CAFFE2_CORE_OPERATOR_H_ + +#include +#include +#include +#include + +#include "caffe2/core/blob.h" +#include "caffe2/core/common.h" +#include "caffe2/core/net.h" +#include "caffe2/core/registry.h" +#include "caffe2/core/workspace.h" +#include "caffe2/proto/caffe2.pb.h" + +namespace caffe2 { + +class OperatorBase { + public: + // The constructor of the operator. Note that you should not do any + // custom initializations in the constructor; instead, do those in the + // SetUp() function. + explicit OperatorBase(const OperatorDef& operator_def, Workspace* ws); + virtual ~OperatorBase() {} + + // Verify return true if an operator is set up correctly. This cannot be + // implemented in the constructor, because there will be calls to overridden + // functions. + virtual bool Verify(); + + // Parameter getters. You can use these to get the arguments that you want. + bool HasArgument(const string& name) { return (arg_map_.count(name) > 0); } + template + + // Functions that deal with arguments. Basically, this allows us to map an + // argument mane to a specific type of argument that we are trying to access. + T GetSingleArgument(const string& name, const T& default_value); + template + vector GetRepeatedArgument(const string& name); + + template + MessageType GetAnyMessageArgument(const string& name) { + CHECK(arg_map_.count(name)) << "Cannot find parameter named " << name; + MessageType message; + CHECK(message.ParseFromString(arg_map_[name]->s())) + << "Faild to parse content from the string"; + return message; + } + template + vector GetAnyRepeatedMessageArgument(const string& name) { + CHECK(arg_map_.count(name)) << "Cannot find parameter named " << name; + vector messages(arg_map_[name]->strings_size()); + for (int i = 0; i < messages.size(); ++i) { + CHECK(messages[i].ParseFromString(arg_map_[name]->strings(i))) + << "Faild to parse content from the string"; + } + return messages; + } + + // Get the inputs and outputs as specific types. + template + inline const T& Input(int idx) { + DCHECK_LT(idx, inputs_.size()); + return inputs_.at(idx)->template Get(); + } + template + inline T* Output(int idx) { + DCHECK_LT(idx, outputs_.size()); + return outputs_.at(idx)->template GetMutable(); + } + template + inline bool InputIsType(int idx) { + return inputs_.at(idx)->template IsType(); + } + inline int InputSize() { return inputs_.size(); } + inline int OutputSize() { return outputs_.size(); } + inline const vector& Inputs() const { return inputs_; } + inline const vector& Outputs() { return outputs_; } + + virtual bool Run() { NOT_IMPLEMENTED; return false; } + + inline const OperatorDef& def() { return operator_def_; } + + protected: + // Do not manually override these functions. Instead, use INPUT_OUTPUT_STATS + // macro below. + virtual int MinInput() { return 0; } + virtual int MaxInput() { return INT_MAX; } + virtual int MinOutput() { return 0; } + virtual int MaxOutput() { return INT_MAX; } + + private: + CaffeMap arg_map_; + OperatorDef operator_def_; + vector inputs_; + vector outputs_; + + DISABLE_COPY_AND_ASSIGN(OperatorBase); +}; + +// If your operator does not need any specialized contructor or destructor, +// you can simply use this to save two lines of code. +#define USE_SIMPLE_BASE_CTOR_DTOR(name) \ + name(const OperatorDef& operator_def, Workspace* ws) \ + : OperatorBase(operator_def, ws) {} \ + virtual ~name() {} + +// INPUT_OUTPUT_STATS gives the statistics of the input and output that are +// legal. If the max input/output is not limited, you can specify INT_MAX. +// TODO(Yangqing): If necessary, add ability to specify that n_input = n_output. +#define INPUT_OUTPUT_STATS(min_input, max_input, min_output, max_output) \ + protected: \ + int MinInput() override { return min_input; } \ + int MaxInput() override { return max_input; } \ + int MinOutput() override { return min_output; } \ + int MaxOutput() override { return max_output; } + +// INPUT_TAGS and OUTPUT_TAGS are optional features to name the indices of the +// operator's inputs and outputs, in order to avoid confusion. For example, for +// a fully convolution layer that has input, weight and bias, you can define its +// input tags as: +// INPUT_TAGS(INPUT, WEIGHT, BIAS); +// And in the code, instead of doing +// auto& weight = Input(1); +// you can now do +// auto& weight = Input(WEIGHT); +// to make it more clear. +#define INPUT_TAGS(first_input, ...) \ + enum _InputTags { first_input = 0, __VA_ARGS__ } +#define OUTPUT_TAGS(first_input, ...) \ + enum _OutputTags { first_input = 0, __VA_ARGS__ } + + +// Operator is the class that you usually want to derive, if your operator will +// run on different devices. You should then implement the RunOnDevice() +// function. +template +class Operator : public OperatorBase { + public: + // The constructor of the operator. Note that you should not do any + // custom initializations in the constructor; instead, do those in the + // SetUp() function. + explicit Operator(const OperatorDef& operator_def, Workspace* ws) + : OperatorBase(operator_def, ws), + device_context_(operator_def.device_option()) { + // In the constructor, we switch to the device so that the child class + // constructors will run on that device. + device_context_.SwitchToDevice(); + } + virtual ~Operator() {} + + inline const Tensor& Input(int idx) { + return OperatorBase::template Input >(idx); } + inline Tensor* Output(int idx) { + return OperatorBase::template Output >(idx); + } + + // The run function of Operator switches to the device, and then carries out + // the actual computation with RunOnDevice(). You should implement RunOnDevice + // instead of Run(). + bool Run() final { + device_context_.SwitchToDevice(); + bool result = RunOnDevice(); + result &= device_context_.FinishDeviceComputation(); + return result; + } + + virtual bool RunOnDevice() = 0; + + protected: + DeviceContext device_context_; + DISABLE_COPY_AND_ASSIGN(Operator); +}; + +#define USE_OPERATOR_BASE_FUNCTIONS \ + using OperatorBase::GetSingleArgument; \ + using OperatorBase::GetRepeatedArgument; \ + using OperatorBase::def; \ + using OperatorBase::InputIsType; \ + using OperatorBase::InputSize; \ + using OperatorBase::OutputSize; \ + using Operator::device_context_; \ + using Operator::Input; \ + using Operator::Output + +#define USE_SIMPLE_CTOR_DTOR(name) \ + name(const OperatorDef& operator_def, Workspace* ws) \ + : Operator(operator_def, ws) {} \ + virtual ~name() {} + +// The operator registry. Since we are not expecting a great number of devices, +// we will simply have an if-then type command and allocate the actual +// generation to device-specific registerers. +// Note that although we have CUDA and CUDNN here, the registerers themselves do +// not depend on specific cuda or cudnn libraries. This means that we will be +// able to compile it even when there is no cuda available - we simply do not +// link any cuda or cudnn operators. +DECLARE_REGISTRY(CPUOperatorRegistry, OperatorBase, + const OperatorDef&, Workspace*); +#define REGISTER_CPU_OPERATOR_CREATOR(key, ...) \ + REGISTER_CREATOR(CPUOperatorRegistry, key, __VA_ARGS__) +#define REGISTER_CPU_OPERATOR(name, ...) \ + REGISTER_CLASS(CPUOperatorRegistry, name, __VA_ARGS__) + +DECLARE_REGISTRY(CUDAOperatorRegistry, OperatorBase, + const OperatorDef&, Workspace*); +#define REGISTER_CUDA_OPERATOR_CREATOR(key, ...) \ + REGISTER_CREATOR(CUDAOperatorRegistry, key, __VA_ARGS__) +#define REGISTER_CUDA_OPERATOR(name, ...) \ + REGISTER_CLASS(CUDAOperatorRegistry, name, __VA_ARGS__) + +DECLARE_REGISTRY(CUDNNOperatorRegistry, OperatorBase, + const OperatorDef&, Workspace*); +#define REGISTER_CUDNN_OPERATOR_CREATOR(key, ...) \ + REGISTER_CREATOR(CUDNNOperatorRegistry, key, __VA_ARGS__) +#define REGISTER_CUDNN_OPERATOR(name, ...) \ + REGISTER_CLASS(CUDNNOperatorRegistry, name, __VA_ARGS__) + +// Creates an operator with the given operator definition and device option. +OperatorBase* CreateOperator(const OperatorDef& operator_def, + const DeviceOption& device_option, + Workspace* ws); + +// Create an operator with the given operator definition, and the device +// option that is specified in the operator definition. +inline OperatorBase* CreateOperator(const OperatorDef& operator_def, + Workspace* ws) { + return CreateOperator(operator_def, operator_def.device_option(), ws); +} + +} // namespace caffe2 + +#endif // CAFFE2_CORE_OPERATOR_H_ diff --git a/caffe2/core/operator_test.cc b/caffe2/core/operator_test.cc new file mode 100644 index 00000000000..764dec41fa4 --- /dev/null +++ b/caffe2/core/operator_test.cc @@ -0,0 +1,213 @@ +#include + +#include "caffe2/core/net.h" +#include "caffe2/core/operator.h" +#include "gtest/gtest.h" + +namespace caffe2 { + +class JustTest : public OperatorBase { + public: + explicit JustTest(const OperatorDef& op_def, Workspace* ws) + : OperatorBase(op_def, ws) {} + bool Run() override { return true; } + INPUT_OUTPUT_STATS(0, 1, 0, 1); +}; +REGISTER_CPU_OPERATOR(JustTest, JustTest); +REGISTER_CUDA_OPERATOR(JustTest, JustTest); + + +TEST(OperatorTest, RegistryWorks) { + OperatorDef op_def; + Workspace ws; + op_def.set_type("JustTest"); + EXPECT_NE(nullptr, CreateOperator(op_def, &ws)); + op_def.mutable_device_option()->set_device_type(CUDA); + EXPECT_NE(nullptr, CreateOperator(op_def, &ws)); + + CPUOperatorRegistry()->TEST_PrintRegisteredNames(); +} + +TEST(OperatorDeathTest, CannotUseUninitializedBlob) { + Workspace ws; + OperatorDef op_def; + op_def.set_name("JustTest0"); + op_def.set_type("JustTest"); + op_def.add_inputs("input"); + op_def.add_outputs("output"); + EXPECT_DEATH(CreateOperator(op_def, &ws), "Check failed"); +} + +TEST(OperatorTest, TestParameterAccess) { + OperatorDef op_def; + Workspace ws; + op_def.set_name("JustTest0"); + op_def.set_type("JustTest"); + op_def.add_inputs("input"); + op_def.add_outputs("output"); + { + Argument* arg = op_def.add_args(); + arg->set_name("arg0"); + arg->set_f(0.1); + } + { + Argument* arg = op_def.add_args(); + arg->set_name("arg1"); + arg->add_ints(1); + arg->add_ints(2); + } + { + Argument* arg = op_def.add_args(); + arg->set_name("arg2"); + arg->set_s("argstring"); + } + EXPECT_NE(ws.CreateBlob("input"), nullptr); + OperatorBase op(op_def, &ws); + EXPECT_TRUE(op.Verify()); + EXPECT_FLOAT_EQ(op.GetSingleArgument("arg0", 0.0), 0.1); + vector i = op.GetRepeatedArgument("arg1"); + EXPECT_EQ(i.size(), 2); + EXPECT_EQ(i[0], 1); + EXPECT_EQ(i[1], 2); + EXPECT_EQ(op.GetSingleArgument("arg2", "default"), "argstring"); +} + + +TEST(OperatorDeathTest, CannotAccessParameterWithWrongType) { + OperatorDef op_def; + Workspace ws; + op_def.set_name("JustTest0"); + op_def.set_type("JustTest"); + op_def.add_inputs("input"); + op_def.add_outputs("output"); + { + Argument* arg = op_def.add_args(); + arg->set_name("arg0"); + arg->set_f(0.1); + } + EXPECT_NE(ws.CreateBlob("input"), nullptr); + OperatorBase op(op_def, &ws); + EXPECT_TRUE(op.Verify()); + EXPECT_FLOAT_EQ(op.GetSingleArgument("arg0", 0.0), 0.1); + EXPECT_DEATH(op.GetSingleArgument("arg0", 0), + "Argument does not have the right field: expected i"); +} + +TEST(OperatorDeathTest, CannotAccessRepeatedParameterWithWrongType) { + OperatorDef op_def; + Workspace ws; + op_def.set_name("JustTest0"); + op_def.set_type("JustTest"); + op_def.add_inputs("input"); + op_def.add_outputs("output"); + { + Argument* arg = op_def.add_args(); + arg->set_name("arg0"); + arg->add_floats(0.1); + } + EXPECT_NE(ws.CreateBlob("input"), nullptr); + OperatorBase op(op_def, &ws); + EXPECT_TRUE(op.Verify()); + auto args = op.GetRepeatedArgument("arg0"); + EXPECT_EQ(args.size(), 1); + EXPECT_FLOAT_EQ(args[0], 0.1); + EXPECT_DEATH(op.GetRepeatedArgument("arg0"), + "Argument does not have the right field: expected ints"); +} + +TEST(OperatorTest, TestDefaultValue) { + OperatorDef op_def; + Workspace ws; + OperatorBase op(op_def, &ws); + EXPECT_FLOAT_EQ( + op.GetSingleArgument("arg-nonexisting", 0.5), 0.5); +} + +TEST(OperatorTest, TestSetUp) { + Workspace ws; + OperatorDef op_def; + op_def.set_name("JustTest0"); + op_def.set_type("JustTest"); + op_def.add_inputs("input"); + op_def.add_outputs("output"); + EXPECT_NE(nullptr, ws.CreateBlob("input")); + unique_ptr op(CreateOperator(op_def, &ws)); + EXPECT_NE(nullptr, op.get()); + EXPECT_TRUE(op->Verify()); + EXPECT_TRUE(ws.HasBlob("output")); +} + +TEST(OperatorTest, TestSetUpInputOutputCount) { + Workspace ws; + OperatorDef op_def; + op_def.set_name("JustTest0"); + op_def.set_type("JustTest"); + op_def.add_inputs("input"); + op_def.add_inputs("input2"); + op_def.add_outputs("output"); + EXPECT_NE(nullptr, ws.CreateBlob("input")); + EXPECT_NE(nullptr, ws.CreateBlob("input2")); + unique_ptr op(CreateOperator(op_def, &ws)); + EXPECT_NE(nullptr, op.get()); + EXPECT_TRUE(ws.HasBlob("output")); + // Because JustTest will only accept one single input, this will return false. + EXPECT_FALSE(op->Verify()); + + op_def.clear_inputs(); + op_def.add_inputs("input"); + op_def.add_outputs("output2"); + op.reset(CreateOperator(op_def, &ws)); + EXPECT_NE(nullptr, op.get()); + // Because JustTest will only produce one single output, this will return + // false. + EXPECT_FALSE(op->Verify()); +} + +NetDef GetNetDefForTest() { + NetDef net_def; + OperatorDef op_def; + net_def.set_name("NetForTest"); + op_def.set_name("JustTest0"); + op_def.set_type("JustTest"); + op_def.add_inputs("input"); + op_def.add_outputs("hidden"); + net_def.add_operators()->CopyFrom(op_def); + op_def.set_name("JustTest1"); + op_def.set_inputs(0, "hidden"); + op_def.set_outputs(0, "output"); + net_def.add_operators()->CopyFrom(op_def); + return net_def; +} + +TEST(NetTest, TestScaffoldingSimpleNet) { + NetDef net_def = GetNetDefForTest(); + net_def.set_net_type("simple"); + Workspace ws; + EXPECT_NE(nullptr, ws.CreateBlob("input")); + unique_ptr net(CreateNet(net_def, &ws)); + EXPECT_NE(nullptr, net.get()); + EXPECT_TRUE(net->Verify()); + EXPECT_TRUE(ws.HasBlob("input")); + EXPECT_TRUE(ws.HasBlob("hidden")); + EXPECT_TRUE(ws.HasBlob("output")); + EXPECT_TRUE(net->Run()); +} + +TEST(NetTest, TestScaffoldingParallelNet) { + NetDef net_def = GetNetDefForTest(); + net_def.set_net_type("parallel"); + net_def.set_num_workers(1); + Workspace ws; + EXPECT_NE(nullptr, ws.CreateBlob("input")); + unique_ptr net(CreateNet(net_def, &ws)); + EXPECT_NE(nullptr, net.get()); + EXPECT_TRUE(net->Verify()); + EXPECT_TRUE(ws.HasBlob("input")); + EXPECT_TRUE(ws.HasBlob("hidden")); + EXPECT_TRUE(ws.HasBlob("output")); + EXPECT_TRUE(net->Run()); +} + +} // namespace caffe2 + + diff --git a/caffe2/core/parallel_net_test.cc b/caffe2/core/parallel_net_test.cc new file mode 100644 index 00000000000..4f311d46fcd --- /dev/null +++ b/caffe2/core/parallel_net_test.cc @@ -0,0 +1,134 @@ +#include // NOLINT +#include +#include // NOLINT + +#include "caffe2/core/net.h" +#include "caffe2/core/operator.h" +#include "google/protobuf/text_format.h" +#include "gtest/gtest.h" + +namespace caffe2 { + +using std::clock_t; +using std::clock; + +// SleepOp basically sleeps for a given number of seconds. +class SleepOp final : public OperatorBase { + public: + SleepOp(const OperatorDef& operator_def, Workspace* ws) + : OperatorBase(operator_def, ws), + ms_(OperatorBase::GetSingleArgument("ms", 1000)) { + DCHECK_GT(ms_, 0); + DCHECK_LT(ms_, 3600 * 1000) << "Really? This long?"; + } + + bool Run() final { + clock_t start = clock(); + std::this_thread::sleep_for(std::chrono::milliseconds(ms_)); + clock_t end = clock(); + if (OperatorBase::OutputSize()) { + vector* output = OperatorBase::Output >(0); + output->resize(2); + (*output)[0] = start; + (*output)[1] = end; + } + return true; + } + + private: + int ms_; + // We allow arbitrary inputs and at most one output so that we can + // test scaffolding of networks. If the output is 1, it will be filled with + // vector with two elements: start time and end time. + INPUT_OUTPUT_STATS(0, INT_MAX, 0, 1); + DISABLE_COPY_AND_ASSIGN(SleepOp); +}; + +namespace { +REGISTER_CPU_OPERATOR(Sleep, SleepOp) +REGISTER_CUDA_OPERATOR(Sleep, SleepOp) +} // namespace + +const char kSleepNetDefString[] = +" name: \"sleepnet\"" +" net_type: \"parallel\"" +" num_workers: 2" +" operators {" +" outputs: \"sleep1\"" +" name: \"sleep1\"" +" type: \"Sleep\"" +" args {" +" name: \"ms\"" +" i: 100" +" }" +" }" +" operators {" +" inputs: \"sleep1\"" +" outputs: \"sleep2\"" +" name: \"sleep2\"" +" type: \"Sleep\"" +" args {" +" name: \"ms\"" +" i: 100" +" }" +" }" +" operators {" +" outputs: \"sleep3\"" +" name: \"sleep3\"" +" type: \"Sleep\"" +" args {" +" name: \"ms\"" +" i: 150" +" }" +" }"; + + +TEST(ParallelNetTest, TestParallelNetTiming) { + NetDef net_def; + CHECK(google::protobuf::TextFormat::ParseFromString( + string(kSleepNetDefString), &net_def)); + // Below is the parallel version + Workspace ws; + unique_ptr net(CreateNet(net_def, &ws)); + EXPECT_NE(nullptr, net.get()); + EXPECT_TRUE(net->Verify()); + auto start_time = std::chrono::system_clock::now(); + EXPECT_TRUE(net->Run()); + // Inspect the time - it should be around 2000 milliseconds, since sleep3 can + // run in parallel with sleep1 and sleep2. + auto duration = std::chrono::duration_cast( + std::chrono::system_clock::now() - start_time); + int milliseconds = duration.count(); + // We should be seeing 200 ms. This adds a little slack time. + EXPECT_GT(milliseconds, 180); + EXPECT_LT(milliseconds, 220); +} + +// For sanity check, we also test the sequential time - it should take 0.35 +// seconds instead since everything has to be sequential. +TEST(SimpleNetTest, TestSimpleNetTiming) { + NetDef net_def; + CHECK(google::protobuf::TextFormat::ParseFromString( + string(kSleepNetDefString), &net_def)); + net_def.set_net_type("simple"); + Workspace ws; + unique_ptr net(CreateNet(net_def, &ws)); + EXPECT_NE(nullptr, net.get()); + EXPECT_TRUE(net->Verify()); + auto start_time = std::chrono::system_clock::now(); + EXPECT_TRUE(net->Run()); + // Inspect the time - it should be around 2000 milliseconds, since sleep3 can + // run in parallel with sleep1 and sleep2. + auto duration = std::chrono::duration_cast( + std::chrono::system_clock::now() - start_time); + int milliseconds = duration.count(); + // We should be seeing 350 ms. This adds a little slack time. + EXPECT_GT(milliseconds, 330); + EXPECT_LT(milliseconds, 370); +} + + +} // namespace caffe2 + + + diff --git a/caffe2/core/registry.h b/caffe2/core/registry.h new file mode 100644 index 00000000000..8d199d8b7df --- /dev/null +++ b/caffe2/core/registry.h @@ -0,0 +1,112 @@ +#ifndef CAFFE2_CORE_REGISTRY_H_ +#define CAFFE2_CORE_REGISTRY_H_ + +#include +#include +#include + +#include "caffe2/core/common.h" + +namespace caffe2 { + +// Registry is a class that allows one to register classes by a specific +// key, usually a string specifying the name. For each key type and object type, +// there should be only one single registry responsible for it. + +template +class Registry { + public: + typedef ObjectType* (*Creator)(Args ...); + typedef CaffeMap CreatorRegistry; + + Registry() : registry_() {} + + void Register(const string& key, Creator creator) { + // The if statement below is essentially the same as the following line: + // CHECK_EQ(registry_.count(key), 0) << "Key " << key + // << " registered twice."; + // However, CHECK_EQ depends on google logging, and since registration is + // carried out at static initialization time, we do not want to have an + // explicit dependency on glog's initialization function. + if (registry_.count(key) != 0) { + std::cerr << "Key " << key << " already registered." << std::endl; + std::exit(1); + } + registry_[key] = creator; + } + + inline bool Has(const string& key) { return (registry_.count(key) != 0); } + + ObjectType* Create(const string& key, Args ... args) { + if (registry_.count(key) == 0) { + std::cerr << "Key " << key << " not found." << std::endl; + std::cerr << "Available keys:" << std::endl; + TEST_PrintRegisteredNames(); + std::cerr << "Returning null pointer."; + return nullptr; + } + return registry_[key](args...); + } + + // This function should only used in test code to inspect registered names. + // You should only call this function after google glog is initialized - + // do NOT call it in static initializations. + void TEST_PrintRegisteredNames() { + std::vector keys; + for (const auto& it : registry_) { + keys.push_back(it.first); + } + std::sort(keys.begin(), keys.end()); + for (const string& key : keys) { + std::cout << "Registry key: " << key << std::endl; + } + std::cout << "A total of " << keys.size() << " registered keys." + << std::endl; + } + + private: + CreatorRegistry registry_; + + DISABLE_COPY_AND_ASSIGN(Registry); +}; + +template +class Registerer { + public: + Registerer(const string& key, Registry* registry, + typename Registry::Creator creator) { + registry->Register(key, creator); + } + + template + static ObjectType* DefaultCreator(Args ... args) { + return new DerivedType(args...); + } +}; + + +#define DECLARE_REGISTRY(RegistryName, ObjectType, ...) \ + Registry* RegistryName(); \ + typedef Registerer Registerer##RegistryName; + +#define DEFINE_REGISTRY(RegistryName, ObjectType, ...) \ + Registry* RegistryName() { \ + static Registry* registry = \ + new Registry(); \ + return registry; \ + } +// Note(Yangqing): The __VA_ARGS__ below allows one to specify a templated +// creator with comma in its templated arguments. +#define REGISTER_CREATOR(RegistryName, key, ...) \ + Registerer##RegistryName g_##RegistryName##_##key( \ + #key, RegistryName(), __VA_ARGS__); + +// Note(Yangqing): The __VA_ARGS__ below allows one to specify a templated class +// with comma in its templated arguments. +#define REGISTER_CLASS(RegistryName, key, ...) \ + Registerer##RegistryName g_##RegistryName##_##key( \ + #key, RegistryName(), \ + Registerer##RegistryName::DefaultCreator<__VA_ARGS__>); + +} // namespace caffe2 +#endif // CAFFE2_CORE_REGISTRY_H_ diff --git a/caffe2/core/registry_test.cc b/caffe2/core/registry_test.cc new file mode 100644 index 00000000000..4ca2d0faf6c --- /dev/null +++ b/caffe2/core/registry_test.cc @@ -0,0 +1,48 @@ +#include +#include + +#include "caffe2/core/registry.h" +#include "gtest/gtest.h" +#include "glog/logging.h" + +namespace caffe2 { + +class Foo { + public: + explicit Foo(int x) { LOG(INFO) << "Foo " << x; } +}; + +DECLARE_REGISTRY(FooRegistry, Foo, int); +DEFINE_REGISTRY(FooRegistry, Foo, int); +#define REGISTER_FOO(clsname) \ + REGISTER_CLASS(FooRegistry, clsname, clsname) + +class Bar : public Foo { + public: + explicit Bar(int x) : Foo(x) { LOG(INFO) << "Bar " << x; } +}; +REGISTER_FOO(Bar); + +class AnotherBar : public Foo { + public: + explicit AnotherBar(int x) : Foo(x) { + LOG(INFO) << "AnotherBar " << x; + } +}; +REGISTER_FOO(AnotherBar); + +TEST(RegistryTest, CanRunCreator) { + unique_ptr bar(FooRegistry()->Create("Bar", 1)); + EXPECT_TRUE(bar != nullptr) << "Cannot create bar."; + unique_ptr another_bar(FooRegistry()->Create("AnotherBar", 1)); + EXPECT_TRUE(another_bar != nullptr); +} + +TEST(RegistryTest, ReturnNullOnNonExistingCreator) { + EXPECT_EQ( + FooRegistry()->Create("Non-existing bar", 1), nullptr); +} + +} // namespace caffe2 + + diff --git a/caffe2/core/typeid.cc b/caffe2/core/typeid.cc new file mode 100644 index 00000000000..90192fb2909 --- /dev/null +++ b/caffe2/core/typeid.cc @@ -0,0 +1,11 @@ +#include "caffe2/core/typeid.h" + +#include + +namespace caffe2 { +namespace internal { + +std::map g_caffe2_type_name_map; + +} // namespace internal +} // namespace caffe2 diff --git a/caffe2/core/typeid.h b/caffe2/core/typeid.h new file mode 100644 index 00000000000..897e4587e84 --- /dev/null +++ b/caffe2/core/typeid.h @@ -0,0 +1,63 @@ +#ifndef CAFFE2_CORE_TYPEID_H_ +#define CAFFE2_CORE_TYPEID_H_ + +#include +#include + +#include "caffe2/core/common.h" +#include "glog/logging.h" + +namespace caffe2 { +namespace internal { + +static_assert(sizeof(void*) <= sizeof(int64_t), + "This does not happen often, but int64_t is not enough for " + "pointers on this platform."); +typedef int64_t TypeId; +extern std::map g_caffe2_type_name_map; +const TypeId gUnknownType = 0; + +template +class TypeIdRegisterer { + public: + TypeIdRegisterer() { + CHECK_EQ(g_caffe2_type_name_map.count(id()), 0) + << "Registerer instantiated twice."; + g_caffe2_type_name_map[id()] = typeid(T).name(); + } + inline TypeId id() { + return reinterpret_cast(type_id_bit); + } + + private: + bool type_id_bit[1]; +}; + +// id = TypeId() gives a unique type id for the given class, which can be +// verified by IsType(id). This allows us to check the type of object +// pointers during run-time. +template +TypeId GetTypeId() { + static TypeIdRegisterer reg; + return reg.id(); +} + +template +inline bool IsTypeId(TypeId id) { + return (id == GetTypeId()); +} + +inline string TypeName(TypeId id) { + if (id == gUnknownType) return "UNKNOWN"; + return g_caffe2_type_name_map[id]; +} + +template +inline string TypeName() { + return TypeName(GetTypeId()); +} + +} // namespace internal +} // namespace caffe2 + +#endif // CAFFE2_CORE_TYPEID_H_ diff --git a/caffe2/core/types.h b/caffe2/core/types.h new file mode 100644 index 00000000000..8a14c279778 --- /dev/null +++ b/caffe2/core/types.h @@ -0,0 +1,27 @@ +#ifndef CAFFE2_CORE_TYPES_H_ +#define CAFFE2_CORE_TYPES_H_ + +#include + +namespace caffe2 { + +// Storage orders that are often used in the image applications. +enum StorageOrder { + UNKNOWN = 0, + NHWC = 1, + NCHW = 2, +}; + +inline StorageOrder StringToStorageOrder(const string& str) { + if (str == "NHWC") { + return StorageOrder::NHWC; + } else if (str == "NCHW") { + return StorageOrder::NCHW; + } else { + return StorageOrder::UNKNOWN; + } +} + +} // namespace caffe2 + +#endif // CAFFE2_CORE_TYPES_H_ diff --git a/caffe2/core/workspace.cc b/caffe2/core/workspace.cc new file mode 100644 index 00000000000..36d6bf3648a --- /dev/null +++ b/caffe2/core/workspace.cc @@ -0,0 +1,177 @@ +#include +#include + +#include "caffe2/core/operator.h" +#include "caffe2/core/net.h" +#include "caffe2/core/workspace.h" +#include "caffe2/proto/caffe2.pb.h" + +namespace caffe2 { + +Blob* Workspace::CreateBlob(const string& name) { + if (HasBlob(name)) { + VLOG(1) << "Blob " << name << " already exists. Skipping."; + } else { + VLOG(1) << "Creating blob " << name; + (*blob_map_)[name] = unique_ptr(new Blob()); + } + return (*blob_map_)[name].get(); +} + +const Blob* Workspace::GetBlob(const string& name) const { + if (!HasBlob(name)) { + LOG(WARNING) << "Blob " << name << " not in the workspace."; + // TODO(Yangqing): do we want to always print out the list of blobs here? + LOG(WARNING) << "Current blobs:"; + for (const auto& entry : *blob_map_) { + LOG(WARNING) << entry.first; + } + return nullptr; + } else { + return (*blob_map_)[name].get(); + } +} + +bool Workspace::CreateNet(const NetDef& net_def) { + CHECK(net_def.has_name()) << "Net definition should have a name."; + if (net_map_.count(net_def.name()) > 0) { + LOG(WARNING) << "Overwriting existing network of the same name."; + // Note(Yangqing): Why do we explicitly erase it here? Some components of + // the old network, such as a opened LevelDB, may prevent us from creating a + // new network before the old one is deleted. Thus we will need to first + // erase the old one before the new one can be constructed. + net_map_.erase(net_def.name()); + } + // Create a new net with its name. + LOG(INFO) << "Initializing network " << net_def.name(); + net_map_[net_def.name()] = + unique_ptr(caffe2::CreateNet(net_def, this)); + if (net_map_[net_def.name()].get() == nullptr) { + LOG(ERROR) << "Error when creating the network."; + net_map_.erase(net_def.name()); + return false; + } + if (!net_map_[net_def.name()]->Verify()) { + LOG(ERROR) << "Error when setting up network " << net_def.name(); + return false; + } + return true; +} + +void Workspace::DeleteNet(const string& name) { + if (net_map_.count(name)) { + net_map_.erase(name); + } +} + +bool Workspace::RunNet(const string& name) { + if (!net_map_.count(name)) { + LOG(ERROR) << "Network " << name << " does not exist yet."; + return false; + } + return net_map_[name]->Run(); +} + +bool Workspace::RunOperatorOnce(const OperatorDef& op_def) { + std::unique_ptr op(CreateOperator(op_def, this)); + if (!op->Verify()) { + LOG(ERROR) << "Error when setting up operator " << op_def.name(); + return false; + } + if (!op->Run()) { + LOG(ERROR) << "Error when running operator " << op_def.name(); + return false; + } + return true; +} +bool Workspace::RunNetOnce(const NetDef& net_def) { + std::unique_ptr net(caffe2::CreateNet(net_def, this)); + if (!net->Verify()) { + LOG(ERROR) << "Error when setting up network " << net_def.name(); + return false; + } + if (!net->Run()) { + LOG(ERROR) << "Error when running network " << net_def.name(); + return false; + } + return true; +} + +bool Workspace::RunPlan(const PlanDef& plan) { + LOG(INFO) << "Started executing plan."; + if (plan.networks_size() == 0 || plan.execution_steps_size() == 0) { + LOG(WARNING) << "Nothing to run - did you define a correct plan?"; + // We will do nothing, but the plan is still legal so we will return true. + return true; + } + LOG(INFO) << "Initializing networks."; + + for (const NetDef& net_def : plan.networks()) { + if (!CreateNet(net_def)) { + LOG(ERROR) << "Failed initializing the networks."; + return false; + } + } + clock_t start_time = clock(); + for (const ExecutionStep& step : plan.execution_steps()) { + clock_t step_start_time = clock(); + if (!ExecuteStepRecursive(step)) { + LOG(ERROR) << "Failed initializing step " << step.name(); + return false; + } + LOG(INFO) << "Step " << step.name() << " took " + << static_cast(clock() - step_start_time) / CLOCKS_PER_SEC + << " seconds."; + } + LOG(INFO) << "Total plan took " + << static_cast(clock() - start_time) / CLOCKS_PER_SEC + << " seconds."; + LOG(INFO) << "Plan executed successfully."; + return true; +} + +bool Workspace::ExecuteStepRecursive(const ExecutionStep& step) { + LOG(INFO) << "Running execution step " << step.name(); + if (!(step.substeps_size() == 0 || step.networks_size() == 0)) { + LOG(ERROR) << "An ExecutionStep should either have substeps or networks " + << "but not both."; + return false; + } + + if (step.substeps_size()) { + int iterations = step.has_iterations() ? step.iterations() : 1; + for (int i = 0; i < iterations; ++i) { + for (const ExecutionStep& substep : step.substeps()) { + if (!ExecuteStepRecursive(substep)) { + return false; + } + } + } + return true; + } else { + // If this ExecutionStep just contains nets, we can directly run it. + vector networks; + // Collect the networks to run. + for (const string& network_name : step.networks()) { + if (!net_map_.count(network_name)) { + LOG(ERROR) << "Network " << network_name << " not found."; + return false; + } + VLOG(1) << "Going to execute network " << network_name; + networks.push_back(net_map_[network_name].get()); + } + int iterations = step.has_iterations() ? step.iterations() : 1; + VLOG(1) << "Executing networks for " << iterations << " iterations."; + for (int iter = 0; iter < iterations; ++iter) { + VLOG(1) << "Executing network iteration " << iter; + for (NetBase* network : networks) { + if (!network->Run()) { + return false; + } + } + } + } + return true; +} + +} // namespace caffe2 diff --git a/caffe2/core/workspace.h b/caffe2/core/workspace.h new file mode 100644 index 00000000000..c8ea754fe7c --- /dev/null +++ b/caffe2/core/workspace.h @@ -0,0 +1,93 @@ +#ifndef CAFFE2_CORE_WORKSPACE_H_ +#define CAFFE2_CORE_WORKSPACE_H_ + +#include +#include +#include +#include + +#include "caffe2/core/blob.h" +#include "caffe2/core/common.h" +#include "caffe2/core/registry.h" +#include "caffe2/proto/caffe2.pb.h" + +namespace caffe2 { + +class NetBase; + +// Workspace is a class that holds all the blobs in this run and also runs +// the operators. +class Workspace { + public: + typedef CaffeMap > BlobMap; + typedef CaffeMap > NetMap; + // Initializes an empty workspace. + Workspace() : blob_map_(new BlobMap()), root_folder_(".") {} + explicit Workspace(const string& root_folder) + : blob_map_(new BlobMap()), net_map_(), root_folder_(root_folder) {} + ~Workspace() {} + + // Return a list of blob names. This may be a bit slow since it will involve + // creation of multiple temp variables - if possible, use HasBlob() or + // GetBlob() below with given names. + vector Blobs() { + vector names; + for (auto& entry : *blob_map_) { + names.push_back(entry.first); + } + return names; + } + // Return the root folder of the workspace. + const string& RootFolder() { return root_folder_; } + inline bool HasBlob(const string& name) const { + return blob_map_->count(name); + } + Blob* CreateBlob(const string& name); + const Blob* GetBlob(const string& name) const; + inline Blob* GetBlob(const string& name) { + return const_cast( + static_cast(this)->GetBlob(name)); + } + + // CreateNet creates a network in the current workspace. It can then + // be referred to by RunNet(). + bool CreateNet(const NetDef& net_def); + void DeleteNet(const string& net_name); + bool RunNet(const string& net_name); + vector Nets() { + vector names; + for (auto& entry : net_map_) { + names.push_back(entry.first); + } + return names; + } + + // RunPlan runs a plan that has multiple nets and execution steps. + bool RunPlan(const PlanDef& plan_def); + + // RunOperatorOnce and RunNetOnce runs an operator or net once. The difference + // between RunNet and RunNetOnce lies in the fact that RunNet allows you to + // have a persistent net object, while RunNetOnce creates a net and discards + // it on the fly - this may make things like database read and random number + // generators repeat the same thing over multiple calls. + bool RunOperatorOnce(const OperatorDef& op_def); + bool RunNetOnce(const NetDef& net_def); + + + protected: + bool ExecuteStepRecursive(const ExecutionStep& execution); + + private: + // If a workspace is shared with another one, the blob_map_ is going to be + // shared, but net_map_ will not be. + // TODO(Yangqing): Are we really going to share workspaces? If not, let's + // remove this unnecessity. + unique_ptr blob_map_; + NetMap net_map_; + string root_folder_; + DISABLE_COPY_AND_ASSIGN(Workspace); +}; + +} // namespace caffe2 + +#endif // CAFFE2_CORE_WORKSPACE_H_ diff --git a/caffe2/core/workspace_test.cc b/caffe2/core/workspace_test.cc new file mode 100644 index 00000000000..3697d258245 --- /dev/null +++ b/caffe2/core/workspace_test.cc @@ -0,0 +1,50 @@ +#include + +#include "caffe2/core/operator.h" +#include "gtest/gtest.h" + + +namespace caffe2 { + +class Foo {}; + +TEST(WorkspaceTest, BlobAccess) { + Workspace ws; + + EXPECT_FALSE(ws.HasBlob("nonexisting")); + EXPECT_EQ(ws.GetBlob("nonexisting"), nullptr); + + EXPECT_EQ(ws.GetBlob("newblob"), nullptr); + EXPECT_NE(nullptr, ws.CreateBlob("newblob")); + EXPECT_NE(nullptr, ws.GetBlob("newblob")); + EXPECT_TRUE(ws.HasBlob("newblob")); + + // Different names should still be not created. + EXPECT_FALSE(ws.HasBlob("nonexisting")); + EXPECT_EQ(ws.GetBlob("nonexisting"), nullptr); + + // Check if the returned Blob is OK for all operations + Blob* blob = ws.GetBlob("newblob"); + int* int_unused UNUSED_VARIABLE = blob->GetMutable(); + EXPECT_TRUE(blob->IsType()); + EXPECT_FALSE(blob->IsType()); + EXPECT_NE(&blob->Get(), nullptr); + + // Re-creating the blob does not change the content as long as it already + // exists. + EXPECT_NE(nullptr, ws.CreateBlob("newblob")); + EXPECT_TRUE(blob->IsType()); + EXPECT_FALSE(blob->IsType()); + // When not null, we should only call with the right type. + EXPECT_NE(&blob->Get(), nullptr); +} + +TEST(WorkspaceTest, RunEmptyPlan) { + PlanDef plan_def; + Workspace ws; + EXPECT_TRUE(ws.RunPlan(plan_def)); +} + +} // namespace caffe2 + + diff --git a/caffe2/db/BREW b/caffe2/db/BREW new file mode 100644 index 00000000000..9e8c4b3dc07 --- /dev/null +++ b/caffe2/db/BREW @@ -0,0 +1,33 @@ +# This folder contains database implementations that has third third_party +# dependencies. + +cc_library( + name = "db", + srcs = [ + "leveldb.cc", + "lmdb.cc", + ], + deps = [ + ":zmqdb", + "//caffe2/core:core", + "//third_party/glog:glog", + "//third_party/leveldb:leveldb", + "//third_party/liblmdb:lmdb", + ], + whole_archive = True, +) + +cc_library( + name = "zmqdb", + srcs = [ + "zmqdb.cc", + ], + deps = [ + "//caffe2/core:core", + "//third_party/glog:glog", + "//third_party/leveldb:leveldb", + "//third_party/liblmdb:lmdb", + "//third_party/libzmq:libzmq", + ], + whole_archive = True, +) diff --git a/caffe2/db/leveldb.cc b/caffe2/db/leveldb.cc new file mode 100644 index 00000000000..6e2b0f3a1d6 --- /dev/null +++ b/caffe2/db/leveldb.cc @@ -0,0 +1,82 @@ +#include "caffe2/core/db.h" +#include "glog/logging.h" +#include "leveldb/db.h" +#include "leveldb/write_batch.h" + +namespace caffe2 { +namespace db { + +class LevelDBCursor : public Cursor { + public: + explicit LevelDBCursor(leveldb::Iterator* iter) + : iter_(iter) { SeekToFirst(); } + ~LevelDBCursor() { delete iter_; } + void SeekToFirst() override { iter_->SeekToFirst(); } + void Next() override { iter_->Next(); } + string key() override { return iter_->key().ToString(); } + string value() override { return iter_->value().ToString(); } + bool Valid() override { return iter_->Valid(); } + + private: + leveldb::Iterator* iter_; +}; + +class LevelDBTransaction : public Transaction { + public: + explicit LevelDBTransaction(leveldb::DB* db) : db_(db) { + CHECK_NOTNULL(db_); + batch_.reset(new leveldb::WriteBatch()); + } + ~LevelDBTransaction() { Commit(); } + void Put(const string& key, const string& value) override { + batch_->Put(key, value); + } + void Commit() override { + leveldb::Status status = db_->Write(leveldb::WriteOptions(), batch_.get()); + batch_.reset(new leveldb::WriteBatch()); + CHECK(status.ok()) << "Failed to write batch to leveldb " + << std::endl << status.ToString(); + } + + private: + leveldb::DB* db_; + std::unique_ptr batch_; + + DISABLE_COPY_AND_ASSIGN(LevelDBTransaction); +}; + +class LevelDB : public DB { + public: + LevelDB(const string& source, Mode mode) : DB(source, mode) { + leveldb::Options options; + options.block_size = 65536; + options.write_buffer_size = 268435456; + options.max_open_files = 100; + options.error_if_exists = mode == NEW; + options.create_if_missing = mode != READ; + leveldb::DB* db_temp; + leveldb::Status status = leveldb::DB::Open(options, source, &db_temp); + CHECK(status.ok()) << "Failed to open leveldb " << source + << std::endl << status.ToString(); + db_.reset(db_temp); + LOG(INFO) << "Opened leveldb " << source; + } + + void Close() override { db_.reset(); } + Cursor* NewCursor() override { + return new LevelDBCursor(db_->NewIterator(leveldb::ReadOptions())); + } + Transaction* NewTransaction() override { + return new LevelDBTransaction(db_.get()); + } + + private: + std::unique_ptr db_; +}; + +REGISTER_CAFFE2_DB(LevelDB, LevelDB); +// For lazy-minded, one can also call with lower-case name. +REGISTER_CAFFE2_DB(leveldb, LevelDB); + +} // namespace db +} // namespace caffe2 diff --git a/caffe2/db/lmdb.cc b/caffe2/db/lmdb.cc new file mode 100644 index 00000000000..e17737b9e55 --- /dev/null +++ b/caffe2/db/lmdb.cc @@ -0,0 +1,136 @@ +#include + +#include "caffe2/core/db.h" +#include "glog/logging.h" +#include "lmdb.h" + +namespace caffe2 { +namespace db { + +constexpr size_t LMDB_MAP_SIZE = 1099511627776; // 1 TB + +inline void MDB_CHECK(int mdb_status) { + CHECK_EQ(mdb_status, MDB_SUCCESS) << mdb_strerror(mdb_status); +} + +class LMDBCursor : public Cursor { + public: + explicit LMDBCursor(MDB_env* mdb_env) + : mdb_env_(mdb_env), valid_(false) { + MDB_CHECK(mdb_txn_begin(mdb_env_, NULL, MDB_RDONLY, &mdb_txn_)); + MDB_CHECK(mdb_dbi_open(mdb_txn_, NULL, 0, &mdb_dbi_)); + MDB_CHECK(mdb_cursor_open(mdb_txn_, mdb_dbi_, &mdb_cursor_)); + SeekToFirst(); + } + virtual ~LMDBCursor() { + mdb_cursor_close(mdb_cursor_); + mdb_dbi_close(mdb_env_, mdb_dbi_); + mdb_txn_abort(mdb_txn_); + } + void SeekToFirst() override { Seek(MDB_FIRST); } + void Next() override { Seek(MDB_NEXT); } + string key() override { + return string(static_cast(mdb_key_.mv_data), mdb_key_.mv_size); + } + string value() override { + return string(static_cast(mdb_value_.mv_data), + mdb_value_.mv_size); + } + bool Valid() override { return valid_; } + + private: + void Seek(MDB_cursor_op op) { + int mdb_status = mdb_cursor_get(mdb_cursor_, &mdb_key_, &mdb_value_, op); + if (mdb_status == MDB_NOTFOUND) { + valid_ = false; + } else { + MDB_CHECK(mdb_status); + valid_ = true; + } + } + + MDB_env* mdb_env_; + MDB_txn* mdb_txn_; + MDB_dbi mdb_dbi_; + MDB_cursor* mdb_cursor_; + MDB_val mdb_key_, mdb_value_; + bool valid_; +}; + +class LMDBTransaction final : public Transaction { + public: + explicit LMDBTransaction(MDB_env* mdb_env) + : mdb_env_(mdb_env) { + MDB_CHECK(mdb_txn_begin(mdb_env_, NULL, 0, &mdb_txn_)); + MDB_CHECK(mdb_dbi_open(mdb_txn_, NULL, 0, &mdb_dbi_)); + } + ~LMDBTransaction() { + MDB_CHECK(mdb_txn_commit(mdb_txn_)); + mdb_dbi_close(mdb_env_, mdb_dbi_); + mdb_txn_abort(mdb_txn_); + } + void Put(const string& key, const string& value) override; + void Commit() override { + MDB_CHECK(mdb_txn_commit(mdb_txn_)); + mdb_dbi_close(mdb_env_, mdb_dbi_); + mdb_txn_abort(mdb_txn_); + // Begin a new transaction. + MDB_CHECK(mdb_txn_begin(mdb_env_, NULL, 0, &mdb_txn_)); + MDB_CHECK(mdb_dbi_open(mdb_txn_, NULL, 0, &mdb_dbi_)); + } + + private: + MDB_env* mdb_env_; + MDB_dbi mdb_dbi_; + MDB_txn* mdb_txn_; + + DISABLE_COPY_AND_ASSIGN(LMDBTransaction); +}; + +class LMDB : public DB { + public: + LMDB(const string& source, Mode mode); + virtual ~LMDB() { Close(); } + void Close() override { + if (mdb_env_ != NULL) { + mdb_env_close(mdb_env_); + mdb_env_ = NULL; + } + } + Cursor* NewCursor() override { return new LMDBCursor(mdb_env_); } + Transaction* NewTransaction() override { + return new LMDBTransaction(mdb_env_); + } + + private: + MDB_env* mdb_env_; +}; + +LMDB::LMDB(const string& source, Mode mode) : DB(source, mode) { + MDB_CHECK(mdb_env_create(&mdb_env_)); + MDB_CHECK(mdb_env_set_mapsize(mdb_env_, LMDB_MAP_SIZE)); + if (mode == NEW) { + CHECK_EQ(mkdir(source.c_str(), 0744), 0) << "mkdir " << source << "failed"; + } + int flags = 0; + if (mode == READ) { + flags = MDB_RDONLY | MDB_NOTLS; + } + MDB_CHECK(mdb_env_open(mdb_env_, source.c_str(), flags, 0664)); + LOG(INFO) << "Opened lmdb " << source; +} + +void LMDBTransaction::Put(const string& key, const string& value) { + MDB_val mdb_key, mdb_value; + mdb_key.mv_data = const_cast(key.data()); + mdb_key.mv_size = key.size(); + mdb_value.mv_data = const_cast(value.data()); + mdb_value.mv_size = value.size(); + MDB_CHECK(mdb_put(mdb_txn_, mdb_dbi_, &mdb_key, &mdb_value, 0)); +} + +REGISTER_CAFFE2_DB(LMDB, LMDB); +REGISTER_CAFFE2_DB(lmdb, LMDB); + +} // namespace db +} // namespace caffe2 diff --git a/caffe2/db/zmqdb.cc b/caffe2/db/zmqdb.cc new file mode 100644 index 00000000000..a8fd3ff11ad --- /dev/null +++ b/caffe2/db/zmqdb.cc @@ -0,0 +1,103 @@ +#include + +#include + +#include "caffe2/core/db.h" +#include "glog/logging.h" +#include "zmq.h" + +namespace caffe2 { +namespace db { + +typedef char ZmqCommand; +typedef int ZmqMessageSize; +const ZmqCommand kQueryMessageSize = 's'; +const ZmqCommand kGet = 'g'; + +class ZmqDBCursor : public Cursor { + public: + explicit ZmqDBCursor(void* requester) + : requester_(requester), buffer_(nullptr), received_size_(0), + buffer_size_(0) { + // Figure out the buffer size. + CHECK_EQ( + zmq_send(requester_, &kQueryMessageSize, sizeof(ZmqCommand), 0), + sizeof(ZmqCommand)) + << "Incorrect zmq communication when querying message size."; + CHECK_EQ( + zmq_recv(requester_, &buffer_size_, sizeof(ZmqMessageSize), 0), + sizeof(ZmqMessageSize)) + << "Incorrect zmq communication when fetching message size."; + CHECK_GT(buffer_size_, 0) << "Incorrect buffer size obtained."; + buffer_.reset(new char[buffer_size_]); + // obtain the first value. + Next(); + } + + ~ZmqDBCursor() {} + void SeekToFirst() override { /* do nothing */ } + void Next() override { + CHECK_EQ( + zmq_send(requester_, &kGet, sizeof(ZmqCommand), 0), sizeof(ZmqCommand)) + << "Incorrect zmq communication when sending request."; + received_size_ = zmq_recv(requester_, buffer_.get(), buffer_size_, 0); + CHECK_GT(received_size_, 0) << "Received no message."; + } + string key() override { return ""; } + string value() override { + return string(buffer_.get(), received_size_); + } + virtual bool Valid() { return true; } + + private: + void* requester_; + unique_ptr buffer_; + int received_size_; + ZmqMessageSize buffer_size_; +}; + + +class ZmqDB : public DB { + public: + ZmqDB(const string& source, Mode mode) + : DB(source, mode), context_(zmq_ctx_new()), + requester_(zmq_socket(context_, ZMQ_REQ)) { + CHECK_EQ(mode, READ) << "ZeroMQ DB only supports read mode."; + VLOG(1) << "Connecting to ZeroMQ server: " << source; + int ret = zmq_connect(requester_, source.c_str()); + CHECK_EQ(ret, 0) << "Error in connecting to zmq server. " + << "Error is: " << errno; + VLOG(1) << "Opened ZeroMQ server: " << source; + } + + ~ZmqDB() { Close(); } + + void Close() override { + if (!requester_) { + zmq_close(requester_); + requester_ = nullptr; + zmq_ctx_destroy(context_); + context_ = nullptr; + } + } + + Cursor* NewCursor() override { + return new ZmqDBCursor(requester_); + } + Transaction* NewTransaction() override { + // TODO(Yangqing): Do I really need to just do log fatal? + LOG(FATAL) << "ZeroMQ DB does not support writing with a transaction."; + return nullptr; // dummy placeholder to suppress old compiler warnings. + } + + private: + void* context_; + void* requester_; +}; + +REGISTER_CAFFE2_DB(ZmqDB, ZmqDB); +// For lazy-minded, one can also call with lower-case name. +REGISTER_CAFFE2_DB(zmqdb, ZmqDB); + +} // namespace db +} // namespace caffe2 diff --git a/caffe2/end_to_end_test/BREW b/caffe2/end_to_end_test/BREW new file mode 100644 index 00000000000..5d985ac0694 --- /dev/null +++ b/caffe2/end_to_end_test/BREW @@ -0,0 +1,17 @@ +cc_test( + name = "end_to_end_tests", + srcs = [ + "end_to_end_tests.cc", + ], + deps = [ + "//caffe2/core:core", + "//caffe2/db:db", + "//caffe2/operators:core_ops", + "//caffe2/operators:core_ops_gpu", + "//caffe2/operators:core_ops_cudnn", + "//caffe2/utils:proto_utils", + "//data/toy:toy_models", + "//data/mnist:mnist_models", + "//gtest:gtest_main", + ], +) diff --git a/caffe2/end_to_end_test/end_to_end_tests.cc b/caffe2/end_to_end_test/end_to_end_tests.cc new file mode 100644 index 00000000000..2501bffb593 --- /dev/null +++ b/caffe2/end_to_end_test/end_to_end_tests.cc @@ -0,0 +1,189 @@ +#include "caffe2/core/context.h" +#include "caffe2/core/context_gpu.h" +#include "caffe2/core/operator.h" +#include "caffe2/utils/proto_utils.h" +#include "gflags/gflags.h" +#include "glog/logging.h" +#include "gtest/gtest.h" + +DECLARE_string(caffe_test_root); + +namespace caffe2 { + +const char kToyRegressionTestPlanPath[] = "/data/toy/toy_regression.pbtxt"; +const char kMNISTLinearClassificationPath[] = + "/data/mnist/linear_classifier_plan.pbtxt"; +const char kMNISTTwoLayerReluClassificationPath[] = + "/data/mnist/mnist_relu_network.pbtxt"; +const char kMNISTLeNetClassificationPath[] = + "/data/mnist/mnist_lenet.pbtxt"; +const char kMNISTLeNetClassificationGPUPath[] = + "/data/mnist/mnist_lenet_gpu.pbtxt"; +const char kMNISTLeNetNHWCClassificationPath[] = + "/data/mnist/mnist_lenet_nhwc.pbtxt"; +const char kMNISTLeNetNHWCClassificationGPUPath[] = + "/data/mnist/mnist_lenet_nhwc_gpu.pbtxt"; +const char kMNISTLeNetGroupConvClassificationPath[] = + "/data/mnist/mnist_lenet_group_convolution.pbtxt"; +const char kMNISTLeNetGroupConvNHWCClassificationPath[] = + "/data/mnist/mnist_lenet_group_convolution_nhwc.pbtxt"; + + +template +void ExpectTensorEquivalence(const Workspace& ws, const string& name_a, + const string& name_b, + const float relative_error) { + const Blob* a = ws.GetBlob(name_a); + EXPECT_TRUE(a != nullptr); + EXPECT_TRUE((a->IsType >())); + int size = a->Get >().size(); + const dtype* a_data = a->Get >().data(); + const Blob* b = ws.GetBlob(name_b); + EXPECT_TRUE(b != nullptr); + EXPECT_TRUE((b->IsType >())); + EXPECT_EQ(size, (b->Get >().size())); + const dtype* b_data = b->Get >().data(); + for (int i = 0; i < size; ++i) { + EXPECT_NEAR(a_data[i], b_data[i], relative_error); + } +} + +TEST(ToyRegressionTest, TestRunPlan) { + PlanDef plan_def; + CHECK(ReadProtoFromFile( + FLAGS_caffe_test_root + kToyRegressionTestPlanPath, &plan_def)); + Workspace workspace; + workspace.RunPlan(plan_def); + ExpectTensorEquivalence(workspace, "W", "W_gt", 0.005); +} + +TEST(MNISTLinearClassificationTest, TestRunPlan) { + PlanDef plan_def; + CHECK(ReadProtoFromFile( + FLAGS_caffe_test_root + kMNISTLinearClassificationPath, &plan_def)); + Workspace workspace; + workspace.RunPlan(plan_def); + const Blob* accuracy = workspace.GetBlob("accuracy"); + EXPECT_TRUE(accuracy != nullptr); + EXPECT_TRUE((accuracy->IsType >())); + auto& accuracy_tensor = accuracy->Get >(); + EXPECT_EQ(accuracy_tensor.size(), 1); + // Accuracy should be above 85%. + EXPECT_GT(accuracy_tensor.data()[0], 0.85); +} + +TEST(MNISTTwoLayerReluClassificationTest, TestRunPlan) { + PlanDef plan_def; + CHECK(ReadProtoFromFile( + FLAGS_caffe_test_root + kMNISTTwoLayerReluClassificationPath, &plan_def)); + Workspace workspace; + workspace.RunPlan(plan_def); + const Blob* accuracy = workspace.GetBlob("accuracy"); + EXPECT_TRUE(accuracy != nullptr); + EXPECT_TRUE((accuracy->IsType >())); + auto& accuracy_tensor = accuracy->Get >(); + EXPECT_EQ(accuracy_tensor.size(), 1); + // Accuracy should be above 90%. + EXPECT_GT(accuracy_tensor.data()[0], 0.90); +} + +TEST(MNISTLeNetClassificationTest, LARGE_TestRunPlan) { + PlanDef plan_def; + CHECK(ReadProtoFromFile( + FLAGS_caffe_test_root + kMNISTLeNetClassificationPath, &plan_def)); + Workspace workspace; + workspace.RunPlan(plan_def); + const Blob* accuracy = workspace.GetBlob("accuracy"); + EXPECT_TRUE(accuracy != nullptr); + EXPECT_TRUE((accuracy->IsType >())); + auto& accuracy_tensor = accuracy->Get >(); + EXPECT_EQ(accuracy_tensor.size(), 1); + // Accuracy should be above 90%. + EXPECT_GT(accuracy_tensor.data()[0], 0.90); +} + +TEST(MNISTLeNetClassificationTestGPU, LARGE_TestRunPlan) { + PlanDef plan_def; + CHECK(ReadProtoFromFile( + FLAGS_caffe_test_root + kMNISTLeNetClassificationGPUPath, &plan_def)); + Workspace workspace; + workspace.RunPlan(plan_def); + const Blob* accuracy = workspace.GetBlob("accuracy"); + EXPECT_TRUE(accuracy != nullptr); + EXPECT_TRUE((accuracy->IsType >())); + CPUContext context; + Tensor accuracy_tensor( + accuracy->Get >(), &context); + EXPECT_EQ(accuracy_tensor.size(), 1); + // Accuracy should be above 90%. + EXPECT_GT(accuracy_tensor.data()[0], 0.90); +} + + +TEST(MNISTLeNetNHWCClassificationTest, LARGE_TestRunPlan) { + PlanDef plan_def; + CHECK(ReadProtoFromFile( + FLAGS_caffe_test_root + kMNISTLeNetNHWCClassificationPath, &plan_def)); + Workspace workspace; + workspace.RunPlan(plan_def); + const Blob* accuracy = workspace.GetBlob("accuracy"); + EXPECT_TRUE(accuracy != nullptr); + EXPECT_TRUE((accuracy->IsType >())); + auto& accuracy_tensor = accuracy->Get >(); + EXPECT_EQ(accuracy_tensor.size(), 1); + // Accuracy should be above 90%. + EXPECT_GT(accuracy_tensor.data()[0], 0.90); +} + +TEST(MNISTLeNetNHWCClassificationGPUTest, LARGE_TestRunPlan) { + PlanDef plan_def; + CHECK(ReadProtoFromFile( + FLAGS_caffe_test_root + kMNISTLeNetNHWCClassificationGPUPath, &plan_def)); + Workspace workspace; + workspace.RunPlan(plan_def); + const Blob* accuracy = workspace.GetBlob("accuracy"); + EXPECT_TRUE(accuracy != nullptr); + EXPECT_TRUE((accuracy->IsType >())); + CPUContext context; + Tensor accuracy_tensor( + accuracy->Get >(), &context); + EXPECT_EQ(accuracy_tensor.size(), 1); + // Accuracy should be above 90%. + EXPECT_GT(accuracy_tensor.data()[0], 0.90); +} + + + +TEST(MNISTLeNetGroupConvolutionClassificationTest, LARGE_TestRunPlan) { + PlanDef plan_def; + CHECK(ReadProtoFromFile( + FLAGS_caffe_test_root + kMNISTLeNetGroupConvClassificationPath, + &plan_def)); + Workspace workspace; + workspace.RunPlan(plan_def); + const Blob* accuracy = workspace.GetBlob("accuracy"); + EXPECT_TRUE(accuracy != nullptr); + EXPECT_TRUE((accuracy->IsType >())); + auto& accuracy_tensor = accuracy->Get >(); + EXPECT_EQ(accuracy_tensor.size(), 1); + // Accuracy should be above 90%. + EXPECT_GT(accuracy_tensor.data()[0], 0.90); +} + +TEST(MNISTLeNetGroupConvolutionNHWCClassificationTest, LARGE_TestRunPlan) { + PlanDef plan_def; + CHECK(ReadProtoFromFile( + FLAGS_caffe_test_root + kMNISTLeNetGroupConvNHWCClassificationPath, + &plan_def)); + Workspace workspace; + workspace.RunPlan(plan_def); + const Blob* accuracy = workspace.GetBlob("accuracy"); + EXPECT_TRUE(accuracy != nullptr); + EXPECT_TRUE((accuracy->IsType >())); + auto& accuracy_tensor = accuracy->Get >(); + EXPECT_EQ(accuracy_tensor.size(), 1); + // Accuracy should be above 90%. + EXPECT_GT(accuracy_tensor.data()[0], 0.90); +} + +} // namespace caffe2 diff --git a/caffe2/image/BREW b/caffe2/image/BREW new file mode 100644 index 00000000000..f72cfa9aefa --- /dev/null +++ b/caffe2/image/BREW @@ -0,0 +1,32 @@ +cc_library( + name = "image_ops", + srcs = [ + "image_input_op.cc", + ], + hdrs = [ + "image_input_op.h", + ], + deps = [ + "//caffe2/core:core", + "//caffe2/operators:core_ops", + "//caffe2/utils:math", + "//caffe2/utils:proto_utils", + ], + external_libs = [ + "opencv_core", + "opencv_highgui", + "opencv_imgproc", + ], + whole_archive = True, +) + +cuda_library( + name = "image_ops_gpu", + srcs = Glob(["*_gpu.cc"]) + Glob(["*.cu"]), + deps = [ + ":image_ops", + "//caffe2/core:core_gpu", + "//caffe2/utils:math_gpu", + ], + whole_archive = True, +) diff --git a/caffe2/image/image_input_op.cc b/caffe2/image/image_input_op.cc new file mode 100644 index 00000000000..5e89627e377 --- /dev/null +++ b/caffe2/image/image_input_op.cc @@ -0,0 +1,7 @@ +#include "caffe2/image/image_input_op.h" + +namespace caffe2 { + +REGISTER_CPU_OPERATOR(ImageInput, ImageInputOp); + +} // namespace caffe2 diff --git a/caffe2/image/image_input_op.h b/caffe2/image/image_input_op.h new file mode 100644 index 00000000000..597036fbc42 --- /dev/null +++ b/caffe2/image/image_input_op.h @@ -0,0 +1,205 @@ +#ifndef CAFFE2_IMAGE_IMAGE_INPUT_OP_H_ +#define CAFFE2_IMAGE_IMAGE_INPUT_OP_H_ + +#include + +#include + +#include "caffe2/core/db.h" +#include "caffe2/operators/prefetch_op.h" + +namespace caffe2 { + +template +class ImageInputOp final + : public PrefetchOperator { + public: + using OperatorBase::OutputSize; + using PrefetchOperator::prefetch_thread_; + explicit ImageInputOp(const OperatorDef& operator_def, + Workspace* ws); + ~ImageInputOp() { + if (prefetch_thread_.get() != nullptr) { + prefetch_thread_->join(); + } + } + + bool Prefetch() override; + bool CopyPrefetched() override; + + private: + unique_ptr db_; + unique_ptr cursor_; + CPUContext cpu_context_; + Tensor prefetched_image_; + Tensor prefetched_label_; + int batch_size_; + string db_name_; + string db_type_; + float mean_; + float std_; + bool color_; + int scale_; + bool warp_; + int crop_; + bool mirror_; + INPUT_OUTPUT_STATS(0, 0, 2, 2); + DISABLE_COPY_AND_ASSIGN(ImageInputOp); +}; + +template +ImageInputOp::ImageInputOp( + const OperatorDef& operator_def, Workspace* ws) + : PrefetchOperator(operator_def, ws), + batch_size_( + OperatorBase::template GetSingleArgument("batch_size", 0)), + db_name_( + OperatorBase::template GetSingleArgument("db", "")), + db_type_(OperatorBase::template GetSingleArgument( + "db_type", "leveldb")), + mean_(OperatorBase::template GetSingleArgument("mean", 0.)), + std_(OperatorBase::template GetSingleArgument("std", 1.)), + color_(OperatorBase::template GetSingleArgument("color", 1)), + scale_(OperatorBase::template GetSingleArgument("scale", -1)), + warp_(OperatorBase::template GetSingleArgument("warp", 0)), + crop_(OperatorBase::template GetSingleArgument("crop", -1)), + mirror_(OperatorBase::template GetSingleArgument("mirror", 0)) { + CHECK_GT(batch_size_, 0) << "Batch size should be nonnegative."; + CHECK_GT(db_name_.size(), 0) << "Must provide a leveldb name."; + CHECK_GT(scale_, 0) << "Must provide the scaling factor."; + CHECK_GT(crop_, 0) << "Must provide the cropping value."; + CHECK_GE(scale_, crop_) + << "The scale value must be no smaller than the crop value."; + + DLOG(INFO) << "Creating an image input op with the following setting: "; + DLOG(INFO) << " Outputting in batches of " << batch_size_ << " images;"; + DLOG(INFO) << " Treating input image as " + << (color_ ? "color " : "grayscale ") << "image;"; + DLOG(INFO) << " Scaling image to " << scale_ + << (warp_ ? " with " : " without ") << "warping;"; + DLOG(INFO) << " Cropping image to " << crop_ + << (mirror_ ? " with " : " without ") << "random mirroring;"; + DLOG(INFO) << " Subtract mean " << mean_ << " and divide by std " << std_ + << "."; + db_.reset(db::CreateDB(db_type_, db_name_, db::READ)); + cursor_.reset(db_->NewCursor()); + cursor_->SeekToFirst(); + prefetched_image_.Reshape( + vector{batch_size_, crop_, crop_, (color_ ? 3 : 1)}); + prefetched_label_.Reshape(vector(1, batch_size_)); +} + +template +bool ImageInputOp::Prefetch() { + std::bernoulli_distribution mirror_this_image(0.5); + float* image_data = prefetched_image_.mutable_data(); + int channels = color_ ? 3 : 1; + for (int item_id = 0; item_id < batch_size_; ++item_id) { + // LOG(INFO) << "Prefetching item " << item_id; + // process data + TensorProtos protos; + CHECK(protos.ParseFromString(cursor_->value())) << cursor_->value(); + const TensorProto& image = protos.protos(0); + const TensorProto& label = protos.protos(1); + cv::Mat final_img; + if (image.data_type() == TensorProto::STRING) { + // Do the image manipuiation, and copy the content. + DCHECK_EQ(image.string_data_size(), 1); + + const string& encoded_image = image.string_data(0); + int encoded_size = encoded_image.size(); + cv::Mat img = cv::imdecode( + cv::Mat(1, &encoded_size, CV_8UC1, + const_cast(encoded_image.data())), + color_ ? CV_LOAD_IMAGE_COLOR : CV_LOAD_IMAGE_GRAYSCALE); + // Do resizing. + int scaled_width, scaled_height; + if (warp_) { + scaled_width = scale_; + scaled_height = scale_; + } else if (img.rows > img.cols) { + scaled_width = scale_; + scaled_height = static_cast(img.rows) * scale_ / img.cols; + } else { + scaled_height = scale_; + scaled_width = static_cast(img.cols) * scale_ / img.rows; + } + cv::resize(img, final_img, cv::Size(scaled_width, scaled_height), 0, 0, + cv::INTER_LINEAR); + } else if (image.data_type() == TensorProto::BYTE) { + // In this case, we will always just take the bytes as the raw image. + CHECK_EQ(image.dims_size(), (color_ ? 3 : 2)); + CHECK_GE(image.dims(0), crop_) + << "Image height must be bigger than crop."; + CHECK_GE(image.dims(1), crop_) << "Image width must be bigger than crop."; + CHECK(!color_ || image.dims(2) == 3); + final_img = cv::Mat( + image.dims(0), image.dims(1), color_ ? CV_8UC3 : CV_8UC1, + const_cast(image.byte_data().data())); + } + // find the cropped region, and copy it to the destination matrix with + // mean subtraction and scaling. + int width_offset = + std::uniform_int_distribution<>(0, final_img.cols - crop_)( + cpu_context_.RandGenerator()); + int height_offset = + std::uniform_int_distribution<>(0, final_img.rows - crop_)( + cpu_context_.RandGenerator()); + // DVLOG(1) << "offset: " << height_offset << ", " << width_offset; + if (mirror_ && mirror_this_image(cpu_context_.RandGenerator())) { + // Copy mirrored image. + for (int h = height_offset; h < height_offset + crop_; ++h) { + for (int w = width_offset + crop_ - 1; w >= width_offset; --w) { + const cv::Vec3b& cv_data = final_img.at(h, w); + for (int c = 0; c < channels; ++c) { + *(image_data++) = + (static_cast(cv_data[c]) - mean_) / std_; + } + } + } + } else { + // Copy normally. + for (int h = height_offset; h < height_offset + crop_; ++h) { + for (int w = width_offset; w < width_offset + crop_; ++w) { + const cv::Vec3b& cv_data = final_img.at(h, w); + for (int c = 0; c < channels; ++c) { + *(image_data++) = + (static_cast(cv_data[c]) - mean_) / std_; + } + } + } + } + // Copy the label + DCHECK_EQ(label.data_type(), TensorProto::INT32); + DCHECK_EQ(label.int32_data_size(), 1); + prefetched_label_.mutable_data()[item_id] = label.int32_data(0); + // Advance to the next item. + cursor_->Next(); + if (!cursor_->Valid()) { + cursor_->SeekToFirst(); + } + } + return true; +} + +template +bool ImageInputOp::CopyPrefetched() { + // The first output is the image data. + auto* image_output = OperatorBase::Output >(0); + image_output->ReshapeLike(prefetched_image_); + this->device_context_.template Copy( + image_output->mutable_data(), prefetched_image_.data(), + prefetched_image_.size()); + // The second output is the label. + auto* label_output = OperatorBase::Output >(1); + label_output->ReshapeLike(prefetched_label_); + this->device_context_.template Copy( + label_output->mutable_data(), prefetched_label_.data(), + prefetched_label_.size()); + return true; +} + +} // namespace caffe2 + +#endif // CAFFE2_IMAGE_IMAGE_INPUT_OP_H_ + diff --git a/caffe2/image/image_input_op_gpu.cc b/caffe2/image/image_input_op_gpu.cc new file mode 100644 index 00000000000..c69889c3f81 --- /dev/null +++ b/caffe2/image/image_input_op_gpu.cc @@ -0,0 +1,9 @@ +#include "caffe2/core/common_gpu.h" +#include "caffe2/core/context_gpu.h" +#include "caffe2/image/image_input_op.h" + +namespace caffe2 { + +REGISTER_CUDA_OPERATOR(ImageInput, ImageInputOp); + +} // namespace caffe2 diff --git a/caffe2/mpi/BREW b/caffe2/mpi/BREW new file mode 100644 index 00000000000..4b839f5f011 --- /dev/null +++ b/caffe2/mpi/BREW @@ -0,0 +1,19 @@ +cc_headers( + name = "mpi_common", + srcs = [ + "mpi_common.h", + ], +) + +cc_library( + name = "mpi_ops", + srcs = [ + "allreduce_op.cc" + ], + deps = [ + ":mpi_common", + "//caffe2/core:core", + ], + external_libs = Env.MPI_LIBS, + whole_archive = True, +) \ No newline at end of file diff --git a/caffe2/mpi/allreduce_op.cc b/caffe2/mpi/allreduce_op.cc new file mode 100644 index 00000000000..c9a0411ca68 --- /dev/null +++ b/caffe2/mpi/allreduce_op.cc @@ -0,0 +1,37 @@ +#include + +#include "caffe2/core/operator.h" +#include "caffe2/mpi/mpi_common.h" + +namespace caffe2 { + +// AllreduceOp does Allreduce using MPI. Currently, only SUM is supported. +template +class AllreduceOp final : public Operator { + public: + USE_OPERATOR_BASE_FUNCTIONS; + USE_SIMPLE_CTOR_DTOR(AllreduceOp); + + bool RunOnDevice() { + auto& input = Input(0); + auto* output = Output(0); + output->ReshapeLike(input); + MPI_Allreduce(const_cast(input.data()), + output->mutable_data(), input.size(), + MPIDataTypeWrapper::type(), MPI_SUM, MPI_COMM_WORLD); + return true; + } + + protected: + // Input: X; Output: X_reduced. + INPUT_OUTPUT_STATS(1, 1, 1, 1); + DISABLE_COPY_AND_ASSIGN(AllreduceOp); +}; + +namespace { +REGISTER_CPU_OPERATOR(Allreduce, AllreduceOp); +// Note: Allreduce does not work on CUDA devices as of OpenMPI 1.8.4 yet. In the +// future we can simply initialize it here. +} + +} // namespace caffe2 diff --git a/caffe2/mpi/mpi_common.h b/caffe2/mpi/mpi_common.h new file mode 100644 index 00000000000..7ef9898b5bc --- /dev/null +++ b/caffe2/mpi/mpi_common.h @@ -0,0 +1,26 @@ +#ifndef CAFFE2_MPI_MPI_COMMON_H_ +#define CAFFE2_MPI_MPI_COMMON_H_ + +namespace caffe2 { + +inline void CheckInitializedMPI() { + int flag; + MPI_Initialized(&flag); + CHECK(flag) << "MPI does not seem to have been initialized."; +} + +template class MPIDataTypeWrapper; + +#define MPI_DATATYPE_WRAPPER(c_type, mpi_type) \ + template<> class MPIDataTypeWrapper { \ + public: \ + inline static MPI_Datatype type() { return mpi_type; } \ + }; + +MPI_DATATYPE_WRAPPER(float, MPI_FLOAT) +MPI_DATATYPE_WRAPPER(double, MPI_DOUBLE) +// Note(Yangqing): as necessary, add more specializations. + +} // namespace caffe2 + +#endif // CAFFE2_MPI_MPI_COMMON_H_ diff --git a/caffe2/operators/BREW b/caffe2/operators/BREW new file mode 100644 index 00000000000..380a4614cb4 --- /dev/null +++ b/caffe2/operators/BREW @@ -0,0 +1,98 @@ +cc_headers( + name = "operators_headers", + srcs = Glob(["*.h"]), +) + +cc_library( + name = "core_ops", + srcs = [ + "accumulate_op.cc", + "accuracy_op.cc", + "averagepool_op.cc", + "conv_op.cc", + "cross_entropy_op.cc", + "depth_split_op.cc", + "dropout_op.cc", + "elementwise_op.cc", + "filler_op.cc", + "fully_connected_op.cc", + "l2_distance_op.cc", + "load_save_op.cc", + "local_response_normalization_op.cc", + "loss_op.cc", + "maxpool_op.cc", + "order_switch_ops.cc", + "relu_op.cc", + "softmax_op.cc", + "summarize_op.cc", + "tensor_protos_db_input.cc", + "utility_ops.cc", + ], + deps = [ + ":operators_headers", + "//caffe2/core:core", + "//caffe2/utils:math", + "//caffe2/utils:proto_utils", + ], + whole_archive = True, +) + +cuda_library( + name = "core_ops_gpu", + srcs = [ + "accumulate_op.cu", + "accuracy_op.cu", + "averagepool_op.cu", + "conv_op.cu", + "cross_entropy_op.cu", + "depth_split_op.cu", + "dropout_op.cu", + "elementwise_op_gpu.cc", + "filler_op.cu", + "fully_connected_op_gpu.cc", + "l2_distance_op.cu", + "load_save_op.cu", + "local_response_normalization_op.cu", + "loss_op_gpu.cc", + "maxpool_op.cu", + "order_switch_ops.cu", + "relu_op.cu", + "softmax_op.cu", + "summarize_op.cu", + "tensor_protos_db_input_gpu.cc", + "utility_ops_gpu.cc", + ], + deps = [ + ":operators_headers", + "//caffe2/core:core_gpu", + "//caffe2/utils:math_gpu", + "//caffe2/utils:proto_utils", + ], + whole_archive = True, +) + +cc_library( + name = "core_ops_cudnn", + srcs = [ + "softmax_op_cudnn.cc", + ], + deps = [ + ":operators_headers", + "//caffe2/core:core_cudnn", + "//caffe2/core:core_gpu", + "//caffe2/utils:math_gpu", + "//third_party/cudnn:cudnn", + ], + whole_archive = True, +) + +cc_test( + name = "core_ops_test", + srcs = Glob(["*_test.cc"]), + deps = [ + ":core_ops", + ":core_ops_gpu", + ":core_ops_cudnn", + "//gtest:gtest_main", + ] +) diff --git a/caffe2/operators/accumulate_op.cc b/caffe2/operators/accumulate_op.cc new file mode 100644 index 00000000000..0b4cca52843 --- /dev/null +++ b/caffe2/operators/accumulate_op.cc @@ -0,0 +1,7 @@ +#include "caffe2/operators/accumulate_op.h" + +namespace caffe2 { +namespace { +REGISTER_CPU_OPERATOR(Accumulate, AccumulateOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/accumulate_op.cu b/caffe2/operators/accumulate_op.cu new file mode 100644 index 00000000000..d30fd7437ab --- /dev/null +++ b/caffe2/operators/accumulate_op.cu @@ -0,0 +1,8 @@ +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/accumulate_op.h" + +namespace caffe2 { +namespace { +REGISTER_CUDA_OPERATOR(Accumulate, AccumulateOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/accumulate_op.h b/caffe2/operators/accumulate_op.h new file mode 100644 index 00000000000..9e1816788e0 --- /dev/null +++ b/caffe2/operators/accumulate_op.h @@ -0,0 +1,50 @@ +#ifndef CAFFE2_OPERATORS_ACCUMULATE_OP_H_ +#define CAFFE2_OPERATORS_ACCUMULATE_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/utils/math.h" + +namespace caffe2 { + +// Accumulate operator accumulates the input tensor to the output tensor. If the +// output tensor already has the right size, we add to it; otherwise, we first +// initialize the output tensor to all zeros, and then do accumulation. Any +// further calls to the operator, given that no one else fiddles with the output +// in the interim, will do simple accumulations. +template +class AccumulateOp final : public Operator { + public: + AccumulateOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + kOne(static_cast(1), &device_context_), + gamma_(static_cast( + OperatorBase::template GetSingleArgument("gamma", 1.0)), + &device_context_) {} + USE_OPERATOR_BASE_FUNCTIONS; + + bool RunOnDevice() override { + auto& input = Input(0); + auto* output = Output(0); + if (output->dims() != input.dims()) { + LOG(INFO) << "Reshaping and initializing output."; + output->ReshapeLike(input); + math::Set( + output->size(), 0, output->mutable_data(), &device_context_); + } + math::Axpby( + input.size(), kOne.data(), input.data(), gamma_.data(), + output->mutable_data(), &device_context_); + return true; + } + + protected: + Tensor kOne; + Tensor gamma_; + INPUT_OUTPUT_STATS(1, 1, 1, 1); + DISABLE_COPY_AND_ASSIGN(AccumulateOp); +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_ACCUMULATE_OP_H_ diff --git a/caffe2/operators/accuracy_op.cc b/caffe2/operators/accuracy_op.cc new file mode 100644 index 00000000000..ad199afb7b4 --- /dev/null +++ b/caffe2/operators/accuracy_op.cc @@ -0,0 +1,40 @@ +#include "caffe2/operators/accuracy_op.h" + +namespace caffe2 { + +template <> +bool AccuracyOp::RunOnDevice() { + auto& X = Input(PREDICTION); + auto& label = OperatorBase::Input >(LABEL); + auto* Y = Output(0); + DCHECK_EQ(X.ndim(), 2); + int N = X.dim(0); + int D = X.dim(1); + DCHECK_EQ(label.ndim(), 1); + DCHECK_EQ(label.dim(0), N); + Y->Reshape(std::vector{1}); + const auto* Xdata = X.data(); + const auto* labeldata = label.data(); + int correct = 0; + for (int i = 0; i < N; ++i) { + float maxval = std::numeric_limits::lowest(); + int maxid = 0; + for (int j = 0; j < D; ++j) { + if (Xdata[i * D + j] > maxval) { + maxval = Xdata[i * D + j]; + maxid = j; + } + } + if (maxid == labeldata[i]) { + ++correct; + } + } + DCHECK_LE(correct, N); + Y->mutable_data()[0] = static_cast(correct) / N; + return true; +} + +namespace { +REGISTER_CPU_OPERATOR(Accuracy, AccuracyOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/accuracy_op.cu b/caffe2/operators/accuracy_op.cu new file mode 100644 index 00000000000..9d0961eb6fc --- /dev/null +++ b/caffe2/operators/accuracy_op.cu @@ -0,0 +1,56 @@ +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/accuracy_op.h" +#include "caffe2/utils/math.h" + +namespace caffe2 { + +namespace { +__global__ void AccuracyKernel(const int N, const int D, const float* Xdata, + const int* labeldata, float* accuracy) { + int count = 0; + CUDA_1D_KERNEL_LOOP(i, N) { + float maxval = Xdata[i * D]; + int maxid = 0; + for (int j = 1; j < D; ++j) { + if (Xdata[i * D + j] > maxval) { + maxval = Xdata[i * D + j]; + maxid = j; + } + } + if (maxid == labeldata[i]) { + ++count; + } + } + atomicAdd(accuracy, static_cast(count)); +} +__global__ void AccuracyDivideKernel(const int N, float* accuracy) { + *accuracy /= N; +} +} // namespace + +template <> +bool AccuracyOp::RunOnDevice() { + auto& X = Input(PREDICTION); + auto& label = OperatorBase::Input >(LABEL); + auto* Y = Output(0); + DCHECK_EQ(X.ndim(), 2); + int N = X.dim(0); + int D = X.dim(1); + DCHECK_EQ(label.ndim(), 1); + DCHECK_EQ(label.dim(0), N); + Y->Reshape(std::vector(1, 1)); + math::Set(1, 0, Y->mutable_data(), &device_context_); + AccuracyKernel<<>>( + N, D, X.data(), label.data(), Y->mutable_data()); + // This is going to be executed only in one single kernel. Not very beautiful, + // but probably we have to do this? + AccuracyDivideKernel<<<1, 1, 0, device_context_.cuda_stream()>>>( + N, Y->mutable_data()); + return true; +} + +namespace { +REGISTER_CUDA_OPERATOR(Accuracy, AccuracyOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/accuracy_op.h b/caffe2/operators/accuracy_op.h new file mode 100644 index 00000000000..5bbae563fb5 --- /dev/null +++ b/caffe2/operators/accuracy_op.h @@ -0,0 +1,24 @@ +#ifndef CAFFE2_OPERATORS_ACCURACY_OP_H_ +#define CAFFE2_OPERATORS_ACCURACY_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" + +namespace caffe2 { + +template +class AccuracyOp final : public Operator { + public: + USE_SIMPLE_CTOR_DTOR(AccuracyOp); + USE_OPERATOR_BASE_FUNCTIONS; + bool RunOnDevice() override; + + protected: + INPUT_OUTPUT_STATS(2, 2, 1, 1); + INPUT_TAGS(PREDICTION, LABEL); + DISABLE_COPY_AND_ASSIGN(AccuracyOp); +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_ACCURACY_OP_H_ diff --git a/caffe2/operators/averagepool_op.cc b/caffe2/operators/averagepool_op.cc new file mode 100644 index 00000000000..dd58d70bf13 --- /dev/null +++ b/caffe2/operators/averagepool_op.cc @@ -0,0 +1,194 @@ +#include "caffe2/operators/averagepool_op.h" + +namespace caffe2 { + +using std::max; +using std::min; + +template <> +bool AveragePoolOp::RunOnDeviceWithOrderNCHW() { + auto& X = Input(0); + auto* Y = Output(0); + ConvPoolOpBase::SetOutputSize(X, Y, X.dim(1)); + + const float* Xdata = X.data(); + float* Ydata = Y->mutable_data(); + math::Set( + Y->size(), 0, Ydata, &device_context_); + // The main loop + int channels = X.dim(1); + int height = X.dim(2); + int width = X.dim(3); + int pooled_height = Y->dim(2); + int pooled_width = Y->dim(3); + for (int n = 0; n < X.dim(0); ++n) { + for (int c = 0; c < channels; ++c) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = ph * stride_h_ - pad_t_; + int wstart = pw * stride_w_ - pad_l_; + int hend = min(hstart + kernel_h_, height); + int wend = min(wstart + kernel_w_, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + const int pool_index = ph * pooled_width + pw; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + const int input_index = h * width + w; + Ydata[pool_index] += Xdata[input_index]; + } + } + Ydata[pool_index] /= (hend - hstart) * (wend - wstart); + } + } + // Do offset. + Xdata += height * width; + Ydata += pooled_height * pooled_width; + } + } + return true; +} + +template <> +bool AveragePoolOp::RunOnDeviceWithOrderNHWC() { + auto& X = Input(0); + auto* Y = Output(0); + int height = X.dim(1); + int width = X.dim(2); + int channels = X.dim(3); + ConvPoolOpBase::SetOutputSize(X, Y, channels); + const float* Xdata = X.data(); + float* Ydata = Y->mutable_data(); + math::Set(Y->size(), 0, Ydata, &device_context_); + // The main loop + int pooled_height = Y->dim(1); + int pooled_width = Y->dim(2); + for (int n = 0; n < X.dim(0); ++n) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = ph * stride_h_ - pad_t_; + int wstart = pw * stride_w_ - pad_l_; + int hend = min(hstart + kernel_h_, height); + int wend = min(wstart + kernel_w_, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + const int pool_index = (ph * pooled_width + pw) * channels; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + const int input_index = (h * width + w) * channels; + for (int c = 0; c < channels; ++c) { + Ydata[pool_index + c] += Xdata[input_index + c]; + } + } + } + float scale = 1. / (hend - hstart) / (wend - wstart); + for (int c = 0; c < channels; ++c) { + Ydata[pool_index + c] *= scale; + } + } + } + // Do offset. + Xdata += X.size() / X.dim(0); + Ydata += Y->size() / Y->dim(0); + } + return true; +} + +template <> +bool AveragePoolGradientOp::RunOnDeviceWithOrderNCHW() { + auto& X = Input(0); + auto& dY = Input(1); + auto* dX = Output(0); + // TODO(Yangqing): Add shape checks. + dX->ReshapeLike(X); + math::Set( + X.size(), 0, dX->mutable_data(), &device_context_); + const float* dYdata = dY.data(); + float* dXdata = dX->mutable_data(); + int channels = X.dim(1); + CHECK_EQ(channels, dY.dim(1)); + int height = X.dim(2); + int width = X.dim(3); + ConvPoolOpBase::ComputePads(height, width); + int pooled_height = dY.dim(2); + int pooled_width = dY.dim(3); + // The main loop + for (int n = 0; n < X.dim(0); ++n) { + for (int c = 0; c < channels; ++c) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = ph * stride_h_ - pad_t_; + int wstart = pw * stride_w_ - pad_l_; + int hend = min(hstart + kernel_h_, height); + int wend = min(wstart + kernel_w_, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + float scale = 1. / (hend - hstart) / (wend - wstart); + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + dXdata[h * width + w] += + dYdata[ph * pooled_width + pw] * scale; + } + } + } + } + // offset + dXdata += height * width; + dYdata += pooled_height * pooled_width; + } + } + return true; +} + +template <> +bool AveragePoolGradientOp::RunOnDeviceWithOrderNHWC() { + auto& X = Input(0); + auto& dY = Input(1); + CHECK_EQ(dY.ndim(), 4); + auto* dX = Output(0); + // TODO(Yangqing): Add shape checks. + dX->ReshapeLike(X); + math::Set( + X.size(), 0, dX->mutable_data(), &device_context_); + const float* dYdata = dY.data(); + float* dXdata = dX->mutable_data(); + // The main loop + int height = X.dim(1); + int width = X.dim(2); + ConvPoolOpBase::ComputePads(height, width); + int pooled_height = dY.dim(1); + int pooled_width = dY.dim(2); + int channels = X.dim(3); + CHECK_EQ(channels, dY.dim(3)); + for (int n = 0; n < X.dim(0); ++n) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = ph * stride_h_ - pad_t_; + int wstart = pw * stride_w_ - pad_l_; + int hend = min(hstart + kernel_h_, height); + int wend = min(wstart + kernel_w_, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + float scale = 1. / (hend - hstart) / (wend - wstart); + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + for (int c = 0; c < channels; ++c) { + dXdata[(h * width + w) * channels + c] += + dYdata[(ph * pooled_width + pw) * channels + c] * scale; + } + } + } + } + } + // offset + dXdata += X.size() / X.dim(0); + dYdata += dY.size() / dY.dim(0); + } + return true; +} + +namespace { +REGISTER_CPU_OPERATOR(AveragePool, AveragePoolOp) +REGISTER_CPU_OPERATOR(AveragePoolGradient, AveragePoolGradientOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/averagepool_op.cu b/caffe2/operators/averagepool_op.cu new file mode 100644 index 00000000000..cba94992703 --- /dev/null +++ b/caffe2/operators/averagepool_op.cu @@ -0,0 +1,218 @@ +#include + +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/averagepool_op.h" + +namespace caffe2 { + +namespace { +template +__global__ void AveragePoolForwardNCHW( + const int nthreads, const dtype* bottom_data, + const int num, const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, const int pad_l, dtype* top_data) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + int hstart = ph * stride_h - pad_t; + int wstart = pw * stride_w - pad_l; + int hend = min(hstart + kernel_h, height); + int wend = min(wstart + kernel_w, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + dtype output = 0; + bottom_data += n * channels * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int idx = c * height * width + h * width + w; + output += bottom_data[idx]; + } + } + int pool_size = (hend - hstart) * (wend - wstart); + top_data[index] = output / pool_size; + } +} + +template +__global__ void AveragePoolForwardNHWC( + const int nthreads, const dtype* bottom_data, + const int num, const int height, const int width, + const int channels, const int pooled_height, const int pooled_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, const int pad_l, dtype* top_data) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int c = index % channels; + int pw = (index / channels) % pooled_width; + int ph = (index / channels / pooled_width) % pooled_height; + int n = index / channels / pooled_width / pooled_height; + int hstart = ph * stride_h - pad_t; + int wstart = pw * stride_w - pad_l; + int hend = min(hstart + kernel_h, height); + int wend = min(wstart + kernel_w, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + dtype output = 0; + bottom_data += n * height * width * channels; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + output += bottom_data[(h * width + w) * channels + c]; + } + } + int pool_size = (hend - hstart) * (wend - wstart); + top_data[index] = output / pool_size; + } +} + +template +__global__ void AvePoolBackwardNCHW(const int nthreads, + const dtype* const top_diff, const int num, const int channels, + const int height, const int width, const int pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_t, + const int pad_l, dtype* const bottom_diff) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // find out the local index + // find out the local offset + const int w = index % width + pad_l; + const int h = (index / width) % height + pad_t; + const int c = (index / width / height) % channels; + const int n = index / width / height / channels; + const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1; + const int phend = min(h / stride_h + 1, pooled_height); + const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; + const int pwend = min(w / stride_w + 1, pooled_width); + dtype gradient = 0; + const dtype* const top_diff_slice = + top_diff + (n * channels + c) * pooled_height * pooled_width; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + // figure out the pooling size + int hstart = ph * stride_h - pad_t; + int wstart = pw * stride_w - pad_l; + int hend = min(hstart + kernel_h, height); + int wend = min(wstart + kernel_w, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + int pool_size = (hend - hstart) * (wend - wstart); + gradient += top_diff_slice[ph * pooled_width + pw] / pool_size; + } + } + bottom_diff[index] = gradient; + } +} + +template +__global__ void AvePoolBackwardNHWC(const int nthreads, + const dtype* const top_diff, const int num, const int height, + const int width, const int channels, const int pooled_height, + const int pooled_width, const int kernel_h, const int kernel_w, + const int stride_h, const int stride_w, const int pad_t, + const int pad_l, dtype* const bottom_diff) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + // find out the local index + // find out the local offset + const int c = index % channels; + const int w = index / channels % width + pad_l; + const int h = (index / channels / width) % height + pad_t; + const int n = index / channels / width / height; + const int phstart = (h < kernel_h) ? 0 : (h - kernel_h) / stride_h + 1; + const int phend = min(h / stride_h + 1, pooled_height); + const int pwstart = (w < kernel_w) ? 0 : (w - kernel_w) / stride_w + 1; + const int pwend = min(w / stride_w + 1, pooled_width); + dtype gradient = 0; + const dtype* const top_diff_slice = + top_diff + n * pooled_height * pooled_width * channels + c; + for (int ph = phstart; ph < phend; ++ph) { + for (int pw = pwstart; pw < pwend; ++pw) { + // figure out the pooling size + int hstart = ph * stride_h - pad_t; + int wstart = pw * stride_w - pad_l; + int hend = min(hstart + kernel_h, height); + int wend = min(wstart + kernel_w, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + int pool_size = (hend - hstart) * (wend - wstart); + gradient += + top_diff_slice[(ph * pooled_width + pw) * channels] / pool_size; + } + } + bottom_diff[index] = gradient; + } +} + +} // namespace + +template <> +bool AveragePoolOp::RunOnDeviceWithOrderNCHW() { + auto& X = Input(0); + auto* Y = Output(0); + ConvPoolOpBase::SetOutputSize(X, Y, X.dim(1)); + int output_size = Y->size(); + AveragePoolForwardNCHW<<>>( + output_size, X.data(), X.dim(0), X.dim(1), X.dim(2), X.dim(3), + Y->dim(2), Y->dim(3), kernel_h_, kernel_w_, stride_h_, stride_w_, + pad_t_, pad_l_, Y->mutable_data()); + return true; +} + +template <> +bool AveragePoolOp::RunOnDeviceWithOrderNHWC() { + auto& X = Input(0); + auto* Y = Output(0); + ConvPoolOpBase::SetOutputSize(X, Y, X.dim(3)); + int output_size = Y->size(); + AveragePoolForwardNHWC<<>>( + output_size, X.data(), X.dim(0), X.dim(1), X.dim(2), X.dim(3), + Y->dim(1), Y->dim(2), kernel_h_, kernel_w_, stride_h_, stride_w_, + pad_t_, pad_l_, Y->mutable_data()); + return true; +} + +template <> +bool AveragePoolGradientOp::RunOnDeviceWithOrderNCHW() { + auto& X = Input(0); + auto& dY = Input(1); + CHECK_EQ(dY.ndim(), 4); + auto* dX = Output(0); + dX->ReshapeLike(X); + ConvPoolOpBase::ComputePads(X.dim(2), X.dim(3)); + AvePoolBackwardNCHW<<>>( + X.size(), dY.data(), X.dim(0), X.dim(1), X.dim(2), X.dim(3), + dY.dim(2), dY.dim(3), kernel_h_, kernel_w_, stride_h_, stride_w_, + pad_t_, pad_l_, dX->mutable_data()); + return true; +} + +template <> +bool AveragePoolGradientOp::RunOnDeviceWithOrderNHWC() { + auto& X = Input(0); + auto& dY = Input(1); + CHECK_EQ(dY.ndim(), 4); + auto* dX = Output(0); + dX->ReshapeLike(X); + ConvPoolOpBase::ComputePads(X.dim(1), X.dim(2)); + AvePoolBackwardNHWC<<>>( + X.size(), dY.data(), X.dim(0), X.dim(1), X.dim(2), X.dim(3), + dY.dim(1), dY.dim(2), kernel_h_, kernel_w_, stride_h_, stride_w_, + pad_t_, pad_l_, dX->mutable_data()); + return true; +} + + +namespace { +REGISTER_CUDA_OPERATOR(AveragePool, AveragePoolOp) +REGISTER_CUDA_OPERATOR(AveragePoolGradient, AveragePoolGradientOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/averagepool_op.h b/caffe2/operators/averagepool_op.h new file mode 100644 index 00000000000..7fdb6aff956 --- /dev/null +++ b/caffe2/operators/averagepool_op.h @@ -0,0 +1,50 @@ +#ifndef CAFFE2_OPERATORS_AVERAGEPOOL_OP_H_ +#define CAFFE2_OPERATORS_AVERAGEPOOL_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/operators/conv_pool_op_base.h" +#include "caffe2/utils/math.h" +#include "glog/logging.h" + +namespace caffe2 { + +template +class AveragePoolOp final : public ConvPoolOpBase { + public: + USE_CONV_POOL_BASE_FUNCTIONS; + AveragePoolOp(const OperatorDef& operator_def, Workspace* ws) + : ConvPoolOpBase(operator_def, ws) {} + ~AveragePoolOp() {} + + bool RunOnDeviceWithOrderNCHW() override; + bool RunOnDeviceWithOrderNHWC() override; + + // Input: X + // Output: Y + INPUT_OUTPUT_STATS(1, 1, 1, 1); + DISABLE_COPY_AND_ASSIGN(AveragePoolOp); +}; + +template +class AveragePoolGradientOp final : + public ConvPoolOpBase { + public: + USE_CONV_POOL_BASE_FUNCTIONS; + AveragePoolGradientOp(const OperatorDef& operator_def, Workspace* ws) + : ConvPoolOpBase(operator_def, ws) {} + ~AveragePoolGradientOp() {} + + bool RunOnDeviceWithOrderNCHW() override; + bool RunOnDeviceWithOrderNHWC() override; + + // Input: X, Y_grad + // Output: X_grad + INPUT_OUTPUT_STATS(2, 2, 1, 1); + DISABLE_COPY_AND_ASSIGN(AveragePoolGradientOp); +}; + + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_AVERAGEPOOL_OP_H_ diff --git a/caffe2/operators/conv_op.cc b/caffe2/operators/conv_op.cc new file mode 100644 index 00000000000..aab80f08bc7 --- /dev/null +++ b/caffe2/operators/conv_op.cc @@ -0,0 +1,10 @@ +#include "caffe2/operators/conv_op.h" +#include "caffe2/operators/conv_op_impl.h" + +namespace caffe2 { +namespace { +REGISTER_CPU_OPERATOR(Conv, ConvOp) +REGISTER_CPU_OPERATOR(ConvGradient, ConvGradientOp) + +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/conv_op.cu b/caffe2/operators/conv_op.cu new file mode 100644 index 00000000000..c4b22c1c7db --- /dev/null +++ b/caffe2/operators/conv_op.cu @@ -0,0 +1,10 @@ +#include "caffe2/operators/conv_op.h" +#include "caffe2/operators/conv_op_impl.h" +#include "caffe2/core/context_gpu.h" + +namespace caffe2 { +namespace { +REGISTER_CUDA_OPERATOR(Conv, ConvOp) +REGISTER_CUDA_OPERATOR(ConvGradient, ConvGradientOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/conv_op.h b/caffe2/operators/conv_op.h new file mode 100644 index 00000000000..3636d48535b --- /dev/null +++ b/caffe2/operators/conv_op.h @@ -0,0 +1,61 @@ +#ifndef CAFFE2_OPERATORS_CONV_OP_H_ +#define CAFFE2_OPERATORS_CONV_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/operators/conv_pool_op_base.h" + +namespace caffe2 { + +template +class ConvOp final : public ConvPoolOpBase { + public: + USE_CONV_POOL_BASE_FUNCTIONS; + ConvOp(const OperatorDef& operator_def, Workspace* ws) + : ConvPoolOpBase(operator_def, ws), + kOne(1, &device_context_), kZero(0, &device_context_) {} + ~ConvOp() {} + + bool RunOnDeviceWithOrderNCHW() override; + bool RunOnDeviceWithOrderNHWC() override; + + private: + Tensor col_buffer_; + Tensor bias_multiplier_; + Tensor kOne; + Tensor kZero; + // Input: X, W, b + // Output: Y + INPUT_TAGS(INPUT, FILTER, BIAS); + INPUT_OUTPUT_STATS(3, 3, 1, 1); + DISABLE_COPY_AND_ASSIGN(ConvOp); +}; + +template +class ConvGradientOp final : public ConvPoolOpBase { + public: + USE_CONV_POOL_BASE_FUNCTIONS; + ConvGradientOp(const OperatorDef& operator_def, Workspace* ws) + : ConvPoolOpBase(operator_def, ws), + kOne(1, &device_context_), kZero(0, &device_context_) {} + ~ConvGradientOp() {} + + bool RunOnDeviceWithOrderNCHW() override; + bool RunOnDeviceWithOrderNHWC() override; + + private: + Tensor col_buffer_; + Tensor bias_multiplier_; + Tensor kOne; + Tensor kZero; + // input: X, W, b, dY + // output: dW, db, and optionally dX + INPUT_TAGS(INPUT, FILTER, BIAS, OUTPUT_GRAD); + OUTPUT_TAGS(FILTER_GRAD, BIAS_GRAD, INPUT_GRAD); + INPUT_OUTPUT_STATS(4, 4, 2, 3); + DISABLE_COPY_AND_ASSIGN(ConvGradientOp); +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_CONV_OP_H_ diff --git a/caffe2/operators/conv_op_cudnn.cu.working b/caffe2/operators/conv_op_cudnn.cu.working new file mode 100644 index 00000000000..2d2482f736c --- /dev/null +++ b/caffe2/operators/conv_op_cudnn.cu.working @@ -0,0 +1,63 @@ +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/conv_pool_op_base.h" + +namespace caffe2 { + +template +class CudnnConvOp final : public ConvPoolOpBase { + public: + CudnnConvOp(const OperatorDef& operator_def, Workspace* ws) + : ConvPoolOpBase(operator_def, ws), + kOne(1, &device_context_), kZero(0, &device_context_) {} + ~CudnnConvOp() {} + + bool ConfigureCudnnConvolution() { + CUDNN_CHECK(cudnnCreateFilterDescriptor(&filter_desc_)); + CUDNN_CHECK(cudnnSetFilter4dDescriptor( + filter_desc, GetCudnnTensorFormat(order_), )) + } + + bool RunOnDevice() override { + // TODO: Reshape + + for (int i) + } + + private: + cudnnTensorDescriptor_t bottom_desc_; + cudnnFilterDescriptor_t filter_desc_; + cudnnTensorDescriptor_t bias_desc_; + cudnnTensorDescriptor_t top_desc_; + cudnnConvolutionDescriptor_t conv_desc_; + // Input: X, W, b + // Output: Y + INPUT_OUTPUT_STATS(3, 3, 1, 1); + DISABLE_COPY_AND_ASSIGN(ConvOp); +}; + +/* +template +class ConvGradientOp final : public ConvPoolOpBase { + public: + USE_CONV_POOL_BASE_FUNCTIONS; + ConvGradientOp(const OperatorDef& operator_def, Workspace* ws) + : ConvPoolOpBase(operator_def, ws), + kOne(1, &device_context_), kZero(0, &device_context_) {} + ~ConvGradientOp() {} + + bool RunOnDeviceWithOrderNCHW() override; + bool RunOnDeviceWithOrderNHWC() override; + + private: + Tensor col_buffer_; + Tensor bias_multiplier_; + Tensor kOne; + Tensor kZero; + // input: X, W, b, dY + // output: dW, db, and optionally dX + INPUT_OUTPUT_STATS(4, 4, 2, 3); + DISABLE_COPY_AND_ASSIGN(ConvGradientOp); +}; +*/ + +} // namespace caffe2 diff --git a/caffe2/operators/conv_op_impl.h b/caffe2/operators/conv_op_impl.h new file mode 100644 index 00000000000..dac02416d7a --- /dev/null +++ b/caffe2/operators/conv_op_impl.h @@ -0,0 +1,336 @@ +// conv_op_impl.h is the templated implementation of the conv_op.h file. +#ifndef CAFFE2_OPERATORS_CONV_OP_IMPL_H_ +#define CAFFE2_OPERATORS_CONV_OP_IMPL_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/operators/conv_op.h" +#include "caffe2/operators/conv_pool_op_base.h" +#include "caffe2/utils/math.h" +#include "glog/logging.h" + +namespace caffe2 { + +template +bool ConvOp::RunOnDeviceWithOrderNCHW() { + auto& X = Input(INPUT); + auto& filter = Input(FILTER); + auto& bias = Input(BIAS); + auto* Y = Output(0); + const int N = X.dim(0), C = X.dim(1), H = X.dim(2), W = X.dim(3); + DCHECK_EQ(filter.ndim(), 4); + const int M = filter.dim(0); + DCHECK_EQ(filter.dim(1), C); + DCHECK_EQ(filter.dim(2), kernel_h_); + DCHECK_EQ(filter.dim(3), kernel_w_); + DCHECK_EQ(bias.ndim(), 1); + DCHECK_EQ(bias.dim(0), M); + ConvPoolOpBase::SetOutputSize(X, Y, filter.dim(0)); + // The dimension of each kernel + const int kernel_dim = C * kernel_h_ * kernel_w_; + // The offset corresponding to a single input image, and a single output + // image. + const int input_offset = C * H * W; + const int output_offset = Y->size() / Y->dim(0); + // The output image size is the spatial size of the output. + const int output_image_size = Y->dim(2) * Y->dim(3); + // The col buffer is stored in CHW order as well - kernel_dim, and the height + // and width. + col_buffer_.Reshape(std::vector{ + C, kernel_h_, kernel_w_, Y->dim(2), Y->dim(3)}); + if (bias_multiplier_.size() != output_image_size) { + // If the helper bias multiplier is not M, reshape and fill it with one. + bias_multiplier_.Reshape(std::vector(1, output_image_size)); + math::Set( + output_image_size, static_cast(1), + bias_multiplier_.mutable_data(), &device_context_); + } + const dtype* Xdata = X.data(); + dtype* col_buffer_data = col_buffer_.mutable_data(); + dtype* Ydata = Y->mutable_data(); + // Im2col, followed by gemm. + for (int image_id = 0; image_id < N; ++image_id) { + math::Im2col( + Xdata, C, H, W, kernel_h_, kernel_w_, + pad_t_, pad_l_, pad_b_, pad_r_, stride_h_, stride_w_, col_buffer_data, + &device_context_); + // Weight term + math::Gemm( + CblasNoTrans, CblasNoTrans, M, output_image_size, kernel_dim, + kOne.data(), filter.data(), col_buffer_data, kZero.data(), Ydata, + &device_context_); + // Bias term + math::Gemm( + CblasNoTrans, CblasNoTrans, M, output_image_size, 1, kOne.data(), + bias.data(), bias_multiplier_.data(), kOne.data(), Ydata, + &device_context_); + Xdata += input_offset; + Ydata += output_offset; + } + return true; +} + +// The implementations. +template +bool ConvOp::RunOnDeviceWithOrderNHWC() { + auto& X = Input(INPUT); + auto& filter = Input(FILTER); + auto& bias = Input(BIAS); + auto* Y = Output(0); + const int N = X.dim(0), H = X.dim(1), W = X.dim(2), C = X.dim(3); + DCHECK_EQ(filter.ndim(), 4); + const int M = filter.dim(0); + DCHECK_EQ(filter.dim(1), kernel_h_); + DCHECK_EQ(filter.dim(2), kernel_w_); + DCHECK_EQ(filter.dim(3), C); + DCHECK_EQ(bias.ndim(), 1); + DCHECK_EQ(bias.dim(0), M); + ConvPoolOpBase::SetOutputSize(X, Y, filter.dim(0)); + // The dimension of each kernel + const int kernel_dim = kernel_h_ * kernel_w_ * C; + // The offset corresponding to a single input image, and a single output + // image. + const int input_offset = H * W * C; + const int output_offset = Y->size() / Y->dim(0); + // The output image size is the spatial size of the output. + const int output_image_size = Y->dim(1) * Y->dim(2); + // The col buffer is stored in HWC order as well - kernel_dim, and the height + // and width. + const dtype* Xdata = X.data(); + dtype* Ydata = Y->mutable_data(); + if (bias_multiplier_.size() != output_image_size) { + // If the helper bias multiplier is not M, reshape and fill it with one. + bias_multiplier_.Reshape(std::vector(1, output_image_size)); + math::Set( + output_image_size, static_cast(1), + bias_multiplier_.mutable_data(), &device_context_); + } + // Specialized path for 1 by 1 convolution + if (kernel_dim == C && Y->dim(1) == X.dim(1) && Y->dim(2) == X.dim(2)) { + if (bias_multiplier_.size() != N * H * W) { + // If the helper bias multiplier is not M, reshape and fill it with one. + bias_multiplier_.Reshape(std::vector(1, N * H * W)); + math::Set( + N * H * W, static_cast(1), + bias_multiplier_.mutable_data(), &device_context_); + } + math::Gemm( + CblasNoTrans, CblasTrans, N * H * W, M, C, kOne.data(), Xdata, + filter.data(), kZero.data(), Ydata, &device_context_); + math::Gemm( + CblasNoTrans, CblasNoTrans, N * H * W, M, 1, kOne.data(), + bias_multiplier_.data(), bias.data(), kOne.data(), Ydata, + &device_context_); + } else { + if (bias_multiplier_.size() != output_image_size) { + // If the helper bias multiplier is not M, reshape and fill it with one. + bias_multiplier_.Reshape(std::vector(1, output_image_size)); + math::Set( + output_image_size, static_cast(1), + bias_multiplier_.mutable_data(), &device_context_); + } + col_buffer_.Reshape(std::vector{ + Y->dim(1), Y->dim(2), kernel_h_, kernel_w_, C}); + dtype* col_buffer_data = col_buffer_.mutable_data(); + // Im2col, followed by gemm. + for (int image_id = 0; image_id < N; ++image_id) { + math::Im2col( + Xdata, C, H, W, kernel_h_, kernel_w_, + pad_t_, pad_l_, pad_b_, pad_r_, stride_h_, stride_w_, col_buffer_data, + &device_context_); + // Weight term + // Wait, is this right....? + math::Gemm( + CblasNoTrans, CblasTrans, output_image_size, M, kernel_dim, + kOne.data(), col_buffer_data, filter.data(), kZero.data(), Ydata, + &device_context_); + // Bias term + math::Gemm( + CblasNoTrans, CblasNoTrans, output_image_size, M, 1, kOne.data(), + bias_multiplier_.data(), bias.data(), kOne.data(), Ydata, + &device_context_); + Xdata += input_offset; + Ydata += output_offset; + } + } + return true; +} + +template +bool ConvGradientOp::RunOnDeviceWithOrderNCHW() { + auto& X = Input(INPUT); + auto& filter = Input(FILTER); + auto& bias = Input(BIAS); + auto& dY = Input(OUTPUT_GRAD); + auto* dfilter = Output(FILTER_GRAD); + auto* dbias = Output(BIAS_GRAD); + const int N = X.dim(0), C = X.dim(1), H = X.dim(2), W = X.dim(3); + ConvPoolOpBase::ComputePads(H, W); + DCHECK_EQ(filter.ndim(), 4); + const int M = filter.dim(0); + DCHECK_EQ(filter.dim(1), C); + DCHECK_EQ(filter.dim(2), kernel_h_); + DCHECK_EQ(filter.dim(3), kernel_w_); + DCHECK_EQ(bias.ndim(), 1); + DCHECK_EQ(bias.dim(0), M); + dfilter->ReshapeLike(filter); + dbias->ReshapeLike(bias); + // The dimension of each kernel + const int kernel_dim = C * kernel_h_ * kernel_w_; + // The offset corresponding to a single input image, and a single output + // image. + const int input_offset = C * H * W; + const int output_offset = dY.size() / dY.dim(0); + // The output image size is the spatial size of the output. + const int output_image_size = dY.dim(2) * dY.dim(3); + // The col buffer is stored in CHW order as well - kernel_dim, and the height + // and width. + col_buffer_.Reshape(std::vector{kernel_dim, output_image_size}); + if (bias_multiplier_.size() != output_image_size) { + // If the helper bias multiplier is not M, reshape and fill it with one. + bias_multiplier_.Reshape(std::vector(1, output_image_size)); + math::Set( + output_image_size, static_cast(1), + bias_multiplier_.mutable_data(), &device_context_); + } + const dtype* Xdata = X.data(); + const dtype* filter_data = filter.data(); + const dtype* dYdata = dY.data(); + dtype* col_buffer_data = col_buffer_.mutable_data(); + dtype* dfilter_data = dfilter->mutable_data(); + dtype* dbias_data = dbias->mutable_data(); + // Pre-setting the gradients to zero. + math::Set(dfilter->size(), 0, dfilter_data, + &device_context_); + math::Set(dbias->size(), 0, dbias_data, + &device_context_); + for (int image_id = 0; image_id < N; ++image_id) { + // When we compute the gradient with respect to the filters, we need to do + // im2col to allow gemm-type computation. + math::Im2col( + Xdata, C, H, W, kernel_h_, kernel_w_, + pad_t_, pad_l_, pad_b_, pad_r_, stride_h_, stride_w_, col_buffer_data, + &device_context_); + // Gradient with respect to filter. + math::Gemm( + CblasNoTrans, CblasTrans, M, kernel_dim, output_image_size, + kOne.data(), dYdata + output_offset * image_id, col_buffer_data, + kOne.data(), dfilter_data, &device_context_); + // Gradient with respect to bias + math::Gemv( + CblasNoTrans, M, output_image_size, kOne.data(), + dYdata + output_offset * image_id, bias_multiplier_.data(), + kOne.data(), dbias_data, &device_context_); + Xdata += input_offset; + } + if (OutputSize() == 3) { + // Compute the gradient w.r.t. the input. + auto *dX = Output(INPUT_GRAD); + dX->ReshapeLike(X); + dtype* dXdata = dX->mutable_data(); + for (int image_id = 0; image_id < N; ++image_id) { + // Compute gradient into col_buffer. + math::Gemm( + CblasTrans, CblasNoTrans, kernel_dim, output_image_size, M, + kOne.data(), filter_data, dYdata + output_offset * image_id, + kZero.data(), col_buffer_data, &device_context_); + math::Col2im( + col_buffer_data, C, H, W, kernel_h_, kernel_w_, + pad_t_, pad_l_, pad_b_, pad_r_, + stride_h_, stride_w_, dXdata, &device_context_); + dXdata += input_offset; + } + } + return true; +} + +template +bool ConvGradientOp::RunOnDeviceWithOrderNHWC() { + auto& X = Input(INPUT); + auto& filter = Input(FILTER); + auto& bias = Input(BIAS); + auto& dY = Input(OUTPUT_GRAD); + auto* dfilter = Output(FILTER_GRAD); + auto* dbias = Output(BIAS_GRAD); + const int N = X.dim(0), H = X.dim(1), W = X.dim(2), C = X.dim(3); + ConvPoolOpBase::ComputePads(H, W); + DCHECK_EQ(filter.ndim(), 4); + const int M = filter.dim(0); + DCHECK_EQ(filter.dim(1), kernel_h_); + DCHECK_EQ(filter.dim(2), kernel_w_); + DCHECK_EQ(filter.dim(3), C); + DCHECK_EQ(bias.ndim(), 1); + DCHECK_EQ(bias.dim(0), M); + dfilter->ReshapeLike(filter); + dbias->ReshapeLike(bias); + // The dimension of each kernel + const int kernel_dim = kernel_h_ * kernel_w_ * C; + // The offset corresponding to a single input image, and a single output + // image. + const int input_offset = H * W * C; + const int output_offset = dY.size() / dY.dim(0); + // The output image size is the spatial size of the output. + const int output_image_size = dY.dim(1) * dY.dim(2); + // The col buffer is stored in CHW order as well - kernel_dim, and the height + // and width. + col_buffer_.Reshape(std::vector{output_image_size, kernel_dim}); + if (bias_multiplier_.size() != output_image_size) { + // If the helper bias multiplier is not M, reshape and fill it with one. + bias_multiplier_.Reshape(std::vector(1, output_image_size)); + math::Set( + output_image_size, static_cast(1), + bias_multiplier_.mutable_data(), &device_context_); + } + const dtype* Xdata = X.data(); + const dtype* const filter_data = filter.data(); + const dtype* const dYdata = dY.data(); + dtype* col_buffer_data = col_buffer_.mutable_data(); + dtype* dfilter_data = dfilter->mutable_data(); + dtype* dbias_data = dbias->mutable_data(); + // Pre-setting the gradients to zero. + math::Set(dfilter->size(), 0, dfilter_data, + &device_context_); + math::Set(dbias->size(), 0, dbias_data, + &device_context_); + for (int image_id = 0; image_id < N; ++image_id) { + // When we compute the gradient with respect to the filters, we need to do + // im2col to allow gemm-type computation. + math::Im2col( + Xdata, C, H, W, kernel_h_, kernel_w_, + pad_t_, pad_l_, pad_b_, pad_r_, stride_h_, stride_w_, col_buffer_data, + &device_context_); + // Gradient with respect to filter. + math::Gemm( + CblasTrans, CblasNoTrans, M, kernel_dim, output_image_size, + kOne.data(), dYdata + output_offset * image_id, col_buffer_data, + kOne.data(), dfilter_data, &device_context_); + // Gradient with respect to bias + math::Gemv( + CblasTrans, output_image_size, M, kOne.data(), + dYdata + output_offset * image_id, bias_multiplier_.data(), + kOne.data(), dbias_data, &device_context_); + Xdata += input_offset; + } + if (OutputSize() == 3) { + // Compute the gradient w.r.t. the input. + auto *dX = Output(INPUT_GRAD); + dX->ReshapeLike(X); + dtype* dXdata = dX->mutable_data(); + for (int image_id = 0; image_id < N; ++image_id) { + // Compute gradient into col_buffer. + math::Gemm( + CblasNoTrans, CblasNoTrans, output_image_size, kernel_dim, M, + kOne.data(), dYdata + output_offset * image_id, filter_data, + kZero.data(), col_buffer_data, &device_context_); + math::Col2im( + col_buffer_data, C, H, W, kernel_h_, kernel_w_, + pad_t_, pad_l_, pad_b_, pad_r_, + stride_h_, stride_w_, dXdata, &device_context_); + dXdata += input_offset; + } + } + return true; +} +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_CONV_OP_IMPL_H_ diff --git a/caffe2/operators/conv_pool_op_base.h b/caffe2/operators/conv_pool_op_base.h new file mode 100644 index 00000000000..964551304c7 --- /dev/null +++ b/caffe2/operators/conv_pool_op_base.h @@ -0,0 +1,222 @@ +#ifndef CAFFE2_OPERATORS_CONV_POOL_OP_BASE_H_ +#define CAFFE2_OPERATORS_CONV_POOL_OP_BASE_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/proto/caffe2_legacy.pb.h" +#include "caffe2/utils/math.h" +#include "glog/logging.h" + +// This macro is here just to allow us to experiment with padding values that +// determines, when we have an odd number of pads, which side gets the one +// additional pad value, the head side, or the tail side. Setting it to false +// will enable the distbelief behavior, and setting it to true will enable +// a behavior more consistent with Caffe and CuDNN. +const bool PAD_HEAD_MORE = false; + +namespace caffe2 { + +template +class ConvPoolOpBase : public Operator { + public: + USE_OPERATOR_BASE_FUNCTIONS; + ConvPoolOpBase(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + legacy_pad_(static_cast( + OperatorBase::GetSingleArgument( + "legacy_pad", LegacyPadding::NOTSET))), + pad_(OperatorBase::GetSingleArgument("pad", 0)), + pad_t_(OperatorBase::GetSingleArgument("pad_t", 0)), + pad_l_(OperatorBase::GetSingleArgument("pad_l", 0)), + pad_b_(OperatorBase::GetSingleArgument("pad_b", 0)), + pad_r_(OperatorBase::GetSingleArgument("pad_r", 0)), + kernel_h_(OperatorBase::GetSingleArgument( + "kernel_h", OperatorBase::GetSingleArgument("kernel", 0))), + kernel_w_(OperatorBase::GetSingleArgument( + "kernel_w", OperatorBase::GetSingleArgument("kernel", 0))), + stride_h_(OperatorBase::GetSingleArgument( + "stride_h", OperatorBase::GetSingleArgument("stride", 1))), + stride_w_(OperatorBase::GetSingleArgument( + "stride_w", OperatorBase::GetSingleArgument("stride", 1))), + order_(StringToStorageOrder( + OperatorBase::GetSingleArgument("order", "NHWC"))) { + CHECK_GT(kernel_h_, 0); + CHECK_GT(kernel_w_, 0); + // For the padding, they should either be the legacy padding strategy + // (VALID or SAME), or an explicit, non-negative value. + if (legacy_pad_ != LegacyPadding::NOTSET) { + CHECK(!OperatorBase::HasArgument("pad") && + !OperatorBase::HasArgument("pad_t") && + !OperatorBase::HasArgument("pad_l") && + !OperatorBase::HasArgument("pad_b") && + !OperatorBase::HasArgument("pad_r")) + << "If you use legacy padding, you should not specify any specific " + "padding values."; + } else if (OperatorBase::HasArgument("pad")) { + // if pad is set, it overrides the individual values. + pad_t_ = pad_; + pad_l_ = pad_; + pad_b_ = pad_; + pad_t_ = pad_; + } + CHECK_GE(pad_, 0); + CHECK_GE(pad_t_, 0); + CHECK_GE(pad_l_, 0); + CHECK_GE(pad_b_, 0); + CHECK_GE(pad_r_, 0); + CHECK_GT(stride_h_, 0); + CHECK_GT(stride_w_, 0); + } + + // Sets the output size. The output channel is manually provided since + // it may not be identical to the input channels. + // This function can be used in the forward functions to obtain the output + // sizes. + void SetOutputSize(const Tensor& input, + Tensor* output, + int output_channel) { + DCHECK_EQ(input.ndim(), 4); + DCHECK_GT(input.size(), 0); + int N = input.dim(0); + bool channel_first; + int C, H, W; + switch (order_) { + case StorageOrder::NHWC: + channel_first = false; + H = input.dim(1); + W = input.dim(2); + C = input.dim(3); + break; + case StorageOrder::NCHW: + // Old Caffe order. + channel_first = true; + C = input.dim(1); + H = input.dim(2); + W = input.dim(3); + break; + default: + LOG(FATAL) << "Unknown Storage order: " << order_; + } + CHECK_GE(H, kernel_h_); + CHECK_GE(W, kernel_w_); + int output_height, output_width; + ComputeSizeAndPad(H, stride_h_, kernel_h_, + &pad_t_, &pad_b_, &output_height); + ComputeSizeAndPad(W, stride_w_, kernel_w_, + &pad_l_, &pad_r_, &output_width); + if (channel_first) { + output->Reshape( + std::vector{N, output_channel, output_height, output_width}); + } else { + output->Reshape( + std::vector{N, output_height, output_width, output_channel}); + } + DVLOG(2) << "In: N " << N << " C " << C << " H " << H << " W " << W; + DVLOG(2) << "Out: C " << output_channel << " H " << output_height + << " W " << output_width; + } + + // ComputePads could be used in backward functions to figure out the padding + // values for the given input. + void ComputePads(const int height, const int width) { + if (legacy_pad_ != LegacyPadding::NOTSET) { + int output_unused; + ComputeSizeAndPad(height, stride_h_, kernel_h_, + &pad_t_, &pad_b_, &output_unused); + ComputeSizeAndPad(width, stride_w_, kernel_w_, + &pad_l_, &pad_r_, &output_unused); + } + } + + bool RunOnDevice() override { + switch (order_) { + case StorageOrder::NHWC: + DVLOG(2) << "Running NHWC"; + return RunOnDeviceWithOrderNHWC(); + case StorageOrder::NCHW: + DVLOG(2) << "Running NCHW"; + return RunOnDeviceWithOrderNCHW(); + default: + LOG(FATAL) << "Unknown storage order: " << order_; + } + // To suppress old compiler warnings + return true; + } + + // The actual function that does the computation, if the different + // storage order leads to different implementations. + virtual bool RunOnDeviceWithOrderNHWC() { NOT_IMPLEMENTED; return false; } + virtual bool RunOnDeviceWithOrderNCHW() { NOT_IMPLEMENTED; return false; } + + virtual ~ConvPoolOpBase() {} + + protected: + int pad_t_; + int pad_l_; + int pad_b_; + int pad_r_; + int kernel_h_; + int kernel_w_; + int stride_h_; + int stride_w_; + StorageOrder order_; + + inline void ComputeSizeAndPad( + const int in_size, const int stride, const int kernel, + int* pad_head, int* pad_tail, int* out_size) { + if (legacy_pad_ == LegacyPadding::NOTSET) { + // We will just use the direct padding head and tail values, but we + // will verify that they are non-negative. + CHECK_GE(*pad_head, 0); + CHECK_GE(*pad_tail, 0); + *out_size = static_cast( + static_cast(in_size + *pad_head + *pad_tail - kernel) / stride + + 1); + } else { + int legacy_target_size; + switch (legacy_pad_) { + case LegacyPadding::VALID: + legacy_target_size = + std::ceil(static_cast(in_size - kernel + 1) / stride); + break; + case LegacyPadding::SAME: + legacy_target_size = std::ceil(static_cast(in_size) / stride); + break; + default: + LOG(FATAL) << "Unsupported raw pad value."; + } + int pad_needed = (legacy_target_size - 1) * stride + kernel - in_size; + // In legacy padding, if there is an odd padding value, we will need + // to pad more on the tail side. + if (PAD_HEAD_MORE) { + *pad_head = (pad_needed + 1) / 2; + } else { + *pad_head = pad_needed / 2; + } + *pad_tail = pad_needed - *pad_head; + *out_size = static_cast( + static_cast(in_size + pad_needed - kernel) / stride + 1); + } + } + + private: + LegacyPadding legacy_pad_; + int pad_; + DISABLE_COPY_AND_ASSIGN(ConvPoolOpBase); +}; + +#define USE_CONV_POOL_BASE_FUNCTIONS \ + USE_OPERATOR_BASE_FUNCTIONS; \ + using ConvPoolOpBase::pad_t_; \ + using ConvPoolOpBase::pad_l_; \ + using ConvPoolOpBase::pad_b_; \ + using ConvPoolOpBase::pad_r_; \ + using ConvPoolOpBase::kernel_h_; \ + using ConvPoolOpBase::kernel_w_; \ + using ConvPoolOpBase::stride_h_; \ + using ConvPoolOpBase::stride_w_; \ + using ConvPoolOpBase::order_ + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_CONV_POOL_OP_BASE_H_ diff --git a/caffe2/operators/cross_entropy_op.cc b/caffe2/operators/cross_entropy_op.cc new file mode 100644 index 00000000000..a7a16e33d80 --- /dev/null +++ b/caffe2/operators/cross_entropy_op.cc @@ -0,0 +1,58 @@ +#include "caffe2/operators/cross_entropy_op.h" + +namespace caffe2 { + +template <> +bool LabelCrossEntropyOp::RunOnDevice() { + auto& X = Input(0); + auto& label = OperatorBase::Input >(1); + auto* Y = Output(0); + DCHECK_EQ(X.ndim(), 2); + int N = X.dim(0); + int D = X.dim(1); + DCHECK_EQ(label.ndim(), 1); + DCHECK_EQ(label.dim(0), N); + Y->Reshape(std::vector{N}); + const auto* Xdata = X.data(); + const auto* labeldata = label.data(); + auto* Ydata = Y->mutable_data(); + for (int i = 0; i < N; ++i) { + DCHECK_LT(labeldata[i], D); + Ydata[i] = -log(std::max(Xdata[i * D + labeldata[i]], kLOG_THRESHOLD())); + } + return true; +} + +template <> +bool LabelCrossEntropyGradientOp::RunOnDevice() { + auto& X = Input(0); + auto& label = OperatorBase::Input >(1); + auto& dY = Input(2); + auto* dX = Output(0); + DCHECK_EQ(X.ndim(), 2); + int N = X.dim(0); + int D = X.dim(1); + DCHECK_EQ(label.ndim(), 1); + DCHECK_EQ(label.dim(0), N); + DCHECK_EQ(dY.ndim(), 1); + DCHECK_EQ(dY.dim(0), N); + dX->ReshapeLike(X); + math::Set(dX->size(), 0.f, dX->mutable_data(), + &device_context_); + const float* Xdata = X.data(); + const float* dYdata = dY.data(); + const int* labeldata = label.data(); + float* dXdata = dX->mutable_data(); + for (int i = 0; i < N; ++i) { + DCHECK_LT(labeldata[i], D); + dXdata[i * D + labeldata[i]] = + - dYdata[i] / std::max(Xdata[i * D + labeldata[i]], kLOG_THRESHOLD()); + } + return true; +} + +REGISTER_CPU_OPERATOR(LabelCrossEntropy, + LabelCrossEntropyOp) +REGISTER_CPU_OPERATOR(LabelCrossEntropyGradient, + LabelCrossEntropyGradientOp) +} // namespace caffe2 diff --git a/caffe2/operators/cross_entropy_op.cu b/caffe2/operators/cross_entropy_op.cu new file mode 100644 index 00000000000..597be4dac9c --- /dev/null +++ b/caffe2/operators/cross_entropy_op.cu @@ -0,0 +1,70 @@ +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/cross_entropy_op.h" + +namespace caffe2 { + +namespace { +__global__ void LabelCrossEntropyKernel( + const int N, const int D, const float* Xdata, const int* labeldata, + const float log_threshold, float* Ydata) { + CUDA_1D_KERNEL_LOOP(i, N) { + Ydata[i] = -logf(max(Xdata[i * D + labeldata[i]], log_threshold)); + } +} +__global__ void LabelCrossEntropyGradientKernel( + const int N, const int D, const float* Xdata, const int* labeldata, + const float* dYdata, const float log_threshold, float* dXdata) { + CUDA_1D_KERNEL_LOOP(i, N) { + int idx = i * D + labeldata[i]; + dXdata[idx] = - dYdata[i] / max(Xdata[idx], log_threshold); + } +} +} // namespace + +template <> +bool LabelCrossEntropyOp::RunOnDevice() { + auto& X = Input(0); + auto& label = OperatorBase::Input >(1); + auto* Y = Output(0); + DCHECK_EQ(X.ndim(), 2); + int N = X.dim(0); + int D = X.dim(1); + DCHECK_EQ(label.ndim(), 1); + DCHECK_EQ(label.dim(0), N); + Y->Reshape(std::vector(1, N)); + LabelCrossEntropyKernel<<>>( + N, D, X.data(), label.data(), kLOG_THRESHOLD(), Y->mutable_data()); + return true; +} + +template <> +bool LabelCrossEntropyGradientOp::RunOnDevice() { + auto& X = Input(0); + auto& label = OperatorBase::Input >(1); + auto& dY = Input(2); + auto* dX = Output(0); + DCHECK_EQ(X.ndim(), 2); + int N = X.dim(0); + int D = X.dim(1); + DCHECK_EQ(label.ndim(), 1); + DCHECK_EQ(label.dim(0), N); + DCHECK_EQ(dY.ndim(), 1); + DCHECK_EQ(dY.dim(0), N); + dX->ReshapeLike(X); + math::Set( + dX->size(), 0.f, dX->mutable_data(), &device_context_); + LabelCrossEntropyGradientKernel<<>>( + N, D, X.data(), label.data(), dY.data(), kLOG_THRESHOLD(), + dX->mutable_data()); + return true; +} + +namespace { +REGISTER_CUDA_OPERATOR(LabelCrossEntropy, + LabelCrossEntropyOp) +REGISTER_CUDA_OPERATOR(LabelCrossEntropyGradient, + LabelCrossEntropyGradientOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/cross_entropy_op.h b/caffe2/operators/cross_entropy_op.h new file mode 100644 index 00000000000..e3aea92c0d2 --- /dev/null +++ b/caffe2/operators/cross_entropy_op.h @@ -0,0 +1,44 @@ +#ifndef CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_ +#define CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/utils/math.h" +#include "glog/logging.h" + +namespace caffe2 { + +template +class LabelCrossEntropyOp final : public Operator { + public: + USE_SIMPLE_CTOR_DTOR(LabelCrossEntropyOp); + USE_OPERATOR_BASE_FUNCTIONS; + bool RunOnDevice() override; + + protected: + static constexpr dtype kLOG_THRESHOLD() { return 1e-20; } + // Input: X, label + // Output: Y + INPUT_OUTPUT_STATS(2, 2, 1, 1); + DISABLE_COPY_AND_ASSIGN(LabelCrossEntropyOp); +}; + +template +class LabelCrossEntropyGradientOp final + : public Operator { + public: + USE_SIMPLE_CTOR_DTOR(LabelCrossEntropyGradientOp); + USE_OPERATOR_BASE_FUNCTIONS; + bool RunOnDevice() override; + + protected: + // Input: X, label, dY + // Ouptut: dX. There is no gradient with respect to the label. + static constexpr dtype kLOG_THRESHOLD() { return 1e-20; } + INPUT_OUTPUT_STATS(3, 3, 1, 1); + DISABLE_COPY_AND_ASSIGN(LabelCrossEntropyGradientOp); +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_CROSS_ENTROPY_OP_H_ diff --git a/caffe2/operators/db.cc b/caffe2/operators/db.cc new file mode 100644 index 00000000000..07da3b625d0 --- /dev/null +++ b/caffe2/operators/db.cc @@ -0,0 +1,9 @@ +#include "caffe2/operators/db.h" + +namespace caffe2 { +namespace db { + +DEFINE_REGISTRY(Caffe2DBRegistry, DB, const string&, Mode); + +} // namespacd db +} // namespace caffe2 diff --git a/caffe2/operators/depth_split_op.cc b/caffe2/operators/depth_split_op.cc new file mode 100644 index 00000000000..9fc945a6fa5 --- /dev/null +++ b/caffe2/operators/depth_split_op.cc @@ -0,0 +1,9 @@ +#include "caffe2/operators/depth_split_op.h" + +namespace caffe2 { +namespace { +REGISTER_CPU_OPERATOR(DepthSplit, DepthSplitOp) +REGISTER_CPU_OPERATOR(DepthConcat, DepthConcatOp) +} // namespace +} // namespace caffe2 + diff --git a/caffe2/operators/depth_split_op.cu b/caffe2/operators/depth_split_op.cu new file mode 100644 index 00000000000..1c778b1e73d --- /dev/null +++ b/caffe2/operators/depth_split_op.cu @@ -0,0 +1,10 @@ +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/depth_split_op.h" + +namespace caffe2 { +namespace { +REGISTER_CUDA_OPERATOR(DepthSplit, DepthSplitOp) +REGISTER_CUDA_OPERATOR(DepthConcat, DepthConcatOp) +} // namespace +} // namespace caffe2 + diff --git a/caffe2/operators/depth_split_op.h b/caffe2/operators/depth_split_op.h new file mode 100644 index 00000000000..7d7d4cb5fba --- /dev/null +++ b/caffe2/operators/depth_split_op.h @@ -0,0 +1,141 @@ +#ifndef CAFFE2_OPERATORS_DEPTH_SPLIT_OP_H_ +#define CAFFE2_OPERATORS_DEPTH_SPLIT_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/core/types.h" +#include "caffe2/utils/math.h" + +namespace caffe2 { + +template +class DepthSplitOp final : public Operator { + public: + USE_OPERATOR_BASE_FUNCTIONS; + DepthSplitOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + order_(StringToStorageOrder( + OperatorBase::GetSingleArgument("order", "NHWC"))) {} + bool RunOnDevice() override; + + protected: + StorageOrder order_; + // Input: X, dimensions + // The dimensions are stored in CPU. + INPUT_OUTPUT_STATS(2, 2, 1, INT_MAX); + DISABLE_COPY_AND_ASSIGN(DepthSplitOp); +}; + +template +class DepthConcatOp final : public Operator { + public: + DepthConcatOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + order_(StringToStorageOrder( + OperatorBase::GetSingleArgument("order", "NHWC"))) {} + USE_OPERATOR_BASE_FUNCTIONS; + bool RunOnDevice() override; + + protected: + StorageOrder order_; + // Input: a number of tensors. Output: Y, dimensions + // The dimensions are stored in CPU. + INPUT_OUTPUT_STATS(1, INT_MAX, 2, 2); + DISABLE_COPY_AND_ASSIGN(DepthConcatOp); +}; + + +// Implementations +template +bool DepthSplitOp::RunOnDevice() { + auto& input = Input(0); + auto& dimensions = + OperatorBase::Input >(1); + const int* dim_data = dimensions.data(); + DCHECK_EQ(dimensions.size(), OutputSize()); + DCHECK_EQ(std::accumulate(dim_data, dim_data + OutputSize(), 0), + (order_ == StorageOrder::NCHW ? input.dim(1) : input.dim(3))); + int input_offset = 0; + for (int i = 0; i < OutputSize(); ++i) { + auto* output = Output(i); + int M, N, lda; + switch (order_) { + case StorageOrder::NCHW: + output->Reshape(vector{ + input.dim(0), dim_data[i], input.dim(2), input.dim(3)}); + M = input.dim(0); + N = dim_data[i] * input.dim(2) * input.dim(3); + lda = input.size() / input.dim(0); + break; + case StorageOrder::NHWC: + output->Reshape(vector{ + input.dim(0), input.dim(1), input.dim(2), dim_data[i]}); + M = input.dim(0) * input.dim(1) * input.dim(2); + N = dim_data[i]; + lda = input.dim(3); + break; + default: + LOG(FATAL) << "Unsupported storage order: " << order_; + } + math::CopyMatrix( + M, N, input.data() + input_offset, lda, output->mutable_data(), N, + &device_context_); + input_offset += N; + } + return true; +} + +template +bool DepthConcatOp::RunOnDevice() { + auto* output = Output(0); + Tensor* dimensions = + OperatorBase::Output >(1); + dimensions->Reshape(vector(1, InputSize())); + int* dim_data = dimensions->mutable_data(); + int output_channels = 0; + for (int i = 0; i < InputSize(); ++i) { + dim_data[i] = + (order_ == StorageOrder::NCHW ? Input(i).dim(1) : Input(i).dim(3)); + output_channels += dim_data[i]; + } + auto& input_zero = Input(0); + output->Reshape(vector{ + input_zero.dim(0), + order_ == StorageOrder::NCHW ? output_channels : input_zero.dim(1), + order_ == StorageOrder::NCHW ? input_zero.dim(2) : input_zero.dim(2), + order_ == StorageOrder::NCHW ? input_zero.dim(3) : output_channels}); + int output_offset = 0; + for (int i = 0; i < InputSize(); ++i) { + auto& input = Input(i); + int M, N, ldb; + switch (order_) { + case StorageOrder::NCHW: + CHECK_EQ(input.dim(0), output->dim(0)); + CHECK_EQ(input.dim(2), output->dim(2)); + CHECK_EQ(input.dim(3), output->dim(3)); + M = input.dim(0); + N = input.size() / M; + ldb = output->size() / output->dim(0); + break; + case StorageOrder::NHWC: + CHECK_EQ(input.dim(0), output->dim(0)); + CHECK_EQ(input.dim(1), output->dim(1)); + CHECK_EQ(input.dim(2), output->dim(2)); + M = input.dim(0) * input.dim(1) * input.dim(2); + N = input.dim(3); + ldb = output->dim(3); + break; + default: + LOG(FATAL) << "Unsupported storage order: " << order_; + } + math::CopyMatrix( + M, N, input.data(), N, output->mutable_data() + output_offset, ldb, + &device_context_); + output_offset += N; + } + return true; +} + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_DEPTH_SPLIT_OP_H_ diff --git a/caffe2/operators/dropout_op.cc b/caffe2/operators/dropout_op.cc new file mode 100644 index 00000000000..84059efaafc --- /dev/null +++ b/caffe2/operators/dropout_op.cc @@ -0,0 +1,52 @@ +#include "caffe2/operators/dropout_op.h" + +namespace caffe2 { + +template <> +bool DropoutOp::RunOnDevice() { + auto& X = Input(0); + auto* Y = Output(0); + Tensor* mask = + OperatorBase::Output >(1); + Y->Reshape(X.dims()); + mask->Reshape(X.dims()); + DCHECK_GT(X.size(), 0); + float scale = 1. / (1. - ratio_); + // mask=true means keep, and mask=false means not keep, so we will + // generate probability depending on 1-ratio. + std::bernoulli_distribution dist(1. - ratio_); + const float* Xdata = X.data(); + float* Ydata = Y->mutable_data(); + bool* mask_data = mask->mutable_data(); + auto& gen = device_context_.RandGenerator(); + for (int i = 0; i < X.size(); ++i) { + mask_data[i] = dist(gen); + Ydata[i] = Xdata[i] * scale * mask_data[i]; + } + return true; +} + +template <> +bool DropoutGradientOp::RunOnDevice() { + auto& dY = Input(0); + const Tensor& mask = + OperatorBase::Input >(1); + auto* dX = Output(0); + DCHECK_GT(dY.size(), 0); + DCHECK_EQ(dY.size(), mask.size()); + dX->Reshape(dY.dims()); + const float* dYdata = dY.data(); + const bool* mask_data = mask.data(); + float* dXdata = dX->mutable_data(); + for (int i = 0; i < dY.size(); ++i) { + dXdata[i] = dYdata[i] * mask_data[i]; + } + return true; +} + + +namespace { +REGISTER_CPU_OPERATOR(Dropout, DropoutOp) +REGISTER_CPU_OPERATOR(DropoutGrad, DropoutGradientOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/dropout_op.cu b/caffe2/operators/dropout_op.cu new file mode 100644 index 00000000000..ff0c56c06f2 --- /dev/null +++ b/caffe2/operators/dropout_op.cu @@ -0,0 +1,68 @@ +#include "caffe2/operators/dropout_op.h" +#include "caffe2/core/context_gpu.h" + +namespace caffe2 { + +namespace { +__global__ void DropoutKernel(const int N, const float ratio, + const float* Xdata, float* Ydata, + bool* maskdata) { + const float scale = 1. / (1. - ratio); + CUDA_1D_KERNEL_LOOP(i, N) { + maskdata[i] = (Ydata[i] > ratio); + Ydata[i] = Xdata[i] * scale * maskdata[i]; + } +} +} // namespace + +template <> +bool DropoutOp::RunOnDevice() { + auto& X = Input(0); + auto* Y = Output(0); + auto* mask = OperatorBase::Output >(1); + Y->Reshape(X.dims()); + mask->Reshape(X.dims()); + DCHECK_GT(X.size(), 0); + // We do a simple trick here: since curand cannot generate random + // boolean numbers, we will generate into dY and write the result to + // mask. + float* Ydata = Y->mutable_data(); + CURAND_CHECK(curandGenerateUniform( + device_context_.curand_generator(), Ydata, X.size())); + DropoutKernel<<>>( + X.size(), ratio_, X.data(), Ydata, mask->mutable_data()); + return true; +} + +namespace { +__global__ void DropoutGradientKernel(const int N, const float* dYdata, + const bool* maskdata, float* dXdata) { + CUDA_1D_KERNEL_LOOP(i, N) { + dXdata[i] = dYdata[i] * maskdata[i]; + } +} +} // namespace + +template <> +bool DropoutGradientOp::RunOnDevice() { + auto& dY = Input(0); + auto& mask = + OperatorBase::Input >(1); + auto* dX = Output(0); + DCHECK_GT(dY.size(), 0); + DCHECK_EQ(dY.size(), mask.size()); + dX->Reshape(dY.dims()); + DropoutGradientKernel<<>>( + dY.size(), dY.data(), mask.data(), dX->mutable_data()); + return true; +} + + +namespace { +REGISTER_CUDA_OPERATOR(Dropout, DropoutOp) +REGISTER_CUDA_OPERATOR(DropoutGrad, DropoutGradientOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/dropout_op.h b/caffe2/operators/dropout_op.h new file mode 100644 index 00000000000..5f2c35a61f5 --- /dev/null +++ b/caffe2/operators/dropout_op.h @@ -0,0 +1,53 @@ +#ifndef CAFFE2_OPERATORS_DROPOUT_OP_H_ +#define CAFFE2_OPERATORS_DROPOUT_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/utils/math.h" +#include "glog/logging.h" + +namespace caffe2 { + +template +class DropoutOp final : public Operator { + public: + USE_OPERATOR_BASE_FUNCTIONS; + DropoutOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + ratio_(OperatorBase::GetSingleArgument("ratio", 0.5)) { + DCHECK_GT(ratio_, 0); + DCHECK_LT(ratio_, 1); + } + + bool RunOnDevice(); + + protected: + float ratio_; + // Input: X; Output: Y, mask. + INPUT_OUTPUT_STATS(1, 1, 2, 2); + DISABLE_COPY_AND_ASSIGN(DropoutOp); +}; + +template +class DropoutGradientOp final : public Operator { + public: + USE_OPERATOR_BASE_FUNCTIONS; + DropoutGradientOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + ratio_(OperatorBase::GetSingleArgument("ratio", 0.5)) { + DCHECK_GT(ratio_, 0); + DCHECK_LT(ratio_, 1); + } + + bool RunOnDevice(); + + protected: + float ratio_; + // Input: dY, mask; Output: dX + INPUT_OUTPUT_STATS(2, 2, 1, 1); + DISABLE_COPY_AND_ASSIGN(DropoutGradientOp); +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_DROPOUT_OP_H_ diff --git a/caffe2/operators/elementwise_op.cc b/caffe2/operators/elementwise_op.cc new file mode 100644 index 00000000000..ca51aaf0617 --- /dev/null +++ b/caffe2/operators/elementwise_op.cc @@ -0,0 +1,12 @@ +#include "caffe2/operators/elementwise_op.h" + +namespace caffe2 { +namespace { + +REGISTER_CPU_OPERATOR(Add, AddOp) +REGISTER_CPU_OPERATOR(Sub, SubOp) +REGISTER_CPU_OPERATOR(Mul, MulOp) +REGISTER_CPU_OPERATOR(Div, DivOp) + +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/elementwise_op.h b/caffe2/operators/elementwise_op.h new file mode 100644 index 00000000000..78f74506640 --- /dev/null +++ b/caffe2/operators/elementwise_op.h @@ -0,0 +1,54 @@ +#ifndef CAFFE2_OPERATORS_ELEMENTWISE_OP_H_ +#define CAFFE2_OPERATORS_ELEMENTWISE_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/utils/math.h" +#include "glog/logging.h" + +namespace caffe2 { + +template +class BinaryElementwiseOp : public Operator { + public: + USE_OPERATOR_BASE_FUNCTIONS; + USE_SIMPLE_CTOR_DTOR(BinaryElementwiseOp); + + bool RunOnDevice() final { + auto& input0 = Input(0); + auto& input1 = Input(1); + auto* output = Output(0); + CHECK_EQ(input0.size(), input1.size()); + output->ReshapeLike(input0); + Functor()(input0.size(), input0.data(), input1.data(), + output->mutable_data(), &device_context_); + return true; + } + + INPUT_OUTPUT_STATS(2, 2, 1, 1); + DISABLE_COPY_AND_ASSIGN(BinaryElementwiseOp); +}; + + +#define CAFFE2_BINARY_FUNCTOR_WRAPPER(name) \ +template \ +struct name##Functor { \ + inline void operator()(const int n, const dtype* x, const dtype* y, \ + dtype* output, DeviceContext* device_context) { \ + math::name(n, x, y, output, device_context); \ + } \ +}; \ +template \ +using name##Op = \ + BinaryElementwiseOp > + + +CAFFE2_BINARY_FUNCTOR_WRAPPER(Add); +CAFFE2_BINARY_FUNCTOR_WRAPPER(Sub); +CAFFE2_BINARY_FUNCTOR_WRAPPER(Mul); +CAFFE2_BINARY_FUNCTOR_WRAPPER(Div); +#undef CAFFE2_BINARY_FUNCTOR_WRAPPER + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_ELEMENTWISE_OP_H_ diff --git a/caffe2/operators/elementwise_op_gpu.cc b/caffe2/operators/elementwise_op_gpu.cc new file mode 100644 index 00000000000..448aea6d510 --- /dev/null +++ b/caffe2/operators/elementwise_op_gpu.cc @@ -0,0 +1,13 @@ +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/elementwise_op.h" + +namespace caffe2 { +namespace { + +REGISTER_CUDA_OPERATOR(Add, AddOp) +REGISTER_CUDA_OPERATOR(Sub, SubOp) +REGISTER_CUDA_OPERATOR(Mul, MulOp) +REGISTER_CUDA_OPERATOR(Div, DivOp) + +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/filler_op.cc b/caffe2/operators/filler_op.cc new file mode 100644 index 00000000000..bd24a75ba95 --- /dev/null +++ b/caffe2/operators/filler_op.cc @@ -0,0 +1,25 @@ +#include "caffe2/operators/filler_op.h" + +namespace caffe2 { + +template <> +bool RangeFillOp::Fill( + Tensor* output) { + float* data = output->mutable_data(); + for (int i = 0; i < output->size(); ++i) { + data[i] = i; + } + return true; +} + +namespace { + +REGISTER_CPU_OPERATOR(UniformFill, UniformFillOp) +REGISTER_CPU_OPERATOR(ConstantFill, ConstantFillOp) +REGISTER_CPU_OPERATOR(GivenTensorFill, GivenTensorFillOp) +REGISTER_CPU_OPERATOR(GaussianFill, GaussianFillOp) +REGISTER_CPU_OPERATOR(XavierFill, XavierFillOp) +REGISTER_CPU_OPERATOR(RangeFill, RangeFillOp) + +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/filler_op.cu b/caffe2/operators/filler_op.cu new file mode 100644 index 00000000000..9fc000c1a96 --- /dev/null +++ b/caffe2/operators/filler_op.cu @@ -0,0 +1,34 @@ +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/filler_op.h" + +namespace caffe2 { + +namespace { +__global__ void FillRangeKernel(const int n, float* data) { + CUDA_1D_KERNEL_LOOP(index, n) { + data[index] = index; + } +} +} + +template <> +bool RangeFillOp::Fill( + Tensor* output) { + int N = output->size(); + FillRangeKernel<<>>( + N, output->mutable_data()); + return true; +} + +namespace { + +REGISTER_CUDA_OPERATOR(UniformFill, UniformFillOp) +REGISTER_CUDA_OPERATOR(ConstantFill, ConstantFillOp) +REGISTER_CUDA_OPERATOR(GivenTensorFill, GivenTensorFillOp) +REGISTER_CUDA_OPERATOR(GaussianFill, GaussianFillOp) +REGISTER_CUDA_OPERATOR(XavierFill, XavierFillOp) +REGISTER_CUDA_OPERATOR(RangeFill, RangeFillOp) + +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/filler_op.h b/caffe2/operators/filler_op.h new file mode 100644 index 00000000000..c26fb251b7a --- /dev/null +++ b/caffe2/operators/filler_op.h @@ -0,0 +1,185 @@ +#ifndef CAFFE2_OPERATORS_FILLER_OP_H_ +#define CAFFE2_OPERATORS_FILLER_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/utils/math.h" +#include "glog/logging.h" + +namespace caffe2 { + +template +class FillerOp : public Operator { + public: + FillerOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + shape_(OperatorBase::GetRepeatedArgument("shape")), + run_once_(OperatorBase::GetSingleArgument("run_once", true)), + already_run_(false) {} + virtual ~FillerOp() {} + USE_OPERATOR_BASE_FUNCTIONS; + + bool RunOnDevice() override { + if (run_once_ && !already_run_) { + already_run_ = true; + auto* output = Operator::Output(0); + if (InputSize()) { + if (shape_.size() != 0) { + LOG(ERROR) << "Cannot set the shape argument and pass in an input at " + "the same time."; + return false; + } + output->ReshapeLike(Input(0)); + } else { + output->Reshape(shape_); + } + return Fill(output); + } + return true; + } + + virtual bool Fill(Tensor* output) = 0; + + protected: + vector shape_; + bool run_once_; + bool already_run_; + // FillerOp takes in either zero or one input. If the number of input is + // 1, the shape will be identical to that of the input at run time, and + // in that case the "shape" parameter should not be set. + INPUT_OUTPUT_STATS(0, 1, 1, 1); + DISABLE_COPY_AND_ASSIGN(FillerOp); +}; + +template +class UniformFillOp final : public FillerOp { + public: + USE_OPERATOR_BASE_FUNCTIONS; + UniformFillOp(const OperatorDef& operator_def, Workspace* ws) + : FillerOp(operator_def, ws), + min_(OperatorBase::template GetSingleArgument("min", 0)), + max_(OperatorBase::template GetSingleArgument("max", 1)) { + DCHECK_LT(min_, max_) << "Max value should be bigger than min value."; + } + + bool Fill(Tensor* output) override { + math::RandUniform( + output->size(), min_, max_, + output->mutable_data(), &device_context_); + return true; + } + + private: + dtype min_; + dtype max_; + DISABLE_COPY_AND_ASSIGN(UniformFillOp); +}; + +template +class ConstantFillOp final : public FillerOp { + public: + USE_OPERATOR_BASE_FUNCTIONS; + ConstantFillOp(const OperatorDef& operator_def, Workspace* ws) + : FillerOp(operator_def, ws), + value_(OperatorBase::template GetSingleArgument("value", 0)) {} + + bool Fill(Tensor* output) override { + math::Set( + output->size(), value_, output->mutable_data(), &device_context_); + return true; + } + + private: + dtype value_; + DISABLE_COPY_AND_ASSIGN(ConstantFillOp); +}; + +template +class GivenTensorFillOp final : public FillerOp { + public: + USE_OPERATOR_BASE_FUNCTIONS; + GivenTensorFillOp(const OperatorDef& operator_def, Workspace* ws) + : FillerOp(operator_def, ws) { + auto source_values = OperatorBase::template GetRepeatedArgument( + "values"); + for (float& f : source_values) { + values_.push_back(static_cast(f)); + } + } + + bool Fill(Tensor* output) override { + DCHECK_EQ(output->size(), values_.size()) + << "output size: " << output->size() << " given size: " + << values_.size(); + device_context_.template Copy( + output->mutable_data(), values_.data(), output->size()); + return true; + } + + private: + vector values_; + DISABLE_COPY_AND_ASSIGN(GivenTensorFillOp); +}; + +template +class GaussianFillOp final : public FillerOp { + public: + USE_OPERATOR_BASE_FUNCTIONS; + GaussianFillOp(const OperatorDef& operator_def, Workspace* ws) + : FillerOp(operator_def, ws), + mean_(OperatorBase::template GetSingleArgument("mean", 0)), + std_(OperatorBase::template GetSingleArgument("std", 1)) { + DCHECK_GT(std_, 0) + << "Standard deviation should be nonnegative."; + } + + bool Fill(Tensor* output) override { + math::RandGaussian( + output->size(), mean_, std_, output->mutable_data(), + &device_context_); + return true; + } + + private: + dtype mean_; + dtype std_; + DISABLE_COPY_AND_ASSIGN(GaussianFillOp); +}; + +template +class XavierFillOp final : public FillerOp { + public: + USE_OPERATOR_BASE_FUNCTIONS; + XavierFillOp(const OperatorDef& operator_def, Workspace* ws) + : FillerOp(operator_def, ws) {} + + bool Fill(Tensor* output) override { + const int fan_in = output->size() / output->dim(0); + dtype scale = sqrt(dtype(3) / fan_in); + math::RandUniform( + output->size(), -scale, scale, + output->mutable_data(), &device_context_); + return true; + } + + DISABLE_COPY_AND_ASSIGN(XavierFillOp); +}; + + +// This is mostly used just as a debugging purpose stuff: it fills a tensor +// sequentially with values 0, 1, 2..., which can then be used to check e.g. +// reshape operations by allowing one to read the indices more easily. +template +class RangeFillOp final : public FillerOp { + public: + USE_OPERATOR_BASE_FUNCTIONS; + RangeFillOp(const OperatorDef& operator_def, Workspace* ws) + : FillerOp(operator_def, ws) {} + + bool Fill(Tensor* output) override; + DISABLE_COPY_AND_ASSIGN(RangeFillOp); +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_FILLER_OP_H_ diff --git a/caffe2/operators/fully_connected_op.cc b/caffe2/operators/fully_connected_op.cc new file mode 100644 index 00000000000..fed3171a4fb --- /dev/null +++ b/caffe2/operators/fully_connected_op.cc @@ -0,0 +1,10 @@ +#include "caffe2/operators/fully_connected_op.h" + +namespace caffe2 { +namespace { + +REGISTER_CPU_OPERATOR(FC, FullyConnectedOp); +REGISTER_CPU_OPERATOR(FCGradient, FullyConnectedGradientOp); + +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/fully_connected_op.h b/caffe2/operators/fully_connected_op.h new file mode 100644 index 00000000000..1bb68f18339 --- /dev/null +++ b/caffe2/operators/fully_connected_op.h @@ -0,0 +1,147 @@ +#ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_ +#define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/utils/math.h" + +namespace caffe2 { + +// This is Caffe's InnerProductOp, with a name that fits its purpose better. +template +class FullyConnectedOp final : public Operator { + public: + USE_OPERATOR_BASE_FUNCTIONS; + FullyConnectedOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + kOne(static_cast(1), &device_context_), + kZero(static_cast(0), &device_context_) {} + ~FullyConnectedOp() {} + + bool RunOnDevice() final { + const auto& X = Input(0); + const auto& W = Input(1); + const auto& b = Input(2); + auto* Y = Output(0); + DCHECK_GE(X.ndim(), 2); + DCHECK_GE(W.ndim(), 2); + if (X.ndim() > 2 || W.ndim() > 2) { + VLOG(1) << "Using legacy support for arbitrary input and weight " + << "dimensions."; + } + DCHECK_EQ(b.ndim(), 1); + // batch size + int M = X.dim(0); + // Feature dimension + int K = X.size() / X.dim(0); + // number of outputs. + int N = W.dim(0); + DCHECK_EQ(K, W.size() / W.dim(0)); + DCHECK_EQ(N, b.dim(0)); + Y->Reshape(vector{M, N}); + // W * x + math::Gemm( + CblasNoTrans, CblasTrans, M, N, K, kOne.data(), X.data(), + W.data(), kZero.data(), Y->mutable_data(), &device_context_); + // Add bias term + if (bias_multiplier_.size() != M) { + // If the helper bias multiplier is not M, reshape and fill it with one. + bias_multiplier_.Reshape(std::vector{M}); + math::Set( + M, static_cast(1), bias_multiplier_.mutable_data(), + &device_context_); + } + math::Gemm( + CblasNoTrans, CblasNoTrans, M, N, 1, kOne.data(), + bias_multiplier_.data(), b.data(), kOne.data(), + Y->mutable_data(), &device_context_); + return true; + } + + protected: + Tensor bias_multiplier_; + Tensor kOne; + Tensor kZero; + // We force this Op to have 3 inputs, since that is almost always the case in + // deep networks. + INPUT_OUTPUT_STATS(3, 3, 1, 1); + DISABLE_COPY_AND_ASSIGN(FullyConnectedOp); +}; + +template +class FullyConnectedGradientOp : public Operator { + public: + USE_OPERATOR_BASE_FUNCTIONS; + FullyConnectedGradientOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + kOne(static_cast(1), &device_context_), + kZero(static_cast(0), &device_context_) {} + ~FullyConnectedGradientOp() {} + + bool RunOnDevice() final { + const auto& X = Input(0); + const auto& W = Input(1); + const auto& b = Input(2); + const auto& dY = Input(3); + auto* dW = Output(0); + auto* db = Output(1); + dW->ReshapeLike(W); + db->ReshapeLike(b); + DCHECK_GE(X.ndim(), 2); + DCHECK_GE(W.ndim(), 2); + DCHECK_EQ(b.ndim(), 1); + DCHECK_EQ(dY.ndim(), 2); + // batch size + int M = X.dim(0); + // Feature dimension + int K = X.size() / X.dim(0); + // number of outputs. + int N = W.dim(0); + DCHECK_EQ(K, W.size() / W.dim(0)); + DCHECK_EQ(N, b.dim(0)); + DCHECK_EQ(M, dY.dim(0)); + DCHECK_EQ(N, dY.dim(1)); + + // Compute dW + math::Gemm( + CblasTrans, CblasNoTrans, N, K, M, kOne.data(), dY.data(), + X.data(), kZero.data(), dW->mutable_data(), &device_context_); + if (bias_multiplier_.size() != M) { + // If the helper bias multiplier is not M, reshape and fill it with one. + bias_multiplier_.Reshape(std::vector{M}); + math::Set( + M, static_cast(1), bias_multiplier_.mutable_data(), + &device_context_); + } + // Compute dB + math::Gemv( + CblasTrans, M, N, kOne.data(), dY.data(), + bias_multiplier_.data(), kZero.data(), db->mutable_data(), + &device_context_); + // Compute dX if necessary. + if (OutputSize() == 3) { + auto* dX = Output(2); + dX->ReshapeLike(X); + math::Gemm( + CblasNoTrans, CblasNoTrans, M, K, N, kOne.data(), + dY.data(), W.data(), kZero.data(), dX->mutable_data(), + &device_context_); + } + + return true; + } + + protected: + Tensor bias_multiplier_; + Tensor kOne; + Tensor kZero; + + // input: X, W, b, dY + // output: dW, db, and optionally dX. + INPUT_OUTPUT_STATS(4, 4, 2, 3); + DISABLE_COPY_AND_ASSIGN(FullyConnectedGradientOp); +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_ diff --git a/caffe2/operators/fully_connected_op_gpu.cc b/caffe2/operators/fully_connected_op_gpu.cc new file mode 100644 index 00000000000..8ee67acd0b5 --- /dev/null +++ b/caffe2/operators/fully_connected_op_gpu.cc @@ -0,0 +1,10 @@ +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/fully_connected_op.h" + +namespace caffe2 { +namespace { +REGISTER_CUDA_OPERATOR(FC, FullyConnectedOp); +REGISTER_CUDA_OPERATOR(FCGradient, + FullyConnectedGradientOp); +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/fully_connected_op_test.cc b/caffe2/operators/fully_connected_op_test.cc new file mode 100644 index 00000000000..eae14599f7f --- /dev/null +++ b/caffe2/operators/fully_connected_op_test.cc @@ -0,0 +1,48 @@ +#include + +#include "caffe2/operators/fully_connected_op.h" +#include "gflags/gflags.h" +#include "gtest/gtest.h" + +DECLARE_string(caffe_test_root); + +namespace caffe2 { + +static void AddConstInput(const std::vector& shape, const float value, + const string& name, Workspace* ws) { + DeviceOption option; + CPUContext context(option); + Blob* blob = ws->CreateBlob(name); + auto* tensor = blob->GetMutable >(); + tensor->Reshape(shape); + math::Set(tensor->size(), value, tensor->mutable_data(), + &context); + return; +} + +TEST(FullyConnectedTest, Test) { + Workspace ws; + OperatorDef def; + def.set_name("test"); + def.set_type("FC"); + def.add_inputs("X"); + def.add_inputs("W"); + def.add_inputs("B"); + def.add_outputs("Y"); + AddConstInput(std::vector{5, 10}, 1., "X", &ws); + AddConstInput(std::vector{6, 10}, 1., "W", &ws); + AddConstInput(std::vector{6}, 0.1, "B", &ws); + unique_ptr op(CreateOperator(def, &ws)); + EXPECT_NE(nullptr, op.get()); + EXPECT_TRUE(op->Run()); + Blob* Yblob = ws.GetBlob("Y"); + EXPECT_NE(nullptr, Yblob); + auto& Y = Yblob->Get >(); + EXPECT_EQ(Y.size(), 5 * 6); + for (int i = 0; i < Y.size(); ++i) { + CHECK_LT(Y.data()[i], 10.11); + CHECK_GT(Y.data()[i], 10.09); + } +} + +} // namespace caffe2 diff --git a/caffe2/operators/l2_distance_op.cc b/caffe2/operators/l2_distance_op.cc new file mode 100644 index 00000000000..1ea15cf9107 --- /dev/null +++ b/caffe2/operators/l2_distance_op.cc @@ -0,0 +1,38 @@ +#include "caffe2/operators/l2_distance_op.h" + +namespace caffe2 { + +template<> +bool SquaredL2DistanceOp::RunOnDevice() { + auto& X = Input(0); + auto& Y = Input(1); + auto* distance = Output(0); + DCHECK_EQ(X.ndim(), Y.ndim()); + for (int i = 0; i < X.ndim(); ++i) { + DCHECK_EQ(X.dim(i), Y.dim(i)); + } + int N = X.dim(0); + int D = X.size() / X.dim(0); + distance->Reshape(std::vector{N}); + float* distance_data = distance->mutable_data(); + for (int i = 0; i < N; ++i) { + float Xscale, Yscale, cross; + math::Dot( + D, X.data(), X.data(), &Xscale, &device_context_); + math::Dot( + D, Y.data(), Y.data(), &Yscale, &device_context_); + math::Dot( + D, X.data(), Y.data(), &cross, &device_context_); + distance_data[i] = (Xscale + Yscale) / 2. - cross; + } + return true; +} + +namespace { +REGISTER_CPU_OPERATOR(SquaredL2Distance, + SquaredL2DistanceOp); +REGISTER_CPU_OPERATOR(SquaredL2DistanceGradient, + SquaredL2DistanceGradientOp); + +} +} // namespace caffe2 diff --git a/caffe2/operators/l2_distance_op.cu b/caffe2/operators/l2_distance_op.cu new file mode 100644 index 00000000000..1108a58d871 --- /dev/null +++ b/caffe2/operators/l2_distance_op.cu @@ -0,0 +1,48 @@ +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/l2_distance_op.h" + +namespace caffe2 { + +namespace { +// TODO(Yangqing): This function does very aweful memory access. +// Need improvement. +template +__global__ void SquaredL2DistanceKernel( + const int N, const int D, const dtype* X, const dtype* Y, dtype* distance) { + CUDA_1D_KERNEL_LOOP(i, N) { + distance[i] = 0; + for (int j = 0; j < D; ++j) { + dtype diff = X[i * D + j] - Y[i * D + j]; + distance[i] += diff * diff; + } + distance[i] /= 2; + } +} +} // namespace + +template<> +bool SquaredL2DistanceOp::RunOnDevice() { + auto& X = Input(0); + auto& Y = Input(1); + auto* distance = Output(0); + DCHECK_EQ(X.ndim(), Y.ndim()); + for (int i = 0; i < X.ndim(); ++i) { + DCHECK_EQ(X.dim(i), Y.dim(i)); + } + int N = X.dim(0); + int D = X.size() / X.dim(0); + distance->Reshape(std::vector(1, N)); + SquaredL2DistanceKernel<<>>( + N, D, X.data(), Y.data(), distance->mutable_data()); + return true; +} + + +namespace { +REGISTER_CUDA_OPERATOR(SquaredL2Distance, + SquaredL2DistanceOp); +REGISTER_CUDA_OPERATOR(SquaredL2DistanceGradient, + SquaredL2DistanceGradientOp); +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/l2_distance_op.h b/caffe2/operators/l2_distance_op.h new file mode 100644 index 00000000000..3e3d4753552 --- /dev/null +++ b/caffe2/operators/l2_distance_op.h @@ -0,0 +1,72 @@ +#ifndef CAFFE2_OPERATORS_L2_DISTANCE_OP_H_ +#define CAFFE2_OPERATORS_L2_DISTANCE_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/utils/math.h" + +namespace caffe2 { + +template +class SquaredL2DistanceOp : public Operator { + public: + SquaredL2DistanceOp(const OperatorDef& def, Workspace* ws) + : Operator(def, ws) {} + USE_OPERATOR_BASE_FUNCTIONS; + + bool RunOnDevice() override; + + protected: + // Input: X, Y; Output: Distance + INPUT_OUTPUT_STATS(2, 2, 1, 1); + DISABLE_COPY_AND_ASSIGN(SquaredL2DistanceOp); +}; + +template +class SquaredL2DistanceGradientOp final + : public Operator { + public: + SquaredL2DistanceGradientOp(const OperatorDef& def, Workspace* ws) + : Operator(def, ws) {} + USE_OPERATOR_BASE_FUNCTIONS; + + bool RunOnDevice() override { + auto& X = Input(0); + auto& Y = Input(1); + auto& dDistance = Input(2); + auto* dX = Output(0); + auto* dY = Output(1); + DCHECK_EQ(X.ndim(), 2); + int N = X.dim(0); + int D = X.dim(1); + DCHECK_EQ(Y.ndim(), 2); + DCHECK_EQ(Y.dim(0), N); + DCHECK_EQ(Y.dim(1), D); + DCHECK_EQ(dDistance.ndim(), 1); + DCHECK_EQ(dDistance.dim(0), N); + dX->ReshapeLike(X); + dY->ReshapeLike(Y); + math::Sub( + X.size(), X.data(), Y.data(), dX->mutable_data(), &device_context_); + for (int i = 0; i < N; ++i) { + math::Scale( + D, dDistance.data() + i, dX->data() + i * D, + dX->mutable_data() + i * D, &device_context_); + } + // The gradient of the other side is basically the negative. + const Tensor gNegativeOne(-1, &device_context_); + math::Scale( + X.size(), gNegativeOne.data(), dX->data(), dY->mutable_data(), + &device_context_); + return true; + } + + protected: + // Input: X, Y, dDistance; Output: dX, dY + INPUT_OUTPUT_STATS(3, 3, 2, 2); + DISABLE_COPY_AND_ASSIGN(SquaredL2DistanceGradientOp); +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_L2_DISTANCE_OP_H_ diff --git a/caffe2/operators/load_save_op.cc b/caffe2/operators/load_save_op.cc new file mode 100644 index 00000000000..c52d10589eb --- /dev/null +++ b/caffe2/operators/load_save_op.cc @@ -0,0 +1,8 @@ +#include "caffe2/operators/load_save_op.h" + +namespace caffe2 { +namespace { +REGISTER_CPU_OPERATOR(LoadFloatTensor, LoadFloatTensorOp); +REGISTER_CPU_OPERATOR(SaveFloatTensor, SaveFloatTensorOp); +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/load_save_op.cu b/caffe2/operators/load_save_op.cu new file mode 100644 index 00000000000..2e824417edf --- /dev/null +++ b/caffe2/operators/load_save_op.cu @@ -0,0 +1,9 @@ +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/load_save_op.h" + +namespace caffe2 { +namespace { +REGISTER_CUDA_OPERATOR(LoadFloatTensor, LoadFloatTensorOp); +REGISTER_CUDA_OPERATOR(SaveFloatTensor, SaveFloatTensorOp); +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/load_save_op.h b/caffe2/operators/load_save_op.h new file mode 100644 index 00000000000..8e257dcd2d7 --- /dev/null +++ b/caffe2/operators/load_save_op.h @@ -0,0 +1,91 @@ +#ifndef CAFFE2_OPERATORS_LOAD_SAVE_OP_H_ +#define CAFFE2_OPERATORS_LOAD_SAVE_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/utils/math.h" +#include "caffe2/utils/proto_utils.h" +#include "glog/logging.h" + +namespace caffe2 { + +// LoadFloatTensorOp is a very simple operator that loads a TensorProto stored +// on disk. The TensorProto should only be stored in float form. +template +class LoadFloatTensorOp final : public Operator { + public: + LoadFloatTensorOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + filename_(OperatorBase::GetSingleArgument("filename", "")) { + CHECK_GT(filename_.size(), 0) << "Must specify an input file."; + } + + bool RunOnDevice() override { + TensorProtos protos; + CHECK(ReadProtoFromFile(filename_, &protos)); + // TODO(Yangqing): Add capability to allow loading a subset of the protos. + CHECK_EQ(protos.protos_size(), OperatorBase::OutputSize()) + << "Inconsistent number of tensors."; + int i = 0; + for (const auto& proto : protos.protos()) { + CHECK_GT(proto.dims_size(), 0); + CHECK_EQ(proto.data_type(), TensorProto::FLOAT); + auto* output = OperatorBase::Output >(i); + output->Reshape(vector(proto.dims().begin(), proto.dims().end())); + CHECK_EQ(output->size(), proto.float_data_size()); + this->device_context_.template Copy( + output->mutable_data(), proto.float_data().data(), output->size()); + VLOG(1) << "Loaded tensor " << this->def().outputs(i); + ++i; + } + return true; + } + + private: + string filename_; + INPUT_OUTPUT_STATS(0, 0, 1, INT_MAX); + DISABLE_COPY_AND_ASSIGN(LoadFloatTensorOp); +}; + +// SaveFloatTensorOp is a very simple operator that loads a TensorProto stored +// on disk. The TensorProto should only be stored in float form. +template +class SaveFloatTensorOp final : public Operator { + public: + SaveFloatTensorOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + filename_(OperatorBase::GetSingleArgument("filename", "")) {} + + bool RunOnDevice() override { + TensorProtos protos; + for (int i = 0; i < OperatorBase::InputSize(); ++i) { + auto& input = OperatorBase::Input >(i); + auto* proto = protos.add_protos(); + proto->set_data_type(TensorProto::FLOAT); + proto->set_name(OperatorBase::def().inputs(i)); + for (int dim : input.dims()) { + proto->add_dims(dim); + } + // Note(Yangqing): there is no way in protobuffer to resize a repeated + // field, so we have to do reserve and insert dummy zeros. + proto->mutable_float_data()->Reserve(input.size()); + for (int i = 0; i < input.size(); ++i) { + proto->add_float_data(0); + } + this->device_context_.template Copy( + proto->mutable_float_data()->mutable_data(), + input.data(), input.size()); + } + WriteProtoToBinaryFile(protos, filename_); + return true; + } + + private: + string filename_; + INPUT_OUTPUT_STATS(1, INT_MAX, 0, 0); + DISABLE_COPY_AND_ASSIGN(SaveFloatTensorOp); +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_LOAD_SAVE_OP_H_ diff --git a/caffe2/operators/local_response_normalization_op.cc b/caffe2/operators/local_response_normalization_op.cc new file mode 100644 index 00000000000..e14270a879d --- /dev/null +++ b/caffe2/operators/local_response_normalization_op.cc @@ -0,0 +1,236 @@ +#include "caffe2/operators/local_response_normalization_op.h" + +namespace caffe2 { + +template<> +bool LRNOp::RunOnDeviceWithOrderNCHW() { + // Note(Yangqing): this one is copied from my Caffe implementation. + auto& X = Input(0); + auto* Y = Output(0); + auto* scale = Output(1); + DCHECK_EQ(X.ndim(), 4); + const int N = X.dim(0); + const int C = X.dim(1); + const int H = X.dim(2); + const int W = X.dim(3); + const int image_size = C * H * W; + const float* Xdata = X.data(); + Y->ReshapeLike(X); + scale->ReshapeLike(X); + float* Ydata = Y->mutable_data(); + float* scale_data = scale->mutable_data(); + math::Set(X.size(), bias_, scale_data, &device_context_); + Tensor padded_square( + std::vector{C + size_ - 1, H, W}); + float* padded_square_data = padded_square.mutable_data(); + math::Set(padded_square.size(), 0., padded_square_data, + &device_context_); + const float alpha_over_size = alpha_ / size_; + // go through the images + for (int n = 0; n < N; ++n) { + // compute the padded square + math::Sqr(image_size, Xdata + image_size * n, + padded_square_data + pre_pad_ * H * W, + &device_context_); + // Create the first channel scale + for (int c = 0; c < size_; ++c) { + math::Axpy( + H * W, &alpha_over_size, padded_square_data + c * H * W, + scale_data + image_size * n, &device_context_); + } + for (int c = 1; c < C; ++c) { + float* this_scale_slice = scale_data + n * image_size + c * H * W; + // copy previous scale + device_context_.Copy( + this_scale_slice, this_scale_slice - H * W, H * W); + // add head + math::Axpy( + H * W, &alpha_over_size, padded_square_data + (c + size_ - 1) * H * W, + this_scale_slice, &device_context_); + // subtract tail + // negative_aos is in order to cope with math::Axpy's requirement. + const float negative_aos = -alpha_over_size; + math::Axpy( + H * W, &negative_aos, padded_square_data + (c - 1) * H * W, + this_scale_slice, &device_context_); + } + } + math::Powx( + X.size(), scale_data, -beta_, Ydata, &device_context_); + math::Mul(X.size(), Ydata, Xdata, Ydata, &device_context_); + return true; +} + +template<> +bool LRNOp::RunOnDeviceWithOrderNHWC() { + // Note(Yangqing): This one is copied from my Decaf implementation. How many + // variants have I written...? + auto& X = Input(0); + auto* Y = Output(0); + auto* scale = Output(1); + DCHECK_EQ(X.ndim(), 4); + const int N = X.dim(0); + const int H = X.dim(1); + const int W = X.dim(2); + const int C = X.dim(3); + const int num_rows = N * H * W; + const float* Xdata = X.data(); + Y->ReshapeLike(X); + scale->ReshapeLike(X); + float* Ydata = Y->mutable_data(); + float* scale_data = scale->mutable_data(); + + Tensor padded_square(std::vector(1, C + size_ - 1)); + float* padded_square_data = padded_square.mutable_data(); + math::Set(padded_square.size(), 0., padded_square_data, + &device_context_); + const float alpha_over_size = alpha_ / size_; + + for (int n = 0; n < num_rows; ++n) { + for (int c = 0; c < C; ++c) { + padded_square_data[c + pre_pad_] = + Xdata[n * C + c] * Xdata[n * C + c] * alpha_over_size; + } + float accum_scale = 0.; + for (int i = 0; i < size_ - 1; ++i) { + accum_scale += padded_square_data[i]; + } + for (int c = 0; c < C; ++c) { + accum_scale += padded_square_data[c + size_ - 1]; + scale_data[n * C + c] = bias_ + accum_scale; + accum_scale -= padded_square_data[c]; + } + } + math::Powx( + X.size(), scale_data, -beta_, Ydata, &device_context_); + math::Mul(X.size(), Ydata, Xdata, Ydata, &device_context_); + return true; +} + +template <> +bool LRNGradientOp::RunOnDeviceWithOrderNCHW() { + auto& X = Input(0); + auto& Y = Input(1); + auto& scale = Input(2); + auto& dY = Input(3); + auto* dX = Output(0); + DCHECK_EQ(X.ndim(), 4); + const int N = X.dim(0); + const int C = X.dim(1); + const int H = X.dim(2); + const int W = X.dim(3); + const int image_size = C * H * W; + // Loosely checking the size, assuming that the shapes will be the same as + // long as the sizes check out. + DCHECK_EQ(X.size(), Y.size()); + DCHECK_EQ(X.size(), scale.size()); + DCHECK_EQ(X.size(), dY.size()); + dX->ReshapeLike(X); + + const float* Xdata = X.data(); + const float* Ydata = Y.data(); + const float* scale_data = scale.data(); + const float* dYdata = dY.data(); + float* dXdata = dX->mutable_data(); + + Tensor padded_ratio( + std::vector{C + size_ - 1, H, W}); + float* padded_ratio_data = padded_ratio.mutable_data(); + math::Set(padded_ratio.size(), 0., padded_ratio_data, + &device_context_); + Tensor accum_ratio(std::vector{H, W}); + float* accum_ratio_data = accum_ratio.mutable_data(); + + + const float cache_ratio = 2. * alpha_ * beta_ / size_; + const int inverse_pre_pad = size_ - (size_ + 1) / 2; + + int offset = 0; + for (int n = 0; n < N; ++n) { + // first, compute diff_i * y_i / s_i + math::Mul( + image_size, dYdata + offset, Ydata + offset, + padded_ratio_data + inverse_pre_pad * H * W, &device_context_); + math::Div( + image_size, padded_ratio_data + inverse_pre_pad * H * W, + scale_data + offset, + padded_ratio_data + inverse_pre_pad * H * W, &device_context_); + // Now, compute the accumulated ratios and the bottom diff + math::Set(accum_ratio.size(), 0., accum_ratio_data, + &device_context_); + for (int c = 0; c < size_ - 1; ++c) { + static const float kOne = 1.; + math::Axpy(H * W, &(kOne), + padded_ratio_data + c * H * W, + accum_ratio_data, &device_context_); + } + for (int c = 0; c < C; ++c) { + for (int hw = 0; hw < H * W; ++hw) { + accum_ratio_data[hw] += padded_ratio_data[(c + size_ - 1) * H * W + hw]; + dXdata[offset] = + dYdata[offset] * pow(scale_data[offset], -beta_) - + cache_ratio * accum_ratio_data[hw] * Xdata[offset]; + accum_ratio_data[hw] -= padded_ratio_data[c * H * W + hw]; + ++offset; + } + } + } + return true; +} + +template <> +bool LRNGradientOp::RunOnDeviceWithOrderNHWC() { + auto& X = Input(0); + auto& Y = Input(1); + auto& scale = Input(2); + auto& dY = Input(3); + auto* dX = Output(0); + DCHECK_EQ(X.ndim(), 4); + const int N = X.dim(0); + const int H = X.dim(1); + const int W = X.dim(2); + const int C = X.dim(3); + // Loosely checking the size, assuming that the shapes will be the same as + // long as the sizes check out. + DCHECK_EQ(X.size(), Y.size()); + DCHECK_EQ(X.size(), scale.size()); + DCHECK_EQ(X.size(), dY.size()); + dX->ReshapeLike(X); + Tensor padded_ratio(std::vector(1, C + size_ - 1)); + float* padded_ratio_data = padded_ratio.mutable_data(); + math::Set(padded_ratio.size(), 0., padded_ratio_data, + &device_context_); + // the ratio 2*alpha*beta/size + const float cache_ratio = 2. * alpha_ * beta_ / size_; + const int num_rows = N * H * W; + const float* Xdata = X.data(); + const float* Ydata = Y.data(); + const float* scale_data = scale.data(); + const float* dYdata = dY.data(); + float* dXdata = dX->mutable_data(); + for (int n = 0; n < num_rows; ++n) { + const int offset = n * C; + for (int c = 0; c < C; ++c) { + padded_ratio_data[c + pre_pad_] = + Ydata[offset + c] * dYdata[offset + c] / scale_data[offset + c]; + } + float accum_ratio = 0.; + for (int c = 0; c < size_ - 1; ++c) { + accum_ratio += padded_ratio_data[c]; + } + for (int c = 0; c < C; ++c) { + accum_ratio += padded_ratio_data[c + size_ - 1]; + dXdata[offset + c] = + dYdata[offset + c] * pow(scale_data[offset + c], -beta_) - + cache_ratio * Xdata[offset + c] * accum_ratio; + accum_ratio -= padded_ratio_data[c]; + } + } + return true; +} + +namespace { +REGISTER_CPU_OPERATOR(LRN, LRNOp); +REGISTER_CPU_OPERATOR(LRNGradient, LRNGradientOp); +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/local_response_normalization_op.cu b/caffe2/operators/local_response_normalization_op.cu new file mode 100644 index 00000000000..50371cf5996 --- /dev/null +++ b/caffe2/operators/local_response_normalization_op.cu @@ -0,0 +1,292 @@ +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/local_response_normalization_op.h" + +namespace caffe2 { + +namespace { +template +__global__ void LRNFillScaleNCHW(const int nthreads, const T* in, + const int num, const int channels, const int height, + const int width, const int size, const T alpha_over_size, + const T bias, T* scale) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < nthreads) { + // find out the local offset + int w = index % width; + int h = (index / width) % height; + int n = index / width / height; + int offset = (n * channels * height + h) * width + w; + int step = height * width; + in += offset; + scale += offset; + int head = 0; + int pre_pad = (size - 1) / 2; + int post_pad = size - pre_pad - 1; + T accum_scale = 0; + // fill the scale at [n, :, h, w] + // accumulate values + while (head < post_pad) { + accum_scale += in[head * step] * in[head * step]; + ++head; + } + // until we reach size, nothing needs to be subtracted + while (head < size) { + accum_scale += in[head * step] * in[head * step]; + scale[(head - post_pad) * step] = bias + accum_scale * alpha_over_size; + ++head; + } + // both add and subtract + while (head < channels) { + accum_scale += in[head * step] * in[head * step]; + accum_scale -= in[(head - size) * step] * in[(head - size) * step]; + scale[(head - post_pad) * step] = bias + accum_scale * alpha_over_size; + ++head; + } + // subtract only + while (head < channels + post_pad) { + accum_scale -= in[(head - size) * step] * in[(head - size) * step]; + scale[(head - post_pad) * step] = bias + accum_scale * alpha_over_size; + ++head; + } + } +} + +template +__global__ void LRNFillScaleNHWC(const int nthreads, const T* in, + const int num, const int height, const int width, + const int channels, const int size, const T alpha_over_size, + const T bias, T* scale) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < nthreads) { + int c = index % channels; + int pre_pad = (size - 1) / 2; + scale[index] = 0; + for (int i = 0; i < size; ++i) { + int raw_idx = c + i - pre_pad; + if (raw_idx >= 0 && raw_idx < channels) { + scale[index] += in[index + i - pre_pad] * in[index + i - pre_pad]; + } + } + scale[index] = bias + scale[index] * alpha_over_size; + } +} + +// TODO(Yangqing): check if it would be faster to just put it into the previous +// kernel. +template +__global__ void LRNComputeOutput(const int nthreads, const T* in, + const T* scale, const T negative_beta, T* out) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < nthreads) { + out[index] = in[index] * pow(scale[index], negative_beta); + } +} + +template +__global__ void LRNComputeDiffNCHW(const int nthreads, const T* bottom_data, + const T* top_data, const T* scale, const T* top_diff, + const int num, const int channels, const int height, + const int width, const int size, const T negative_beta, + const T cache_ratio, + T* bottom_diff) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < nthreads) { + // find out the local offset + int w = index % width; + int h = (index / width) % height; + int n = index / width / height; + int offset = (n * channels * height + h) * width + w; + int step = height * width; + bottom_data += offset; + top_data += offset; + scale += offset; + top_diff += offset; + bottom_diff += offset; + int head = 0; + int pre_pad = size - (size + 1) / 2; + int post_pad = size - pre_pad - 1; + T accum_ratio = 0; + // accumulate values + while (head < post_pad) { + accum_ratio += top_diff[head * step] * top_data[head * step] / + scale[head * step]; + ++head; + } + // until we reach size, nothing needs to be subtracted + while (head < size) { + accum_ratio += top_diff[head * step] * top_data[head * step] / + scale[head * step]; + bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step] + * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * + bottom_data[(head - post_pad) * step] * accum_ratio; + ++head; + } + // both add and subtract + while (head < channels) { + accum_ratio += top_diff[head * step] * top_data[head * step] / + scale[head * step]; + accum_ratio -= top_diff[(head - size) * step] * + top_data[(head - size) * step] / scale[(head - size) * step]; + bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step] + * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * + bottom_data[(head - post_pad) * step] * accum_ratio; + ++head; + } + // subtract only + while (head < channels + post_pad) { + accum_ratio -= top_diff[(head - size) * step] * + top_data[(head - size) * step] / scale[(head - size) * step]; + bottom_diff[(head - post_pad) * step] = top_diff[(head - post_pad) * step] + * pow(scale[(head - post_pad) * step], negative_beta) - cache_ratio * + bottom_data[(head - post_pad) * step] * accum_ratio; + ++head; + } + } +} + +// This local response normalization gradient does one sum per output location +// and does not use the running trick for 1-d convolution: thus it might not be +// the fastest implementation. +template +__global__ void LRNComputeDiffNHWC(const int nthreads, const T* bottom_data, + const T* top_data, const T* scale, const T* top_diff, + const int num, const int height, const int width, const int channels, + const int size, const T negative_beta, const T cache_ratio, + T* bottom_diff) { + int index = threadIdx.x + blockIdx.x * blockDim.x; + if (index < nthreads) { + // find out the local channel offset + int c = index % channels; + int pre_pad = size / 2; + T accum_ratio = 0; + for (int i = -pre_pad; i < size - pre_pad; ++i) { + if (c + i >= 0 && c + i < channels) { + accum_ratio += top_diff[index + i] * top_data[index + i] / + scale[index + i]; + } + } + bottom_diff[index] = top_diff[index] * pow(scale[index], negative_beta) - + cache_ratio * bottom_data[index] * accum_ratio; + } +} +} // namespace + + + +template<> +bool LRNOp::RunOnDeviceWithOrderNCHW() { + auto& X = Input(0); + auto* Y = Output(0); + auto* scale = Output(1); + DCHECK_EQ(X.ndim(), 4); + const int N = X.dim(0); + const int C = X.dim(1); + const int H = X.dim(2); + const int W = X.dim(3); + const float* Xdata = X.data(); + Y->ReshapeLike(X); + scale->ReshapeLike(X); + float* Ydata = Y->mutable_data(); + float* scale_data = scale->mutable_data(); + + int n_threads = N * H * W; + LRNFillScaleNCHW<<>>( + n_threads, Xdata, N, C, H, W, size_, alpha_ / size_, bias_, scale_data); + n_threads = X.size(); + LRNComputeOutput<<>>( + n_threads, Xdata, scale_data, -beta_, Ydata); + return true; +} + +template<> +bool LRNOp::RunOnDeviceWithOrderNHWC() { + auto& X = Input(0); + auto* Y = Output(0); + auto* scale = Output(1); + DCHECK_EQ(X.ndim(), 4); + const int N = X.dim(0); + const int H = X.dim(1); + const int W = X.dim(2); + const int C = X.dim(3); + const float* Xdata = X.data(); + Y->ReshapeLike(X); + scale->ReshapeLike(X); + float* Ydata = Y->mutable_data(); + float* scale_data = scale->mutable_data(); + + int n_threads = X.size(); + LRNFillScaleNHWC<<>>( + n_threads, Xdata, N, H, W, C, size_, alpha_ / size_, bias_, scale_data); + LRNComputeOutput<<>>( + n_threads, Xdata, scale_data, -beta_, Ydata); + return true; +} + +template <> +bool LRNGradientOp::RunOnDeviceWithOrderNCHW() { + auto& X = Input(0); + auto& Y = Input(1); + auto& scale = Input(2); + auto& dY = Input(3); + auto* dX = Output(0); + DCHECK_EQ(X.ndim(), 4); + const int N = X.dim(0); + const int C = X.dim(1); + const int H = X.dim(2); + const int W = X.dim(3); + // Loosely checking the size, assuming that the shapes will be the same as + // long as the sizes check out. + DCHECK_EQ(X.size(), Y.size()); + DCHECK_EQ(X.size(), scale.size()); + DCHECK_EQ(X.size(), dY.size()); + dX->ReshapeLike(X); + + const float* Xdata = X.data(); + const float* Ydata = Y.data(); + const float* scale_data = scale.data(); + const float* dYdata = dY.data(); + float* dXdata = dX->mutable_data(); + + int n_threads = N * H * W; + LRNComputeDiffNCHW<<>>( + n_threads, Xdata, Ydata, scale_data, dYdata, N, C, H, W, size_, -beta_, + 2.f * alpha_ * beta_ / size_, dXdata); + return true; +} + +template <> +bool LRNGradientOp::RunOnDeviceWithOrderNHWC() { + auto& X = Input(0); + auto& Y = Input(1); + auto& scale = Input(2); + auto& dY = Input(3); + auto* dX = Output(0); + DCHECK_EQ(X.ndim(), 4); + // Loosely checking the size, assuming that the shapes will be the same as + // long as the sizes check out. + DCHECK_EQ(X.size(), Y.size()); + DCHECK_EQ(X.size(), scale.size()); + DCHECK_EQ(X.size(), dY.size()); + dX->ReshapeLike(X); + + LRNComputeDiffNHWC<<>>( + X.size(), X.data(), Y.data(), scale.data(), dY.data(), + X.dim(0), X.dim(1), X.dim(2), X.dim(3), size_, -beta_, + 2.f * alpha_ * beta_ / size_, dX->mutable_data()); + return true; +} + + +namespace { +REGISTER_CUDA_OPERATOR(LRN, LRNOp); +REGISTER_CUDA_OPERATOR(LRNGradient, LRNGradientOp); +} + +} // namespace caffe2 diff --git a/caffe2/operators/local_response_normalization_op.h b/caffe2/operators/local_response_normalization_op.h new file mode 100644 index 00000000000..d7022c8d059 --- /dev/null +++ b/caffe2/operators/local_response_normalization_op.h @@ -0,0 +1,94 @@ +#ifndef CAFFE2_OPERATORS_LOCAL_RESPONSE_NORMALIZATION_OP_H_ +#define CAFFE2_OPERATORS_LOCAL_RESPONSE_NORMALIZATION_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/utils/math.h" +#include "glog/logging.h" + +namespace caffe2 { + +template +class LRNOpBase : public Operator { + public: + USE_OPERATOR_BASE_FUNCTIONS; + LRNOpBase(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + size_(OperatorBase::GetSingleArgument("size", 0)), + alpha_(OperatorBase::GetSingleArgument("alpha", 0)), + beta_(OperatorBase::GetSingleArgument("beta", 0)), + bias_(OperatorBase::GetSingleArgument("bias", 1)), + order_(StringToStorageOrder( + OperatorBase::GetSingleArgument("order", "NHWC"))), + pre_pad_((size_ - 1) / 2) { + DCHECK_GT(size_, 0); + DCHECK_EQ(size_ % 2, 1); + DCHECK_GT(alpha_, 0); + DCHECK_GT(beta_, 0); + } + + bool RunOnDevice() override { + switch (order_) { + case StorageOrder::NHWC: + return RunOnDeviceWithOrderNHWC(); + case StorageOrder::NCHW: + return RunOnDeviceWithOrderNCHW(); + default: + LOG(FATAL) << "Unknown storage order: " << order_; + } + // To suppress old compiler warnings + return true; + } + + virtual bool RunOnDeviceWithOrderNCHW() = 0; + virtual bool RunOnDeviceWithOrderNHWC() = 0; + + protected: + const int size_; + const float alpha_; + const float beta_; + const float bias_; + const StorageOrder order_; + const int pre_pad_; + // Input: X; Output: Y, scale. + INPUT_OUTPUT_STATS(1, 1, 2, 2); + DISABLE_COPY_AND_ASSIGN(LRNOpBase); +}; + +template +class LRNOp final : public LRNOpBase { + public: + USE_OPERATOR_BASE_FUNCTIONS; + LRNOp(const OperatorDef& operator_def, Workspace* ws) + : LRNOpBase(operator_def, ws) {} + + bool RunOnDeviceWithOrderNCHW() override; + bool RunOnDeviceWithOrderNHWC() override; + + protected: + // Input: X; Output: Y, scale. + OUTPUT_TAGS(OUTPUT, SCALE); + INPUT_OUTPUT_STATS(1, 1, 2, 2); + DISABLE_COPY_AND_ASSIGN(LRNOp); +}; + +template +class LRNGradientOp final : public LRNOpBase { + public: + USE_OPERATOR_BASE_FUNCTIONS; + LRNGradientOp(const OperatorDef& operator_def, Workspace* ws) + : LRNOpBase(operator_def, ws) {} + + bool RunOnDeviceWithOrderNCHW() override; + bool RunOnDeviceWithOrderNHWC() override; + + protected: + // Input: X, Y, scale, dY; Output: dX + INPUT_TAGS(INPUT, OUTPUT, SCALE, OUTPUT_GRAD); + INPUT_OUTPUT_STATS(4, 4, 1, 1); + DISABLE_COPY_AND_ASSIGN(LRNGradientOp); +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_LOCAL_RESPONSE_NORMALIZATION_OP_H_ diff --git a/caffe2/operators/loss_op.cc b/caffe2/operators/loss_op.cc new file mode 100644 index 00000000000..9798c7cebda --- /dev/null +++ b/caffe2/operators/loss_op.cc @@ -0,0 +1,10 @@ +#include "caffe2/operators/loss_op.h" + +namespace caffe2 { +namespace { + +REGISTER_CPU_OPERATOR(AveragedLoss, AveragedLoss) +REGISTER_CPU_OPERATOR(WeightedSumLoss, WeightedSumLoss) + +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/loss_op.h b/caffe2/operators/loss_op.h new file mode 100644 index 00000000000..dbfc4993acd --- /dev/null +++ b/caffe2/operators/loss_op.h @@ -0,0 +1,66 @@ +#ifndef CAFFE2_OPERATORS_LOSS_OP_H_ +#define CAFFE2_OPERATORS_LOSS_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/utils/math.h" +#include "glog/logging.h" + +namespace caffe2 { + +// AveragedLoss takes in the input and produces two outputs: one being the loss +// value, and one being the gradient. +template +class AveragedLoss final : public Operator { + public: + USE_SIMPLE_CTOR_DTOR(AveragedLoss); + USE_OPERATOR_BASE_FUNCTIONS; + + bool RunOnDevice() override { + auto& X = Input(0); + auto* Loss = Output(0); + auto* dX = Output(1); + Loss->Reshape(std::vector{1}); + dX->ReshapeLike(X); + math::Set( + dX->size(), static_cast(1.) / X.size(), dX->mutable_data(), + &device_context_); + math::Dot( + X.size(), X.data(), dX->data(), Loss->mutable_data(), &device_context_); + return true; + } + + protected: + INPUT_OUTPUT_STATS(1, 1, 2, 2); + DISABLE_COPY_AND_ASSIGN(AveragedLoss); +}; + +template +class WeightedSumLoss final : public Operator { + public: + USE_SIMPLE_CTOR_DTOR(WeightedSumLoss); + USE_OPERATOR_BASE_FUNCTIONS; + + bool RunOnDevice() override { + auto& X = Input(0); + auto& W = Input(1); + DCHECK_EQ(X.size(), W.size()); + auto* Loss = Output(0); + auto* dX = Output(1); + Loss->Reshape(std::vector{1}); + math::Dot( + X.size(), X.data(), W.data(), Loss->mutable_data(), &device_context_); + dX->ReshapeLike(X); + dX->ShareData(W); + return true; + } + + protected: + INPUT_OUTPUT_STATS(2, 2, 2, 2); + DISABLE_COPY_AND_ASSIGN(WeightedSumLoss); +}; + + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_LOSS_OP_H_ diff --git a/caffe2/operators/loss_op_gpu.cc b/caffe2/operators/loss_op_gpu.cc new file mode 100644 index 00000000000..ace7be820aa --- /dev/null +++ b/caffe2/operators/loss_op_gpu.cc @@ -0,0 +1,11 @@ +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/loss_op.h" + +namespace caffe2 { +namespace { + +REGISTER_CUDA_OPERATOR(AveragedLoss, AveragedLoss) +REGISTER_CUDA_OPERATOR(WeightedSumLoss, WeightedSumLoss) + +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/maxpool_op.cc b/caffe2/operators/maxpool_op.cc new file mode 100644 index 00000000000..866586002b6 --- /dev/null +++ b/caffe2/operators/maxpool_op.cc @@ -0,0 +1,146 @@ +#include "caffe2/operators/maxpool_op.h" + +namespace caffe2 { + +using std::max; +using std::min; + +template <> +bool MaxPoolOp::RunOnDeviceWithOrderNCHW() { + auto& X = Input(0); + auto* Y = Output(0); + Tensor* index = + OperatorBase::template Output >(1); + ConvPoolOpBase::SetOutputSize(X, Y, X.dim(1)); + index->ReshapeLike(*Y); + + const float* Xdata = X.data(); + float* Ydata = Y->mutable_data(); + int* index_data = index->mutable_data(); + math::Set( + Y->size(), std::numeric_limits::lowest(), Ydata, &device_context_); + // The main loop + int channels = X.dim(1); + int height = X.dim(2); + int width = X.dim(3); + int pooled_height = Y->dim(2); + int pooled_width = Y->dim(3); + for (int n = 0; n < X.dim(0); ++n) { + for (int c = 0; c < channels; ++c) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = ph * stride_h_ - pad_t_; + int wstart = pw * stride_w_ - pad_l_; + int hend = min(hstart + kernel_h_, height); + int wend = min(wstart + kernel_w_, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + const int pool_index = ph * pooled_width + pw; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + const int input_index = h * width + w; + if (Xdata[input_index] > Ydata[pool_index]) { + Ydata[pool_index] = Xdata[input_index]; + index_data[pool_index] = c * height * width + h * width + w; + } + } + } + } + } + // Do offset. + Xdata += height * width; + Ydata += pooled_height * pooled_width; + index_data += pooled_height * pooled_width; + } + } + return true; +} + +template <> +bool MaxPoolOp::RunOnDeviceWithOrderNHWC() { + auto& X = Input(0); + auto* Y = Output(0); + Tensor* index = + OperatorBase::template Output >(1); + int height = X.dim(1); + int width = X.dim(2); + int channels = X.dim(3); + ConvPoolOpBase::SetOutputSize(X, Y, channels); + index->ReshapeLike(*Y); + + const float* Xdata = X.data(); + float* Ydata = Y->mutable_data(); + int* index_data = index->mutable_data(); + math::Set( + Y->size(), std::numeric_limits::lowest(), Ydata, &device_context_); + // The main loop + int pooled_height = Y->dim(1); + int pooled_width = Y->dim(2); + for (int n = 0; n < X.dim(0); ++n) { + for (int ph = 0; ph < pooled_height; ++ph) { + for (int pw = 0; pw < pooled_width; ++pw) { + int hstart = ph * stride_h_ - pad_t_; + int wstart = pw * stride_w_ - pad_l_; + int hend = min(hstart + kernel_h_, height); + int wend = min(wstart + kernel_w_, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + // compute max in range X[n, hstart:hend, wstart:wend, :] + const int pool_index = (ph * pooled_width + pw) * channels; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + const int input_index = (h * width + w) * channels; + for (int c = 0; c < channels; ++c) { + if (Xdata[input_index + c] > Ydata[pool_index + c]) { + Ydata[pool_index + c] = Xdata[input_index + c]; + index_data[pool_index + c] = input_index + c; + } + } + } + } + } + } + // Do offset. + Xdata += X.size() / X.dim(0); + Ydata += Y->size() / Y->dim(0); + index_data += Y->size() / Y->dim(0); + } + return true; +} + +template <> +bool MaxPoolGradientOp::RunOnDevice() { + auto& X = Input(0); + auto& dY = Input(1); + const Tensor& maxid = + OperatorBase::template Input >(2); + DCHECK_EQ(maxid.size(), dY.size()); + auto* dX = Output(0); + // TODO(Yangqing): Add shape checks. + dX->ReshapeLike(X); + math::Set( + X.size(), 0, dX->mutable_data(), &device_context_); + const float* dYdata = dY.data(); + const int* maxid_data = maxid.data(); + float* dXdata = dX->mutable_data(); + // Since we have recorded all the indices, we just need to run a simple + // assignment pass. + const int single_input_size = X.size() / X.dim(0); + const int single_output_size = dY.size() / dY.dim(0); + for (int n = 0; n < dY.dim(0); ++n) { + for (int i = 0; i < single_output_size; ++i) { + // DCHECK_LT(maxid_data[i], single_input_size); + dXdata[maxid_data[i]] += dYdata[i]; + } + dXdata += single_input_size; + maxid_data += single_output_size; + dYdata += single_output_size; + } + return true; +} + +namespace { +REGISTER_CPU_OPERATOR(MaxPool, MaxPoolOp) +REGISTER_CPU_OPERATOR(MaxPoolGradient, MaxPoolGradientOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/maxpool_op.cu b/caffe2/operators/maxpool_op.cu new file mode 100644 index 00000000000..d8bee0fc7e7 --- /dev/null +++ b/caffe2/operators/maxpool_op.cu @@ -0,0 +1,153 @@ +#include + +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/maxpool_op.h" + +namespace caffe2 { + +namespace { +template +__global__ void MaxPoolForwardNCHW(const int nthreads, const dtype* bottom_data, + const int channels, const int height, + const int width, const int pooled_height, const int pooled_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, const int pad_l, dtype* top_data, + int* mask) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int pw = index % pooled_width; + int ph = (index / pooled_width) % pooled_height; + int c = (index / pooled_width / pooled_height) % channels; + int n = index / pooled_width / pooled_height / channels; + int hstart = ph * stride_h - pad_t; + int wstart = pw * stride_w - pad_l; + int hend = min(hstart + kernel_h, height); + int wend = min(wstart + kernel_w, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + dtype maxval = -FLT_MAX; + int maxidx = -1; + bottom_data += n * channels * height * width; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int idx = c * height * width + h * width + w; + if (bottom_data[idx] > maxval) { + maxidx = idx; + maxval = bottom_data[idx]; + } + } + } + top_data[index] = maxval; + mask[index] = maxidx; + } +} + +template +__global__ void MaxPoolForwardNHWC(const int nthreads, const dtype* bottom_data, + const int height, const int width, + const int channels, const int pooled_height, const int pooled_width, + const int kernel_h, const int kernel_w, const int stride_h, + const int stride_w, const int pad_t, const int pad_l, dtype* top_data, + int* mask) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int n = index; + int c = n % channels; + n /= channels; + int wstart = (n % pooled_width) * stride_w - pad_l; + n /= pooled_width; + int hstart = (n % pooled_height) * stride_h - pad_t; + n /= pooled_height; + int hend = min(hstart + kernel_h, height); + int wend = min(wstart + kernel_w, width); + hstart = max(hstart, 0); + wstart = max(wstart, 0); + dtype maxval = -FLT_MAX; + int maxidx = -1; + bottom_data += n * height * width * channels; + for (int h = hstart; h < hend; ++h) { + for (int w = wstart; w < wend; ++w) { + int idx = (h * width + w) * channels + c; + if (bottom_data[idx] > maxval) { + maxidx = idx; + maxval = bottom_data[idx]; + } + } + } + top_data[index] = maxval; + mask[index] = maxidx; + } +} + +template +__global__ void MaxPoolBackward( + const int nthreads, const dtype* top_diff, const int* mask, + const int top_offset, const int bottom_offset, dtype* bottom_diff) { + CUDA_1D_KERNEL_LOOP(index, nthreads) { + int image_id = (index / top_offset); + atomicAdd(bottom_diff + image_id * bottom_offset + mask[index], + top_diff[index]); + } +} + +} // namespace + +template <> +bool MaxPoolOp::RunOnDeviceWithOrderNCHW() { + auto& X = Input(0); + auto* Y = Output(0); + Tensor* maxid = + OperatorBase::template Output >(1); + ConvPoolOpBase::SetOutputSize(X, Y, X.dim(1)); + maxid->ReshapeLike(*Y); + int output_size = Y->size(); + MaxPoolForwardNCHW<<>>( + output_size, X.data(), X.dim(1), X.dim(2), X.dim(3), + Y->dim(2), Y->dim(3), kernel_h_, kernel_w_, stride_h_, stride_w_, + pad_t_, pad_l_, Y->mutable_data(), maxid->mutable_data()); + return true; +} + +template <> +bool MaxPoolOp::RunOnDeviceWithOrderNHWC() { + auto& X = Input(0); + auto* Y = Output(0); + Tensor* maxid = + OperatorBase::template Output >(1); + ConvPoolOpBase::SetOutputSize(X, Y, X.dim(3)); + maxid->ReshapeLike(*Y); + int output_size = Y->size(); + MaxPoolForwardNHWC<<>>( + output_size, X.data(), X.dim(1), X.dim(2), X.dim(3), + Y->dim(1), Y->dim(2), kernel_h_, kernel_w_, stride_h_, stride_w_, + pad_t_, pad_l_, Y->mutable_data(), maxid->mutable_data()); + return true; +} + + +template <> +bool MaxPoolGradientOp::RunOnDevice() { + auto& X = Input(0); + auto& dY = Input(1); + const Tensor& maxid = + OperatorBase::template Input >(2); + auto* dX = Output(0); + // TODO(Yangqing): Add shape checks. + dX->ReshapeLike(X); + math::Set( + X.size(), 0, dX->mutable_data(), &device_context_); + MaxPoolBackward<<>>( + dY.size(), dY.data(), maxid.data(), dY.size() / dY.dim(0), + X.size() / X.dim(0), dX->mutable_data()); + return true; +} + +namespace { +REGISTER_CUDA_OPERATOR(MaxPool, MaxPoolOp) +REGISTER_CUDA_OPERATOR(MaxPoolGradient, MaxPoolGradientOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/maxpool_op.h b/caffe2/operators/maxpool_op.h new file mode 100644 index 00000000000..806d473e889 --- /dev/null +++ b/caffe2/operators/maxpool_op.h @@ -0,0 +1,51 @@ +#ifndef CAFFE2_OPERATORS_MAXPOOL_OP_H_ +#define CAFFE2_OPERATORS_MAXPOOL_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/operators/conv_pool_op_base.h" +#include "caffe2/utils/math.h" +#include "glog/logging.h" + +namespace caffe2 { + +// MaxPool will produce the max values as well as the indices of the original +// input that leads to the max value. Note that the indices are PER IMAGE, +// meaning that if you compute the offset in the original raw data buffer, you +// will need to deal with the number of images and channels accordingly. +template +class MaxPoolOp final : public ConvPoolOpBase { + public: + USE_CONV_POOL_BASE_FUNCTIONS; + MaxPoolOp(const OperatorDef& operator_def, Workspace* ws) + : ConvPoolOpBase(operator_def, ws) {} + ~MaxPoolOp() {} + + bool RunOnDeviceWithOrderNCHW() override; + bool RunOnDeviceWithOrderNHWC() override; + + // Input: X + // Output: Y, index + INPUT_OUTPUT_STATS(1, 1, 2, 2); + DISABLE_COPY_AND_ASSIGN(MaxPoolOp); +}; + +template +class MaxPoolGradientOp final : public ConvPoolOpBase { + public: + USE_CONV_POOL_BASE_FUNCTIONS; + MaxPoolGradientOp(const OperatorDef& operator_def, Workspace* ws) + : ConvPoolOpBase(operator_def, ws) {} + ~MaxPoolGradientOp() {} + + bool RunOnDevice() override; + + // Input: X, dY, index + // Output: dX + INPUT_OUTPUT_STATS(3, 3, 1, 1); + DISABLE_COPY_AND_ASSIGN(MaxPoolGradientOp); +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_MAXPOOL_OP_H_ diff --git a/caffe2/operators/order_switch_ops.cc b/caffe2/operators/order_switch_ops.cc new file mode 100644 index 00000000000..fc7beadb194 --- /dev/null +++ b/caffe2/operators/order_switch_ops.cc @@ -0,0 +1,52 @@ +#include "caffe2/operators/order_switch_ops.h" + +namespace caffe2 { + +template <> +bool NHWC2NCHWOp::RunOnDevice() { + auto& X = Input(0); + auto* Y = Output(0); + DCHECK_EQ(X.ndim(), 4); + const int N = X.dim(0), H = X.dim(1), W = X.dim(2), C = X.dim(3); + Y->Reshape(std::vector{N, C, H, W}); + const float* Xdata = X.data(); + float* Ydata = Y->mutable_data(); + for (int n = 0; n < N; ++n) { + for (int h = 0; h < H; ++h) { + for (int w = 0; w < W; ++w) { + for (int c = 0; c < C; ++c) { + Ydata[((n * C + c) * H + h) * W + w] = *(Xdata++); + } + } + } + } + return true; +} + +template <> +bool NCHW2NHWCOp::RunOnDevice() { + auto& X = Input(0); + auto* Y = Output(0); + DCHECK_EQ(X.ndim(), 4); + const int N = X.dim(0), C = X.dim(1), H = X.dim(2), W = X.dim(3); + Y->Reshape(std::vector{N, H, W, C}); + const float* Xdata = X.data(); + float* Ydata = Y->mutable_data(); + for (int n = 0; n < N; ++n) { + for (int c = 0; c < C; ++c) { + for (int h = 0; h < H; ++h) { + for (int w = 0; w < W; ++w) { + Ydata[((n * H + h) * W + w) * C + c] = *(Xdata++); + } + } + } + } + return true; +} + + +namespace { +REGISTER_CPU_OPERATOR(NHWC2NCHW, NHWC2NCHWOp) +REGISTER_CPU_OPERATOR(NCHW2NHWC, NCHW2NHWCOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/order_switch_ops.cu b/caffe2/operators/order_switch_ops.cu new file mode 100644 index 00000000000..b73b929cfef --- /dev/null +++ b/caffe2/operators/order_switch_ops.cu @@ -0,0 +1,57 @@ +#include "caffe2/operators/order_switch_ops.h" +#include "caffe2/core/context_gpu.h" + +namespace caffe2 { + +__global__ void NHWC2NCHWKernel(const int N, const int HW, const int C, + const float* X, float* Y) { + CUDA_1D_KERNEL_LOOP(i, N * HW * C) { + const int c = i % C; + const int hw = i / C % HW; + const int n = i / C / HW; + Y[(n * C + c) * HW + hw] = X[i]; + } +} + +__global__ void NCHW2NHWCKernel(const int N, const int C, const int HW, + const float* X, float* Y) { + CUDA_1D_KERNEL_LOOP(i, N * C * HW) { + const int hw = i % HW; + const int c = i / HW % C; + const int n = i / C / HW; + Y[(n * HW + hw) * C + c] = X[i]; + } +} + +template <> +bool NHWC2NCHWOp::RunOnDevice() { + auto& X = Input(0); + auto* Y = Output(0); + DCHECK_EQ(X.ndim(), 4); + const int N = X.dim(0), H = X.dim(1), W = X.dim(2), C = X.dim(3); + Y->Reshape(std::vector{N, C, H, W}); + NHWC2NCHWKernel<<>>( + N, H * W, C, X.data(), Y->mutable_data()); + return true; +} + +template <> +bool NCHW2NHWCOp::RunOnDevice() { + auto& X = Input(0); + auto* Y = Output(0); + DCHECK_EQ(X.ndim(), 4); + const int N = X.dim(0), C = X.dim(1), H = X.dim(2), W = X.dim(3); + Y->Reshape(std::vector{N, H, W, C}); + NCHW2NHWCKernel<<>>( + N, C, H * W, X.data(), Y->mutable_data()); + return true; +} + + +namespace { +REGISTER_CUDA_OPERATOR(NHWC2NCHW, NHWC2NCHWOp) +REGISTER_CUDA_OPERATOR(NCHW2NHWC, NCHW2NHWCOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/order_switch_ops.h b/caffe2/operators/order_switch_ops.h new file mode 100644 index 00000000000..8f37d69c691 --- /dev/null +++ b/caffe2/operators/order_switch_ops.h @@ -0,0 +1,38 @@ +#ifndef CAFFE2_OPERATORS_ORDER_SWITCH_OPS_H_ +#define CAFFE2_OPERATORS_ORDER_SWITCH_OPS_H_ + +#include "caffe2/core/operator.h" + +namespace caffe2 { + +// Note(Yangqing): I think it is possible to do a more general swapaxes operator +// but I am a little afraid of going down that general path. Only implementing +// the two actually needed ones here. + +template +class NHWC2NCHWOp final : public Operator { + public: + USE_SIMPLE_CTOR_DTOR(NHWC2NCHWOp); + USE_OPERATOR_BASE_FUNCTIONS; + bool RunOnDevice() override; + + protected: + INPUT_OUTPUT_STATS(1, 1, 1, 1); + DISABLE_COPY_AND_ASSIGN(NHWC2NCHWOp); +}; + +template +class NCHW2NHWCOp final : public Operator { + public: + USE_SIMPLE_CTOR_DTOR(NCHW2NHWCOp); + USE_OPERATOR_BASE_FUNCTIONS; + bool RunOnDevice() override; + + protected: + INPUT_OUTPUT_STATS(1, 1, 1, 1); + DISABLE_COPY_AND_ASSIGN(NCHW2NHWCOp); +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_ORDER_SWITCH_OPS_H_ diff --git a/caffe2/operators/prefetch_op.h b/caffe2/operators/prefetch_op.h new file mode 100644 index 00000000000..8f5fa5ecbd3 --- /dev/null +++ b/caffe2/operators/prefetch_op.h @@ -0,0 +1,82 @@ +#ifndef CAFFE2_OPERATORS_PREFETCH_OP_H_ +#define CAFFE2_OPERATORS_PREFETCH_OP_H_ + +#include // NOLINT + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" + +namespace caffe2 { + +template +class PrefetchOperator; + +namespace internal { +// We define a prefetch function so that the prefetch function can call virtual +// member functions of the prefetch operator. +template +void PrefetchFunc(PrefetchOperator* op) { + op->prefetch_success_ = op->Prefetch(); +} +} + +// PrefetchOperator is an operator that prefetches the next batch. It should +// almost always be used to read things from disk, so I am setting the input to +// zero blobs. +template +class PrefetchOperator : public OperatorBase { + public: + PrefetchOperator(const OperatorDef& operator_def, Workspace* ws) + : OperatorBase(operator_def, ws), + device_context_(operator_def.device_option()), + prefetch_success_(false) { + device_context_.SwitchToDevice(); + } + virtual ~PrefetchOperator() {} + + bool Run() final { + if (prefetch_thread_ == nullptr) { + VLOG(1) << "Starting a new prefetch thread."; + prefetch_thread_.reset( + new std::thread( + internal::PrefetchFunc, this)); + } + // Join the last prefetch thread. + VLOG(1) << "Waiting for the prefetch thread."; + prefetch_thread_->join(); + + if (!prefetch_success_) { + LOG(ERROR) << "Prefetching failed."; + return false; + } + VLOG(1) << "Copy prefetched result."; + if (!CopyPrefetched()) { + LOG(ERROR) << "Error when copying prefetched data."; + return false; + } + prefetch_success_ = false; + VLOG(1) << "Starting a new prefetch thread."; + prefetch_thread_.reset( + new std::thread( + internal::PrefetchFunc, this)); + return true; + } + + // You will need to implement this instead of the Run function. + virtual bool Prefetch() = 0; + virtual bool CopyPrefetched() = 0; + friend void internal::PrefetchFunc( + PrefetchOperator*); + + protected: + DeviceContext device_context_; + unique_ptr prefetch_thread_; + bool prefetch_success_; + + INPUT_OUTPUT_STATS(0, 0, 1, INT_MAX); + DISABLE_COPY_AND_ASSIGN(PrefetchOperator); +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_PREFETCH_OP_H_ diff --git a/caffe2/operators/relu_op.cc b/caffe2/operators/relu_op.cc new file mode 100644 index 00000000000..4b3fbefcafb --- /dev/null +++ b/caffe2/operators/relu_op.cc @@ -0,0 +1,40 @@ +#include "caffe2/operators/relu_op.h" + +namespace caffe2 { + +template <> +bool ReluOp::RunOnDevice() { + auto& X = Input(0); + auto* Y = Output(0); + DCHECK_GT(X.size(), 0); + Y->ReshapeLike(X); + const float* Xdata = X.data(); + float* Ydata = Y->mutable_data(); + for (int i = 0; i < X.size(); ++i) { + Ydata[i] = std::max(Xdata[i], 0.f); + } + return true; +} + +template <> +bool ReluGradientOp::RunOnDevice() { + auto& X = Input(0); + auto& dY = Input(1); + auto* dX = Output(0); + DCHECK_GT(X.size(), 0); + DCHECK_EQ(dY.size(), X.size()); + dX->ReshapeLike(X); + const float* Xdata = X.data(); + const float* dYdata = dY.data(); + float* dXdata = dX->mutable_data(); + for (int i = 0; i < X.size(); ++i) { + dXdata[i] = dYdata[i] * (Xdata[i] > 0); + } + return true; +} + +namespace { +REGISTER_CPU_OPERATOR(Relu, ReluOp) +REGISTER_CPU_OPERATOR(ReluGradient, ReluGradientOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/relu_op.cu b/caffe2/operators/relu_op.cu new file mode 100644 index 00000000000..0b349ffa496 --- /dev/null +++ b/caffe2/operators/relu_op.cu @@ -0,0 +1,52 @@ +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/relu_op.h" + +namespace caffe2 { +namespace { +template +__global__ void ReluKernel(const int N, const dtype* X, dtype* Y) { + CUDA_1D_KERNEL_LOOP(i, N) { + Y[i] = X[i] > 0 ? X[i] : 0; + } +} + +template +__global__ void ReluGradientKernel(const int N, const dtype* X, const dtype* dY, + dtype* dX) { + CUDA_1D_KERNEL_LOOP(i, N) { + dX[i] = dY[i] * (X[i] > 0); + } +} +} // namespace + +template <> +bool ReluOp::RunOnDevice() { + auto& X = Input(0); + auto* Y = Output(0); + DCHECK_GT(X.size(), 0); + Y->ReshapeLike(X); + ReluKernel<<>>( + X.size(), X.data(), Y->mutable_data()); + return true; +} + +template <> +bool ReluGradientOp::RunOnDevice() { + auto& X = Input(0); + auto& dY = Input(1); + auto* dX = Output(0); + DCHECK_GT(X.size(), 0); + DCHECK_EQ(dY.size(), X.size()); + dX->ReshapeLike(X); + ReluGradientKernel<<>>( + X.size(), X.data(), dY.data(), dX->mutable_data()); + return true; +} + +namespace { +REGISTER_CUDA_OPERATOR(Relu, ReluOp) +REGISTER_CUDA_OPERATOR(ReluGradient, ReluGradientOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/relu_op.h b/caffe2/operators/relu_op.h new file mode 100644 index 00000000000..2d8ac1aaa7c --- /dev/null +++ b/caffe2/operators/relu_op.h @@ -0,0 +1,40 @@ +#ifndef CAFFE2_OPERATORS_RELU_OP_H_ +#define CAFFE2_OPERATORS_RELU_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/utils/math.h" +#include "glog/logging.h" + +namespace caffe2 { + +template +class ReluOp final : public Operator { + public: + USE_SIMPLE_CTOR_DTOR(ReluOp); + USE_OPERATOR_BASE_FUNCTIONS; + + bool RunOnDevice(); + + protected: + INPUT_OUTPUT_STATS(1, 1, 1, 1); + DISABLE_COPY_AND_ASSIGN(ReluOp); +}; + +template +class ReluGradientOp final : public Operator { + public: + USE_SIMPLE_CTOR_DTOR(ReluGradientOp); + USE_OPERATOR_BASE_FUNCTIONS; + + bool RunOnDevice(); + + protected: + // Input: X, dY; Output: dX + INPUT_OUTPUT_STATS(2, 2, 1, 1); + DISABLE_COPY_AND_ASSIGN(ReluGradientOp); +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_RELU_OP_H_ diff --git a/caffe2/operators/softmax_op.cc b/caffe2/operators/softmax_op.cc new file mode 100644 index 00000000000..c7c6fa09a9f --- /dev/null +++ b/caffe2/operators/softmax_op.cc @@ -0,0 +1,95 @@ +#include "caffe2/operators/softmax_op.h" + +namespace caffe2 { + +// Implementation for the CPU context. +template <> +bool SoftmaxOp::RunOnDevice() { + auto& X = Input(0); + auto* Y = Output(0); + DCHECK_EQ(X.ndim(), 2); + int N = X.dim(0); + int D = X.dim(1); + Y->ReshapeLike(X); + // First, get scales + if (scale_.size() != N) { + scale_.Reshape(std::vector{N}); + } + if (sum_multiplier_.size() != D) { + sum_multiplier_.Reshape(std::vector{D}); + math::Set(D, 1.f, sum_multiplier_.mutable_data(), + &device_context_); + } + math::RowwiseMax(N, D, X.data(), scale_.mutable_data(), + &device_context_); + // Put the intermediate result X - max(X) into Y + device_context_.template Copy( + Y->mutable_data(), X.data(), X.size()); + // Subtract the scale + static const float kMinusOne = -1.; + static const float kOne = 1.; + static const float kZero = 0; + math::Gemm(CblasNoTrans, CblasNoTrans, N, D, 1, + &kMinusOne, scale_.data(), sum_multiplier_.data(), &kOne, + Y->mutable_data(), &device_context_); + // Exponentiation + math::Exp(Y->size(), Y->data(), Y->mutable_data(), + &device_context_); + math::Gemv(CblasNoTrans, N, D, &kOne, Y->data(), + sum_multiplier_.data(), &kZero, + scale_.mutable_data(), &device_context_); + // Do division + // TODO(Yangqing): maybe implement it more beautifully? + float* output = Y->mutable_data(); + const float* scale = scale_.data(); + for (int i = 0; i < N; ++i) { + for (int j = 0; j < D; ++j) { + output[i * D + j] /= scale[i]; + } + } + return true; +} + +// Implementation for the CPU context. +template <> +bool SoftmaxGradientOp::RunOnDevice() { + auto& Y = Input(0); + auto& dY = Input(1); + auto* dX = Output(0); + DCHECK_EQ(Y.ndim(), 2); + int N = Y.dim(0); + int D = Y.dim(1); + DCHECK_EQ(dY.dim(0), N); + DCHECK_EQ(dY.dim(1), D); + // First, get scales + if (scale_.size() != N) { + scale_.Reshape(std::vector{N}); + } + if (sum_multiplier_.size() != D) { + sum_multiplier_.Reshape(std::vector{D}); + math::Set(D, 1.f, sum_multiplier_.mutable_data(), + &device_context_); + } + dX->Reshape(std::vector{N, D}); + const float* Ydata = Y.data(); + const float* dYdata = dY.data(); + float* dXdata = dX->mutable_data(); + device_context_.Copy(dXdata, dYdata, Y.size()); + float* scaledata = scale_.mutable_data(); + for (int i = 0; i < N; ++i) { + math::Dot(D, Ydata + i * D, dYdata + i * D, + scaledata + i, &device_context_); + } + const float kMinusOne = -1.; + const float kOne = 1.; + math::Gemm(CblasNoTrans, CblasNoTrans, N, D, 1, &kMinusOne, + scaledata, sum_multiplier_.data(), &kOne, + dXdata, &device_context_); + math::Mul(Y.size(), dXdata, Ydata, dXdata, + &device_context_); + return true; +} + +REGISTER_CPU_OPERATOR(Softmax, SoftmaxOp) +REGISTER_CPU_OPERATOR(SoftmaxGradient, SoftmaxGradientOp) +} // namespace caffe2 diff --git a/caffe2/operators/softmax_op.cu b/caffe2/operators/softmax_op.cu new file mode 100644 index 00000000000..7a614ed562d --- /dev/null +++ b/caffe2/operators/softmax_op.cu @@ -0,0 +1,128 @@ +#include + +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/softmax_op.h" + + +namespace caffe2 { + +#define SOFTMAX_NUM_THREADS 128 + +namespace { +// The softmax kernel. This kernel has to be called with the number of threads +// per block being no more than SOFTMAX_NUM_THREADS. +__global__ void softmax_kernel( + const int dim, const float* data, float* out) { + // For the softmax kernel, each block is a data example. + data += blockIdx.x * dim; + out += blockIdx.x * dim; + const int idx = threadIdx.x; + __shared__ float reduction_buffer[SOFTMAX_NUM_THREADS]; + float tmp; + + // A two-level reduction to get the max. + tmp = -FLT_MAX; + for (int i = idx; i < dim; i += blockDim.x) { + tmp = fmaxf(data[i], tmp); + } + reduction_buffer[idx] = tmp; + __syncthreads(); + if (idx == 0) { + tmp = reduction_buffer[0]; + for (int i = 1; i < blockDim.x; ++i) { + tmp = fmaxf(reduction_buffer[i], tmp); + } + reduction_buffer[0] = tmp; + } + __syncthreads(); + // compute sum with a two-level reduction. + float maxval = reduction_buffer[0]; + reduction_buffer[idx] = 0; + for (int i = idx; i < dim; i += blockDim.x) { + tmp = __expf(data[i] - maxval); + reduction_buffer[idx] += tmp; + out[i] = tmp; + } + __syncthreads(); + if (idx == 0) { + tmp = reduction_buffer[0]; + for (int i = 1; i < blockDim.x; ++i) { + tmp += reduction_buffer[i]; + } + reduction_buffer[0] = tmp; + } + __syncthreads(); + // Compute the softmax; + tmp = reduction_buffer[0]; + for (int i = idx; i < dim; i += blockDim.x) { + out[i] /= tmp; + } +} + +// The softmax gradient kernel. This kernel has to be called with the number of +// threads per block being no more than SOFTMAX_NUM_THREADS. +__global__ void softmax_gradient_kernel( + const int dim, const float* Y, const float* dY, float* dX) { + Y += blockIdx.x * dim; + dY += blockIdx.x * dim; + dX += blockIdx.x * dim; + const int idx = threadIdx.x; + __shared__ float reduction_buffer[SOFTMAX_NUM_THREADS]; + float tmp; + + // A two-level reduction to compute the inner products. + tmp = 0; + for (int i = idx; i < dim; i += blockDim.x) { + tmp += dY[i] * Y[i]; + } + reduction_buffer[idx] = tmp; + __syncthreads(); + if (idx == 0) { + tmp = reduction_buffer[0]; + for (int i = 1; i < blockDim.x; ++i) tmp += reduction_buffer[i]; + reduction_buffer[0] = tmp; + } + __syncthreads(); + // Compute gradient. + tmp = reduction_buffer[0]; + for (int i = idx; i < dim; i += blockDim.x) { + dX[i] = Y[i] * (dY[i] - tmp); + } +} +} // namespace + +// Implementation for the CPU context. +template <> +bool SoftmaxOp::RunOnDevice() { + auto& X = Input(0); + auto* Y = Output(0); + DCHECK_EQ(X.ndim(), 2); + int N = X.dim(0); + int D = X.dim(1); + Y->ReshapeLike(X); + softmax_kernel<<>>( + D, X.data(), Y->mutable_data()); + return true; +} + +// Implementation for the CPU context. +template <> +bool SoftmaxGradientOp::RunOnDevice() { + auto& Y = Input(0); + auto& dY = Input(1); + auto* dX = Output(0); + DCHECK_EQ(Y.ndim(), 2); + int N = Y.dim(0); + int D = Y.dim(1); + DCHECK_EQ(dY.dim(0), N); + DCHECK_EQ(dY.dim(1), D); + dX->ReshapeLike(Y); + softmax_gradient_kernel<<>>( + D, Y.data(), dY.data(), dX->mutable_data()); + return true; +} + +REGISTER_CUDA_OPERATOR(Softmax, SoftmaxOp) +REGISTER_CUDA_OPERATOR(SoftmaxGradient, SoftmaxGradientOp) +} // namespace caffe2 diff --git a/caffe2/operators/softmax_op.h b/caffe2/operators/softmax_op.h new file mode 100644 index 00000000000..1114188908d --- /dev/null +++ b/caffe2/operators/softmax_op.h @@ -0,0 +1,42 @@ +#ifndef CAFFE2_OPERATORS_SOFTMAX_OP_H_ +#define CAFFE2_OPERATORS_SOFTMAX_OP_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/utils/math.h" +#include "glog/logging.h" + +namespace caffe2 { + +template +class SoftmaxOp final : public Operator { + public: + USE_SIMPLE_CTOR_DTOR(SoftmaxOp); + USE_OPERATOR_BASE_FUNCTIONS; + bool RunOnDevice() override; + + protected: + Tensor scale_; + Tensor sum_multiplier_; + INPUT_OUTPUT_STATS(1, 1, 1, 1); + DISABLE_COPY_AND_ASSIGN(SoftmaxOp); +}; + +template +class SoftmaxGradientOp final : public Operator { + public: + USE_SIMPLE_CTOR_DTOR(SoftmaxGradientOp); + USE_OPERATOR_BASE_FUNCTIONS; + bool RunOnDevice() override; + + protected: + Tensor scale_; + Tensor sum_multiplier_; + // Input: Y, dY. Output: dX + INPUT_OUTPUT_STATS(2, 2, 1, 1); + DISABLE_COPY_AND_ASSIGN(SoftmaxGradientOp); +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_SOFTMAX_OP_H_ diff --git a/caffe2/operators/softmax_op_cudnn.cc b/caffe2/operators/softmax_op_cudnn.cc new file mode 100644 index 00000000000..4470e849bd7 --- /dev/null +++ b/caffe2/operators/softmax_op_cudnn.cc @@ -0,0 +1,99 @@ +#include "caffe2/core/common_cudnn.h" +#include "caffe2/core/context_gpu.h" +#include "caffe2/core/types.h" +#include "caffe2/operators/softmax_op.h" + +namespace caffe2 { + +namespace { +const int NUM_DESCRIPTORS = 2; +const int GRADIENT_NUM_DESCRIPTORS = 3; +const int BOTTOM_DESC_ID = 0; +const int TOP_DESC_ID = 1; +const int TOP_GRADIENT_DESC_ID = 2; +} // namespace + + +class CuDNNSoftmaxOp final : public Operator { + public: + explicit CuDNNSoftmaxOp(const OperatorDef& def, Workspace* ws) + : Operator(def, ws), + cudnn_wrapper_(&device_context_) {} + bool RunOnDevice() override; + + protected: + CuDNNWrapper cudnn_wrapper_; + INPUT_OUTPUT_STATS(1, 1, 1, 1); + DISABLE_COPY_AND_ASSIGN(CuDNNSoftmaxOp); +}; + + +class CuDNNSoftmaxGradientOp final : public Operator { + public: + explicit CuDNNSoftmaxGradientOp(const OperatorDef& def, Workspace* ws) + : Operator(def, ws), + cudnn_wrapper_(&device_context_) {} + bool RunOnDevice() override; + + protected: + CuDNNWrapper cudnn_wrapper_; + // Input: Y, dY. Output: dX + INPUT_OUTPUT_STATS(2, 2, 1, 1); + DISABLE_COPY_AND_ASSIGN(CuDNNSoftmaxGradientOp); +}; + +bool CuDNNSoftmaxOp::RunOnDevice() { + auto& X = Input(0); + auto* Y = Output(0); + DCHECK_EQ(X.ndim(), 2); + int N = X.dim(0); + int D = X.dim(1); + Y->ReshapeLike(X); + const float alpha = 1.0; + const float beta = 0.0; + vector dims{N, D, 1, 1}; + cudnn_wrapper_.cudnnSetNumTensorDescriptors(NUM_DESCRIPTORS); + CUDNN_CHECK(cudnnSoftmaxForward(cudnn_wrapper_.cudnn_handle(), + CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_INSTANCE, &alpha, + cudnn_wrapper_.cudnnGetTensor4dDesc( + BOTTOM_DESC_ID, CUDNN_TENSOR_NCHW, dims, nullptr), + X.data(), &beta, + cudnn_wrapper_.cudnnGetTensor4dDesc( + TOP_DESC_ID, CUDNN_TENSOR_NCHW, dims, nullptr), + Y->mutable_data())); + return true; +} + +bool CuDNNSoftmaxGradientOp::RunOnDevice() { + auto& Y = Input(0); + auto& dY = Input(1); + auto* dX = Output(0); + DCHECK_EQ(Y.ndim(), 2); + int N = Y.dim(0); + int D = Y.dim(1); + DCHECK_EQ(dY.dim(0), N); + DCHECK_EQ(dY.dim(1), D); + dX->ReshapeLike(Y); + const float alpha = 1.0; + const float beta = 0.0; + cudnn_wrapper_.cudnnSetNumTensorDescriptors(GRADIENT_NUM_DESCRIPTORS); + vector dims{N, D, 1, 1}; + CUDNN_CHECK(cudnnSoftmaxBackward(cudnn_wrapper_.cudnn_handle(), + CUDNN_SOFTMAX_ACCURATE, CUDNN_SOFTMAX_MODE_INSTANCE, &alpha, + cudnn_wrapper_.cudnnGetTensor4dDesc( + TOP_DESC_ID, CUDNN_TENSOR_NCHW, dims, nullptr), + Y.data(), + cudnn_wrapper_.cudnnGetTensor4dDesc( + TOP_GRADIENT_DESC_ID, CUDNN_TENSOR_NCHW, dims, nullptr), + dY.data(), &beta, + cudnn_wrapper_.cudnnGetTensor4dDesc( + BOTTOM_DESC_ID, CUDNN_TENSOR_NCHW, dims, nullptr), + dX->mutable_data())); + return true; +} + +namespace { +REGISTER_CUDNN_OPERATOR(Softmax, CuDNNSoftmaxOp) +REGISTER_CUDNN_OPERATOR(SoftmaxGradient, CuDNNSoftmaxGradientOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/summarize_op.cc b/caffe2/operators/summarize_op.cc new file mode 100644 index 00000000000..f85ee9024fe --- /dev/null +++ b/caffe2/operators/summarize_op.cc @@ -0,0 +1,48 @@ +#include "caffe2/operators/summarize_op.h" + +namespace caffe2 { + +template<> +bool SummarizeOp::RunOnDevice() { + auto& X = Input(0); + const int N = X.size(); + DCHECK_GT(N, 0); + const float* Xdata = X.data(); + float mean = 0; + float max = Xdata[0]; + float min = Xdata[0]; + for (int i = 0; i < N; ++i) { + mean += Xdata[i]; + max = std::max(max, Xdata[i]); + min = std::min(min, Xdata[i]); + } + mean /= N; + // We will simply do a two-pass. More efficient solutions can be written but + // I'll keep code simple for now. + float standard_deviation = 0; + for (int i = 0; i < N; ++i) { + float diff = Xdata[i] - mean; + standard_deviation += diff * diff; + } + // Unbiased or biased? Let's do unbiased now. + standard_deviation = N == 1 ? 0 : std::sqrt(standard_deviation / (N - 1)); + if (to_file_) { + (*log_file_) << min << " " << max << " " << mean << " " + << standard_deviation << std::endl; + } + if (OutputSize()) { + auto* Y = Output(0); + Y->Reshape(std::vector{NUM_STATS}); + float* Ydata = Y->mutable_data(); + Ydata[MIN_IDX] = min; + Ydata[MAX_IDX] = max; + Ydata[MEAN_IDX] = mean; + Ydata[STD_IDX] = standard_deviation; + } + return true; +} + +namespace { +REGISTER_CPU_OPERATOR(Summarize, SummarizeOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/summarize_op.cu b/caffe2/operators/summarize_op.cu new file mode 100644 index 00000000000..080028822fd --- /dev/null +++ b/caffe2/operators/summarize_op.cu @@ -0,0 +1,112 @@ +#include +#include +#include +#include + +#include "caffe2/operators/summarize_op.h" +#include "caffe2/core/context_gpu.h" + +namespace caffe2 { + +namespace { + +// structure used to accumulate the moments and other statistical properties +// encountered so far. +template +struct SummaryStatsData { + T n; + T min; + T max; + T mean; + T M2; + + // initialize to the identity element + void initialize() { + n = mean = M2 = 0; + min = std::numeric_limits::max(); + max = std::numeric_limits::min(); + } + + T variance() { return (n == 1 ? 0 : M2 / (n - 1)); } +}; + +// stats_unary_op is a functor that takes in a value x and +// returns a variace_data whose mean value is initialized to x. +template +struct summary_stats_unary_op { + __host__ __device__ SummaryStatsData operator()(const T& x) const { + SummaryStatsData result; + result.n = 1; + result.min = x; + result.max = x; + result.mean = x; + result.M2 = 0; + return result; + } +}; + +// summary_stats_binary_op is a functor that accepts two SummaryStatsData +// structs and returns a new SummaryStatsData which are an +// approximation to the summary_stats for +// all values that have been agregated so far +template +struct summary_stats_binary_op + : public thrust::binary_function&, + const SummaryStatsData&, + SummaryStatsData > { + __host__ __device__ SummaryStatsData operator()( + const SummaryStatsData& x, const SummaryStatsData & y) const { + SummaryStatsData result; + T n = x.n + y.n; + T delta = y.mean - x.mean; + T delta2 = delta * delta; + result.n = n; + result.min = thrust::min(x.min, y.min); + result.max = thrust::max(x.max, y.max); + result.mean = x.mean + delta * y.n / n; + result.M2 = x.M2 + y.M2; + result.M2 += delta2 * x.n * y.n / n; + return result; + } +}; + +} // namespace + +template<> +bool SummarizeOp::RunOnDevice() { + auto& X = Input(0); + const int N = X.size(); + DCHECK_GT(N, 0); + + // TODO(Yangqing): Any better way to avoid having to const cast? + thrust::device_ptr Xdata(const_cast(X.data())); + summary_stats_unary_op unary_op; + summary_stats_binary_op binary_op; + SummaryStatsData init; + init.initialize(); + // compute summary statistics + SummaryStatsData result = thrust::transform_reduce( +#if CUDA_VERSION >= 7000 + thrust::cuda::par.on(device_context_.cuda_stream()), +#endif // CUDA_VERSION + Xdata, Xdata + N, unary_op, init, binary_op); + float standard_deviation = std::sqrt(result.variance()); + if (to_file_) { + (*log_file_) << result.min << " " << result.max << " " << result.mean << " " + << standard_deviation << std::endl; + } + if (OutputSize()) { + auto* Y = OperatorBase::Output >(0); + Y->Reshape(std::vector{4}); + float output_buffer[NUM_STATS] = {result.min, result.max, result.mean, + standard_deviation}; + device_context_.Copy( + Y->mutable_data(), output_buffer, NUM_STATS); + } + return true; +} + +namespace { +REGISTER_CUDA_OPERATOR(Summarize, SummarizeOp) +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/summarize_op.h b/caffe2/operators/summarize_op.h new file mode 100644 index 00000000000..01d94db41ee --- /dev/null +++ b/caffe2/operators/summarize_op.h @@ -0,0 +1,58 @@ +#ifndef CAFFE2_OPERATORS_SUMMARIZE_OP_H_ +#define CAFFE2_OPERATORS_SUMMARIZE_OP_H_ + +#include + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/utils/math.h" + +namespace caffe2 { + +constexpr char kSummaryzeOpExtension[] = ".summary"; + +// Accumulate operator accumulates the input tensor to the output tensor. If the +// output tensor already has the right size, we add to it; otherwise, we first +// initialize the output tensor to all zeros, and then do accumulation. Any +// further calls to the operator, given that no one else fiddles with the output +// in the interim, will do simple accumulations. +template +class SummarizeOp final : public Operator { + public: + SummarizeOp(const OperatorDef& def, Workspace* ws) + : Operator(def, ws), + to_file_(OperatorBase::GetSingleArgument("to_file", 0)) { + if (to_file_) { + // We will output to file instead of printing on screen. + const string& target_folder = ws->RootFolder(); + // We will write each individual tensor to its individual file. + log_file_.reset(new std::ofstream( + target_folder + "/" + def.inputs(0) + kSummaryzeOpExtension, + std::ofstream::out | std::ofstream::trunc)); + CHECK(log_file_->good()) + << "Failed to open summarize file for tensor " << def.inputs(0) + << ". rdstate() = " << log_file_->rdstate(); + } + } + ~SummarizeOp() { if (to_file_) log_file_->close(); } + USE_OPERATOR_BASE_FUNCTIONS; + bool RunOnDevice() override; + + static constexpr int MIN_IDX = 0; + static constexpr int MAX_IDX = 1; + static constexpr int MEAN_IDX = 2; + static constexpr int STD_IDX = 3; + static constexpr int NUM_STATS = 4; + + protected: + bool to_file_; + std::unique_ptr log_file_; + // Input: X; output: if set, a summarized vector of shape 4, with the values + // being min, max, mean and std respectively. + INPUT_OUTPUT_STATS(1, 1, 0, 1); + DISABLE_COPY_AND_ASSIGN(SummarizeOp); +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_SUMMARIZE_OP_H_ diff --git a/caffe2/operators/tensor_protos_db_input.cc b/caffe2/operators/tensor_protos_db_input.cc new file mode 100644 index 00000000000..6459d6d76dd --- /dev/null +++ b/caffe2/operators/tensor_protos_db_input.cc @@ -0,0 +1,7 @@ +#include "caffe2/operators/tensor_protos_db_input.h" + +namespace caffe2 { +namespace { +REGISTER_CPU_OPERATOR(TensorProtosDBInput, TensorProtosDBInput); +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/tensor_protos_db_input.h b/caffe2/operators/tensor_protos_db_input.h new file mode 100644 index 00000000000..69d4a33f8a1 --- /dev/null +++ b/caffe2/operators/tensor_protos_db_input.h @@ -0,0 +1,193 @@ +#ifndef CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_ +#define CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_ + +#include + +#include "caffe2/core/db.h" +#include "caffe2/operators/prefetch_op.h" + +namespace caffe2 { + +// tensor protos db input is the simplest input that basically reads +// things from a db where each key-value pair stores a TensorProtos object. +// These tensorprotos should have the same size, and they will be grouped into +// batches of the given size. The output will simply be tensors of float data. +template +class TensorProtosDBInput final + : public PrefetchOperator { + public: + using OperatorBase::OutputSize; + using PrefetchOperator::prefetch_thread_; + explicit TensorProtosDBInput(const OperatorDef& operator_def, Workspace* ws); + ~TensorProtosDBInput() { + if (prefetch_thread_.get() != nullptr) { + prefetch_thread_->join(); + } + } + + bool Prefetch() override; + bool CopyPrefetched() override; + + private: + unique_ptr db_; + unique_ptr cursor_; + // Prefetch will always just happen on the CPU side. + vector > prefetched_blobs_; + vector data_types_; + int batch_size_; + string db_name_; + string db_type_; + DISABLE_COPY_AND_ASSIGN(TensorProtosDBInput); +}; + +template +TensorProtosDBInput::TensorProtosDBInput( + const OperatorDef& operator_def, Workspace* ws) + : PrefetchOperator(operator_def, ws), + batch_size_( + OperatorBase::template GetSingleArgument("batch_size", 0)), + db_name_( + OperatorBase::template GetSingleArgument("db", "")), + db_type_(OperatorBase::template GetSingleArgument( + "db_type", "leveldb")) { + CHECK_GT(batch_size_, 0) << "Batch size should be nonnegative."; + CHECK_GT(db_name_.size(), 0) << "Must provide a leveldb name."; + + db_.reset(db::CreateDB(db_type_, db_name_, db::READ)); + cursor_.reset(db_->NewCursor()); + cursor_->SeekToFirst(); + + // Now, we want to read a data point to initialize the contents. + TensorProtos protos; + protos.ParseFromString(cursor_->value()); + CHECK_EQ(protos.protos_size(), OutputSize()); + prefetched_blobs_.resize(protos.protos_size()); + data_types_.resize(protos.protos_size()); + VLOG(1) << "Figuring data types."; + for (int i = 0; i < protos.protos_size(); ++i) { + vector dims; + for (const int dim : protos.protos(i).dims()) { + dims.push_back(dim); + } + dims[0] = batch_size_; + prefetched_blobs_[i].reset(new Blob()); + Blob* blob = prefetched_blobs_[i].get(); + data_types_[i] = protos.protos(i).data_type(); + switch (data_types_[i]) { + case TensorProto::FLOAT: + VLOG(1) << "Output " << i << ": float"; + blob->GetMutable >()->Reshape(dims); + break; + case TensorProto::INT32: + VLOG(1) << "Output " << i << ": int"; + blob->GetMutable >()->Reshape(dims); + break; + case TensorProto::BYTE: + VLOG(1) << "Output " << i << ": byte -> float"; + // TODO(Yangqing): What type should I use here? Float? + blob->GetMutable >()->Reshape(dims); + break; + case TensorProto::STRING: + LOG(FATAL) << "Not expecting string."; + } + } + cursor_->SeekToFirst(); +} + +template +bool TensorProtosDBInput::Prefetch() { + for (int item_id = 0; item_id < batch_size_; ++item_id) { + // LOG(INFO) << "Prefetching item " << item_id; + // process data + TensorProtos protos; + protos.ParseFromString(cursor_->value()); + // TODO(Yangqing): do we want to do anything to sanity check the data? + for (int i = 0; i < protos.protos_size(); ++i) { + const TensorProto& proto = protos.protos(i); + Blob* blob = prefetched_blobs_[i].get(); + switch (proto.data_type()) { + case TensorProto::FLOAT: + { + DCHECK((blob->IsType >())); + auto* tensor = blob->GetMutable >(); + int single_size = proto.float_data_size(); + CHECK_EQ(single_size * batch_size_, tensor->size()); + memcpy(tensor->mutable_data() + single_size * item_id, + proto.float_data().data(), single_size * sizeof(float)); + break; + } + case TensorProto::INT32: + { + DCHECK((blob->IsType >())); + auto* tensor = blob->GetMutable >(); + int single_size = proto.int32_data_size(); + CHECK_EQ(single_size * batch_size_, tensor->size()); + int* dst_pointer = tensor->mutable_data() + single_size * item_id; + for (int i = 0; i < single_size; ++i) { + dst_pointer[i] = proto.int32_data(i); + } + break; + } + case TensorProto::BYTE: + { + DCHECK((blob->IsType >())); + auto* tensor = blob->GetMutable >(); + const string& src_data = proto.byte_data(); + int single_size = src_data.size(); + CHECK_EQ(single_size * batch_size_, tensor->size()); + float* dst_pointer = tensor->mutable_data() + single_size * item_id; + for (int i = 0; i < single_size; ++i) { + dst_pointer[i] = + static_cast(static_cast(src_data[i])) / 256.f; + } + break; + } + default: + LOG(ERROR) << "Unknown input data type: " << proto.data_type(); + return false; + } + } + cursor_->Next(); + if (!cursor_->Valid()) { + cursor_->SeekToFirst(); + } + } + return true; +} + +template +bool TensorProtosDBInput::CopyPrefetched() { + for (int i = 0; i < OutputSize(); ++i) { + switch (data_types_[i]) { + case TensorProto::FLOAT: + case TensorProto::BYTE: + { + auto* output = OperatorBase::Output >(i); + auto& input = + prefetched_blobs_[i]->template Get >(); + output->ReshapeLike(input); + this->device_context_.template Copy( + output->mutable_data(), input.data(), input.size()); + break; + } + case TensorProto::INT32: + { + auto* output = OperatorBase::Output >(i); + auto& input = + prefetched_blobs_[i]->template Get >(); + output->ReshapeLike(input); + this->device_context_.template Copy( + output->mutable_data(), input.data(), input.size()); + break; + } + case TensorProto::STRING: + LOG(FATAL) << "Not expecting string."; + } + } + return true; +} + + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_TENSOR_PROTOS_DB_INPUT_H_ diff --git a/caffe2/operators/tensor_protos_db_input_gpu.cc b/caffe2/operators/tensor_protos_db_input_gpu.cc new file mode 100644 index 00000000000..816eabe64a3 --- /dev/null +++ b/caffe2/operators/tensor_protos_db_input_gpu.cc @@ -0,0 +1,9 @@ +#include "caffe2/core/common_gpu.h" +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/tensor_protos_db_input.h" + +namespace caffe2 { +namespace { +REGISTER_CUDA_OPERATOR(TensorProtosDBInput, TensorProtosDBInput); +} // namespace +} // namespace caffe2 diff --git a/caffe2/operators/tensor_protos_db_input_test.cc b/caffe2/operators/tensor_protos_db_input_test.cc new file mode 100644 index 00000000000..2b2b52451bb --- /dev/null +++ b/caffe2/operators/tensor_protos_db_input_test.cc @@ -0,0 +1,81 @@ +#include + +#include "caffe2/operators/tensor_protos_db_input.h" +#include "gflags/gflags.h" +#include "gtest/gtest.h" + +DECLARE_string(caffe_test_root); + +const char* kTestDBPath = "/data/mnist/mnist-train-minidb"; + +namespace caffe2 { + +const int kNumItems = 51200; +const int kLabelsToCheck = 12; +const int kLabels[] = {5, 0, 4, 1, 9, 2, 1, 3, 1, 4, 3, 5}; + +static void TestMNISTLoad(const int batch_size) { + Workspace ws; + OperatorDef def; + def.set_name("test"); + def.set_type("TensorProtosDBInput"); + def.add_outputs("data"); + def.add_outputs("label"); + auto* batch_arg = def.add_args(); + batch_arg->set_name("batch_size"); + batch_arg->set_i(batch_size); + auto* db_arg = def.add_args(); + db_arg->set_name("db"); + db_arg->set_s(FLAGS_caffe_test_root + string(kTestDBPath)); + auto* db_type_arg = def.add_args(); + db_type_arg->set_name("db_type"); + db_type_arg->set_s("minidb"); + unique_ptr op(CreateOperator(def, &ws)); + EXPECT_NE(nullptr, op.get()); + for (int iter = 0; iter < kNumItems / batch_size; ++iter) { + EXPECT_TRUE(op->Run()); + // Inspect the result + auto* data_blob = ws.GetBlob("data"); + EXPECT_TRUE((data_blob->IsType >())); + auto* label_blob = ws.GetBlob("label"); + EXPECT_TRUE((label_blob->IsType >())); + auto& data_tensor = data_blob->Get >(); + auto& label_tensor = label_blob->Get >(); + EXPECT_EQ(data_tensor.ndim(), 4); + EXPECT_EQ(data_tensor.dim(0), batch_size); + EXPECT_EQ(data_tensor.dim(1), 28); + EXPECT_EQ(data_tensor.dim(2), 28); + EXPECT_EQ(data_tensor.dim(3), 1); + EXPECT_EQ(label_tensor.ndim(), 1); + EXPECT_EQ(label_tensor.dim(0), batch_size); + /* + // Visualization just for inspection purpose. + int idx = 0; + for (int b = 0; b < batch_size; ++b) { + for (int row = 0; row < 28; ++row) { + for (int col = 0; col < 28; ++col) { + std::cout << (data_tensor.data()[idx++] > 128) << " "; + } + std::cout << std::endl; + } + std::cout << std::endl << std::endl; + } + std::cout << "label: " << label_tensor.data()[0] << std::endl; + */ + for (int i = 0; i < batch_size; ++i) { + if (iter * batch_size + i < kLabelsToCheck) { + EXPECT_EQ(label_tensor.data()[i], kLabels[iter * batch_size + i]); + } + } + } +} + +TEST(TensorProtosDBInputTest, TestLoadBatchOne) { + TestMNISTLoad(1); +} + +TEST(TensorProtosDBInputTest, TestLoadBatch64) { + TestMNISTLoad(64); +} + +} // namespace caffe2 diff --git a/caffe2/operators/utility_ops.cc b/caffe2/operators/utility_ops.cc new file mode 100644 index 00000000000..576bfb90489 --- /dev/null +++ b/caffe2/operators/utility_ops.cc @@ -0,0 +1,21 @@ +#include "caffe2/operators/utility_ops.h" + +namespace caffe2 { +namespace { + +REGISTER_CPU_OPERATOR(Free, FreeOp); +REGISTER_CPU_OPERATOR(Print, PrintOp); +REGISTER_CPU_OPERATOR(PrintInt, PrintOp); +REGISTER_CPU_OPERATOR(Flatten, FlattenOp); +REGISTER_CPU_OPERATOR(Alias, AliasOp); +REGISTER_CPU_OPERATOR(ReshapeLike, ReshapeLikeOp); +REGISTER_CPU_OPERATOR(Split, SplitOp); +REGISTER_CPU_OPERATOR(Sum, SumOp); +REGISTER_CPU_OPERATOR(WeightedSum, WeightedSumOp); +REGISTER_CPU_OPERATOR(Copy, CopyOp); + + +} // namespace +} // namespace caffe2 + + diff --git a/caffe2/operators/utility_ops.h b/caffe2/operators/utility_ops.h new file mode 100644 index 00000000000..d6d18827977 --- /dev/null +++ b/caffe2/operators/utility_ops.h @@ -0,0 +1,283 @@ +#ifndef CAFFE2_OPERATORS_UTILITY_OPS_H_ +#define CAFFE2_OPERATORS_UTILITY_OPS_H_ + +#include +#include + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/utils/math.h" +#include "glog/logging.h" + +namespace caffe2 { + +const char kPrintFileExtension[] = ".log"; + +// FreeOp frees the content of the output blob. We allow it to take in input +// blobs purely for the reason that it can "wait" on the input blobs to be +// produced by some of the earlier operators before it is used. +class FreeOp : public OperatorBase { + public: + USE_SIMPLE_BASE_CTOR_DTOR(FreeOp); + + bool Run() final { + for (Blob* output : Outputs()) { + output->Reset(); + } + return true; + } + + INPUT_OUTPUT_STATS(0, INT_MAX, 1, INT_MAX); + DISABLE_COPY_AND_ASSIGN(FreeOp); +}; + +template +class PrintOp final : public Operator { + public: + USE_OPERATOR_BASE_FUNCTIONS; + PrintOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), + to_file_(OperatorBase::GetSingleArgument("to_file", 0)), + limit_(OperatorBase::GetSingleArgument("limit", 0)) { + if (limit_ == 0) { + limit_ = INT_MAX; + } + if (to_file_) { + // We will output to file instead of printing on screen. + const string& target_folder = ws->RootFolder(); + // We will write each individual tensor to its individual file. + log_files_.resize(def().inputs_size()); + for (int i = 0; i < def().inputs_size(); ++i) { + log_files_[i].reset(new std::ofstream( + target_folder + "/" + def().inputs(i) + kPrintFileExtension, + std::ofstream::out | std::ofstream::trunc)); + CHECK(log_files_[i]->good()) + << "Failed to open PrintOp file for tensor " << def().inputs(i) + << ". rdstate() = " << log_files_[i]->rdstate(); + } + } + } + + ~PrintOp() { + for (auto& log_file : log_files_) { + log_file->close(); + } + } + + bool RunOnDevice() final { + Tensor temp_tensor; + for (int input_id = 0; input_id < InputSize(); ++input_id) { + auto& input = Input(input_id); + DCHECK_GT(input.size(), 0); + temp_tensor.ReshapeLike(input); + device_context_.template Copy( + temp_tensor.mutable_data(), input.data(), input.size()); + std::stringstream dims_stream; + for (const int dim : input.dims()) { + dims_stream << dim << ","; + } + std::stringstream values_stream; + int total_count = std::min(temp_tensor.size(), limit_); + const dtype* temp_tensor_data = temp_tensor.data(); + for (int i = 0; i < total_count - 1; ++i) { + values_stream << temp_tensor_data[i] << ","; + } + // We do not add a comma after the last item. + values_stream << temp_tensor_data[total_count - 1]; + if (to_file_) { + // Also log to file. + auto& stream = *log_files_[input_id]; + stream << values_stream.str(); + stream << std::endl; + } else { + // Log to console. + LOG(INFO) << "Tensor " << def().inputs(input_id) + << " (" << dims_stream.str() << "): " << values_stream.str(); + } + } + return true; + } + + private: + bool to_file_; + int limit_; + vector > log_files_; + INPUT_OUTPUT_STATS(1, INT_MAX, 0, 0); + DISABLE_COPY_AND_ASSIGN(PrintOp); +}; + +template +class AliasOp final : public Operator { + public: + USE_OPERATOR_BASE_FUNCTIONS; + USE_SIMPLE_CTOR_DTOR(AliasOp); + + bool RunOnDevice() final { + auto& input = Input(0); + DCHECK_GT(input.size(), 0); + if (Output(0) == &input) { + // If one calls an AliasOp but in fact it is in-place (input and output + // are the same tensor), we will simply skip. + return true; + } else { + Output(0)->ReshapeLike(input); + Output(0)->ShareData(input); + } + return true; + } + + INPUT_OUTPUT_STATS(1, 1, 1, 1); + DISABLE_COPY_AND_ASSIGN(AliasOp); +}; + +template +class FlattenOp : public Operator { + public: + USE_OPERATOR_BASE_FUNCTIONS; + USE_SIMPLE_CTOR_DTOR(FlattenOp); + + bool RunOnDevice() final { + auto& input = Input(0); + DCHECK_GT(input.size(), 0); + Output(0)->Reshape( + std::vector{input.dim(0), input.size() / input.dim(0)}); + Output(0)->ShareData(input); + return true; + } + + INPUT_OUTPUT_STATS(1, 1, 1, 1); + DISABLE_COPY_AND_ASSIGN(FlattenOp); +}; + +// Output shares the data of input(0), but reshapes it like input(1). +template +class ReshapeLikeOp : public Operator { + public: + USE_OPERATOR_BASE_FUNCTIONS; + USE_SIMPLE_CTOR_DTOR(ReshapeLikeOp); + + bool RunOnDevice() final { + auto* output = Output(0); + DCHECK_EQ(Input(0).size(), Input(1).size()); + output->ReshapeLike(Input(1)); + output->ShareData(Input(0)); + return true; + } + + INPUT_OUTPUT_STATS(2, 2, 1, 1); + DISABLE_COPY_AND_ASSIGN(ReshapeLikeOp); +}; + +template +class SplitOp : public Operator { + public: + USE_OPERATOR_BASE_FUNCTIONS; + USE_SIMPLE_CTOR_DTOR(SplitOp); + + bool RunOnDevice() final { + const auto& input = Input(0); + for (int i = 0; i < OutputSize(); ++i) { + auto* output = Output(i); + output->ReshapeLike(input); + output->ShareData(input); + } + return true; + } + + INPUT_OUTPUT_STATS(1, 1, 1, INT_MAX); + DISABLE_COPY_AND_ASSIGN(SplitOp); +}; + +template +class SumOp : public Operator { + public: + USE_OPERATOR_BASE_FUNCTIONS; + USE_SIMPLE_CTOR_DTOR(SumOp); + + bool RunOnDevice() final { + auto& input = Input(0); + auto* output = Output(0); + output->ReshapeLike(input); + device_context_.template Copy( + output->mutable_data(), input.data(), input.size()); + for (int i = 1; i < InputSize(); ++i) { + math::Add(output->size(), output->data(), Input(i).data(), + output->mutable_data(), &device_context_); + } + return true; + } + + INPUT_OUTPUT_STATS(1, INT_MAX, 1, 1); + DISABLE_COPY_AND_ASSIGN(SumOp); +}; + +// WeightedSumOp computes the weighted sum of several tensors. The input should +// be in the form X_0, weight_0, X_1, weight_1, ... where X_i all have the same +// shape, and weight_i are size 1 tensors that specifies the weight of each +// vector. Note that if one wants to do in-place computation, it could only be +// done with X_0 also as the output, but not other X_i. +template +class WeightedSumOp : public Operator { + public: + USE_OPERATOR_BASE_FUNCTIONS; + USE_SIMPLE_CTOR_DTOR(WeightedSumOp); + + bool RunOnDevice() final { + DCHECK_EQ(InputSize() % 2, 0); + auto& X0 = Input(0); + auto& weight0 = Input(1); + DCHECK_GT(X0.size(), 0); + DCHECK_EQ(weight0.size(), 1); + int size = X0.size(); + auto* output = Output(0); + output->ReshapeLike(X0); + math::Scale( + size, weight0.data(), X0.data(), output->mutable_data(), + &device_context_); + for (int i = 2; i < InputSize(); i += 2) { + auto& X = Input(i); + // Do a check: if the input is the same as output, we have a problem - + // in-place update should always only happen with the zeroth input. + if (&X == output) { + LOG(ERROR) << "Input #" << i << " is the same as output. " + << "If you want to do in-place updates, put the output as " + << "input #0."; + return false; + } + auto& weight = Input(i + 1); + DCHECK_EQ(X.size(), size); + DCHECK_EQ(weight.size(), 1); + math::Axpy( + size, weight.data(), X.data(), output->mutable_data(), + &device_context_); + } + return true; + } + + INPUT_OUTPUT_STATS(2, INT_MAX, 1, 1); + DISABLE_COPY_AND_ASSIGN(WeightedSumOp); +}; + +template +class CopyOp : public Operator { + public: + USE_OPERATOR_BASE_FUNCTIONS; + USE_SIMPLE_CTOR_DTOR(CopyOp); + + bool RunOnDevice() final { + auto& input = OperatorBase::Input >(0); + auto* output = OperatorBase::Output >(0); + output->ReshapeLike(input); + this->device_context_.template Copy( + output->mutable_data(), input.data(), input.size()); + return true; + } + + INPUT_OUTPUT_STATS(1, 1, 1, 1); + DISABLE_COPY_AND_ASSIGN(CopyOp); +}; + +} // namespace caffe2 + +#endif // CAFFE2_OPERATORS_UTILITY_OPS_H_ diff --git a/caffe2/operators/utility_ops_gpu.cc b/caffe2/operators/utility_ops_gpu.cc new file mode 100644 index 00000000000..b92264f18fa --- /dev/null +++ b/caffe2/operators/utility_ops_gpu.cc @@ -0,0 +1,24 @@ +#include "caffe2/core/context_gpu.h" +#include "caffe2/operators/utility_ops.h" + +namespace caffe2 { +namespace { + +REGISTER_CUDA_OPERATOR(Free, FreeOp); +REGISTER_CUDA_OPERATOR(Print, PrintOp); +REGISTER_CUDA_OPERATOR(PrintInt, PrintOp); +REGISTER_CUDA_OPERATOR(Flatten, FlattenOp); +REGISTER_CUDA_OPERATOR(Alias, FlattenOp); +REGISTER_CUDA_OPERATOR(ReshapeLike, ReshapeLikeOp); +REGISTER_CUDA_OPERATOR(Split, SplitOp); +REGISTER_CUDA_OPERATOR(Sum, SumOp); +REGISTER_CUDA_OPERATOR(WeightedSum, WeightedSumOp); +REGISTER_CUDA_OPERATOR(CopyGPUToCPU, + CopyOp); +REGISTER_CUDA_OPERATOR(CopyCPUToGPU, + CopyOp); + +} // namespace +} // namespace caffe2 + + diff --git a/caffe2/proto/BREW b/caffe2/proto/BREW new file mode 100644 index 00000000000..b8a4a6e263f --- /dev/null +++ b/caffe2/proto/BREW @@ -0,0 +1,18 @@ +# Build file for the caffe2 protocol buffers. + +proto_library( + name = 'caffe2_proto', + srcs = Glob(['*.proto']), + deps = [ + "//third_party/google:protobuf", + ] +) + +filegroup( + name = "caffe2_proto_py", + srcs = ["__init__.py"], + deps = [ + "//caffe2:caffe2_python", + ] +) + diff --git a/caffe2/proto/__init__.py b/caffe2/proto/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/caffe2/proto/caffe2.proto b/caffe2/proto/caffe2.proto new file mode 100644 index 00000000000..7d6f20fa226 --- /dev/null +++ b/caffe2/proto/caffe2.proto @@ -0,0 +1,117 @@ +syntax = "proto2"; + +package caffe2; + +// option optimize_for = LITE_RUNTIME; + +message TensorProto { + // The dimensions in the tensor. + repeated int32 dims = 1; + enum DataType { + FLOAT = 1; + INT32 = 2; + BYTE = 3; + STRING = 4; + } + optional DataType data_type = 2 [default = FLOAT]; + repeated float float_data = 3 [packed = true]; + repeated int32 int32_data = 4 [packed = true]; + optional bytes byte_data = 5; + repeated bytes string_data = 6; + optional string name = 7; +} + +message TensorProtos { + repeated TensorProto protos = 1; +} + +message Argument { + optional string name = 1; + optional float f = 2; + optional int32 i = 3; + optional string s = 4; + repeated float floats = 5; + repeated int32 ints = 6; + repeated string strings = 7; +} + +enum DeviceType { + CPU = 0; // In default, we will use CPU. + CUDA = 1; // CUDA, with custom kernels. +} + +message DeviceOption { + // Options that need to be carried out before running the execution. + optional DeviceType device_type = 1 [ default = CPU ]; + // the cuda gpu id. If the device is not CUDA, this field will simply be + // ignored. + optional int32 cuda_gpu_id = 2; + // The random seed to start the device random number generator with. + optional uint32 random_seed = 3; +} + +message OperatorDef { + repeated string inputs = 1; // the name of the input blobs + repeated string outputs = 2; // the name of output top blobs + optional string name = 3; // the layer name + optional string type = 4; // the layer type + + repeated Argument args = 5; + + optional DeviceOption device_option = 6; + + // For most networks, don't do extensions. Instead, pack the parameters into + // the three categories listed above, and document them clearly in the source + // code. + extensions 1000 to max; +} + +message NetDef { + optional string name = 1; // the network's name + repeated OperatorDef operators = 2; // a bunch of operators. + + // net_type and net_args are implementation-specific parameters that we want + // to pass to specialized implementations. If you do not care about this, you + // don't need to set them. + optional string net_type = 3; // the type of network that we run this with. + // the number of workers, if the operators in the network is to be carried out + // in parallel. + optional int32 num_workers = 4; + // The device option for the network. If a network has a specific device + // option and one of its operators does not have it set, we will copy over the + // device option to the operator. This allows us to basically avoid putting + // device options at every operator. + optional DeviceOption device_option = 5; +} + +// ExecutionStep is actually a sort-of-hacky way we simulate iteration right +// now. +message ExecutionStep { + // ExecutionStep should either contain a set of substeps, or a set of + // network names to run in this execution step. They should NOT both be set + // at the same time. + optional string name = 1; + repeated ExecutionStep substeps = 2; + repeated string networks = 3; + optional int32 iterations = 4; +} + +message PlanDef { + // All the networks that are used in this execution. Note that networks should + // be orderd in the way they are executed, i.e. for a layer in a network, all + // its input blobs should already have been initialized by the layers or + // networks defined before it. + optional string name = 1; + repeated NetDef networks = 2; + repeated ExecutionStep execution_steps = 3; +} + +// ClientDef is a model we use to ship a pre-trained model. This contains two +// parts basically: one set of parameters, and one network. +message SimpleClientDef { + optional string name = 1; + optional NetDef init_net = 2; + optional NetDef main_net = 3; + optional string input = 4; + optional string output = 5; +} \ No newline at end of file diff --git a/caffe2/proto/caffe2_legacy.proto b/caffe2/proto/caffe2_legacy.proto new file mode 100644 index 00000000000..146fc8ae5c1 --- /dev/null +++ b/caffe2/proto/caffe2_legacy.proto @@ -0,0 +1,9 @@ +syntax = "proto2"; + +package caffe2; + +enum LegacyPadding { + NOTSET = 0; + VALID = 1; + SAME = 2; +} \ No newline at end of file diff --git a/caffe2/sgd/BREW b/caffe2/sgd/BREW new file mode 100644 index 00000000000..55f0706f998 --- /dev/null +++ b/caffe2/sgd/BREW @@ -0,0 +1,27 @@ +cc_library( + name = "sgd_ops", + srcs = [ + "iter_op.cc", + "learning_rate_op.cc", + ], + hdrs = [ + "learning_rate_functors.h", + "learning_rate_op.h" + ], + deps = [ + "//caffe2/core:core", + ], + whole_archive = True, +) + +cuda_library( + name = "sgd_ops_gpu", + srcs = [ + "learning_rate_op.cu", + ], + deps = [ + ":sgd_ops", + "//caffe2/core:core_gpu", + ], + whole_archive = True, +) \ No newline at end of file diff --git a/caffe2/sgd/iter_op.cc b/caffe2/sgd/iter_op.cc new file mode 100644 index 00000000000..24cdeaa0f20 --- /dev/null +++ b/caffe2/sgd/iter_op.cc @@ -0,0 +1,30 @@ +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" + +namespace caffe2 { + +// IterOp runs an iteration counter. I cannot think of a case where we would +// need to access the iter variable on device, so this will always produce an +// int value as its output. +class IterOp final : public OperatorBase { + public: + IterOp(const OperatorDef& operator_def, Workspace* ws) + : OperatorBase(operator_def, ws), iter_(-1) {} + + bool Run() override { + iter_++; + *OperatorBase::Output(0) = iter_; + return true; + } + + private: + int iter_; + INPUT_OUTPUT_STATS(0, 0, 1, 1); + DISABLE_COPY_AND_ASSIGN(IterOp); +}; + +namespace { +REGISTER_CPU_OPERATOR(Iter, IterOp) +REGISTER_CUDA_OPERATOR(Iter, IterOp) +} +} // namespace caffe2 diff --git a/caffe2/sgd/learning_rate_functors.h b/caffe2/sgd/learning_rate_functors.h new file mode 100644 index 00000000000..917e5f68549 --- /dev/null +++ b/caffe2/sgd/learning_rate_functors.h @@ -0,0 +1,63 @@ +#ifndef CAFFE2_SGD_LEARNING_RATE_FUNCTORS_H_ +#define CAFFE2_SGD_LEARNING_RATE_FUNCTORS_H_ + +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" + +namespace caffe2 { + +template +class LearningRateFunctor { + public: + virtual dtype operator()(const int iter) const = 0; +}; + +// Fixed: not changing the learning rate at all. +template +class FixedLearningRate : public LearningRateFunctor { + public: + dtype operator()(const int iter) const override { return 1.; } +}; + +// Step: return gamma ^ (floor(iter / step)) +template +class StepLearningRate : public LearningRateFunctor { + public: + StepLearningRate(const int stepsize, const dtype gamma) + : stepsize_(stepsize), gamma_(gamma) {} + dtype operator()(const int iter) const override { + return std::pow(gamma_, static_cast(iter / stepsize_)); + } + + int stepsize_; + dtype gamma_; +}; + +// Exp: return gamma ^ iter +template +class ExpLearningRate : public LearningRateFunctor { + public: + explicit ExpLearningRate(const dtype gamma) : gamma_(gamma) {} + dtype operator()(const int iter) const override { + return std::pow(gamma_, static_cast(iter)); + } + + dtype gamma_; +}; + +// Inv: return (1 + gamma * iter) ^ (-power) +template +class InvLearningRate : public LearningRateFunctor { + public: + InvLearningRate(const dtype gamma, const dtype power) + : gamma_(gamma), power_(power) {} + dtype operator()(const int iter) const override { + return std::pow(dtype(1) + gamma_ * iter, -power_); + } + dtype gamma_; + dtype power_; +}; + +} // namespace caffe2 + +#endif // CAFFE2_SGD_LEARNING_RATE_FUNCTORS_H_ diff --git a/caffe2/sgd/learning_rate_op.cc b/caffe2/sgd/learning_rate_op.cc new file mode 100644 index 00000000000..2be60f4867d --- /dev/null +++ b/caffe2/sgd/learning_rate_op.cc @@ -0,0 +1,7 @@ +#include "caffe2/sgd/learning_rate_op.h" + +namespace caffe2 { +namespace { +REGISTER_CPU_OPERATOR(LearningRate, LearningRateOp); +} // namespace +} // namespace caffe2 diff --git a/caffe2/sgd/learning_rate_op.cu b/caffe2/sgd/learning_rate_op.cu new file mode 100644 index 00000000000..49461a7674d --- /dev/null +++ b/caffe2/sgd/learning_rate_op.cu @@ -0,0 +1,8 @@ +#include "caffe2/core/context_gpu.h" +#include "caffe2/sgd/learning_rate_op.h" + +namespace caffe2 { +namespace { +REGISTER_CUDA_OPERATOR(LearningRate, LearningRateOp); +} // namespace +} // namespace caffe2 diff --git a/caffe2/sgd/learning_rate_op.h b/caffe2/sgd/learning_rate_op.h new file mode 100644 index 00000000000..0a89dacdd6b --- /dev/null +++ b/caffe2/sgd/learning_rate_op.h @@ -0,0 +1,68 @@ +#ifndef CAFFE2_SGD_LEARNING_RATE_OP_H_ +#define CAFFE2_SGD_LEARNING_RATE_OP_H_ + +#include +#include +#include "caffe2/core/context.h" +#include "caffe2/core/operator.h" +#include "caffe2/sgd/learning_rate_functors.h" + +namespace caffe2 { + +template +class LearningRateOp final : public Operator { + public: + LearningRateOp(const OperatorDef& operator_def, Workspace* ws) + : Operator(operator_def, ws), functor_(nullptr), + base_lr_( + OperatorBase::template GetSingleArgument("base_lr", FLT_MAX)) { + CHECK_NE(base_lr_, FLT_MAX) << "Base learning rate must be set."; + const string policy = OperatorBase::GetSingleArgument("policy", ""); + CHECK(policy.size()) << "Must specify a learning rate policy."; + if (policy == "fixed") { + functor_.reset(new FixedLearningRate()); + } else if (policy == "step") { + int stepsize = + OperatorBase::template GetSingleArgument("stepsize", 0); + dtype gamma = OperatorBase::template GetSingleArgument("gamma", 0); + DCHECK_GT(stepsize, 0); + DCHECK_GT(gamma, 0); + functor_.reset(new StepLearningRate(stepsize, gamma)); + } else if (policy == "exp") { + dtype gamma = OperatorBase::template GetSingleArgument("gamma", 0); + DCHECK_GT(gamma, 0); + functor_.reset(new ExpLearningRate(gamma)); + } else if (policy == "inv") { + dtype gamma = OperatorBase::template GetSingleArgument("gamma", 0); + dtype power = OperatorBase::template GetSingleArgument("power", 0); + DCHECK_GT(gamma, 0); + DCHECK_GT(power, 0); + functor_.reset(new InvLearningRate(gamma, power)); + } else { + LOG(FATAL) << "Unknown learning rate policy: " << policy; + } + } + USE_OPERATOR_BASE_FUNCTIONS; + + bool RunOnDevice() override { + int iter = OperatorBase::Input(0); + dtype learning_rate = base_lr_ * (*functor_)(iter); + // Write to output. + auto* output = Output(0); + output->Reshape(std::vector{1}); + device_context_.template Copy( + Output(0)->mutable_data(), &learning_rate, 1); + return true; + } + + private: + unique_ptr > functor_; + dtype base_lr_; + + INPUT_OUTPUT_STATS(1, 1, 1, 1); + DISABLE_COPY_AND_ASSIGN(LearningRateOp); +}; + +} // namespace caffe2 + +#endif // CAFFE2_SGD_LEARNING_RATE_OP_H_ diff --git a/caffe2/utils/BREW b/caffe2/utils/BREW new file mode 100644 index 00000000000..8a5bf7e6913 --- /dev/null +++ b/caffe2/utils/BREW @@ -0,0 +1,70 @@ +cc_library( + name = "math", + srcs = [ + "math_cpu.cc", + ], + hdrs = [ + "cblas.h", + "math.h", + "mkl_alternate.h", + ], + cflags = [ "-DEIGEN_NO_DEBUG", ], + deps = [ + "//third_party/eigen3:eigen", + "//caffe2/core:core", + ], +) + +cuda_library( + name = "math_gpu", + srcs = [ + "math_gpu.cu", + ], + deps = [ + ":math", + "//caffe2/core:core_gpu", + ], +) + +cc_library( + name = "proto_utils", + srcs = ["proto_utils.cc"], + hdrs = [ + "proto_utils.h", + ], + deps = [ + "//caffe2/proto:caffe2_proto", + "//third_party/glog:glog", + ], +) + +cc_test( + name = "math_test", + srcs = [ + "math_test.cc", + ], + deps = [ + ":math", + "//caffe2/proto:caffe2_proto", + "//gtest:gtest_main", + "//caffe2/core:core", + ], +) + +cc_headers( + name = "simple_queue", + srcs = [ + "simple_queue.h" + ], +) + +cc_test( + name = "simple_queue_test", + srcs = [ + "simple_queue_test.cc", + ], + deps = [ + ":simple_queue", + "//gtest:gtest_main", + ], +) diff --git a/caffe2/utils/cblas.h b/caffe2/utils/cblas.h new file mode 100644 index 00000000000..bb37a0d3684 --- /dev/null +++ b/caffe2/utils/cblas.h @@ -0,0 +1,600 @@ +// This is the exact cblas.h header file, placed here purely in order to get +// the enums. + +#ifndef CBLAS_H + +#ifndef CBLAS_ENUM_DEFINED_H + #define CBLAS_ENUM_DEFINED_H + enum CBLAS_ORDER {CblasRowMajor=101, CblasColMajor=102 }; + enum CBLAS_TRANSPOSE {CblasNoTrans=111, CblasTrans=112, CblasConjTrans=113, + AtlasConj=114}; + enum CBLAS_UPLO {CblasUpper=121, CblasLower=122}; + enum CBLAS_DIAG {CblasNonUnit=131, CblasUnit=132}; + enum CBLAS_SIDE {CblasLeft=141, CblasRight=142}; +#endif + +#ifndef CBLAS_ENUM_ONLY +#define CBLAS_H +#define CBLAS_INDEX int + +int cblas_errprn(int ierr, int info, char *form, ...); +void cblas_xerbla(int p, const char *rout, const char *form, ...); + +/* + * =========================================================================== + * Prototypes for level 1 BLAS functions (complex are recast as routines) + * =========================================================================== + */ +float cblas_sdsdot(const int N, const float alpha, const float *X, + const int incX, const float *Y, const int incY); +double cblas_dsdot(const int N, const float *X, const int incX, const float *Y, + const int incY); +float cblas_sdot(const int N, const float *X, const int incX, + const float *Y, const int incY); +double cblas_ddot(const int N, const double *X, const int incX, + const double *Y, const int incY); +/* + * Functions having prefixes Z and C only + */ +void cblas_cdotu_sub(const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotu); +void cblas_cdotc_sub(const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotc); + +void cblas_zdotu_sub(const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotu); +void cblas_zdotc_sub(const int N, const void *X, const int incX, + const void *Y, const int incY, void *dotc); + + +/* + * Functions having prefixes S D SC DZ + */ +float cblas_snrm2(const int N, const float *X, const int incX); +float cblas_sasum(const int N, const float *X, const int incX); + +double cblas_dnrm2(const int N, const double *X, const int incX); +double cblas_dasum(const int N, const double *X, const int incX); + +float cblas_scnrm2(const int N, const void *X, const int incX); +float cblas_scasum(const int N, const void *X, const int incX); + +double cblas_dznrm2(const int N, const void *X, const int incX); +double cblas_dzasum(const int N, const void *X, const int incX); + + +/* + * Functions having standard 4 prefixes (S D C Z) + */ +CBLAS_INDEX cblas_isamax(const int N, const float *X, const int incX); +CBLAS_INDEX cblas_idamax(const int N, const double *X, const int incX); +CBLAS_INDEX cblas_icamax(const int N, const void *X, const int incX); +CBLAS_INDEX cblas_izamax(const int N, const void *X, const int incX); + +/* + * =========================================================================== + * Prototypes for level 1 BLAS routines + * =========================================================================== + */ + +/* + * Routines with standard 4 prefixes (s, d, c, z) + */ +void cblas_sswap(const int N, float *X, const int incX, + float *Y, const int incY); +void cblas_scopy(const int N, const float *X, const int incX, + float *Y, const int incY); +void cblas_saxpy(const int N, const float alpha, const float *X, + const int incX, float *Y, const int incY); +void catlas_saxpby(const int N, const float alpha, const float *X, + const int incX, const float beta, float *Y, const int incY); +void catlas_sset + (const int N, const float alpha, float *X, const int incX); + +void cblas_dswap(const int N, double *X, const int incX, + double *Y, const int incY); +void cblas_dcopy(const int N, const double *X, const int incX, + double *Y, const int incY); +void cblas_daxpy(const int N, const double alpha, const double *X, + const int incX, double *Y, const int incY); +void catlas_daxpby(const int N, const double alpha, const double *X, + const int incX, const double beta, double *Y, const int incY); +void catlas_dset + (const int N, const double alpha, double *X, const int incX); + +void cblas_cswap(const int N, void *X, const int incX, + void *Y, const int incY); +void cblas_ccopy(const int N, const void *X, const int incX, + void *Y, const int incY); +void cblas_caxpy(const int N, const void *alpha, const void *X, + const int incX, void *Y, const int incY); +void catlas_caxpby(const int N, const void *alpha, const void *X, + const int incX, const void *beta, void *Y, const int incY); +void catlas_cset + (const int N, const void *alpha, void *X, const int incX); + +void cblas_zswap(const int N, void *X, const int incX, + void *Y, const int incY); +void cblas_zcopy(const int N, const void *X, const int incX, + void *Y, const int incY); +void cblas_zaxpy(const int N, const void *alpha, const void *X, + const int incX, void *Y, const int incY); +void catlas_zaxpby(const int N, const void *alpha, const void *X, + const int incX, const void *beta, void *Y, const int incY); +void catlas_zset + (const int N, const void *alpha, void *X, const int incX); + + +/* + * Routines with S and D prefix only + */ +void cblas_srotg(float *a, float *b, float *c, float *s); +void cblas_srotmg(float *d1, float *d2, float *b1, const float b2, float *P); +void cblas_srot(const int N, float *X, const int incX, + float *Y, const int incY, const float c, const float s); +void cblas_srotm(const int N, float *X, const int incX, + float *Y, const int incY, const float *P); + +void cblas_drotg(double *a, double *b, double *c, double *s); +void cblas_drotmg(double *d1, double *d2, double *b1, const double b2, double *P); +void cblas_drot(const int N, double *X, const int incX, + double *Y, const int incY, const double c, const double s); +void cblas_drotm(const int N, double *X, const int incX, + double *Y, const int incY, const double *P); + + +/* + * Routines with S D C Z CS and ZD prefixes + */ +void cblas_sscal(const int N, const float alpha, float *X, const int incX); +void cblas_dscal(const int N, const double alpha, double *X, const int incX); +void cblas_cscal(const int N, const void *alpha, void *X, const int incX); +void cblas_zscal(const int N, const void *alpha, void *X, const int incX); +void cblas_csscal(const int N, const float alpha, void *X, const int incX); +void cblas_zdscal(const int N, const double alpha, void *X, const int incX); + +/* + * Extra reference routines provided by ATLAS, but not mandated by the standard + */ +void cblas_crotg(void *a, void *b, void *c, void *s); +void cblas_zrotg(void *a, void *b, void *c, void *s); +void cblas_csrot(const int N, void *X, const int incX, void *Y, const int incY, + const float c, const float s); +void cblas_zdrot(const int N, void *X, const int incX, void *Y, const int incY, + const double c, const double s); + +/* + * =========================================================================== + * Prototypes for level 2 BLAS + * =========================================================================== + */ + +/* + * Routines with standard 4 prefixes (S, D, C, Z) + */ +void cblas_sgemv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const float alpha, const float *A, const int lda, + const float *X, const int incX, const float beta, + float *Y, const int incY); +void cblas_sgbmv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const int KL, const int KU, const float alpha, + const float *A, const int lda, const float *X, + const int incX, const float beta, float *Y, const int incY); +void cblas_strmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const float *A, const int lda, + float *X, const int incX); +void cblas_stbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const float *A, const int lda, + float *X, const int incX); +void cblas_stpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const float *Ap, float *X, const int incX); +void cblas_strsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const float *A, const int lda, float *X, + const int incX); +void cblas_stbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const float *A, const int lda, + float *X, const int incX); +void cblas_stpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const float *Ap, float *X, const int incX); + +void cblas_dgemv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const double alpha, const double *A, const int lda, + const double *X, const int incX, const double beta, + double *Y, const int incY); +void cblas_dgbmv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const int KL, const int KU, const double alpha, + const double *A, const int lda, const double *X, + const int incX, const double beta, double *Y, const int incY); +void cblas_dtrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const double *A, const int lda, + double *X, const int incX); +void cblas_dtbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const double *A, const int lda, + double *X, const int incX); +void cblas_dtpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const double *Ap, double *X, const int incX); +void cblas_dtrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const double *A, const int lda, double *X, + const int incX); +void cblas_dtbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const double *A, const int lda, + double *X, const int incX); +void cblas_dtpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const double *Ap, double *X, const int incX); + +void cblas_cgemv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *X, const int incX, const void *beta, + void *Y, const int incY); +void cblas_cgbmv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const int KL, const int KU, const void *alpha, + const void *A, const int lda, const void *X, + const int incX, const void *beta, void *Y, const int incY); +void cblas_ctrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *A, const int lda, + void *X, const int incX); +void cblas_ctbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const void *A, const int lda, + void *X, const int incX); +void cblas_ctpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *Ap, void *X, const int incX); +void cblas_ctrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *A, const int lda, void *X, + const int incX); +void cblas_ctbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const void *A, const int lda, + void *X, const int incX); +void cblas_ctpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *Ap, void *X, const int incX); + +void cblas_zgemv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *X, const int incX, const void *beta, + void *Y, const int incY); +void cblas_zgbmv(const enum CBLAS_ORDER Order, + const enum CBLAS_TRANSPOSE TransA, const int M, const int N, + const int KL, const int KU, const void *alpha, + const void *A, const int lda, const void *X, + const int incX, const void *beta, void *Y, const int incY); +void cblas_ztrmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *A, const int lda, + void *X, const int incX); +void cblas_ztbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const void *A, const int lda, + void *X, const int incX); +void cblas_ztpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *Ap, void *X, const int incX); +void cblas_ztrsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *A, const int lda, void *X, + const int incX); +void cblas_ztbsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const int K, const void *A, const int lda, + void *X, const int incX); +void cblas_ztpsv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE TransA, const enum CBLAS_DIAG Diag, + const int N, const void *Ap, void *X, const int incX); + + +/* + * Routines with S and D prefixes only + */ +void cblas_ssymv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *A, + const int lda, const float *X, const int incX, + const float beta, float *Y, const int incY); +void cblas_ssbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int K, const float alpha, const float *A, + const int lda, const float *X, const int incX, + const float beta, float *Y, const int incY); +void cblas_sspmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *Ap, + const float *X, const int incX, + const float beta, float *Y, const int incY); +void cblas_sger(const enum CBLAS_ORDER Order, const int M, const int N, + const float alpha, const float *X, const int incX, + const float *Y, const int incY, float *A, const int lda); +void cblas_ssyr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *X, + const int incX, float *A, const int lda); +void cblas_sspr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *X, + const int incX, float *Ap); +void cblas_ssyr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *X, + const int incX, const float *Y, const int incY, float *A, + const int lda); +void cblas_sspr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const float *X, + const int incX, const float *Y, const int incY, float *A); + +void cblas_dsymv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *A, + const int lda, const double *X, const int incX, + const double beta, double *Y, const int incY); +void cblas_dsbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int K, const double alpha, const double *A, + const int lda, const double *X, const int incX, + const double beta, double *Y, const int incY); +void cblas_dspmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *Ap, + const double *X, const int incX, + const double beta, double *Y, const int incY); +void cblas_dger(const enum CBLAS_ORDER Order, const int M, const int N, + const double alpha, const double *X, const int incX, + const double *Y, const int incY, double *A, const int lda); +void cblas_dsyr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *X, + const int incX, double *A, const int lda); +void cblas_dspr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *X, + const int incX, double *Ap); +void cblas_dsyr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *X, + const int incX, const double *Y, const int incY, double *A, + const int lda); +void cblas_dspr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const double *X, + const int incX, const double *Y, const int incY, double *A); + + +/* + * Routines with C and Z prefixes only + */ +void cblas_chemv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const void *alpha, const void *A, + const int lda, const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_chbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int K, const void *alpha, const void *A, + const int lda, const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_chpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const void *alpha, const void *Ap, + const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_cgeru(const enum CBLAS_ORDER Order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_cgerc(const enum CBLAS_ORDER Order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_cher(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const void *X, const int incX, + void *A, const int lda); +void cblas_chpr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const float alpha, const void *X, + const int incX, void *A); +void cblas_cher2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_chpr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *Ap); + +void cblas_zhemv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const void *alpha, const void *A, + const int lda, const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_zhbmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const int K, const void *alpha, const void *A, + const int lda, const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_zhpmv(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const void *alpha, const void *Ap, + const void *X, const int incX, + const void *beta, void *Y, const int incY); +void cblas_zgeru(const enum CBLAS_ORDER Order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_zgerc(const enum CBLAS_ORDER Order, const int M, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_zher(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const void *X, const int incX, + void *A, const int lda); +void cblas_zhpr(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const int N, const double alpha, const void *X, + const int incX, void *A); +void cblas_zher2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *A, const int lda); +void cblas_zhpr2(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, const int N, + const void *alpha, const void *X, const int incX, + const void *Y, const int incY, void *Ap); + +/* + * =========================================================================== + * Prototypes for level 3 BLAS + * =========================================================================== + */ + +/* + * Routines with standard 4 prefixes (S, D, C, Z) + */ +void cblas_sgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const float alpha, const float *A, + const int lda, const float *B, const int ldb, + const float beta, float *C, const int ldc); +void cblas_ssymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const float alpha, const float *A, const int lda, + const float *B, const int ldb, const float beta, + float *C, const int ldc); +void cblas_ssyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const float alpha, const float *A, const int lda, + const float beta, float *C, const int ldc); +void cblas_ssyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const float alpha, const float *A, const int lda, + const float *B, const int ldb, const float beta, + float *C, const int ldc); +void cblas_strmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const float alpha, const float *A, const int lda, + float *B, const int ldb); +void cblas_strsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const float alpha, const float *A, const int lda, + float *B, const int ldb); + +void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const double alpha, const double *A, + const int lda, const double *B, const int ldb, + const double beta, double *C, const int ldc); +void cblas_dsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const double alpha, const double *A, const int lda, + const double *B, const int ldb, const double beta, + double *C, const int ldc); +void cblas_dsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const double alpha, const double *A, const int lda, + const double beta, double *C, const int ldc); +void cblas_dsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const double alpha, const double *A, const int lda, + const double *B, const int ldb, const double beta, + double *C, const int ldc); +void cblas_dtrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const double alpha, const double *A, const int lda, + double *B, const int ldb); +void cblas_dtrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const double alpha, const double *A, const int lda, + double *B, const int ldb); + +void cblas_cgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const void *alpha, const void *A, + const int lda, const void *B, const int ldb, + const void *beta, void *C, const int ldc); +void cblas_csymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_csyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *beta, void *C, const int ldc); +void cblas_csyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_ctrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const void *alpha, const void *A, const int lda, + void *B, const int ldb); +void cblas_ctrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const void *alpha, const void *A, const int lda, + void *B, const int ldb); + +void cblas_zgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_TRANSPOSE TransB, const int M, const int N, + const int K, const void *alpha, const void *A, + const int lda, const void *B, const int ldb, + const void *beta, void *C, const int ldc); +void cblas_zsymm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_zsyrk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *beta, void *C, const int ldc); +void cblas_zsyr2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_ztrmm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const void *alpha, const void *A, const int lda, + void *B, const int ldb); +void cblas_ztrsm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const enum CBLAS_TRANSPOSE TransA, + const enum CBLAS_DIAG Diag, const int M, const int N, + const void *alpha, const void *A, const int lda, + void *B, const int ldb); + + +/* + * Routines with prefixes C and Z only + */ +void cblas_chemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_cherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const float alpha, const void *A, const int lda, + const float beta, void *C, const int ldc); +void cblas_cher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const float beta, + void *C, const int ldc); +void cblas_zhemm(const enum CBLAS_ORDER Order, const enum CBLAS_SIDE Side, + const enum CBLAS_UPLO Uplo, const int M, const int N, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const void *beta, + void *C, const int ldc); +void cblas_zherk(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const double alpha, const void *A, const int lda, + const double beta, void *C, const int ldc); +void cblas_zher2k(const enum CBLAS_ORDER Order, const enum CBLAS_UPLO Uplo, + const enum CBLAS_TRANSPOSE Trans, const int N, const int K, + const void *alpha, const void *A, const int lda, + const void *B, const int ldb, const double beta, + void *C, const int ldc); + +int cblas_errprn(int ierr, int info, char *form, ...); + +#endif /* end #ifdef CBLAS_ENUM_ONLY */ +#endif \ No newline at end of file diff --git a/caffe2/utils/math.h b/caffe2/utils/math.h new file mode 100644 index 00000000000..b5a03e232fd --- /dev/null +++ b/caffe2/utils/math.h @@ -0,0 +1,135 @@ +#ifndef CAFFE2_UTILS_MATH_H_ +#define CAFFE2_UTILS_MATH_H_ +// This is a simple translation from the old Caffe math interfaces. We aim to +// still keep it simple, so all platforms would be able to support it fairly +// easily. + +extern "C" { +#include "caffe2/utils/cblas.h" +} + +#include "caffe2/core/common.h" +#include "caffe2/core/types.h" + +namespace caffe2 { + +namespace math { + +template +void Exp(const int N, const T* x, T* y, DeviceContext* context); +template +void Log(const int N, const T* x, T* y, DeviceContext* context); +template +void Sqr(const int N, const T* x, T* y, DeviceContext* context); + +template +void Powx(const int N, const T* a, const T b, T* y, DeviceContext* context); + + +template +void Add(const int N, const T* a, const T* b, T* y, DeviceContext* context); +template +void Sub(const int N, const T* a, const T* b, T* y, DeviceContext* context); +template +void Mul(const int N, const T* a, const T* b, T* y, DeviceContext* context); +template +void Div(const int N, const T* a, const T* b, T* y, DeviceContext* context); + + +// Compute the row-wise max of a N*D matrix X, and write it to a N +// dimensional vector y. +template +void RowwiseMax(const int N, const int D, const T* x, T* y, + DeviceContext* context); + +// Compute the column-wise max of a N*D matrix X, and write it to a D +// dimensional vector y. +template +void ColwiseMax(const int N, const int D, const T* x, T* y, + DeviceContext* context); + +// AddToRow and AddToCol adds the corresponding row/col vector x to the matrix y +// of shape MxN. +template +void AddToRow(const int M, const int N, const T* x, T* y, + DeviceContext* context); +template +void AddToCol(const int M, const int N, const T* x, T* y, + DeviceContext* context); + +// Decaf gemm provides a simpler interface to the gemm functions, with the +// limitation that the data has to be contiguous in memory. +template +void Gemm(const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, + const int M, const int N, const int K, const T* alpha, const T* A, + const T* B, const T* beta, T* C, DeviceContext* context); + +// Gemv always takes in a M*N matrix A, and depending on whether we set TransA +// to Trans, the output is: +// CblasNoTrans: x is an N dim vector and y is an M dim vector. +// CblasTrans: x is an M dim vector and y is an N dim vector. +template +void Gemv(const CBLAS_TRANSPOSE TransA, const int M, const int N, + const T* alpha, const T* A, const T* x, const T* beta, + T* y, DeviceContext* context); + +template +void Set(const int N, const T alpha, T* X, DeviceContext* context); + +template +void RandUniform(const int n, const T a, const T b, T* r, + DeviceContext* context); + +template +void RandGaussian(const int n, const T mean, const T std, T* r, + DeviceContext* context); + +// Dot matrix of vector a and b, and writes the result to a single value y. +template +void Dot(const int N, const T* a, const T* b, T* y, DeviceContext* context); + +// Sum of vector x, and writes the result to a single value y. +template +void Sum(const int N, const T* x, T* y, DeviceContext* context); + +// Select does index selection of the rows a N*D matrix x, and gives the N +// dimensional vector y that contains the selected data. +template +void Select(const int N, const int D, const T* x, const int* idx, T* y, + DeviceContext* context); + +template +void Scale(const int N, const T* alpha, const T* x, T* y, + DeviceContext* context); + +template +void Axpy(const int N, const T* alpha, const T* x, T* y, + DeviceContext* context); + +template +void Axpby(const int N, const T* alpha, const T* x, const T* b, T* y, + DeviceContext* context); + +template +void Im2col(const T* data_im, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_t, const int pad_l, const int pad_b, const int pad_r, + const int stride_h, const int stride_w, T* data_col, + DeviceContext* context); + +template +void Col2im(const T* data_col, const int channels, + const int height, const int width, const int patch_h, const int patch_w, + const int pad_t, const int pad_l, const int pad_b, const int pad_r, + const int stride_h, const int stride_w, T* data_im, + DeviceContext* context); + +template +void CopyMatrix(const int M, const int N, const T* A, const int lda, + T* B, const int ldb, DeviceContext* context); + +} // namespace math +} // namespace caffe2 + + +#endif // CAFFE2_UTILS_MATH_H_ diff --git a/caffe2/utils/math_cpu.cc b/caffe2/utils/math_cpu.cc new file mode 100644 index 00000000000..268fc31eea0 --- /dev/null +++ b/caffe2/utils/math_cpu.cc @@ -0,0 +1,430 @@ +// Implementes the math functions for CPU. +#include + +#include "caffe2/utils/math.h" +#include "caffe2/utils/mkl_alternate.h" +#include "caffe2/core/context.h" +#include "eigen3/Eigen/Core" +#include "eigen3/Eigen/Dense" + +namespace { +// Common Eigen types that we will often use +template +using EigenMatrixMap = + Eigen::Map >; +template +using EigenVectorMap = Eigen::Map >; +template +using ConstEigenMatrixMap = + Eigen::Map >; +template +using ConstEigenVectorMap = + Eigen::Map >; +} // namespace + +namespace caffe2 { +namespace math { + +#define DELEGATE_SIMPLE_UNARY_FUNCTION(dtype, Funcname, OriginalFunc) \ +template <> \ +void Funcname( \ + const int N, const dtype* x, dtype* y, \ + CPUContext* context) { \ + OriginalFunc(N, x, y); \ +} +DELEGATE_SIMPLE_UNARY_FUNCTION(float, Exp, vsExp) +DELEGATE_SIMPLE_UNARY_FUNCTION(double, Exp, vdExp) +DELEGATE_SIMPLE_UNARY_FUNCTION(float, Log, vsLn) +DELEGATE_SIMPLE_UNARY_FUNCTION(double, Log, vdLn) +DELEGATE_SIMPLE_UNARY_FUNCTION(float, Sqr, vsSqr) +DELEGATE_SIMPLE_UNARY_FUNCTION(double, Sqr, vdSqr) + +template <> +void Powx( + const int N, const float* a, float b, float* y, CPUContext* context) { + vsPowx(N, a, b, y); +} + +template <> +void Powx( + const int N, const double* a, double b, double* y, CPUContext* context) { + vdPowx(N, a, b, y); +} + + +#define DELEGATE_SIMPLE_BINARY_FUNCTION(dtype, Funcname, OriginalFunc) \ +template <> \ +void Funcname( \ + const int N, const dtype* a, const dtype* b, dtype* y, \ + CPUContext* context) { \ + OriginalFunc(N, a, b, y); \ +} + +DELEGATE_SIMPLE_BINARY_FUNCTION(float, Add, vsAdd) +DELEGATE_SIMPLE_BINARY_FUNCTION(double, Add, vdAdd) +DELEGATE_SIMPLE_BINARY_FUNCTION(float, Sub, vsSub) +DELEGATE_SIMPLE_BINARY_FUNCTION(double, Sub, vdSub) +DELEGATE_SIMPLE_BINARY_FUNCTION(float, Mul, vsMul) +DELEGATE_SIMPLE_BINARY_FUNCTION(double, Mul, vdMul) +DELEGATE_SIMPLE_BINARY_FUNCTION(float, Div, vsDiv) +DELEGATE_SIMPLE_BINARY_FUNCTION(double, Div, vdDiv) +#undef DELEGATE_SIMPLE_BINARY_FUNCTION + +#define CAFFE2_SPECIALIZED_ROWWISEMAX(T) \ +template <> void RowwiseMax( \ + const int N, const int D, const T* x, T* y, CPUContext* context) { \ + EigenVectorMap(y, N) = \ + ConstEigenMatrixMap(x, D, N).colwise().maxCoeff(); \ +} +CAFFE2_SPECIALIZED_ROWWISEMAX(float) + +#define CAFFE2_SPECIALIZED_COLWISEMAX(T) \ +template <> void ColwiseMax( \ + const int N, const int D, const T* x, T* y, CPUContext* context) { \ + EigenVectorMap(y, D) = \ + ConstEigenMatrixMap(x, D, N).rowwise().maxCoeff(); \ +} +CAFFE2_SPECIALIZED_COLWISEMAX(float) + +// AddToRow and AddToCol adds the corresponding row/col vector x to the matrix y +// of shape M x N. The actual implementation uses eigen which is column major, +// so notice the row/column swap in the actual implementation. +template <> +void AddToRow( + const int M, const int N, const float* x, float* y, CPUContext* context) { + EigenMatrixMap(y, N, M).colwise() += ConstEigenVectorMap(x, N); +} +template <> +void AddToCol( + const int M, const int N, const float* x, float* y, CPUContext* context) { + EigenMatrixMap(y, N, M).rowwise() += + ConstEigenVectorMap(x, M).transpose(); +} + + +// Caffe2 gemm provides a simpler interface to the gemm functions, with the +// limitation that the data has to be contiguous in memory. +// A (M*K) * B(K*N) = C(M*N) +template <> +void Gemm( + const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, + const int M, const int N, const int K, const float* alpha, const float* A, + const float* B, const float* beta, float* C, CPUContext* context) { + auto C_mat = EigenMatrixMap(C, N, M); + if (*beta == 0) { + C_mat.setZero(); + } else { + C_mat *= (*beta); + } + switch (TransA) { + case CblasNoTrans: { + switch (TransB) { + case CblasNoTrans: + C_mat.noalias() += (*alpha) * ( + ConstEigenMatrixMap(B, N, K) * + ConstEigenMatrixMap(A, K, M)); + return; + case CblasTrans: + C_mat.noalias() += (*alpha) * ( + ConstEigenMatrixMap(B, K, N).transpose() * + ConstEigenMatrixMap(A, K, M)); + return; + default: + LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for TransB"; + } + } + case CblasTrans: { + switch (TransB) { + case CblasNoTrans: + C_mat.noalias() += (*alpha) * ( + ConstEigenMatrixMap(B, N, K) * + ConstEigenMatrixMap(A, M, K).transpose()); + return; + case CblasTrans: + C_mat.noalias() += (*alpha) * ( + ConstEigenMatrixMap(B, K, N).transpose() * + ConstEigenMatrixMap(A, M, K).transpose()); + return; + default: + LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for TransB"; + } + } + default: + LOG(FATAL) << "Unexpected CBLAS_TRANSPOSE for TransA"; + } +} + +template <> +void Gemv( + const CBLAS_TRANSPOSE TransA, const int M, const int N, const float* alpha, + const float* A, const float* x, const float* beta, float* y, + CPUContext* context) { + EigenVectorMap y_vec(y, TransA == CblasNoTrans ? M : N); + if (*beta == 0) { + // In Caffe2 we often do a lazy initialization, which may contain NaNs in + // the float values. As a result, if beta is 0, we explicitly do a setzero. + y_vec.setZero(); + } else { + y_vec *= (*beta); + } + switch (TransA) { + case CblasNoTrans: { + y_vec.noalias() += (*alpha) * ( + ConstEigenMatrixMap(A, N, M).transpose() * + ConstEigenVectorMap(x, N)); + return; + } + case CblasTrans: { + y_vec.noalias() += (*alpha) * ( + ConstEigenMatrixMap(A, N, M) * + ConstEigenVectorMap(x, M)); + return; + } + default: + LOG(FATAL) << "Gemv float found an unexpected CBLAS_TRANSPOSE input."; + } +} + +#define CAFFE2_SPECIALIZED_SET(dtype) \ +template <> \ +void Set(const int N, const dtype alpha, dtype *Y, \ + CPUContext* context) { \ + EigenVectorMap(Y, N).setConstant(alpha); \ +} + +CAFFE2_SPECIALIZED_SET(float); +CAFFE2_SPECIALIZED_SET(double); +CAFFE2_SPECIALIZED_SET(int); +#undef CAFFE2_SPECIALIZED_SET + +template <> +void RandUniform( + const int n, const float a, const float b, float* r, + CPUContext* context) { + std::uniform_real_distribution distribution(a, b); + for (int i = 0; i < n; ++i) { + r[i] = distribution(context->RandGenerator()); + } +} + +template <> +void RandGaussian( + const int n, const float mean, const float std, float* r, + CPUContext* context) { + std::normal_distribution distribution(mean, std); + for (int i = 0; i < n; ++i) { + r[i] = distribution(context->RandGenerator()); + } +} + +template<> +void Dot( + const int N, const float* a, const float* b, float* y, + CPUContext* context) { + *y = ConstEigenVectorMap(a, N).dot(ConstEigenVectorMap(b, N)); +} + +template<> +void Dot( + const int N, const double* a, const double* b, double* y, + CPUContext* context) { + *y = ConstEigenVectorMap(a, N).dot(ConstEigenVectorMap(b, N)); +} + +template<> +void Sum( + const int N, const float* x, float* y, + CPUContext* context) { + *y = ConstEigenVectorMap(x, N).sum(); +} + +template<> +void Sum( + const int N, const double* x, double* y, + CPUContext* context) { + *y = ConstEigenVectorMap(x, N).sum(); +} + +template <> +void Select( + const int N, const int D, const float* x, const int* idx, float* y, + CPUContext* context) { + for (int i = 0; i < N; ++i) { + DCHECK_LT(idx[i], D); + y[i] = x[i * D + idx[i]]; + } +} + +template <> +void Scale( + const int n, const float* alpha, const float *x, float* y, + CPUContext* context) { + EigenVectorMap(y, n) = ConstEigenVectorMap(x, n) * (*alpha); +} + +template <> +void Scale( + const int n, const double* alpha, const double *x, double* y, + CPUContext* context) { + EigenVectorMap(y, n) = ConstEigenVectorMap(x, n) * (*alpha); +} + +template <> +void Axpy(const int N, const float* alpha, const float* x, + float* Y, CPUContext* context) { + EigenVectorMap(Y, N) += ConstEigenVectorMap(x, N) * (*alpha); +} + +template <> +void Axpy(const int N, const double* alpha, const double* x, + double* Y, CPUContext* context) { + EigenVectorMap(Y, N) += ConstEigenVectorMap(x, N) * (*alpha); +} + +template <> +void Axpby(const int N, const float* alpha, const float* x, + const float* beta, float* y, + CPUContext* context) { + EigenVectorMap y_vec(y, N); + y_vec = y_vec * (*beta) + ConstEigenVectorMap(x, N) * (*alpha); +} + +template <> +void Axpby(const int N, const double* alpha, + const double* x, const double* beta, double* y, + CPUContext* context) { + EigenVectorMap y_vec(y, N); + y_vec = y_vec * (*beta) + ConstEigenVectorMap(x, N) * (*alpha); +} + +template <> +void Im2col( + const float* data_im, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_t, const int pad_l, const int pad_b, const int pad_r, + const int stride_h, + const int stride_w, float* data_col, CPUContext* context) { + int height_col = (height + pad_t + pad_b - kernel_h) / stride_h + 1; + int width_col = (width + pad_l + pad_r - kernel_w) / stride_w + 1; + int channels_col = channels * kernel_h * kernel_w; + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % kernel_w; + int h_offset = (c / kernel_w) % kernel_h; + int c_im = c / kernel_h / kernel_w; + for (int h = 0; h < height_col; ++h) { + for (int w = 0; w < width_col; ++w) { + int h_pad = h * stride_h - pad_t + h_offset; + int w_pad = w * stride_w - pad_l + w_offset; + if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width) + data_col[(c * height_col + h) * width_col + w] = + data_im[(c_im * height + h_pad) * width + w_pad]; + else + data_col[(c * height_col + h) * width_col + w] = 0; + } + } + } +} + +template <> +void Im2col( + const float* data_im, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_t, const int pad_l, const int pad_b, const int pad_r, + const int stride_h, const int stride_w, float* data_col, + CPUContext* context) { + int height_col = (height + pad_t + pad_b - kernel_h) / stride_h + 1; + int width_col = (width + pad_l + pad_r - kernel_w) / stride_w + 1; + + int h_pad = -pad_t; + for (int h = 0; h < height_col; ++h) { + int w_pad = -pad_l; + for (int w = 0; w < width_col; ++w) { + for (int ih = h_pad; ih < h_pad + kernel_h; ++ih) { + for (int iw = w_pad; iw < w_pad + kernel_w; ++iw) { + if (ih >= 0 && ih < height && iw >= 0 && iw < width) { + memcpy(data_col, data_im + (ih * width + iw) * channels, + sizeof(float) * channels); + } else { + // This should be simply padded with zero. + memset(data_col, 0, sizeof(float) * channels); + } + data_col += channels; + } + } + w_pad += stride_w; + } + h_pad += stride_h; + } +} + +template <> +void Col2im( + const float* data_col, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_t, const int pad_l, const int pad_b, const int pad_r, + const int stride_h, + const int stride_w, float* data_im, CPUContext* context) { + Set(height * width * channels, 0, data_im, context); + int height_col = (height + pad_t + pad_b - kernel_h) / stride_h + 1; + int width_col = (width + pad_l + pad_r - kernel_w) / stride_w + 1; + int channels_col = channels * kernel_h * kernel_w; + for (int c = 0; c < channels_col; ++c) { + int w_offset = c % kernel_w; + int h_offset = (c / kernel_w) % kernel_h; + int c_im = c / kernel_h / kernel_w; + for (int h = 0; h < height_col; ++h) { + for (int w = 0; w < width_col; ++w) { + int h_pad = h * stride_h - pad_t + h_offset; + int w_pad = w * stride_w - pad_l + w_offset; + if (h_pad >= 0 && h_pad < height && w_pad >= 0 && w_pad < width) + data_im[(c_im * height + h_pad) * width + w_pad] += + data_col[(c * height_col + h) * width_col + w]; + } + } + } +} + +template <> +void Col2im( + const float* data_col, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_t, const int pad_l, const int pad_b, const int pad_r, + const int stride_h, + const int stride_w, float* data_im, CPUContext* context) { + Set(height * width * channels, 0, data_im, context); + int height_col = (height + pad_t + pad_b - kernel_h) / stride_h + 1; + int width_col = (width + pad_l + pad_r - kernel_w) / stride_w + 1; + int h_pad = -pad_t; + for (int h = 0; h < height_col; ++h) { + int w_pad = -pad_l; + for (int w = 0; w < width_col; ++w) { + float* data_im_patch = data_im + (h_pad * width + w_pad) * channels; + for (int ih = h_pad; ih < h_pad + kernel_h; ++ih) { + for (int iw = w_pad; iw < w_pad + kernel_w; ++iw) { + if (ih >= 0 && ih < height && iw >= 0 && iw < width) { + Add( + channels, data_im_patch, data_col, data_im_patch, context); + } + data_im_patch += channels; + data_col += channels; + } + // Jump over remaining number of channels + data_im_patch += channels * (width - kernel_w); + } + w_pad += stride_w; + } + h_pad += stride_h; + } +} + +template <> +void CopyMatrix( + const int M, const int N, const float* A, const int lda, float* B, + const int ldb, CPUContext* context) { + for (int i = 0; i < M; ++i) { + memcpy(B + ldb * i, A + lda * i, sizeof(float) * N); + } +} + +} // namespace math +} // namespace caffe2 diff --git a/caffe2/utils/math_gpu.cu b/caffe2/utils/math_gpu.cu new file mode 100644 index 00000000000..82c5f5281ab --- /dev/null +++ b/caffe2/utils/math_gpu.cu @@ -0,0 +1,576 @@ +// Implementes the math functions for CPU. +#include "caffe2/utils/math.h" +#include "caffe2/core/context_gpu.h" + +namespace caffe2 { +namespace math { + +// TODO(Yangqing): Yuck again. Maybe change it to templated functors? +#define DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(dtype, Funcname, function) \ +__global__ \ +void _Kernel_##dtype##_##Funcname(const int N, const dtype* x, dtype* y) { \ + CUDA_1D_KERNEL_LOOP(i, N) { \ + y[i] = function(x[i]); \ + } \ +} \ +template <> \ +void Funcname( \ + const int N, const dtype* x, dtype* y, \ + CUDAContext* context) { \ + _Kernel_##dtype##_##Funcname<<cuda_stream()>>>( \ + N, x, y); \ +} + +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Exp, expf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Exp, exp) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Log, logf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Log, log) + +__device__ float cuda_sqrf(const float x) { return x * x; } +__device__ double cuda_sqr(const double x) { return x * x; } + +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(float, Sqr, cuda_sqrf) +DELEGATE_SIMPLE_CUDA_UNARY_FUNCTION(double, Sqr, cuda_sqr) + +#define DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(dtype, Funcname, function) \ +__global__ \ +void _Kernel_##dtype##_##Funcname( \ + const int N, const dtype* a, const dtype* b, dtype* y) { \ + CUDA_1D_KERNEL_LOOP(i, N) { \ + y[i] = function(a[i], b[i]); \ + } \ +} \ +template <> \ +void Funcname( \ + const int N, const dtype* a, const dtype* b, dtype* y, \ + CUDAContext* context) { \ + _Kernel_##dtype##_##Funcname<<cuda_stream()>>>( \ + N, a, b, y); \ +} + + +#define CAFFE_MATH_CUDA_ADD(x, y) (x + y) +#define CAFFE_MATH_CUDA_SUB(x, y) (x - y) +#define CAFFE_MATH_CUDA_MUL(x, y) (x * y) +#define CAFFE_MATH_CUDA_DIV(x, y) (x / y) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Add, CAFFE_MATH_CUDA_ADD) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Add, CAFFE_MATH_CUDA_ADD) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Sub, CAFFE_MATH_CUDA_SUB) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Sub, CAFFE_MATH_CUDA_SUB) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Mul, CAFFE_MATH_CUDA_MUL) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Mul, CAFFE_MATH_CUDA_MUL) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(float, Div, CAFFE_MATH_CUDA_DIV) +DELEGATE_SIMPLE_CUDA_BINARY_FUNCTION(double, Div, CAFFE_MATH_CUDA_DIV) + + +/* +#define CAFFE2_SPECIALIZED_ROWWISEMAX(T) \ +template <> \ +void RowwiseMax( \ + const int N, const int D, const T* x, T* y, CPUContext* context) { \ + for (int i = 0; i < N; ++i) { \ + y[i] = x[i*D]; \ + for (int j = 1; j < D; ++j) { \ + y[i] = std::max(y[i], x[i * D + j]); \ + } \ + } \ +} +CAFFE2_SPECIALIZED_ROWWISEMAX(float) + +#define CAFFE2_SPECIALIZED_COLWISEMAX(T) \ +template <> \ +void ColwiseMax( \ + const int N, const int D, const T* x, T* y, CPUContext* context) { \ + memcpy(y, x, sizeof(T) * D); \ + for (int i = 1; i < N; ++i) { \ + for (int j = 0; j < D; ++j) { \ + y[j] = std::max(y[j], x[i * D + j]); \ + } \ + } \ +} +CAFFE2_SPECIALIZED_COLWISEMAX(float) +*/ + +namespace { +template +__global__ void AddToRowKernel(const int M, const int N, const dtype* x, + dtype* y) { + CUDA_1D_KERNEL_LOOP(i, M * N) { + y[i] += x[i % N]; + } +} +template +__global__ void AddToColKernel(const int M, const int N, const dtype* x, + dtype* y) { + CUDA_1D_KERNEL_LOOP(i, M * N) { + y[i] += x[i % M]; + } +} +} // namespace + +template <> +void AddToRow( + const int M, const int N, const float* x, float* y, CUDAContext* context) { + AddToRowKernel<<cuda_stream()>>>(M, N, x, y); +} +template <> +void AddToCol( + const int M, const int N, const float* x, float* y, CUDAContext* context) { + AddToColKernel<<cuda_stream()>>>(M, N, x, y); +} + +// Caffe2 gemm provides a simpler interface to the gemm functions, with the +// limitation that the data has to be contiguous in memory. +template <> +void Gemm( + const CBLAS_TRANSPOSE TransA, const CBLAS_TRANSPOSE TransB, + const int M, const int N, const int K, const float* alpha, const float* A, + const float* B, const float* beta, float* C, CUDAContext* context) { + // Note that cublas follows fortran order, so the order is different from + // the cblas convention. + int lda = (TransA == CblasNoTrans) ? K : M; + int ldb = (TransB == CblasNoTrans) ? N : K; + cublasOperation_t cuTransA = + (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + cublasOperation_t cuTransB = + (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T; + CUBLAS_CHECK(cublasSgemm(context->cublas_handle(), cuTransB, cuTransA, + N, M, K, alpha, B, ldb, A, lda, beta, C, N)); +} + + +template <> +void Gemv( + const CBLAS_TRANSPOSE TransA, const int M, const int N, const float* alpha, + const float* A, const float* x, const float* beta, float* y, + CUDAContext* context) { + cublasOperation_t cuTransA = + (TransA == CblasNoTrans) ? CUBLAS_OP_T : CUBLAS_OP_N; + CUBLAS_CHECK(cublasSgemv(context->cublas_handle(), cuTransA, N, M, alpha, + A, N, x, 1, beta, y, 1)); +} + + +namespace { +template +__global__ void SetKernel(const int N, const dtype alpha, dtype* Y) { + CUDA_1D_KERNEL_LOOP(i, N) { + Y[i] = alpha; + } +} +} // namespace + +#define CAFFE2_SPECIALIZED_CUDA_SET(dtype) \ + template <> \ + void Set(const int N, const dtype alpha, dtype *Y, \ + CUDAContext* context) { \ + SetKernel<<cuda_stream()>>>(N, alpha, Y); \ + } + +CAFFE2_SPECIALIZED_CUDA_SET(float); +CAFFE2_SPECIALIZED_CUDA_SET(double); +CAFFE2_SPECIALIZED_CUDA_SET(int); +#undef CAFFE2_SPECIALIZED_CUDA_SET + +namespace { +template +__global__ void UniformShift(const int N, const dtype min, const dtype max, + dtype* x) { + dtype scale = max - min; + CUDA_1D_KERNEL_LOOP(i, N) { + x[i] = x[i] * scale + min; + } +} +} // namespace + +template <> +void RandUniform( + const int n, const float min, const float max, float* r, + CUDAContext* context) { + CURAND_CHECK(curandGenerateUniform(context->curand_generator(), r, n)); + UniformShift<<cuda_stream()>>>(n, min, max, r); +} + +template <> +void RandUniform( + const int n, const double min, const double max, double* r, + CUDAContext* context) { + CURAND_CHECK(curandGenerateUniformDouble(context->curand_generator(), r, n)); + UniformShift<<cuda_stream()>>>(n, min, max, r); +} + +template <> +void RandGaussian( + const int n, const float mean, const float std, float* r, + CUDAContext* context) { + CURAND_CHECK(curandGenerateNormal( + context->curand_generator(), r, n, mean, std)); +} + +template <> +void RandGaussian( + const int n, const double mean, const double std, double* r, + CUDAContext* context) { + CURAND_CHECK(curandGenerateNormalDouble( + context->curand_generator(), r, n, mean, std)); +} + + +template<> +void Dot( + const int n, const float* a, const float* b, float* y, + CUDAContext* context) { + CUBLAS_CHECK(cublasSdot(context->cublas_handle(), n, a, 1, b, 1, y)); +} + +template<> +void Dot( + const int n, const double* a, const double* b, double* y, + CUDAContext* context) { + CUBLAS_CHECK(cublasDdot(context->cublas_handle(), n, a, 1, b, 1, y)); +} + +/* +template<> +void Sum( + const int N, const float* x, float* y, + CPUContext* context) { + *y = 0; + for (int i = 0; i < N; ++i) *y += x[i]; +} + +template<> +void Sum( + const int N, const double* x, double* y, + CPUContext* context) { + *y = 0; + for (int i = 0; i < N; ++i) *y += x[i]; +} +*/ + +namespace { +template +__global__ void SelectKernel( + const int N, const int D, const float* x, const int* idx, float* y) { + CUDA_1D_KERNEL_LOOP(i, N) { + y[i] = x[i * D + idx[i]]; + } +} +} // namespace + +template <> +void Select( + const int N, const int D, const float* x, const int* idx, float* y, + CUDAContext* context) { + SelectKernel<<cuda_stream()>>>(N, D, x, idx, y); +} + +namespace { +template +__global__ void ScaleKernel( + const int n, const dtype* alpha, const dtype* x, dtype* y) { + CUDA_1D_KERNEL_LOOP(i, n) { + y[i] = x[i] * (*alpha); + } +} +} // namespace + +template <> +void Scale( + const int n, const float* alpha, const float *x, float* y, + CUDAContext* context) { + ScaleKernel<<cuda_stream()>>>(n, alpha, x, y); +} + +template <> +void Scale( + const int n, const double* alpha, const double *x, double* y, + CUDAContext* context) { + ScaleKernel<<cuda_stream()>>>(n, alpha, x, y); +} + +template <> +void Axpy(const int N, const float* alpha, const float* X, + float* Y, CUDAContext* context) { + CUBLAS_CHECK(cublasSaxpy(context->cublas_handle(), N, alpha, X, 1, Y, 1)); +} + +template <> +void Axpy( + const int N, const double* alpha, const double* X, + double* Y, CUDAContext* context) { + CUBLAS_CHECK(cublasDaxpy(context->cublas_handle(), N, alpha, X, 1, Y, 1)); +} + +namespace { +template +__global__ void AxpbyKernel(const int n, const dtype* a, const dtype* x, + const dtype* b, dtype* y) { + CUDA_1D_KERNEL_LOOP(index, n) { + y[index] = x[index] * (*a) + y[index] * (*b); + } +} +} // namespace + +template <> +void Axpby( + const int n, const float* a, const float* x, const float* b, float* y, + CUDAContext* context) { + AxpbyKernel<<cuda_stream()>>>(n, a, x, b, y); +} + +template <> +void Axpby( + const int n, const double* a, const double* x, const double* b, double* y, + CUDAContext* context) { + AxpbyKernel<<cuda_stream()>>>(n, a, x, b, y); +} + +namespace { + +template +__global__ void im2col_gpu_kernel_nchw(const int n, const dtype* data_im, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_t, const int pad_l, + const int stride_h, const int stride_w, + const int height_col, const int width_col, + dtype* data_col) { + CUDA_1D_KERNEL_LOOP(index, n) { + int w_out = index % width_col; + int h_index = index / width_col; + int h_out = h_index % height_col; + int channel_in = h_index / height_col; + int channel_out = channel_in * kernel_h * kernel_w; + int h_in = h_out * stride_h - pad_t; + int w_in = w_out * stride_w - pad_l; + dtype* data_col_ptr = data_col; + data_col_ptr += (channel_out * height_col + h_out) * width_col + w_out; + const dtype* data_im_ptr = data_im; + data_im_ptr += (channel_in * height + h_in) * width + w_in; + for (int i = 0; i < kernel_h; ++i) { + for (int j = 0; j < kernel_w; ++j) { + int h = h_in + i; + int w = w_in + j; + *data_col_ptr = (h >= 0 && w >= 0 && h < height && w < width) ? + data_im_ptr[i * width + j] : 0; + data_col_ptr += height_col * width_col; + } + } + } +} + +template +__global__ void im2col_gpu_kernel_nhwc(const int n, const dtype* data_im, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_t, const int pad_l, + const int stride_h, const int stride_w, + const int width_col, const int channels, + dtype* data_col) { + CUDA_1D_KERNEL_LOOP(index, n) { + int channel_in = index % channels; + int w_out = index / channels % width_col; + int h_out = index / channels / width_col; + int h_in = h_out * stride_h - pad_t; + int w_in = w_out * stride_w - pad_l; + dtype* local_data_col = data_col + + ((h_out * width_col) + w_out) * channels * kernel_h * kernel_w + + channel_in; + for (int i = 0; i < kernel_h; ++i) { + int h = h_in + i; + for (int j = 0; j < kernel_w; ++j) { + int w = w_in + j; + *local_data_col = (h >= 0 && w >= 0 && h < height && w < width) ? + data_im[(h * width + w) * channels + channel_in] : 0; + local_data_col += channels; + } + } + } +} + +template +__global__ void col2im_gpu_kernel_nchw(const int n, const dtype* data_col, + const int height, const int width, + const int patch_h, const int patch_w, + const int pad_t, const int pad_l, + const int stride_h, const int stride_w, + const int height_col, const int width_col, + dtype* data_im) { + CUDA_1D_KERNEL_LOOP(index, n) { + dtype val = 0; + int w = index % width + pad_l; + int h = (index / width) % height + pad_t; + int c = index / (width * height); + // compute the start and end of the output + int w_col_start = (w < patch_w) ? 0 : (w - patch_w) / stride_w + 1; + int w_col_end = min(w / stride_w + 1, width_col); + int h_col_start = (h < patch_h) ? 0 : (h - patch_h) / stride_h + 1; + int h_col_end = min(h / stride_h + 1, height_col); + int offset = + (c * patch_h * patch_w + h * patch_w + w) * height_col * width_col; + int coeff_h_col = (1 - stride_h * patch_w * height_col) * width_col; + int coeff_w_col = (1 - stride_w * height_col * width_col); + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { + for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { + val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col]; + } + } + data_im[index] = val; + } +} + +template +__global__ void col2im_gpu_kernel_nhwc(const int n, const dtype* data_col, + const int width, const int channels, + const int patch_h, const int patch_w, + const int pad_t, const int pad_l, + const int stride_h, const int stride_w, + const int height_col, const int width_col, + dtype* data_im) { + CUDA_1D_KERNEL_LOOP(index, n) { + dtype val = 0; + int c = index % channels; + int w = index / channels % width + pad_l; + int h = index / channels / width + pad_t; + // compute the start and end of the output + int w_col_start = (w < patch_w) ? 0 : (w - patch_w) / stride_w + 1; + int w_col_end = min(w / stride_w + 1, width_col); + int h_col_start = (h < patch_h) ? 0 : (h - patch_h) / stride_h + 1; + int h_col_end = min(h / stride_h + 1, height_col); + int channels_col = patch_h * patch_w * channels; + /* + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { + for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { + int c_col = ((h - h_col * stride_h) * patch_w + w - w_col * stride_w) * channels + c; + val += data_col[(h_col * width_col + w_col) * channels_col + c_col]; + } + } + */ + // Equivalent of above + int offset = (h * patch_w + w) * channels + c; + int coeff_h_col = width_col * channels_col - stride_h * patch_w * channels; + int coeff_w_col = channels_col - stride_w * channels; + for (int h_col = h_col_start; h_col < h_col_end; ++h_col) { + for (int w_col = w_col_start; w_col < w_col_end; ++w_col) { + val += data_col[offset + h_col * coeff_h_col + w_col * coeff_w_col]; + } + } + data_im[index] = val; + } +} + +} // namespace + +template <> +void Im2col( + const float* data_im, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_t, const int pad_l, const int pad_b, const int pad_r, + const int stride_h, + const int stride_w, float* data_col, CUDAContext* context) { + // We are going to launch channels * height_col * width_col kernels, each + // kernel responsible for copying a single-channel grid. + int height_col = (height + pad_t + pad_b - kernel_h) / stride_h + 1; + int width_col = (width + pad_l + pad_r - kernel_w) / stride_w + 1; + int num_kernels = channels * height_col * width_col; + // NOLINT_NEXT_LINE(whitespace/operators) + im2col_gpu_kernel_nchw<<cuda_stream()>>>( + num_kernels, data_im, height, width, kernel_h, kernel_w, pad_t, + pad_l, stride_h, stride_w, height_col, width_col, data_col); +} + +template <> +void Im2col( + const float* data_im, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_t, const int pad_l, const int pad_b, const int pad_r, + const int stride_h, + const int stride_w, float* data_col, CUDAContext* context) { + // We are going to launch height_col * width_col * channels kernels, each + // kernel responsible for copying a single-channel grid. + int height_col = (height + pad_t + pad_b - kernel_h) / stride_h + 1; + int width_col = (width + pad_l + pad_r - kernel_w) / stride_w + 1; + int num_kernels = height_col * width_col * channels; + // NOLINT_NEXT_LINE(whitespace/operators) + im2col_gpu_kernel_nhwc<<cuda_stream()>>>( + num_kernels, data_im, height, width, kernel_h, kernel_w, pad_t, + pad_l, stride_h, stride_w, width_col, channels, data_col); +} + + +template <> +void Col2im( + const float* data_col, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_t, const int pad_l, const int pad_b, const int pad_r, + const int stride_h, + const int stride_w, float* data_im, CUDAContext* context) { + int height_col = (height + pad_t + pad_b - kernel_h) / stride_h + 1; + int width_col = (width + pad_l + pad_r - kernel_w) / stride_w + 1; + int num_kernels = channels * height * width; + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + col2im_gpu_kernel_nchw<<cuda_stream()>>>( + num_kernels, data_col, height, width, kernel_h, kernel_w, + pad_t, pad_l, stride_h, stride_w, + height_col, width_col, data_im); +} + +template <> +void Col2im( + const float* data_col, const int channels, + const int height, const int width, const int kernel_h, const int kernel_w, + const int pad_t, const int pad_l, const int pad_b, const int pad_r, + const int stride_h, + const int stride_w, float* data_im, CUDAContext* context) { + int height_col = (height + pad_t + pad_b - kernel_h) / stride_h + 1; + int width_col = (width + pad_l + pad_r - kernel_w) / stride_w + 1; + int num_kernels = height * width * channels; + // To avoid involving atomic operations, we will launch one kernel per + // bottom dimension, and then in the kernel add up the top dimensions. + col2im_gpu_kernel_nhwc<<cuda_stream()>>>( + num_kernels, data_col, width, channels, kernel_h, kernel_w, + pad_t, pad_l, stride_h, stride_w, height_col, width_col, data_im); +} + +namespace { +template +__global__ void CopyMatrixKernel( + const int M, const int N, const T* A, const int lda, T* B, const int ldb) { + CUDA_1D_KERNEL_LOOP(i, M * N) { + int r = i / N; + int c = i % N; + B[r * ldb + c] = A[r * lda + c]; + } +} +} // namespace + +template <> +void CopyMatrix( + const int M, const int N, const float* A, const int lda, float* B, + const int ldb, CUDAContext* context) { + CopyMatrixKernel<<cuda_stream()>>>(M, N, A, lda, B, ldb); +} + +} // namespace math +} // namespace caffe2 diff --git a/caffe2/utils/math_test.cc b/caffe2/utils/math_test.cc new file mode 100644 index 00000000000..0056df9b406 --- /dev/null +++ b/caffe2/utils/math_test.cc @@ -0,0 +1,186 @@ + +#include "caffe2/core/blob.h" +#include "caffe2/utils/math.h" +#include "caffe2/core/context.h" +#include "caffe2/proto/caffe2.pb.h" +#include "gtest/gtest.h" + +namespace caffe2 { + +TEST(MathTest, GemmNoTransNoTrans) { + DeviceOption option; + CPUContext cpu_context(option); + Tensor X(std::vector{5, 10}); + Tensor W(std::vector{10, 6}); + Tensor Y(std::vector{5, 6}); + EXPECT_EQ(X.size(), 50); + EXPECT_EQ(W.size(), 60); + math::Set(X.size(), 1, X.mutable_data(), &cpu_context); + math::Set(W.size(), 1, W.mutable_data(), &cpu_context); + EXPECT_EQ(Y.size(), 30); + for (int i = 0; i < X.size(); ++i) { + CHECK_EQ(X.data()[i], 1); + } + for (int i = 0; i < W.size(); ++i) { + CHECK_EQ(W.data()[i], 1); + } + + const float kOne = 1.0; + const float kPointFive = 0.5; + const float kZero = 0.0; + math::Gemm(CblasNoTrans, CblasNoTrans, 5, 6, 10, &kOne, + X.data(), W.data(), &kZero, Y.mutable_data(), + &cpu_context); + EXPECT_EQ(Y.size(), 30); + for (int i = 0; i < Y.size(); ++i) { + CHECK_EQ(Y.data()[i], 10) << i; + } + // Test Accumulate + math::Gemm(CblasNoTrans, CblasNoTrans, 5, 6, 10, &kOne, + X.data(), W.data(), &kPointFive, + Y.mutable_data(), &cpu_context); + EXPECT_EQ(Y.size(), 30); + for (int i = 0; i < Y.size(); ++i) { + CHECK_EQ(Y.data()[i], 15) << i; + } + // Test Accumulate + math::Gemm(CblasNoTrans, CblasNoTrans, 5, 6, 10, + &kPointFive, + X.data(), W.data(), &kOne, Y.mutable_data(), + &cpu_context); + EXPECT_EQ(Y.size(), 30); + for (int i = 0; i < Y.size(); ++i) { + CHECK_EQ(Y.data()[i], 20) << i; + } +} + +TEST(MathTest, GemmNoTransTrans) { + DeviceOption option; + CPUContext cpu_context(option); + Tensor X(std::vector{5, 10}); + Tensor W(std::vector{6, 10}); + Tensor Y(std::vector{5, 6}); + EXPECT_EQ(X.size(), 50); + EXPECT_EQ(W.size(), 60); + math::Set(X.size(), 1, X.mutable_data(), &cpu_context); + math::Set(W.size(), 1, W.mutable_data(), &cpu_context); + EXPECT_EQ(Y.size(), 30); + for (int i = 0; i < X.size(); ++i) { + CHECK_EQ(X.data()[i], 1); + } + for (int i = 0; i < W.size(); ++i) { + CHECK_EQ(W.data()[i], 1); + } + + const float kOne = 1.0; + const float kPointFive = 0.5; + const float kZero = 0.0; + math::Gemm(CblasNoTrans, CblasTrans, 5, 6, 10, &kOne, + X.data(), W.data(), &kZero, Y.mutable_data(), + &cpu_context); + EXPECT_EQ(Y.size(), 30); + for (int i = 0; i < Y.size(); ++i) { + CHECK_EQ(Y.data()[i], 10) << i; + } + // Test Accumulate + math::Gemm(CblasNoTrans, CblasTrans, 5, 6, 10, &kOne, + X.data(), W.data(), &kPointFive, + Y.mutable_data(), &cpu_context); + EXPECT_EQ(Y.size(), 30); + for (int i = 0; i < Y.size(); ++i) { + CHECK_EQ(Y.data()[i], 15) << i; + } + math::Gemm(CblasNoTrans, CblasTrans, 5, 6, 10, &kPointFive, + X.data(), W.data(), &kOne, Y.mutable_data(), + &cpu_context); + EXPECT_EQ(Y.size(), 30); + for (int i = 0; i < Y.size(); ++i) { + CHECK_EQ(Y.data()[i], 20) << i; + } +} + +TEST(MathTest, GemvNoTrans) { + DeviceOption option; + CPUContext cpu_context(option); + Tensor A(std::vector{5, 10}); + Tensor X(std::vector{10}); + Tensor Y(std::vector{5}); + EXPECT_EQ(A.size(), 50); + EXPECT_EQ(X.size(), 10); + math::Set(A.size(), 1, A.mutable_data(), &cpu_context); + math::Set(X.size(), 1, X.mutable_data(), &cpu_context); + EXPECT_EQ(Y.size(), 5); + for (int i = 0; i < A.size(); ++i) { + CHECK_EQ(A.data()[i], 1); + } + for (int i = 0; i < X.size(); ++i) { + CHECK_EQ(X.data()[i], 1); + } + + const float kOne = 1.0; + const float kPointFive = 0.5; + const float kZero = 0.0; + math::Gemv(CblasNoTrans, 5, 10, &kOne, A.data(), X.data(), + &kZero, Y.mutable_data(), &cpu_context); + for (int i = 0; i < Y.size(); ++i) { + CHECK_EQ(Y.data()[i], 10) << i; + } + // Test Accumulate + math::Gemv(CblasNoTrans, 5, 10, &kOne, A.data(), X.data(), + &kPointFive, Y.mutable_data(), &cpu_context); + for (int i = 0; i < Y.size(); ++i) { + CHECK_EQ(Y.data()[i], 15) << i; + } + // Test Accumulate + math::Gemv(CblasNoTrans, 5, 10, &kPointFive, A.data(), + X.data(), &kOne, Y.mutable_data(), + &cpu_context); + for (int i = 0; i < Y.size(); ++i) { + CHECK_EQ(Y.data()[i], 20) << i; + } +} + +TEST(MathTest, GemvTrans) { + DeviceOption option; + CPUContext cpu_context(option); + Tensor A(std::vector{6, 10}); + Tensor X(std::vector{6}); + Tensor Y(std::vector{10}); + EXPECT_EQ(A.size(), 60); + EXPECT_EQ(X.size(), 6); + math::Set(A.size(), 1, A.mutable_data(), &cpu_context); + math::Set(X.size(), 1, X.mutable_data(), &cpu_context); + EXPECT_EQ(Y.size(), 10); + for (int i = 0; i < A.size(); ++i) { + CHECK_EQ(A.data()[i], 1); + } + for (int i = 0; i < X.size(); ++i) { + CHECK_EQ(X.data()[i], 1); + } + + const float kOne = 1.0; + const float kPointFive = 0.5; + const float kZero = 0.0; + math::Gemv(CblasTrans, 6, 10, &kOne, A.data(), X.data(), + &kZero, Y.mutable_data(), &cpu_context); + for (int i = 0; i < Y.size(); ++i) { + CHECK_EQ(Y.data()[i], 6) << i; + } + // Test Accumulate + math::Gemv(CblasTrans, 6, 10, &kOne, A.data(), X.data(), + &kPointFive, Y.mutable_data(), &cpu_context); + for (int i = 0; i < Y.size(); ++i) { + CHECK_EQ(Y.data()[i], 9) << i; + } + // Test Accumulate + math::Gemv(CblasTrans, 6, 10, &kPointFive, A.data(), + X.data(), &kOne, Y.mutable_data(), + &cpu_context); + for (int i = 0; i < Y.size(); ++i) { + CHECK_EQ(Y.data()[i], 12) << i; + } +} + +} // namespace caffe2 + + diff --git a/caffe2/utils/mkl_alternate.h b/caffe2/utils/mkl_alternate.h new file mode 100644 index 00000000000..340b0791b7a --- /dev/null +++ b/caffe2/utils/mkl_alternate.h @@ -0,0 +1,83 @@ +// This file implements a set of mkl functions when MKL is not available. +#ifndef CAFFE2_UTILS_MKL_ALTERNATE_H_ +#define CAFFE2_UTILS_MKL_ALTERNATE_H_ + +#ifdef USE_MKL + +#include + +#else // If use MKL, simply include the MKL header + +#include +extern "C" { +#include "caffe2/utils/cblas.h" +} +#include "glog/logging.h" + +// Functions that caffe uses but are not present if MKL is not linked. + +// A simple way to define the vsl unary functions. The operation should +// be in the form e.g. y[i] = sqrt(a[i]) +#define DEFINE_VSL_UNARY_FUNC(name, operation) \ + template \ + inline void v##name(const int n, const Dtype* a, Dtype* y) { \ + DCHECK_GT(n, 0); DCHECK(a); DCHECK(y); \ + for (int i = 0; i < n; ++i) { operation; } \ + } \ + inline void vs##name( \ + const int n, const float* a, float* y) { \ + v##name(n, a, y); \ + } \ + inline void vd##name( \ + const int n, const double* a, double* y) { \ + v##name(n, a, y); \ + } + +DEFINE_VSL_UNARY_FUNC(Sqr, y[i] = a[i] * a[i]); +DEFINE_VSL_UNARY_FUNC(Exp, y[i] = exp(a[i])); +DEFINE_VSL_UNARY_FUNC(Ln, y[i] = std::log(a[i])); +DEFINE_VSL_UNARY_FUNC(Abs, y[i] = fabs(a[i])); + +// A simple way to define the vsl unary functions with singular parameter b. +// The operation should be in the form e.g. y[i] = pow(a[i], b) +#define DEFINE_VSL_UNARY_FUNC_WITH_PARAM(name, operation) \ + template \ + inline void v##name(const int n, const Dtype* a, const Dtype b, Dtype* y) { \ + DCHECK_GT(n, 0); DCHECK(a); DCHECK(y); \ + for (int i = 0; i < n; ++i) { operation; } \ + } \ + inline void vs##name( \ + const int n, const float* a, const float b, float* y) { \ + v##name(n, a, b, y); \ + } \ + inline void vd##name( \ + const int n, const double* a, const float b, double* y) { \ + v##name(n, a, b, y); \ + } + +DEFINE_VSL_UNARY_FUNC_WITH_PARAM(Powx, y[i] = pow(a[i], b)); + +// A simple way to define the vsl binary functions. The operation should +// be in the form e.g. y[i] = a[i] + b[i] +#define DEFINE_VSL_BINARY_FUNC(name, operation) \ + template \ + inline void v##name(const int n, const Dtype* a, const Dtype* b, Dtype* y) { \ + DCHECK_GT(n, 0); DCHECK(a); DCHECK(b); DCHECK(y); \ + for (int i = 0; i < n; ++i) { operation; } \ + } \ + inline void vs##name( \ + const int n, const float* a, const float* b, float* y) { \ + v##name(n, a, b, y); \ + } \ + inline void vd##name( \ + const int n, const double* a, const double* b, double* y) { \ + v##name(n, a, b, y); \ + } + +DEFINE_VSL_BINARY_FUNC(Add, y[i] = a[i] + b[i]); +DEFINE_VSL_BINARY_FUNC(Sub, y[i] = a[i] - b[i]); +DEFINE_VSL_BINARY_FUNC(Mul, y[i] = a[i] * b[i]); +DEFINE_VSL_BINARY_FUNC(Div, y[i] = a[i] / b[i]); + +#endif // USE_MKL +#endif // CAFFE2_UTILS_MKL_ALTERNATE_H_ diff --git a/caffe2/utils/proto_utils.cc b/caffe2/utils/proto_utils.cc new file mode 100644 index 00000000000..1db95fceed1 --- /dev/null +++ b/caffe2/utils/proto_utils.cc @@ -0,0 +1,72 @@ +#include +#include +#include +#include + +#include "caffe2/utils/proto_utils.h" +#include "glog/logging.h" +#include "google/protobuf/io/coded_stream.h" +#include "google/protobuf/io/zero_copy_stream_impl.h" +#include "google/protobuf/text_format.h" + +namespace caffe2 { + +using google::protobuf::io::FileInputStream; +using google::protobuf::io::FileOutputStream; +using google::protobuf::io::ZeroCopyInputStream; +using google::protobuf::io::CodedInputStream; +using google::protobuf::io::ZeroCopyOutputStream; +using google::protobuf::io::CodedOutputStream; +using google::protobuf::Message; +using google::protobuf::MessageLite; + +using std::fstream; +using std::ios; + +bool ReadProtoFromTextFile(const char* filename, Message* proto) { + int fd = open(filename, O_RDONLY); + CHECK_NE(fd, -1) << "File not found: " << filename; + FileInputStream* input = new FileInputStream(fd); + bool success = google::protobuf::TextFormat::Parse(input, proto); + delete input; + close(fd); + return success; +} + +void WriteProtoToTextFile(const Message& proto, const char* filename) { + int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); + FileOutputStream* output = new FileOutputStream(fd); + CHECK(google::protobuf::TextFormat::Print(proto, output)); + delete output; + close(fd); +} + +bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto) { + int fd = open(filename, O_RDONLY); + CHECK_NE(fd, -1) << "File not found: " << filename; + ZeroCopyInputStream* raw_input = new FileInputStream(fd); + CodedInputStream* coded_input = new CodedInputStream(raw_input); + // A hack to manually allow using very large protocol buffers. + coded_input->SetTotalBytesLimit(1073741824, 536870912); + + bool success = proto->ParseFromCodedStream(coded_input); + + delete coded_input; + delete raw_input; + close(fd); + return success; +} + +void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename) { + int fd = open(filename, O_WRONLY | O_CREAT | O_TRUNC, 0644); + CHECK_NE(fd, -1) << "File cannot be created: " << filename + << " error number: " << errno; + ZeroCopyOutputStream* raw_output = new FileOutputStream(fd); + CodedOutputStream* coded_output = new CodedOutputStream(raw_output); + CHECK(proto.SerializeToCodedStream(coded_output)); + delete coded_output; + delete raw_output; + close(fd); +} + +} // namespace caffe2 diff --git a/caffe2/utils/proto_utils.h b/caffe2/utils/proto_utils.h new file mode 100644 index 00000000000..a847dff7bfb --- /dev/null +++ b/caffe2/utils/proto_utils.h @@ -0,0 +1,91 @@ +#ifndef CAFFE2_UTILS_PROTO_UTILS_H_ +#define CAFFE2_UTILS_PROTO_UTILS_H_ + +#include "caffe2/proto/caffe2.pb.h" +#include "google/protobuf/message.h" +#include "glog/logging.h" + +namespace caffe2 { + +using std::string; +using ::google::protobuf::Message; +using ::google::protobuf::MessageLite; +using std::string; + +bool ReadProtoFromTextFile(const char* filename, Message* proto); +inline bool ReadProtoFromTextFile(const string filename, Message* proto) { + return ReadProtoFromTextFile(filename.c_str(), proto); +} + +void WriteProtoToTextFile(const Message& proto, const char* filename); +inline void WriteProtoToTextFile(const Message& proto, const string& filename) { + return WriteProtoToTextFile(proto, filename.c_str()); +} + +// Text format MessageLite wrappers: these functions do nothing but just +// allowing things to compile. It will produce a runtime error if you are using +// MessageLite but still want text support. +inline bool ReadProtoFromTextFile(const char* filename, MessageLite* proto) { + LOG(FATAL) << "If you are running lite version, you should not be " + << "calling any text-format protobuffers."; + return false; // Just to suppress compiler warning. +} +inline bool ReadProtoFromTextFile(const string filename, MessageLite* proto) { + return ReadProtoFromTextFile(filename.c_str(), proto); +} + +inline void WriteProtoToTextFile(const MessageLite& proto, + const char* filename) { + LOG(FATAL) << "If you are running lite version, you should not be " + << "calling any text-format protobuffers."; +} +inline void WriteProtoToTextFile(const MessageLite& proto, + const string& filename) { + return WriteProtoToTextFile(proto, filename.c_str()); +} + +bool ReadProtoFromBinaryFile(const char* filename, MessageLite* proto); +inline bool ReadProtoFromBinaryFile(const string filename, MessageLite* proto) { + return ReadProtoFromBinaryFile(filename.c_str(), proto); +} + +void WriteProtoToBinaryFile(const MessageLite& proto, const char* filename); +inline void WriteProtoToBinaryFile(const MessageLite& proto, + const string& filename) { + return WriteProtoToBinaryFile(proto, filename.c_str()); +} + +// Read Proto from a file, letting the code figure out if it is text or binary. +inline bool ReadProtoFromFile(const char* filename, Message* proto) { + return (ReadProtoFromBinaryFile(filename, proto) || + ReadProtoFromTextFile(filename, proto)); +} +inline bool ReadProtoFromFile(const string& filename, Message* proto) { + return ReadProtoFromFile(filename.c_str(), proto); +} + +inline bool ReadProtoFromFile(const char* filename, MessageLite* proto) { + return (ReadProtoFromBinaryFile(filename, proto) || + ReadProtoFromTextFile(filename, proto)); +} +inline bool ReadProtoFromFile(const string& filename, MessageLite* proto) { + return ReadProtoFromFile(filename.c_str(), proto); +} + +// A coarse support for the Any message in proto3. I am a bit afraid of going +// directly to proto3 yet, so let's do this first... +class Any { + template + static MessageType Parse(const Argument& arg) { + CHECK_EQ(arg.strings_size(), 1) + << "An Any object should parse from a single string."; + MessageType message; + CHECK(message.ParseFromString(arg.strings(0))) + << "Faild to parse from the string"; + return message; + } +}; + +} // namespace caffe2 + +#endif // CAFFE2_UTILS_PROTO_UTILS_H_ diff --git a/caffe2/utils/simple_queue.h b/caffe2/utils/simple_queue.h new file mode 100644 index 00000000000..7bb883f559d --- /dev/null +++ b/caffe2/utils/simple_queue.h @@ -0,0 +1,73 @@ +#ifndef CAFFE2_UTILS_SIMPLE_QUEUE_H_ +#define CAFFE2_UTILS_SIMPLE_QUEUE_H_ + +#include // NOLINT +#include // NOLINT +#include + +#include "glog/logging.h" + +namespace caffe2 { + +// This is a very simple queue that Yangqing wrote when nursing the baby, so +// don't take it seriously. What it does is a minimal thread-safe queue that +// allows me to run network as a DAG. +// +// A usual work pattern looks like this: one or multiple producers push jobs +// into this queue, and one or multiple workers pops jobs from this queue. If +// nothing is in the queue but NoMoreJobs() is not called yet, the pop calls +// will wait. If NoMoreJobs() has been called, pop calls will return false, +// which serves as a message to the workers that they should exit. +template +class SimpleQueue { + public: + SimpleQueue() : no_more_jobs_(false) {} + + // Pops a value and writes it to the value pointer. If there is nothing in the + // queue, this will wait till a value is inserted to the queue. If there are + // no more jobs to pop, the function returns false. Otherwise, it returns + // true. + bool Pop(T* value) { + std::unique_lock mutex_lock(mutex_); + while (queue_.size() == 0 && !no_more_jobs_) cv_.wait(mutex_lock); + if (queue_.size() == 0 && no_more_jobs_) return false; + *value = queue_.front(); + queue_.pop(); + return true; + } + + // Push pushes a value to the queue. + void Push(const T& value) { + std::unique_lock mutex_lock(mutex_); + CHECK(!no_more_jobs_) + << "Cannot push to a closed queue."; + queue_.push(value); + mutex_lock.unlock(); + cv_.notify_one(); + } + + // NoMoreJobs() marks the close of this queue. It also notifies all waiting + // Pop() calls so that they either check out remaining jobs, or return false. + // After NoMoreJobs() is called, this queue is considered closed - no more + // Push() functions are allowed, and once existing items are all checked out + // by the Pop() functions, any more Pop() function will immediately return + // false with nothing set to the value. + void NoMoreJobs() { + std::unique_lock mutex_lock(mutex_); + no_more_jobs_ = true; + mutex_lock.unlock(); + cv_.notify_all(); + } + + private: + std::mutex mutex_; + std::condition_variable cv_; + std::queue queue_; + bool no_more_jobs_; + // We do not allow copy constructors. + SimpleQueue(const SimpleQueue& src) {} +}; + +} // namespace caffe2 + +#endif // CAFFE2_UTILS_SIMPLE_QUEUE_H_ diff --git a/caffe2/utils/simple_queue_test.cc b/caffe2/utils/simple_queue_test.cc new file mode 100644 index 00000000000..37e55274666 --- /dev/null +++ b/caffe2/utils/simple_queue_test.cc @@ -0,0 +1,73 @@ +#include // NOLINT + +#include "caffe2/utils/simple_queue.h" +#include "gtest/gtest.h" + +namespace caffe2 { + +static std::unique_ptr > gQueue; + +static void ConsumerFunction(int thread_idx) { + int value; + while (true) { + if (!gQueue->Pop(&value)) return; + LOG(INFO) << "Emitting " << value << " from thread " << thread_idx; + } +} + +static void ProducerFunction(int thread_idx, int start, int count) { + for (int i = 0; i < count; ++i) { + LOG(INFO) << "Pushing " << i + start << " from thread " << thread_idx; + gQueue->Push(i + start); + } +} + + +TEST(SimpleQueueTest, SingleProducerSingleConsumer) { + gQueue.reset(new SimpleQueue()); + std::thread consumer(ConsumerFunction, 0); + for (int i = 0; i < 10; ++i) { + gQueue->Push(i); + } + gQueue->NoMoreJobs(); + consumer.join(); +} + +TEST(SimpleQueueTest, SingleProducerDoubleConsumer) { + gQueue.reset(new SimpleQueue()); + std::thread consumer0(ConsumerFunction, 0); + std::thread consumer1(ConsumerFunction, 1); + for (int i = 0; i < 10; ++i) { + gQueue->Push(i); + } + gQueue->NoMoreJobs(); + consumer0.join(); + consumer1.join(); +} + + +TEST(SimpleQueueTest, DoubleProducerDoubleConsumer) { + gQueue.reset(new SimpleQueue()); + std::thread producer0(ProducerFunction, 0, 0, 10); + std::thread producer1(ProducerFunction, 0, 10, 10); + std::thread consumer0(ConsumerFunction, 2); + std::thread consumer1(ConsumerFunction, 3); + producer0.join(); + producer1.join(); + gQueue->NoMoreJobs(); + consumer0.join(); + consumer1.join(); +} + +TEST(SimpleQueueDeathTest, CannotAddAfterQueueFinished) { + gQueue.reset(new SimpleQueue()); + gQueue->Push(0); + gQueue->NoMoreJobs(); + EXPECT_DEATH(gQueue->Push(0), + "Check failed: !no_more_jobs_ Cannot push to a closed queue."); +} + + +} // namespace caffe2 + + diff --git a/cpplint.py b/cpplint.py new file mode 100644 index 00000000000..5a914575b7f --- /dev/null +++ b/cpplint.py @@ -0,0 +1,6309 @@ +#!/usr/bin/env python +# +# Copyright (c) 2009 Google Inc. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Does google-lint on c++ files. + +The goal of this script is to identify places in the code that *may* +be in non-compliance with google style. It does not attempt to fix +up these problems -- the point is to educate. It does also not +attempt to find all problems, or to ensure that everything it does +find is legitimately a problem. + +In particular, we can get very confused by /* and // inside strings! +We do a small hack, which is to ignore //'s with "'s after them on the +same line, but it is far from perfect (in either direction). +""" + +import codecs +import copy +import getopt +import math # for log +import os +import re +import sre_compile +import string +import sys +import unicodedata + + +_USAGE = """ +Syntax: cpplint.py [--verbose=#] [--output=vs7] [--filter=-x,+y,...] + [--counting=total|toplevel|detailed] [--root=subdir] + [--linelength=digits] + [file] ... + + The style guidelines this tries to follow are those in + http://google-styleguide.googlecode.com/svn/trunk/cppguide.xml + + Every problem is given a confidence score from 1-5, with 5 meaning we are + certain of the problem, and 1 meaning it could be a legitimate construct. + This will miss some errors, and is not a substitute for a code review. + + To suppress false-positive errors of a certain category, add a + 'NOLINT(category)' comment to the line. NOLINT or NOLINT(*) + suppresses errors of all categories on that line. + + The files passed in will be linted; at least one file must be provided. + Default linted extensions are .cc, .cpp, .cu, .cuh and .h. Change the + extensions with the --extensions flag. + + Flags: + + output=vs7 + By default, the output is formatted to ease emacs parsing. Visual Studio + compatible output (vs7) may also be used. Other formats are unsupported. + + verbose=# + Specify a number 0-5 to restrict errors to certain verbosity levels. + + filter=-x,+y,... + Specify a comma-separated list of category-filters to apply: only + error messages whose category names pass the filters will be printed. + (Category names are printed with the message and look like + "[whitespace/indent]".) Filters are evaluated left to right. + "-FOO" and "FOO" means "do not print categories that start with FOO". + "+FOO" means "do print categories that start with FOO". + + Examples: --filter=-whitespace,+whitespace/braces + --filter=whitespace,runtime/printf,+runtime/printf_format + --filter=-,+build/include_what_you_use + + To see a list of all the categories used in cpplint, pass no arg: + --filter= + + counting=total|toplevel|detailed + The total number of errors found is always printed. If + 'toplevel' is provided, then the count of errors in each of + the top-level categories like 'build' and 'whitespace' will + also be printed. If 'detailed' is provided, then a count + is provided for each category like 'build/class'. + + root=subdir + The root directory used for deriving header guard CPP variable. + By default, the header guard CPP variable is calculated as the relative + path to the directory that contains .git, .hg, or .svn. When this flag + is specified, the relative path is calculated from the specified + directory. If the specified directory does not exist, this flag is + ignored. + + Examples: + Assuming that src/.git exists, the header guard CPP variables for + src/chrome/browser/ui/browser.h are: + + No flag => CHROME_BROWSER_UI_BROWSER_H_ + --root=chrome => BROWSER_UI_BROWSER_H_ + --root=chrome/browser => UI_BROWSER_H_ + + linelength=digits + This is the allowed line length for the project. The default value is + 80 characters. + + Examples: + --linelength=120 + + extensions=extension,extension,... + The allowed file extensions that cpplint will check + + Examples: + --extensions=hpp,cpp + + cpplint.py supports per-directory configurations specified in CPPLINT.cfg + files. CPPLINT.cfg file can contain a number of key=value pairs. + Currently the following options are supported: + + set noparent + filter=+filter1,-filter2,... + exclude_files=regex + linelength=80 + + "set noparent" option prevents cpplint from traversing directory tree + upwards looking for more .cfg files in parent directories. This option + is usually placed in the top-level project directory. + + The "filter" option is similar in function to --filter flag. It specifies + message filters in addition to the |_DEFAULT_FILTERS| and those specified + through --filter command-line flag. + + "exclude_files" allows to specify a regular expression to be matched against + a file name. If the expression matches, the file is skipped and not run + through liner. + + "linelength" allows to specify the allowed line length for the project. + + CPPLINT.cfg has an effect on files in the same directory and all + sub-directories, unless overridden by a nested configuration file. + + Example file: + filter=-build/include_order,+build/include_alpha + exclude_files=.*\.cc + + The above example disables build/include_order warning and enables + build/include_alpha as well as excludes all .cc from being + processed by linter, in the current directory (where the .cfg + file is located) and all sub-directories. +""" + +# We categorize each error message we print. Here are the categories. +# We want an explicit list so we can list them all in cpplint --filter=. +# If you add a new error message with a new category, add it to the list +# here! cpplint_unittest.py should tell you if you forget to do this. +_ERROR_CATEGORIES = [ + 'build/class', + 'build/c++11', + 'build/deprecated', + 'build/endif_comment', + 'build/explicit_make_pair', + 'build/forward_decl', + 'build/header_guard', + 'build/include', + 'build/include_alpha', + 'build/include_order', + 'build/include_what_you_use', + 'build/namespaces', + 'build/printf_format', + 'build/storage_class', + 'legal/copyright', + 'readability/alt_tokens', + 'readability/braces', + 'readability/casting', + 'readability/check', + 'readability/constructors', + 'readability/fn_size', + 'readability/function', + 'readability/inheritance', + 'readability/multiline_comment', + 'readability/multiline_string', + 'readability/namespace', + 'readability/nolint', + 'readability/nul', + 'readability/strings', + 'readability/todo', + 'readability/utf8', + 'runtime/arrays', + 'runtime/casting', + 'runtime/explicit', + 'runtime/int', + 'runtime/init', + 'runtime/invalid_increment', + 'runtime/member_string_references', + 'runtime/memset', + 'runtime/indentation_namespace', + 'runtime/operator', + 'runtime/printf', + 'runtime/printf_format', + 'runtime/references', + 'runtime/string', + 'runtime/threadsafe_fn', + 'runtime/vlog', + 'whitespace/blank_line', + 'whitespace/braces', + 'whitespace/comma', + 'whitespace/comments', + 'whitespace/empty_conditional_body', + 'whitespace/empty_loop_body', + 'whitespace/end_of_line', + 'whitespace/ending_newline', + 'whitespace/forcolon', + 'whitespace/indent', + 'whitespace/line_length', + 'whitespace/newline', + 'whitespace/operators', + 'whitespace/parens', + 'whitespace/semicolon', + 'whitespace/tab', + 'whitespace/todo', + ] + +# These error categories are no longer enforced by cpplint, but for backwards- +# compatibility they may still appear in NOLINT comments. +_LEGACY_ERROR_CATEGORIES = [ + 'readability/streams', + ] + +# The default state of the category filter. This is overridden by the --filter= +# flag. By default all errors are on, so only add here categories that should be +# off by default (i.e., categories that must be enabled by the --filter= flags). +# All entries here should start with a '-' or '+', as in the --filter= flag. +_DEFAULT_FILTERS = ['-build/include_alpha'] + +# We used to check for high-bit characters, but after much discussion we +# decided those were OK, as long as they were in UTF-8 and didn't represent +# hard-coded international strings, which belong in a separate i18n file. + +# C++ headers +_CPP_HEADERS = frozenset([ + # Legacy + 'algobase.h', + 'algo.h', + 'alloc.h', + 'builtinbuf.h', + 'bvector.h', + 'complex.h', + 'defalloc.h', + 'deque.h', + 'editbuf.h', + 'fstream.h', + 'function.h', + 'hash_map', + 'hash_map.h', + 'hash_set', + 'hash_set.h', + 'hashtable.h', + 'heap.h', + 'indstream.h', + 'iomanip.h', + 'iostream.h', + 'istream.h', + 'iterator.h', + 'list.h', + 'map.h', + 'multimap.h', + 'multiset.h', + 'ostream.h', + 'pair.h', + 'parsestream.h', + 'pfstream.h', + 'procbuf.h', + 'pthread_alloc', + 'pthread_alloc.h', + 'rope', + 'rope.h', + 'ropeimpl.h', + 'set.h', + 'slist', + 'slist.h', + 'stack.h', + 'stdiostream.h', + 'stl_alloc.h', + 'stl_relops.h', + 'streambuf.h', + 'stream.h', + 'strfile.h', + 'strstream.h', + 'tempbuf.h', + 'tree.h', + 'type_traits.h', + 'vector.h', + # 17.6.1.2 C++ library headers + 'algorithm', + 'array', + 'atomic', + 'bitset', + 'chrono', + 'codecvt', + 'complex', + 'condition_variable', + 'deque', + 'exception', + 'forward_list', + 'fstream', + 'functional', + 'future', + 'initializer_list', + 'iomanip', + 'ios', + 'iosfwd', + 'iostream', + 'istream', + 'iterator', + 'limits', + 'list', + 'locale', + 'map', + 'memory', + 'mutex', + 'new', + 'numeric', + 'ostream', + 'queue', + 'random', + 'ratio', + 'regex', + 'set', + 'sstream', + 'stack', + 'stdexcept', + 'streambuf', + 'string', + 'strstream', + 'system_error', + 'thread', + 'tuple', + 'typeindex', + 'typeinfo', + 'type_traits', + 'unordered_map', + 'unordered_set', + 'utility', + 'valarray', + 'vector', + # 17.6.1.2 C++ headers for C library facilities + 'cassert', + 'ccomplex', + 'cctype', + 'cerrno', + 'cfenv', + 'cfloat', + 'cinttypes', + 'ciso646', + 'climits', + 'clocale', + 'cmath', + 'csetjmp', + 'csignal', + 'cstdalign', + 'cstdarg', + 'cstdbool', + 'cstddef', + 'cstdint', + 'cstdio', + 'cstdlib', + 'cstring', + 'ctgmath', + 'ctime', + 'cuchar', + 'cwchar', + 'cwctype', + ]) + + +# These headers are excluded from [build/include] and [build/include_order] +# checks: +# - Anything not following google file name conventions (containing an +# uppercase character, such as Python.h or nsStringAPI.h, for example). +# - Lua headers. +_THIRD_PARTY_HEADERS_PATTERN = re.compile( + r'^(?:[^/]*[A-Z][^/]*\.h|lua\.h|lauxlib\.h|lualib\.h)$') + + +# Assertion macros. These are defined in base/logging.h and +# testing/base/gunit.h. Note that the _M versions need to come first +# for substring matching to work. +_CHECK_MACROS = [ + 'DCHECK', 'CHECK', + 'EXPECT_TRUE_M', 'EXPECT_TRUE', + 'ASSERT_TRUE_M', 'ASSERT_TRUE', + 'EXPECT_FALSE_M', 'EXPECT_FALSE', + 'ASSERT_FALSE_M', 'ASSERT_FALSE', + ] + +# Replacement macros for CHECK/DCHECK/EXPECT_TRUE/EXPECT_FALSE +_CHECK_REPLACEMENT = dict([(m, {}) for m in _CHECK_MACROS]) + +for op, replacement in [('==', 'EQ'), ('!=', 'NE'), + ('>=', 'GE'), ('>', 'GT'), + ('<=', 'LE'), ('<', 'LT')]: + _CHECK_REPLACEMENT['DCHECK'][op] = 'DCHECK_%s' % replacement + _CHECK_REPLACEMENT['CHECK'][op] = 'CHECK_%s' % replacement + _CHECK_REPLACEMENT['EXPECT_TRUE'][op] = 'EXPECT_%s' % replacement + _CHECK_REPLACEMENT['ASSERT_TRUE'][op] = 'ASSERT_%s' % replacement + _CHECK_REPLACEMENT['EXPECT_TRUE_M'][op] = 'EXPECT_%s_M' % replacement + _CHECK_REPLACEMENT['ASSERT_TRUE_M'][op] = 'ASSERT_%s_M' % replacement + +for op, inv_replacement in [('==', 'NE'), ('!=', 'EQ'), + ('>=', 'LT'), ('>', 'LE'), + ('<=', 'GT'), ('<', 'GE')]: + _CHECK_REPLACEMENT['EXPECT_FALSE'][op] = 'EXPECT_%s' % inv_replacement + _CHECK_REPLACEMENT['ASSERT_FALSE'][op] = 'ASSERT_%s' % inv_replacement + _CHECK_REPLACEMENT['EXPECT_FALSE_M'][op] = 'EXPECT_%s_M' % inv_replacement + _CHECK_REPLACEMENT['ASSERT_FALSE_M'][op] = 'ASSERT_%s_M' % inv_replacement + +# Alternative tokens and their replacements. For full list, see section 2.5 +# Alternative tokens [lex.digraph] in the C++ standard. +# +# Digraphs (such as '%:') are not included here since it's a mess to +# match those on a word boundary. +_ALT_TOKEN_REPLACEMENT = { + 'and': '&&', + 'bitor': '|', + 'or': '||', + 'xor': '^', + 'compl': '~', + 'bitand': '&', + 'and_eq': '&=', + 'or_eq': '|=', + 'xor_eq': '^=', + 'not': '!', + 'not_eq': '!=' + } + +# Compile regular expression that matches all the above keywords. The "[ =()]" +# bit is meant to avoid matching these keywords outside of boolean expressions. +# +# False positives include C-style multi-line comments and multi-line strings +# but those have always been troublesome for cpplint. +_ALT_TOKEN_REPLACEMENT_PATTERN = re.compile( + r'[ =()](' + ('|'.join(_ALT_TOKEN_REPLACEMENT.keys())) + r')(?=[ (]|$)') + + +# These constants define types of headers for use with +# _IncludeState.CheckNextIncludeOrder(). +_C_SYS_HEADER = 1 +_CPP_SYS_HEADER = 2 +_LIKELY_MY_HEADER = 3 +_POSSIBLE_MY_HEADER = 4 +_OTHER_HEADER = 5 + +# These constants define the current inline assembly state +_NO_ASM = 0 # Outside of inline assembly block +_INSIDE_ASM = 1 # Inside inline assembly block +_END_ASM = 2 # Last line of inline assembly block +_BLOCK_ASM = 3 # The whole block is an inline assembly block + +# Match start of assembly blocks +_MATCH_ASM = re.compile(r'^\s*(?:asm|_asm|__asm|__asm__)' + r'(?:\s+(volatile|__volatile__))?' + r'\s*[{(]') + + +_regexp_compile_cache = {} + +# {str, set(int)}: a map from error categories to sets of linenumbers +# on which those errors are expected and should be suppressed. +_error_suppressions = {} + +# The root directory used for deriving header guard CPP variable. +# This is set by --root flag. +_root = None + +# The allowed line length of files. +# This is set by --linelength flag. +_line_length = 80 + +# The allowed extensions for file names +# This is set by --extensions flag. +_valid_extensions = set(['cc', 'h', 'cpp', 'cu', 'cuh']) + +def ParseNolintSuppressions(filename, raw_line, linenum, error): + """Updates the global list of error-suppressions. + + Parses any NOLINT comments on the current line, updating the global + error_suppressions store. Reports an error if the NOLINT comment + was malformed. + + Args: + filename: str, the name of the input file. + raw_line: str, the line of input text, with comments. + linenum: int, the number of the current line. + error: function, an error handler. + """ + matched = Search(r'\bNOLINT(NEXTLINE)?\b(\([^)]+\))?', raw_line) + if matched: + if matched.group(1): + suppressed_line = linenum + 1 + else: + suppressed_line = linenum + category = matched.group(2) + if category in (None, '(*)'): # => "suppress all" + _error_suppressions.setdefault(None, set()).add(suppressed_line) + else: + if category.startswith('(') and category.endswith(')'): + category = category[1:-1] + if category in _ERROR_CATEGORIES: + _error_suppressions.setdefault(category, set()).add(suppressed_line) + elif category not in _LEGACY_ERROR_CATEGORIES: + error(filename, linenum, 'readability/nolint', 5, + 'Unknown NOLINT error category: %s' % category) + + +def ResetNolintSuppressions(): + """Resets the set of NOLINT suppressions to empty.""" + _error_suppressions.clear() + + +def IsErrorSuppressedByNolint(category, linenum): + """Returns true if the specified error category is suppressed on this line. + + Consults the global error_suppressions map populated by + ParseNolintSuppressions/ResetNolintSuppressions. + + Args: + category: str, the category of the error. + linenum: int, the current line number. + Returns: + bool, True iff the error should be suppressed due to a NOLINT comment. + """ + return (linenum in _error_suppressions.get(category, set()) or + linenum in _error_suppressions.get(None, set())) + + +def Match(pattern, s): + """Matches the string with the pattern, caching the compiled regexp.""" + # The regexp compilation caching is inlined in both Match and Search for + # performance reasons; factoring it out into a separate function turns out + # to be noticeably expensive. + if pattern not in _regexp_compile_cache: + _regexp_compile_cache[pattern] = sre_compile.compile(pattern) + return _regexp_compile_cache[pattern].match(s) + + +def ReplaceAll(pattern, rep, s): + """Replaces instances of pattern in a string with a replacement. + + The compiled regex is kept in a cache shared by Match and Search. + + Args: + pattern: regex pattern + rep: replacement text + s: search string + + Returns: + string with replacements made (or original string if no replacements) + """ + if pattern not in _regexp_compile_cache: + _regexp_compile_cache[pattern] = sre_compile.compile(pattern) + return _regexp_compile_cache[pattern].sub(rep, s) + + +def Search(pattern, s): + """Searches the string for the pattern, caching the compiled regexp.""" + if pattern not in _regexp_compile_cache: + _regexp_compile_cache[pattern] = sre_compile.compile(pattern) + return _regexp_compile_cache[pattern].search(s) + + +class _IncludeState(object): + """Tracks line numbers for includes, and the order in which includes appear. + + include_list contains list of lists of (header, line number) pairs. + It's a lists of lists rather than just one flat list to make it + easier to update across preprocessor boundaries. + + Call CheckNextIncludeOrder() once for each header in the file, passing + in the type constants defined above. Calls in an illegal order will + raise an _IncludeError with an appropriate error message. + + """ + # self._section will move monotonically through this set. If it ever + # needs to move backwards, CheckNextIncludeOrder will raise an error. + _INITIAL_SECTION = 0 + _MY_H_SECTION = 1 + _C_SECTION = 2 + _CPP_SECTION = 3 + _OTHER_H_SECTION = 4 + + _TYPE_NAMES = { + _C_SYS_HEADER: 'C system header', + _CPP_SYS_HEADER: 'C++ system header', + _LIKELY_MY_HEADER: 'header this file implements', + _POSSIBLE_MY_HEADER: 'header this file may implement', + _OTHER_HEADER: 'other header', + } + _SECTION_NAMES = { + _INITIAL_SECTION: "... nothing. (This can't be an error.)", + _MY_H_SECTION: 'a header this file implements', + _C_SECTION: 'C system header', + _CPP_SECTION: 'C++ system header', + _OTHER_H_SECTION: 'other header', + } + + def __init__(self): + self.include_list = [[]] + self.ResetSection('') + + def FindHeader(self, header): + """Check if a header has already been included. + + Args: + header: header to check. + Returns: + Line number of previous occurrence, or -1 if the header has not + been seen before. + """ + for section_list in self.include_list: + for f in section_list: + if f[0] == header: + return f[1] + return -1 + + def ResetSection(self, directive): + """Reset section checking for preprocessor directive. + + Args: + directive: preprocessor directive (e.g. "if", "else"). + """ + # The name of the current section. + self._section = self._INITIAL_SECTION + # The path of last found header. + self._last_header = '' + + # Update list of includes. Note that we never pop from the + # include list. + if directive in ('if', 'ifdef', 'ifndef'): + self.include_list.append([]) + elif directive in ('else', 'elif'): + self.include_list[-1] = [] + + def SetLastHeader(self, header_path): + self._last_header = header_path + + def CanonicalizeAlphabeticalOrder(self, header_path): + """Returns a path canonicalized for alphabetical comparison. + + - replaces "-" with "_" so they both cmp the same. + - removes '-inl' since we don't require them to be after the main header. + - lowercase everything, just in case. + + Args: + header_path: Path to be canonicalized. + + Returns: + Canonicalized path. + """ + return header_path.replace('-inl.h', '.h').replace('-', '_').lower() + + def IsInAlphabeticalOrder(self, clean_lines, linenum, header_path): + """Check if a header is in alphabetical order with the previous header. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + header_path: Canonicalized header to be checked. + + Returns: + Returns true if the header is in alphabetical order. + """ + # If previous section is different from current section, _last_header will + # be reset to empty string, so it's always less than current header. + # + # If previous line was a blank line, assume that the headers are + # intentionally sorted the way they are. + if (self._last_header > header_path and + Match(r'^\s*#\s*include\b', clean_lines.elided[linenum - 1])): + return False + return True + + def CheckNextIncludeOrder(self, header_type): + """Returns a non-empty error message if the next header is out of order. + + This function also updates the internal state to be ready to check + the next include. + + Args: + header_type: One of the _XXX_HEADER constants defined above. + + Returns: + The empty string if the header is in the right order, or an + error message describing what's wrong. + + """ + error_message = ('Found %s after %s' % + (self._TYPE_NAMES[header_type], + self._SECTION_NAMES[self._section])) + + last_section = self._section + + if header_type == _C_SYS_HEADER: + if self._section <= self._C_SECTION: + self._section = self._C_SECTION + else: + self._last_header = '' + return error_message + elif header_type == _CPP_SYS_HEADER: + if self._section <= self._CPP_SECTION: + self._section = self._CPP_SECTION + else: + self._last_header = '' + return error_message + elif header_type == _LIKELY_MY_HEADER: + if self._section <= self._MY_H_SECTION: + self._section = self._MY_H_SECTION + else: + self._section = self._OTHER_H_SECTION + elif header_type == _POSSIBLE_MY_HEADER: + if self._section <= self._MY_H_SECTION: + self._section = self._MY_H_SECTION + else: + # This will always be the fallback because we're not sure + # enough that the header is associated with this file. + self._section = self._OTHER_H_SECTION + else: + assert header_type == _OTHER_HEADER + self._section = self._OTHER_H_SECTION + + if last_section != self._section: + self._last_header = '' + + return '' + + +class _CppLintState(object): + """Maintains module-wide state..""" + + def __init__(self): + self.verbose_level = 1 # global setting. + self.error_count = 0 # global count of reported errors + # filters to apply when emitting error messages + self.filters = _DEFAULT_FILTERS[:] + # backup of filter list. Used to restore the state after each file. + self._filters_backup = self.filters[:] + self.counting = 'total' # In what way are we counting errors? + self.errors_by_category = {} # string to int dict storing error counts + + # output format: + # "emacs" - format that emacs can parse (default) + # "vs7" - format that Microsoft Visual Studio 7 can parse + self.output_format = 'emacs' + + def SetOutputFormat(self, output_format): + """Sets the output format for errors.""" + self.output_format = output_format + + def SetVerboseLevel(self, level): + """Sets the module's verbosity, and returns the previous setting.""" + last_verbose_level = self.verbose_level + self.verbose_level = level + return last_verbose_level + + def SetCountingStyle(self, counting_style): + """Sets the module's counting options.""" + self.counting = counting_style + + def SetFilters(self, filters): + """Sets the error-message filters. + + These filters are applied when deciding whether to emit a given + error message. + + Args: + filters: A string of comma-separated filters (eg "+whitespace/indent"). + Each filter should start with + or -; else we die. + + Raises: + ValueError: The comma-separated filters did not all start with '+' or '-'. + E.g. "-,+whitespace,-whitespace/indent,whitespace/badfilter" + """ + # Default filters always have less priority than the flag ones. + self.filters = _DEFAULT_FILTERS[:] + self.AddFilters(filters) + + def AddFilters(self, filters): + """ Adds more filters to the existing list of error-message filters. """ + for filt in filters.split(','): + clean_filt = filt.strip() + if clean_filt: + self.filters.append(clean_filt) + for filt in self.filters: + if not (filt.startswith('+') or filt.startswith('-')): + raise ValueError('Every filter in --filters must start with + or -' + ' (%s does not)' % filt) + + def BackupFilters(self): + """ Saves the current filter list to backup storage.""" + self._filters_backup = self.filters[:] + + def RestoreFilters(self): + """ Restores filters previously backed up.""" + self.filters = self._filters_backup[:] + + def ResetErrorCounts(self): + """Sets the module's error statistic back to zero.""" + self.error_count = 0 + self.errors_by_category = {} + + def IncrementErrorCount(self, category): + """Bumps the module's error statistic.""" + self.error_count += 1 + if self.counting in ('toplevel', 'detailed'): + if self.counting != 'detailed': + category = category.split('/')[0] + if category not in self.errors_by_category: + self.errors_by_category[category] = 0 + self.errors_by_category[category] += 1 + + def PrintErrorCounts(self): + """Print a summary of errors by category, and the total.""" + for category, count in self.errors_by_category.iteritems(): + sys.stderr.write('Category \'%s\' errors found: %d\n' % + (category, count)) + sys.stderr.write('Total errors found: %d\n' % self.error_count) + +_cpplint_state = _CppLintState() + + +def _OutputFormat(): + """Gets the module's output format.""" + return _cpplint_state.output_format + + +def _SetOutputFormat(output_format): + """Sets the module's output format.""" + _cpplint_state.SetOutputFormat(output_format) + + +def _VerboseLevel(): + """Returns the module's verbosity setting.""" + return _cpplint_state.verbose_level + + +def _SetVerboseLevel(level): + """Sets the module's verbosity, and returns the previous setting.""" + return _cpplint_state.SetVerboseLevel(level) + + +def _SetCountingStyle(level): + """Sets the module's counting options.""" + _cpplint_state.SetCountingStyle(level) + + +def _Filters(): + """Returns the module's list of output filters, as a list.""" + return _cpplint_state.filters + + +def _SetFilters(filters): + """Sets the module's error-message filters. + + These filters are applied when deciding whether to emit a given + error message. + + Args: + filters: A string of comma-separated filters (eg "whitespace/indent"). + Each filter should start with + or -; else we die. + """ + _cpplint_state.SetFilters(filters) + +def _AddFilters(filters): + """Adds more filter overrides. + + Unlike _SetFilters, this function does not reset the current list of filters + available. + + Args: + filters: A string of comma-separated filters (eg "whitespace/indent"). + Each filter should start with + or -; else we die. + """ + _cpplint_state.AddFilters(filters) + +def _BackupFilters(): + """ Saves the current filter list to backup storage.""" + _cpplint_state.BackupFilters() + +def _RestoreFilters(): + """ Restores filters previously backed up.""" + _cpplint_state.RestoreFilters() + +class _FunctionState(object): + """Tracks current function name and the number of lines in its body.""" + + _NORMAL_TRIGGER = 250 # for --v=0, 500 for --v=1, etc. + _TEST_TRIGGER = 400 # about 50% more than _NORMAL_TRIGGER. + + def __init__(self): + self.in_a_function = False + self.lines_in_function = 0 + self.current_function = '' + + def Begin(self, function_name): + """Start analyzing function body. + + Args: + function_name: The name of the function being tracked. + """ + self.in_a_function = True + self.lines_in_function = 0 + self.current_function = function_name + + def Count(self): + """Count line in current function body.""" + if self.in_a_function: + self.lines_in_function += 1 + + def Check(self, error, filename, linenum): + """Report if too many lines in function body. + + Args: + error: The function to call with any errors found. + filename: The name of the current file. + linenum: The number of the line to check. + """ + if Match(r'T(EST|est)', self.current_function): + base_trigger = self._TEST_TRIGGER + else: + base_trigger = self._NORMAL_TRIGGER + trigger = base_trigger * 2**_VerboseLevel() + + if self.lines_in_function > trigger: + error_level = int(math.log(self.lines_in_function / base_trigger, 2)) + # 50 => 0, 100 => 1, 200 => 2, 400 => 3, 800 => 4, 1600 => 5, ... + if error_level > 5: + error_level = 5 + error(filename, linenum, 'readability/fn_size', error_level, + 'Small and focused functions are preferred:' + ' %s has %d non-comment lines' + ' (error triggered by exceeding %d lines).' % ( + self.current_function, self.lines_in_function, trigger)) + + def End(self): + """Stop analyzing function body.""" + self.in_a_function = False + + +class _IncludeError(Exception): + """Indicates a problem with the include order in a file.""" + pass + + +class FileInfo(object): + """Provides utility functions for filenames. + + FileInfo provides easy access to the components of a file's path + relative to the project root. + """ + + def __init__(self, filename): + self._filename = filename + + def FullName(self): + """Make Windows paths like Unix.""" + return os.path.abspath(self._filename).replace('\\', '/') + + def RepositoryName(self): + """FullName after removing the local path to the repository. + + If we have a real absolute path name here we can try to do something smart: + detecting the root of the checkout and truncating /path/to/checkout from + the name so that we get header guards that don't include things like + "C:\Documents and Settings\..." or "/home/username/..." in them and thus + people on different computers who have checked the source out to different + locations won't see bogus errors. + """ + fullname = self.FullName() + + if os.path.exists(fullname): + project_dir = os.path.dirname(fullname) + + if os.path.exists(os.path.join(project_dir, ".svn")): + # If there's a .svn file in the current directory, we recursively look + # up the directory tree for the top of the SVN checkout + root_dir = project_dir + one_up_dir = os.path.dirname(root_dir) + while os.path.exists(os.path.join(one_up_dir, ".svn")): + root_dir = os.path.dirname(root_dir) + one_up_dir = os.path.dirname(one_up_dir) + + prefix = os.path.commonprefix([root_dir, project_dir]) + return fullname[len(prefix) + 1:] + + # Not SVN <= 1.6? Try to find a git, hg, or svn top level directory by + # searching up from the current path. + root_dir = os.path.dirname(fullname) + while (root_dir != os.path.dirname(root_dir) and + not os.path.exists(os.path.join(root_dir, ".git")) and + not os.path.exists(os.path.join(root_dir, ".hg")) and + not os.path.exists(os.path.join(root_dir, ".svn"))): + root_dir = os.path.dirname(root_dir) + + if (os.path.exists(os.path.join(root_dir, ".git")) or + os.path.exists(os.path.join(root_dir, ".hg")) or + os.path.exists(os.path.join(root_dir, ".svn"))): + prefix = os.path.commonprefix([root_dir, project_dir]) + return fullname[len(prefix) + 1:] + + # Don't know what to do; header guard warnings may be wrong... + return fullname + + def Split(self): + """Splits the file into the directory, basename, and extension. + + For 'chrome/browser/browser.cc', Split() would + return ('chrome/browser', 'browser', '.cc') + + Returns: + A tuple of (directory, basename, extension). + """ + + googlename = self.RepositoryName() + project, rest = os.path.split(googlename) + return (project,) + os.path.splitext(rest) + + def BaseName(self): + """File base name - text after the final slash, before the final period.""" + return self.Split()[1] + + def Extension(self): + """File extension - text following the final period.""" + return self.Split()[2] + + def NoExtension(self): + """File has no source file extension.""" + return '/'.join(self.Split()[0:2]) + + def IsSource(self): + """File has a source file extension.""" + return self.Extension()[1:] in ('c', 'cc', 'cpp', 'cxx') + + +def _ShouldPrintError(category, confidence, linenum): + """If confidence >= verbose, category passes filter and is not suppressed.""" + + # There are three ways we might decide not to print an error message: + # a "NOLINT(category)" comment appears in the source, + # the verbosity level isn't high enough, or the filters filter it out. + if IsErrorSuppressedByNolint(category, linenum): + return False + + if confidence < _cpplint_state.verbose_level: + return False + + is_filtered = False + for one_filter in _Filters(): + if one_filter.startswith('-'): + if category.startswith(one_filter[1:]): + is_filtered = True + elif one_filter.startswith('+'): + if category.startswith(one_filter[1:]): + is_filtered = False + else: + assert False # should have been checked for in SetFilter. + if is_filtered: + return False + + return True + + +def Error(filename, linenum, category, confidence, message): + """Logs the fact we've found a lint error. + + We log where the error was found, and also our confidence in the error, + that is, how certain we are this is a legitimate style regression, and + not a misidentification or a use that's sometimes justified. + + False positives can be suppressed by the use of + "cpplint(category)" comments on the offending line. These are + parsed into _error_suppressions. + + Args: + filename: The name of the file containing the error. + linenum: The number of the line containing the error. + category: A string used to describe the "category" this bug + falls under: "whitespace", say, or "runtime". Categories + may have a hierarchy separated by slashes: "whitespace/indent". + confidence: A number from 1-5 representing a confidence score for + the error, with 5 meaning that we are certain of the problem, + and 1 meaning that it could be a legitimate construct. + message: The error message. + """ + if _ShouldPrintError(category, confidence, linenum): + _cpplint_state.IncrementErrorCount(category) + if _cpplint_state.output_format == 'vs7': + sys.stderr.write('%s(%s): %s [%s] [%d]\n' % ( + filename, linenum, message, category, confidence)) + elif _cpplint_state.output_format == 'eclipse': + sys.stderr.write('%s:%s: warning: %s [%s] [%d]\n' % ( + filename, linenum, message, category, confidence)) + else: + sys.stderr.write('%s:%s: %s [%s] [%d]\n' % ( + filename, linenum, message, category, confidence)) + + +# Matches standard C++ escape sequences per 2.13.2.3 of the C++ standard. +_RE_PATTERN_CLEANSE_LINE_ESCAPES = re.compile( + r'\\([abfnrtv?"\\\']|\d+|x[0-9a-fA-F]+)') +# Match a single C style comment on the same line. +_RE_PATTERN_C_COMMENTS = r'/\*(?:[^*]|\*(?!/))*\*/' +# Matches multi-line C style comments. +# This RE is a little bit more complicated than one might expect, because we +# have to take care of space removals tools so we can handle comments inside +# statements better. +# The current rule is: We only clear spaces from both sides when we're at the +# end of the line. Otherwise, we try to remove spaces from the right side, +# if this doesn't work we try on left side but only if there's a non-character +# on the right. +_RE_PATTERN_CLEANSE_LINE_C_COMMENTS = re.compile( + r'(\s*' + _RE_PATTERN_C_COMMENTS + r'\s*$|' + + _RE_PATTERN_C_COMMENTS + r'\s+|' + + r'\s+' + _RE_PATTERN_C_COMMENTS + r'(?=\W)|' + + _RE_PATTERN_C_COMMENTS + r')') + + +def IsCppString(line): + """Does line terminate so, that the next symbol is in string constant. + + This function does not consider single-line nor multi-line comments. + + Args: + line: is a partial line of code starting from the 0..n. + + Returns: + True, if next character appended to 'line' is inside a + string constant. + """ + + line = line.replace(r'\\', 'XX') # after this, \\" does not match to \" + return ((line.count('"') - line.count(r'\"') - line.count("'\"'")) & 1) == 1 + + +def CleanseRawStrings(raw_lines): + """Removes C++11 raw strings from lines. + + Before: + static const char kData[] = R"( + multi-line string + )"; + + After: + static const char kData[] = "" + (replaced by blank line) + ""; + + Args: + raw_lines: list of raw lines. + + Returns: + list of lines with C++11 raw strings replaced by empty strings. + """ + + delimiter = None + lines_without_raw_strings = [] + for line in raw_lines: + if delimiter: + # Inside a raw string, look for the end + end = line.find(delimiter) + if end >= 0: + # Found the end of the string, match leading space for this + # line and resume copying the original lines, and also insert + # a "" on the last line. + leading_space = Match(r'^(\s*)\S', line) + line = leading_space.group(1) + '""' + line[end + len(delimiter):] + delimiter = None + else: + # Haven't found the end yet, append a blank line. + line = '""' + + # Look for beginning of a raw string, and replace them with + # empty strings. This is done in a loop to handle multiple raw + # strings on the same line. + while delimiter is None: + # Look for beginning of a raw string. + # See 2.14.15 [lex.string] for syntax. + matched = Match(r'^(.*)\b(?:R|u8R|uR|UR|LR)"([^\s\\()]*)\((.*)$', line) + if matched: + delimiter = ')' + matched.group(2) + '"' + + end = matched.group(3).find(delimiter) + if end >= 0: + # Raw string ended on same line + line = (matched.group(1) + '""' + + matched.group(3)[end + len(delimiter):]) + delimiter = None + else: + # Start of a multi-line raw string + line = matched.group(1) + '""' + else: + break + + lines_without_raw_strings.append(line) + + # TODO(unknown): if delimiter is not None here, we might want to + # emit a warning for unterminated string. + return lines_without_raw_strings + + +def FindNextMultiLineCommentStart(lines, lineix): + """Find the beginning marker for a multiline comment.""" + while lineix < len(lines): + if lines[lineix].strip().startswith('/*'): + # Only return this marker if the comment goes beyond this line + if lines[lineix].strip().find('*/', 2) < 0: + return lineix + lineix += 1 + return len(lines) + + +def FindNextMultiLineCommentEnd(lines, lineix): + """We are inside a comment, find the end marker.""" + while lineix < len(lines): + if lines[lineix].strip().endswith('*/'): + return lineix + lineix += 1 + return len(lines) + + +def RemoveMultiLineCommentsFromRange(lines, begin, end): + """Clears a range of lines for multi-line comments.""" + # Having // dummy comments makes the lines non-empty, so we will not get + # unnecessary blank line warnings later in the code. + for i in range(begin, end): + lines[i] = '/**/' + + +def RemoveMultiLineComments(filename, lines, error): + """Removes multiline (c-style) comments from lines.""" + lineix = 0 + while lineix < len(lines): + lineix_begin = FindNextMultiLineCommentStart(lines, lineix) + if lineix_begin >= len(lines): + return + lineix_end = FindNextMultiLineCommentEnd(lines, lineix_begin) + if lineix_end >= len(lines): + error(filename, lineix_begin + 1, 'readability/multiline_comment', 5, + 'Could not find end of multi-line comment') + return + RemoveMultiLineCommentsFromRange(lines, lineix_begin, lineix_end + 1) + lineix = lineix_end + 1 + + +def CleanseComments(line): + """Removes //-comments and single-line C-style /* */ comments. + + Args: + line: A line of C++ source. + + Returns: + The line with single-line comments removed. + """ + commentpos = line.find('//') + if commentpos != -1 and not IsCppString(line[:commentpos]): + line = line[:commentpos].rstrip() + # get rid of /* ... */ + return _RE_PATTERN_CLEANSE_LINE_C_COMMENTS.sub('', line) + + +class CleansedLines(object): + """Holds 4 copies of all lines with different preprocessing applied to them. + + 1) elided member contains lines without strings and comments. + 2) lines member contains lines without comments. + 3) raw_lines member contains all the lines without processing. + 4) lines_without_raw_strings member is same as raw_lines, but with C++11 raw + strings removed. + All these members are of , and of the same length. + """ + + def __init__(self, lines): + self.elided = [] + self.lines = [] + self.raw_lines = lines + self.num_lines = len(lines) + self.lines_without_raw_strings = CleanseRawStrings(lines) + for linenum in range(len(self.lines_without_raw_strings)): + self.lines.append(CleanseComments( + self.lines_without_raw_strings[linenum])) + elided = self._CollapseStrings(self.lines_without_raw_strings[linenum]) + self.elided.append(CleanseComments(elided)) + + def NumLines(self): + """Returns the number of lines represented.""" + return self.num_lines + + @staticmethod + def _CollapseStrings(elided): + """Collapses strings and chars on a line to simple "" or '' blocks. + + We nix strings first so we're not fooled by text like '"http://"' + + Args: + elided: The line being processed. + + Returns: + The line with collapsed strings. + """ + if _RE_PATTERN_INCLUDE.match(elided): + return elided + + # Remove escaped characters first to make quote/single quote collapsing + # basic. Things that look like escaped characters shouldn't occur + # outside of strings and chars. + elided = _RE_PATTERN_CLEANSE_LINE_ESCAPES.sub('', elided) + + # Replace quoted strings and digit separators. Both single quotes + # and double quotes are processed in the same loop, otherwise + # nested quotes wouldn't work. + collapsed = '' + while True: + # Find the first quote character + match = Match(r'^([^\'"]*)([\'"])(.*)$', elided) + if not match: + collapsed += elided + break + head, quote, tail = match.groups() + + if quote == '"': + # Collapse double quoted strings + second_quote = tail.find('"') + if second_quote >= 0: + collapsed += head + '""' + elided = tail[second_quote + 1:] + else: + # Unmatched double quote, don't bother processing the rest + # of the line since this is probably a multiline string. + collapsed += elided + break + else: + # Found single quote, check nearby text to eliminate digit separators. + # + # There is no special handling for floating point here, because + # the integer/fractional/exponent parts would all be parsed + # correctly as long as there are digits on both sides of the + # separator. So we are fine as long as we don't see something + # like "0.'3" (gcc 4.9.0 will not allow this literal). + if Search(r'\b(?:0[bBxX]?|[1-9])[0-9a-fA-F]*$', head): + match_literal = Match(r'^((?:\'?[0-9a-zA-Z_])*)(.*)$', "'" + tail) + collapsed += head + match_literal.group(1).replace("'", '') + elided = match_literal.group(2) + else: + second_quote = tail.find('\'') + if second_quote >= 0: + collapsed += head + "''" + elided = tail[second_quote + 1:] + else: + # Unmatched single quote + collapsed += elided + break + + return collapsed + + +def FindEndOfExpressionInLine(line, startpos, stack): + """Find the position just after the end of current parenthesized expression. + + Args: + line: a CleansedLines line. + startpos: start searching at this position. + stack: nesting stack at startpos. + + Returns: + On finding matching end: (index just after matching end, None) + On finding an unclosed expression: (-1, None) + Otherwise: (-1, new stack at end of this line) + """ + for i in xrange(startpos, len(line)): + char = line[i] + if char in '([{': + # Found start of parenthesized expression, push to expression stack + stack.append(char) + elif char == '<': + # Found potential start of template argument list + if i > 0 and line[i - 1] == '<': + # Left shift operator + if stack and stack[-1] == '<': + stack.pop() + if not stack: + return (-1, None) + elif i > 0 and Search(r'\boperator\s*$', line[0:i]): + # operator<, don't add to stack + continue + else: + # Tentative start of template argument list + stack.append('<') + elif char in ')]}': + # Found end of parenthesized expression. + # + # If we are currently expecting a matching '>', the pending '<' + # must have been an operator. Remove them from expression stack. + while stack and stack[-1] == '<': + stack.pop() + if not stack: + return (-1, None) + if ((stack[-1] == '(' and char == ')') or + (stack[-1] == '[' and char == ']') or + (stack[-1] == '{' and char == '}')): + stack.pop() + if not stack: + return (i + 1, None) + else: + # Mismatched parentheses + return (-1, None) + elif char == '>': + # Found potential end of template argument list. + + # Ignore "->" and operator functions + if (i > 0 and + (line[i - 1] == '-' or Search(r'\boperator\s*$', line[0:i - 1]))): + continue + + # Pop the stack if there is a matching '<'. Otherwise, ignore + # this '>' since it must be an operator. + if stack: + if stack[-1] == '<': + stack.pop() + if not stack: + return (i + 1, None) + elif char == ';': + # Found something that look like end of statements. If we are currently + # expecting a '>', the matching '<' must have been an operator, since + # template argument list should not contain statements. + while stack and stack[-1] == '<': + stack.pop() + if not stack: + return (-1, None) + + # Did not find end of expression or unbalanced parentheses on this line + return (-1, stack) + + +def CloseExpression(clean_lines, linenum, pos): + """If input points to ( or { or [ or <, finds the position that closes it. + + If lines[linenum][pos] points to a '(' or '{' or '[' or '<', finds the + linenum/pos that correspond to the closing of the expression. + + TODO(unknown): cpplint spends a fair bit of time matching parentheses. + Ideally we would want to index all opening and closing parentheses once + and have CloseExpression be just a simple lookup, but due to preprocessor + tricks, this is not so easy. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + pos: A position on the line. + + Returns: + A tuple (line, linenum, pos) pointer *past* the closing brace, or + (line, len(lines), -1) if we never find a close. Note we ignore + strings and comments when matching; and the line we return is the + 'cleansed' line at linenum. + """ + + line = clean_lines.elided[linenum] + if (line[pos] not in '({[<') or Match(r'<[<=]', line[pos:]): + return (line, clean_lines.NumLines(), -1) + + # Check first line + (end_pos, stack) = FindEndOfExpressionInLine(line, pos, []) + if end_pos > -1: + return (line, linenum, end_pos) + + # Continue scanning forward + while stack and linenum < clean_lines.NumLines() - 1: + linenum += 1 + line = clean_lines.elided[linenum] + (end_pos, stack) = FindEndOfExpressionInLine(line, 0, stack) + if end_pos > -1: + return (line, linenum, end_pos) + + # Did not find end of expression before end of file, give up + return (line, clean_lines.NumLines(), -1) + + +def FindStartOfExpressionInLine(line, endpos, stack): + """Find position at the matching start of current expression. + + This is almost the reverse of FindEndOfExpressionInLine, but note + that the input position and returned position differs by 1. + + Args: + line: a CleansedLines line. + endpos: start searching at this position. + stack: nesting stack at endpos. + + Returns: + On finding matching start: (index at matching start, None) + On finding an unclosed expression: (-1, None) + Otherwise: (-1, new stack at beginning of this line) + """ + i = endpos + while i >= 0: + char = line[i] + if char in ')]}': + # Found end of expression, push to expression stack + stack.append(char) + elif char == '>': + # Found potential end of template argument list. + # + # Ignore it if it's a "->" or ">=" or "operator>" + if (i > 0 and + (line[i - 1] == '-' or + Match(r'\s>=\s', line[i - 1:]) or + Search(r'\boperator\s*$', line[0:i]))): + i -= 1 + else: + stack.append('>') + elif char == '<': + # Found potential start of template argument list + if i > 0 and line[i - 1] == '<': + # Left shift operator + i -= 1 + else: + # If there is a matching '>', we can pop the expression stack. + # Otherwise, ignore this '<' since it must be an operator. + if stack and stack[-1] == '>': + stack.pop() + if not stack: + return (i, None) + elif char in '([{': + # Found start of expression. + # + # If there are any unmatched '>' on the stack, they must be + # operators. Remove those. + while stack and stack[-1] == '>': + stack.pop() + if not stack: + return (-1, None) + if ((char == '(' and stack[-1] == ')') or + (char == '[' and stack[-1] == ']') or + (char == '{' and stack[-1] == '}')): + stack.pop() + if not stack: + return (i, None) + else: + # Mismatched parentheses + return (-1, None) + elif char == ';': + # Found something that look like end of statements. If we are currently + # expecting a '<', the matching '>' must have been an operator, since + # template argument list should not contain statements. + while stack and stack[-1] == '>': + stack.pop() + if not stack: + return (-1, None) + + i -= 1 + + return (-1, stack) + + +def ReverseCloseExpression(clean_lines, linenum, pos): + """If input points to ) or } or ] or >, finds the position that opens it. + + If lines[linenum][pos] points to a ')' or '}' or ']' or '>', finds the + linenum/pos that correspond to the opening of the expression. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + pos: A position on the line. + + Returns: + A tuple (line, linenum, pos) pointer *at* the opening brace, or + (line, 0, -1) if we never find the matching opening brace. Note + we ignore strings and comments when matching; and the line we + return is the 'cleansed' line at linenum. + """ + line = clean_lines.elided[linenum] + if line[pos] not in ')}]>': + return (line, 0, -1) + + # Check last line + (start_pos, stack) = FindStartOfExpressionInLine(line, pos, []) + if start_pos > -1: + return (line, linenum, start_pos) + + # Continue scanning backward + while stack and linenum > 0: + linenum -= 1 + line = clean_lines.elided[linenum] + (start_pos, stack) = FindStartOfExpressionInLine(line, len(line) - 1, stack) + if start_pos > -1: + return (line, linenum, start_pos) + + # Did not find start of expression before beginning of file, give up + return (line, 0, -1) + + +def GetIndentLevel(line): + """Return the number of leading spaces in line. + + Args: + line: A string to check. + + Returns: + An integer count of leading spaces, possibly zero. + """ + indent = Match(r'^( *)\S', line) + if indent: + return len(indent.group(1)) + else: + return 0 + + +def GetHeaderGuardCPPVariable(filename): + """Returns the CPP variable that should be used as a header guard. + + Args: + filename: The name of a C++ header file. + + Returns: + The CPP variable that should be used as a header guard in the + named file. + + """ + + # Restores original filename in case that cpplint is invoked from Emacs's + # flymake. + filename = re.sub(r'_flymake\.h$', '.h', filename) + filename = re.sub(r'/\.flymake/([^/]*)$', r'/\1', filename) + # Replace 'c++' with 'cpp'. + filename = filename.replace('C++', 'cpp').replace('c++', 'cpp') + + fileinfo = FileInfo(filename) + file_path_from_root = fileinfo.RepositoryName() + if _root: + file_path_from_root = re.sub('^' + _root + os.sep, '', file_path_from_root) + return re.sub(r'[^a-zA-Z0-9]', '_', file_path_from_root).upper() + '_' + + +def CheckForHeaderGuard(filename, clean_lines, error): + """Checks that the file contains a header guard. + + Logs an error if no #ifndef header guard is present. For other + headers, checks that the full pathname is used. + + Args: + filename: The name of the C++ header file. + clean_lines: A CleansedLines instance containing the file. + error: The function to call with any errors found. + """ + + # Don't check for header guards if there are error suppression + # comments somewhere in this file. + # + # Because this is silencing a warning for a nonexistent line, we + # only support the very specific NOLINT(build/header_guard) syntax, + # and not the general NOLINT or NOLINT(*) syntax. + raw_lines = clean_lines.lines_without_raw_strings + for i in raw_lines: + if Search(r'//\s*NOLINT\(build/header_guard\)', i): + return + + cppvar = GetHeaderGuardCPPVariable(filename) + + ifndef = '' + ifndef_linenum = 0 + define = '' + endif = '' + endif_linenum = 0 + for linenum, line in enumerate(raw_lines): + linesplit = line.split() + if len(linesplit) >= 2: + # find the first occurrence of #ifndef and #define, save arg + if not ifndef and linesplit[0] == '#ifndef': + # set ifndef to the header guard presented on the #ifndef line. + ifndef = linesplit[1] + ifndef_linenum = linenum + if not define and linesplit[0] == '#define': + define = linesplit[1] + # find the last occurrence of #endif, save entire line + if line.startswith('#endif'): + endif = line + endif_linenum = linenum + + if not ifndef or not define or ifndef != define: + error(filename, 0, 'build/header_guard', 5, + 'No #ifndef header guard found, suggested CPP variable is: %s' % + cppvar) + return + + # The guard should be PATH_FILE_H_, but we also allow PATH_FILE_H__ + # for backward compatibility. + if ifndef != cppvar: + error_level = 0 + if ifndef != cppvar + '_': + error_level = 5 + + ParseNolintSuppressions(filename, raw_lines[ifndef_linenum], ifndef_linenum, + error) + error(filename, ifndef_linenum, 'build/header_guard', error_level, + '#ifndef header guard has wrong style, please use: %s' % cppvar) + + # Check for "//" comments on endif line. + ParseNolintSuppressions(filename, raw_lines[endif_linenum], endif_linenum, + error) + match = Match(r'#endif\s*//\s*' + cppvar + r'(_)?\b', endif) + if match: + if match.group(1) == '_': + # Issue low severity warning for deprecated double trailing underscore + error(filename, endif_linenum, 'build/header_guard', 0, + '#endif line should be "#endif // %s"' % cppvar) + return + + # Didn't find the corresponding "//" comment. If this file does not + # contain any "//" comments at all, it could be that the compiler + # only wants "/**/" comments, look for those instead. + no_single_line_comments = True + for i in xrange(1, len(raw_lines) - 1): + line = raw_lines[i] + if Match(r'^(?:(?:\'(?:\.|[^\'])*\')|(?:"(?:\.|[^"])*")|[^\'"])*//', line): + no_single_line_comments = False + break + + if no_single_line_comments: + match = Match(r'#endif\s*/\*\s*' + cppvar + r'(_)?\s*\*/', endif) + if match: + if match.group(1) == '_': + # Low severity warning for double trailing underscore + error(filename, endif_linenum, 'build/header_guard', 0, + '#endif line should be "#endif /* %s */"' % cppvar) + return + + # Didn't find anything + error(filename, endif_linenum, 'build/header_guard', 5, + '#endif line should be "#endif // %s"' % cppvar) + + +def CheckHeaderFileIncluded(filename, include_state, error): + """Logs an error if a .cc file does not include its header.""" + + # Do not check test files + if filename.endswith('_test.cc') or filename.endswith('_unittest.cc'): + return + + fileinfo = FileInfo(filename) + headerfile = filename[0:len(filename) - 2] + 'h' + if not os.path.exists(headerfile): + return + headername = FileInfo(headerfile).RepositoryName() + first_include = 0 + for section_list in include_state.include_list: + for f in section_list: + if headername in f[0] or f[0] in headername: + return + if not first_include: + first_include = f[1] + + error(filename, first_include, 'build/include', 5, + '%s should include its header file %s' % (fileinfo.RepositoryName(), + headername)) + + +def CheckForBadCharacters(filename, lines, error): + """Logs an error for each line containing bad characters. + + Two kinds of bad characters: + + 1. Unicode replacement characters: These indicate that either the file + contained invalid UTF-8 (likely) or Unicode replacement characters (which + it shouldn't). Note that it's possible for this to throw off line + numbering if the invalid UTF-8 occurred adjacent to a newline. + + 2. NUL bytes. These are problematic for some tools. + + Args: + filename: The name of the current file. + lines: An array of strings, each representing a line of the file. + error: The function to call with any errors found. + """ + for linenum, line in enumerate(lines): + if u'\ufffd' in line: + error(filename, linenum, 'readability/utf8', 5, + 'Line contains invalid UTF-8 (or Unicode replacement character).') + if '\0' in line: + error(filename, linenum, 'readability/nul', 5, 'Line contains NUL byte.') + + +def CheckForNewlineAtEOF(filename, lines, error): + """Logs an error if there is no newline char at the end of the file. + + Args: + filename: The name of the current file. + lines: An array of strings, each representing a line of the file. + error: The function to call with any errors found. + """ + + # The array lines() was created by adding two newlines to the + # original file (go figure), then splitting on \n. + # To verify that the file ends in \n, we just have to make sure the + # last-but-two element of lines() exists and is empty. + if len(lines) < 3 or lines[-2]: + error(filename, len(lines) - 2, 'whitespace/ending_newline', 5, + 'Could not find a newline character at the end of the file.') + + +def CheckForMultilineCommentsAndStrings(filename, clean_lines, linenum, error): + """Logs an error if we see /* ... */ or "..." that extend past one line. + + /* ... */ comments are legit inside macros, for one line. + Otherwise, we prefer // comments, so it's ok to warn about the + other. Likewise, it's ok for strings to extend across multiple + lines, as long as a line continuation character (backslash) + terminates each line. Although not currently prohibited by the C++ + style guide, it's ugly and unnecessary. We don't do well with either + in this lint program, so we warn about both. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Remove all \\ (escaped backslashes) from the line. They are OK, and the + # second (escaped) slash may trigger later \" detection erroneously. + line = line.replace('\\\\', '') + + if line.count('/*') > line.count('*/'): + error(filename, linenum, 'readability/multiline_comment', 5, + 'Complex multi-line /*...*/-style comment found. ' + 'Lint may give bogus warnings. ' + 'Consider replacing these with //-style comments, ' + 'with #if 0...#endif, ' + 'or with more clearly structured multi-line comments.') + + if (line.count('"') - line.count('\\"')) % 2: + error(filename, linenum, 'readability/multiline_string', 5, + 'Multi-line string ("...") found. This lint script doesn\'t ' + 'do well with such strings, and may give bogus warnings. ' + 'Use C++11 raw strings or concatenation instead.') + + +# (non-threadsafe name, thread-safe alternative, validation pattern) +# +# The validation pattern is used to eliminate false positives such as: +# _rand(); // false positive due to substring match. +# ->rand(); // some member function rand(). +# ACMRandom rand(seed); // some variable named rand. +# ISAACRandom rand(); // another variable named rand. +# +# Basically we require the return value of these functions to be used +# in some expression context on the same line by matching on some +# operator before the function name. This eliminates constructors and +# member function calls. +_UNSAFE_FUNC_PREFIX = r'(?:[-+*/=%^&|(<]\s*|>\s+)' +_THREADING_LIST = ( + ('asctime(', 'asctime_r(', _UNSAFE_FUNC_PREFIX + r'asctime\([^)]+\)'), + ('ctime(', 'ctime_r(', _UNSAFE_FUNC_PREFIX + r'ctime\([^)]+\)'), + ('getgrgid(', 'getgrgid_r(', _UNSAFE_FUNC_PREFIX + r'getgrgid\([^)]+\)'), + ('getgrnam(', 'getgrnam_r(', _UNSAFE_FUNC_PREFIX + r'getgrnam\([^)]+\)'), + ('getlogin(', 'getlogin_r(', _UNSAFE_FUNC_PREFIX + r'getlogin\(\)'), + ('getpwnam(', 'getpwnam_r(', _UNSAFE_FUNC_PREFIX + r'getpwnam\([^)]+\)'), + ('getpwuid(', 'getpwuid_r(', _UNSAFE_FUNC_PREFIX + r'getpwuid\([^)]+\)'), + ('gmtime(', 'gmtime_r(', _UNSAFE_FUNC_PREFIX + r'gmtime\([^)]+\)'), + ('localtime(', 'localtime_r(', _UNSAFE_FUNC_PREFIX + r'localtime\([^)]+\)'), + ('rand(', 'rand_r(', _UNSAFE_FUNC_PREFIX + r'rand\(\)'), + ('strtok(', 'strtok_r(', + _UNSAFE_FUNC_PREFIX + r'strtok\([^)]+\)'), + ('ttyname(', 'ttyname_r(', _UNSAFE_FUNC_PREFIX + r'ttyname\([^)]+\)'), + ) + + +def CheckPosixThreading(filename, clean_lines, linenum, error): + """Checks for calls to thread-unsafe functions. + + Much code has been originally written without consideration of + multi-threading. Also, engineers are relying on their old experience; + they have learned posix before threading extensions were added. These + tests guide the engineers to use thread-safe functions (when using + posix directly). + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + for single_thread_func, multithread_safe_func, pattern in _THREADING_LIST: + # Additional pattern matching check to confirm that this is the + # function we are looking for + if Search(pattern, line): + error(filename, linenum, 'runtime/threadsafe_fn', 2, + 'Consider using ' + multithread_safe_func + + '...) instead of ' + single_thread_func + + '...) for improved thread safety.') + + +def CheckVlogArguments(filename, clean_lines, linenum, error): + """Checks that VLOG() is only used for defining a logging level. + + For example, VLOG(2) is correct. VLOG(INFO), VLOG(WARNING), VLOG(ERROR), and + VLOG(FATAL) are not. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + if Search(r'\bVLOG\((INFO|ERROR|WARNING|DFATAL|FATAL)\)', line): + error(filename, linenum, 'runtime/vlog', 5, + 'VLOG() should be used with numeric verbosity level. ' + 'Use LOG() if you want symbolic severity levels.') + +# Matches invalid increment: *count++, which moves pointer instead of +# incrementing a value. +_RE_PATTERN_INVALID_INCREMENT = re.compile( + r'^\s*\*\w+(\+\+|--);') + + +def CheckInvalidIncrement(filename, clean_lines, linenum, error): + """Checks for invalid increment *count++. + + For example following function: + void increment_counter(int* count) { + *count++; + } + is invalid, because it effectively does count++, moving pointer, and should + be replaced with ++*count, (*count)++ or *count += 1. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + if _RE_PATTERN_INVALID_INCREMENT.match(line): + error(filename, linenum, 'runtime/invalid_increment', 5, + 'Changing pointer instead of value (or unused value of operator*).') + + +def IsMacroDefinition(clean_lines, linenum): + if Search(r'^#define', clean_lines[linenum]): + return True + + if linenum > 0 and Search(r'\\$', clean_lines[linenum - 1]): + return True + + return False + + +def IsForwardClassDeclaration(clean_lines, linenum): + return Match(r'^\s*(\btemplate\b)*.*class\s+\w+;\s*$', clean_lines[linenum]) + + +class _BlockInfo(object): + """Stores information about a generic block of code.""" + + def __init__(self, seen_open_brace): + self.seen_open_brace = seen_open_brace + self.open_parentheses = 0 + self.inline_asm = _NO_ASM + self.check_namespace_indentation = False + + def CheckBegin(self, filename, clean_lines, linenum, error): + """Run checks that applies to text up to the opening brace. + + This is mostly for checking the text after the class identifier + and the "{", usually where the base class is specified. For other + blocks, there isn't much to check, so we always pass. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + pass + + def CheckEnd(self, filename, clean_lines, linenum, error): + """Run checks that applies to text after the closing brace. + + This is mostly used for checking end of namespace comments. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + pass + + def IsBlockInfo(self): + """Returns true if this block is a _BlockInfo. + + This is convenient for verifying that an object is an instance of + a _BlockInfo, but not an instance of any of the derived classes. + + Returns: + True for this class, False for derived classes. + """ + return self.__class__ == _BlockInfo + + +class _ExternCInfo(_BlockInfo): + """Stores information about an 'extern "C"' block.""" + + def __init__(self): + _BlockInfo.__init__(self, True) + + +class _ClassInfo(_BlockInfo): + """Stores information about a class.""" + + def __init__(self, name, class_or_struct, clean_lines, linenum): + _BlockInfo.__init__(self, False) + self.name = name + self.starting_linenum = linenum + self.is_derived = False + self.check_namespace_indentation = True + if class_or_struct == 'struct': + self.access = 'public' + self.is_struct = True + else: + self.access = 'private' + self.is_struct = False + + # Remember initial indentation level for this class. Using raw_lines here + # instead of elided to account for leading comments. + self.class_indent = GetIndentLevel(clean_lines.raw_lines[linenum]) + + # Try to find the end of the class. This will be confused by things like: + # class A { + # } *x = { ... + # + # But it's still good enough for CheckSectionSpacing. + self.last_line = 0 + depth = 0 + for i in range(linenum, clean_lines.NumLines()): + line = clean_lines.elided[i] + depth += line.count('{') - line.count('}') + if not depth: + self.last_line = i + break + + def CheckBegin(self, filename, clean_lines, linenum, error): + # Look for a bare ':' + if Search('(^|[^:]):($|[^:])', clean_lines.elided[linenum]): + self.is_derived = True + + def CheckEnd(self, filename, clean_lines, linenum, error): + # If there is a DISALLOW macro, it should appear near the end of + # the class. + seen_last_thing_in_class = False + for i in xrange(linenum - 1, self.starting_linenum, -1): + match = Search( + r'\b(DISALLOW_COPY_AND_ASSIGN|DISALLOW_IMPLICIT_CONSTRUCTORS)\(' + + self.name + r'\)', + clean_lines.elided[i]) + if match: + if seen_last_thing_in_class: + error(filename, i, 'readability/constructors', 3, + match.group(1) + ' should be the last thing in the class') + break + + if not Match(r'^\s*$', clean_lines.elided[i]): + seen_last_thing_in_class = True + + # Check that closing brace is aligned with beginning of the class. + # Only do this if the closing brace is indented by only whitespaces. + # This means we will not check single-line class definitions. + indent = Match(r'^( *)\}', clean_lines.elided[linenum]) + if indent and len(indent.group(1)) != self.class_indent: + if self.is_struct: + parent = 'struct ' + self.name + else: + parent = 'class ' + self.name + error(filename, linenum, 'whitespace/indent', 3, + 'Closing brace should be aligned with beginning of %s' % parent) + + +class _NamespaceInfo(_BlockInfo): + """Stores information about a namespace.""" + + def __init__(self, name, linenum): + _BlockInfo.__init__(self, False) + self.name = name or '' + self.starting_linenum = linenum + self.check_namespace_indentation = True + + def CheckEnd(self, filename, clean_lines, linenum, error): + """Check end of namespace comments.""" + line = clean_lines.raw_lines[linenum] + + # Check how many lines is enclosed in this namespace. Don't issue + # warning for missing namespace comments if there aren't enough + # lines. However, do apply checks if there is already an end of + # namespace comment and it's incorrect. + # + # TODO(unknown): We always want to check end of namespace comments + # if a namespace is large, but sometimes we also want to apply the + # check if a short namespace contained nontrivial things (something + # other than forward declarations). There is currently no logic on + # deciding what these nontrivial things are, so this check is + # triggered by namespace size only, which works most of the time. + if (linenum - self.starting_linenum < 10 + and not Match(r'};*\s*(//|/\*).*\bnamespace\b', line)): + return + + # Look for matching comment at end of namespace. + # + # Note that we accept C style "/* */" comments for terminating + # namespaces, so that code that terminate namespaces inside + # preprocessor macros can be cpplint clean. + # + # We also accept stuff like "// end of namespace ." with the + # period at the end. + # + # Besides these, we don't accept anything else, otherwise we might + # get false negatives when existing comment is a substring of the + # expected namespace. + if self.name: + # Named namespace + if not Match((r'};*\s*(//|/\*).*\bnamespace\s+' + re.escape(self.name) + + r'[\*/\.\\\s]*$'), + line): + error(filename, linenum, 'readability/namespace', 5, + 'Namespace should be terminated with "// namespace %s"' % + self.name) + else: + # Anonymous namespace + if not Match(r'};*\s*(//|/\*).*\bnamespace[\*/\.\\\s]*$', line): + # If "// namespace anonymous" or "// anonymous namespace (more text)", + # mention "// anonymous namespace" as an acceptable form + if Match(r'}.*\b(namespace anonymous|anonymous namespace)\b', line): + error(filename, linenum, 'readability/namespace', 5, + 'Anonymous namespace should be terminated with "// namespace"' + ' or "// anonymous namespace"') + else: + error(filename, linenum, 'readability/namespace', 5, + 'Anonymous namespace should be terminated with "// namespace"') + + +class _PreprocessorInfo(object): + """Stores checkpoints of nesting stacks when #if/#else is seen.""" + + def __init__(self, stack_before_if): + # The entire nesting stack before #if + self.stack_before_if = stack_before_if + + # The entire nesting stack up to #else + self.stack_before_else = [] + + # Whether we have already seen #else or #elif + self.seen_else = False + + +class NestingState(object): + """Holds states related to parsing braces.""" + + def __init__(self): + # Stack for tracking all braces. An object is pushed whenever we + # see a "{", and popped when we see a "}". Only 3 types of + # objects are possible: + # - _ClassInfo: a class or struct. + # - _NamespaceInfo: a namespace. + # - _BlockInfo: some other type of block. + self.stack = [] + + # Top of the previous stack before each Update(). + # + # Because the nesting_stack is updated at the end of each line, we + # had to do some convoluted checks to find out what is the current + # scope at the beginning of the line. This check is simplified by + # saving the previous top of nesting stack. + # + # We could save the full stack, but we only need the top. Copying + # the full nesting stack would slow down cpplint by ~10%. + self.previous_stack_top = [] + + # Stack of _PreprocessorInfo objects. + self.pp_stack = [] + + def SeenOpenBrace(self): + """Check if we have seen the opening brace for the innermost block. + + Returns: + True if we have seen the opening brace, False if the innermost + block is still expecting an opening brace. + """ + return (not self.stack) or self.stack[-1].seen_open_brace + + def InNamespaceBody(self): + """Check if we are currently one level inside a namespace body. + + Returns: + True if top of the stack is a namespace block, False otherwise. + """ + return self.stack and isinstance(self.stack[-1], _NamespaceInfo) + + def InExternC(self): + """Check if we are currently one level inside an 'extern "C"' block. + + Returns: + True if top of the stack is an extern block, False otherwise. + """ + return self.stack and isinstance(self.stack[-1], _ExternCInfo) + + def InClassDeclaration(self): + """Check if we are currently one level inside a class or struct declaration. + + Returns: + True if top of the stack is a class/struct, False otherwise. + """ + return self.stack and isinstance(self.stack[-1], _ClassInfo) + + def InAsmBlock(self): + """Check if we are currently one level inside an inline ASM block. + + Returns: + True if the top of the stack is a block containing inline ASM. + """ + return self.stack and self.stack[-1].inline_asm != _NO_ASM + + def InTemplateArgumentList(self, clean_lines, linenum, pos): + """Check if current position is inside template argument list. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + pos: position just after the suspected template argument. + Returns: + True if (linenum, pos) is inside template arguments. + """ + while linenum < clean_lines.NumLines(): + # Find the earliest character that might indicate a template argument + line = clean_lines.elided[linenum] + match = Match(r'^[^{};=\[\]\.<>]*(.)', line[pos:]) + if not match: + linenum += 1 + pos = 0 + continue + token = match.group(1) + pos += len(match.group(0)) + + # These things do not look like template argument list: + # class Suspect { + # class Suspect x; } + if token in ('{', '}', ';'): return False + + # These things look like template argument list: + # template + # template + # template + # template + if token in ('>', '=', '[', ']', '.'): return True + + # Check if token is an unmatched '<'. + # If not, move on to the next character. + if token != '<': + pos += 1 + if pos >= len(line): + linenum += 1 + pos = 0 + continue + + # We can't be sure if we just find a single '<', and need to + # find the matching '>'. + (_, end_line, end_pos) = CloseExpression(clean_lines, linenum, pos - 1) + if end_pos < 0: + # Not sure if template argument list or syntax error in file + return False + linenum = end_line + pos = end_pos + return False + + def UpdatePreprocessor(self, line): + """Update preprocessor stack. + + We need to handle preprocessors due to classes like this: + #ifdef SWIG + struct ResultDetailsPageElementExtensionPoint { + #else + struct ResultDetailsPageElementExtensionPoint : public Extension { + #endif + + We make the following assumptions (good enough for most files): + - Preprocessor condition evaluates to true from #if up to first + #else/#elif/#endif. + + - Preprocessor condition evaluates to false from #else/#elif up + to #endif. We still perform lint checks on these lines, but + these do not affect nesting stack. + + Args: + line: current line to check. + """ + if Match(r'^\s*#\s*(if|ifdef|ifndef)\b', line): + # Beginning of #if block, save the nesting stack here. The saved + # stack will allow us to restore the parsing state in the #else case. + self.pp_stack.append(_PreprocessorInfo(copy.deepcopy(self.stack))) + elif Match(r'^\s*#\s*(else|elif)\b', line): + # Beginning of #else block + if self.pp_stack: + if not self.pp_stack[-1].seen_else: + # This is the first #else or #elif block. Remember the + # whole nesting stack up to this point. This is what we + # keep after the #endif. + self.pp_stack[-1].seen_else = True + self.pp_stack[-1].stack_before_else = copy.deepcopy(self.stack) + + # Restore the stack to how it was before the #if + self.stack = copy.deepcopy(self.pp_stack[-1].stack_before_if) + else: + # TODO(unknown): unexpected #else, issue warning? + pass + elif Match(r'^\s*#\s*endif\b', line): + # End of #if or #else blocks. + if self.pp_stack: + # If we saw an #else, we will need to restore the nesting + # stack to its former state before the #else, otherwise we + # will just continue from where we left off. + if self.pp_stack[-1].seen_else: + # Here we can just use a shallow copy since we are the last + # reference to it. + self.stack = self.pp_stack[-1].stack_before_else + # Drop the corresponding #if + self.pp_stack.pop() + else: + # TODO(unknown): unexpected #endif, issue warning? + pass + + # TODO(unknown): Update() is too long, but we will refactor later. + def Update(self, filename, clean_lines, linenum, error): + """Update nesting state with current line. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Remember top of the previous nesting stack. + # + # The stack is always pushed/popped and not modified in place, so + # we can just do a shallow copy instead of copy.deepcopy. Using + # deepcopy would slow down cpplint by ~28%. + if self.stack: + self.previous_stack_top = self.stack[-1] + else: + self.previous_stack_top = None + + # Update pp_stack + self.UpdatePreprocessor(line) + + # Count parentheses. This is to avoid adding struct arguments to + # the nesting stack. + if self.stack: + inner_block = self.stack[-1] + depth_change = line.count('(') - line.count(')') + inner_block.open_parentheses += depth_change + + # Also check if we are starting or ending an inline assembly block. + if inner_block.inline_asm in (_NO_ASM, _END_ASM): + if (depth_change != 0 and + inner_block.open_parentheses == 1 and + _MATCH_ASM.match(line)): + # Enter assembly block + inner_block.inline_asm = _INSIDE_ASM + else: + # Not entering assembly block. If previous line was _END_ASM, + # we will now shift to _NO_ASM state. + inner_block.inline_asm = _NO_ASM + elif (inner_block.inline_asm == _INSIDE_ASM and + inner_block.open_parentheses == 0): + # Exit assembly block + inner_block.inline_asm = _END_ASM + + # Consume namespace declaration at the beginning of the line. Do + # this in a loop so that we catch same line declarations like this: + # namespace proto2 { namespace bridge { class MessageSet; } } + while True: + # Match start of namespace. The "\b\s*" below catches namespace + # declarations even if it weren't followed by a whitespace, this + # is so that we don't confuse our namespace checker. The + # missing spaces will be flagged by CheckSpacing. + namespace_decl_match = Match(r'^\s*namespace\b\s*([:\w]+)?(.*)$', line) + if not namespace_decl_match: + break + + new_namespace = _NamespaceInfo(namespace_decl_match.group(1), linenum) + self.stack.append(new_namespace) + + line = namespace_decl_match.group(2) + if line.find('{') != -1: + new_namespace.seen_open_brace = True + line = line[line.find('{') + 1:] + + # Look for a class declaration in whatever is left of the line + # after parsing namespaces. The regexp accounts for decorated classes + # such as in: + # class LOCKABLE API Object { + # }; + class_decl_match = Match( + r'^(\s*(?:template\s*<[\w\s<>,:]*>\s*)?' + r'(class|struct)\s+(?:[A-Z_]+\s+)*(\w+(?:::\w+)*))' + r'(.*)$', line) + if (class_decl_match and + (not self.stack or self.stack[-1].open_parentheses == 0)): + # We do not want to accept classes that are actually template arguments: + # template , + # template class Ignore3> + # void Function() {}; + # + # To avoid template argument cases, we scan forward and look for + # an unmatched '>'. If we see one, assume we are inside a + # template argument list. + end_declaration = len(class_decl_match.group(1)) + if not self.InTemplateArgumentList(clean_lines, linenum, end_declaration): + self.stack.append(_ClassInfo( + class_decl_match.group(3), class_decl_match.group(2), + clean_lines, linenum)) + line = class_decl_match.group(4) + + # If we have not yet seen the opening brace for the innermost block, + # run checks here. + if not self.SeenOpenBrace(): + self.stack[-1].CheckBegin(filename, clean_lines, linenum, error) + + # Update access control if we are inside a class/struct + if self.stack and isinstance(self.stack[-1], _ClassInfo): + classinfo = self.stack[-1] + access_match = Match( + r'^(.*)\b(public|private|protected|signals)(\s+(?:slots\s*)?)?' + r':(?:[^:]|$)', + line) + if access_match: + classinfo.access = access_match.group(2) + + # Check that access keywords are indented +1 space. Skip this + # check if the keywords are not preceded by whitespaces. + indent = access_match.group(1) + if (len(indent) != classinfo.class_indent + 1 and + Match(r'^\s*$', indent)): + if classinfo.is_struct: + parent = 'struct ' + classinfo.name + else: + parent = 'class ' + classinfo.name + slots = '' + if access_match.group(3): + slots = access_match.group(3) + error(filename, linenum, 'whitespace/indent', 3, + '%s%s: should be indented +1 space inside %s' % ( + access_match.group(2), slots, parent)) + + # Consume braces or semicolons from what's left of the line + while True: + # Match first brace, semicolon, or closed parenthesis. + matched = Match(r'^[^{;)}]*([{;)}])(.*)$', line) + if not matched: + break + + token = matched.group(1) + if token == '{': + # If namespace or class hasn't seen a opening brace yet, mark + # namespace/class head as complete. Push a new block onto the + # stack otherwise. + if not self.SeenOpenBrace(): + self.stack[-1].seen_open_brace = True + elif Match(r'^extern\s*"[^"]*"\s*\{', line): + self.stack.append(_ExternCInfo()) + else: + self.stack.append(_BlockInfo(True)) + if _MATCH_ASM.match(line): + self.stack[-1].inline_asm = _BLOCK_ASM + + elif token == ';' or token == ')': + # If we haven't seen an opening brace yet, but we already saw + # a semicolon, this is probably a forward declaration. Pop + # the stack for these. + # + # Similarly, if we haven't seen an opening brace yet, but we + # already saw a closing parenthesis, then these are probably + # function arguments with extra "class" or "struct" keywords. + # Also pop these stack for these. + if not self.SeenOpenBrace(): + self.stack.pop() + else: # token == '}' + # Perform end of block checks and pop the stack. + if self.stack: + self.stack[-1].CheckEnd(filename, clean_lines, linenum, error) + self.stack.pop() + line = matched.group(2) + + def InnermostClass(self): + """Get class info on the top of the stack. + + Returns: + A _ClassInfo object if we are inside a class, or None otherwise. + """ + for i in range(len(self.stack), 0, -1): + classinfo = self.stack[i - 1] + if isinstance(classinfo, _ClassInfo): + return classinfo + return None + + def CheckCompletedBlocks(self, filename, error): + """Checks that all classes and namespaces have been completely parsed. + + Call this when all lines in a file have been processed. + Args: + filename: The name of the current file. + error: The function to call with any errors found. + """ + # Note: This test can result in false positives if #ifdef constructs + # get in the way of brace matching. See the testBuildClass test in + # cpplint_unittest.py for an example of this. + for obj in self.stack: + if isinstance(obj, _ClassInfo): + error(filename, obj.starting_linenum, 'build/class', 5, + 'Failed to find complete declaration of class %s' % + obj.name) + elif isinstance(obj, _NamespaceInfo): + error(filename, obj.starting_linenum, 'build/namespaces', 5, + 'Failed to find complete declaration of namespace %s' % + obj.name) + + +def CheckForNonStandardConstructs(filename, clean_lines, linenum, + nesting_state, error): + r"""Logs an error if we see certain non-ANSI constructs ignored by gcc-2. + + Complain about several constructs which gcc-2 accepts, but which are + not standard C++. Warning about these in lint is one way to ease the + transition to new compilers. + - put storage class first (e.g. "static const" instead of "const static"). + - "%lld" instead of %qd" in printf-type functions. + - "%1$d" is non-standard in printf-type functions. + - "\%" is an undefined character escape sequence. + - text after #endif is not allowed. + - invalid inner-style forward declaration. + - >? and ?= and )\?=?\s*(\w+|[+-]?\d+)(\.\d*)?', + line): + error(filename, linenum, 'build/deprecated', 3, + '>? and ))?' + # r'\s*const\s*' + type_name + '\s*&\s*\w+\s*;' + error(filename, linenum, 'runtime/member_string_references', 2, + 'const string& members are dangerous. It is much better to use ' + 'alternatives, such as pointers or simple constants.') + + # Everything else in this function operates on class declarations. + # Return early if the top of the nesting stack is not a class, or if + # the class head is not completed yet. + classinfo = nesting_state.InnermostClass() + if not classinfo or not classinfo.seen_open_brace: + return + + # The class may have been declared with namespace or classname qualifiers. + # The constructor and destructor will not have those qualifiers. + base_classname = classinfo.name.split('::')[-1] + + # Look for single-argument constructors that aren't marked explicit. + # Technically a valid construct, but against style. Also look for + # non-single-argument constructors which are also technically valid, but + # strongly suggest something is wrong. + explicit_constructor_match = Match( + r'\s+(?:inline\s+)?(explicit\s+)?(?:inline\s+)?%s\s*' + r'\(((?:[^()]|\([^()]*\))*)\)' + % re.escape(base_classname), + line) + + if explicit_constructor_match: + is_marked_explicit = explicit_constructor_match.group(1) + + if not explicit_constructor_match.group(2): + constructor_args = [] + else: + constructor_args = explicit_constructor_match.group(2).split(',') + + # collapse arguments so that commas in template parameter lists and function + # argument parameter lists don't split arguments in two + i = 0 + while i < len(constructor_args): + constructor_arg = constructor_args[i] + while (constructor_arg.count('<') > constructor_arg.count('>') or + constructor_arg.count('(') > constructor_arg.count(')')): + constructor_arg += ',' + constructor_args[i + 1] + del constructor_args[i + 1] + constructor_args[i] = constructor_arg + i += 1 + + defaulted_args = [arg for arg in constructor_args if '=' in arg] + noarg_constructor = (not constructor_args or # empty arg list + # 'void' arg specifier + (len(constructor_args) == 1 and + constructor_args[0].strip() == 'void')) + onearg_constructor = ((len(constructor_args) == 1 and # exactly one arg + not noarg_constructor) or + # all but at most one arg defaulted + (len(constructor_args) >= 1 and + not noarg_constructor and + len(defaulted_args) >= len(constructor_args) - 1)) + initializer_list_constructor = bool( + onearg_constructor and + Search(r'\bstd\s*::\s*initializer_list\b', constructor_args[0])) + copy_constructor = bool( + onearg_constructor and + Match(r'(const\s+)?%s(\s*<[^>]*>)?(\s+const)?\s*(?:<\w+>\s*)?&' + % re.escape(base_classname), constructor_args[0].strip())) + + if (not is_marked_explicit and + onearg_constructor and + not initializer_list_constructor and + not copy_constructor): + if defaulted_args: + error(filename, linenum, 'runtime/explicit', 5, + 'Constructors callable with one argument ' + 'should be marked explicit.') + else: + error(filename, linenum, 'runtime/explicit', 5, + 'Single-parameter constructors should be marked explicit.') + elif is_marked_explicit and not onearg_constructor: + if noarg_constructor: + error(filename, linenum, 'runtime/explicit', 5, + 'Zero-parameter constructors should not be marked explicit.') + else: + error(filename, linenum, 'runtime/explicit', 0, + 'Constructors that require multiple arguments ' + 'should not be marked explicit.') + + +def CheckSpacingForFunctionCall(filename, clean_lines, linenum, error): + """Checks for the correctness of various spacing around function calls. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Since function calls often occur inside if/for/while/switch + # expressions - which have their own, more liberal conventions - we + # first see if we should be looking inside such an expression for a + # function call, to which we can apply more strict standards. + fncall = line # if there's no control flow construct, look at whole line + for pattern in (r'\bif\s*\((.*)\)\s*{', + r'\bfor\s*\((.*)\)\s*{', + r'\bwhile\s*\((.*)\)\s*[{;]', + r'\bswitch\s*\((.*)\)\s*{'): + match = Search(pattern, line) + if match: + fncall = match.group(1) # look inside the parens for function calls + break + + # Except in if/for/while/switch, there should never be space + # immediately inside parens (eg "f( 3, 4 )"). We make an exception + # for nested parens ( (a+b) + c ). Likewise, there should never be + # a space before a ( when it's a function argument. I assume it's a + # function argument when the char before the whitespace is legal in + # a function name (alnum + _) and we're not starting a macro. Also ignore + # pointers and references to arrays and functions coz they're too tricky: + # we use a very simple way to recognize these: + # " (something)(maybe-something)" or + # " (something)(maybe-something," or + # " (something)[something]" + # Note that we assume the contents of [] to be short enough that + # they'll never need to wrap. + if ( # Ignore control structures. + not Search(r'\b(if|for|while|switch|return|new|delete|catch|sizeof)\b', + fncall) and + # Ignore pointers/references to functions. + not Search(r' \([^)]+\)\([^)]*(\)|,$)', fncall) and + # Ignore pointers/references to arrays. + not Search(r' \([^)]+\)\[[^\]]+\]', fncall)): + if Search(r'\w\s*\(\s(?!\s*\\$)', fncall): # a ( used for a fn call + error(filename, linenum, 'whitespace/parens', 4, + 'Extra space after ( in function call') + elif Search(r'\(\s+(?!(\s*\\)|\()', fncall): + error(filename, linenum, 'whitespace/parens', 2, + 'Extra space after (') + if (Search(r'\w\s+\(', fncall) and + not Search(r'#\s*define|typedef|using\s+\w+\s*=', fncall) and + not Search(r'\w\s+\((\w+::)*\*\w+\)\(', fncall) and + not Search(r'\bcase\s+\(', fncall)): + # TODO(unknown): Space after an operator function seem to be a common + # error, silence those for now by restricting them to highest verbosity. + if Search(r'\boperator_*\b', line): + error(filename, linenum, 'whitespace/parens', 0, + 'Extra space before ( in function call') + else: + error(filename, linenum, 'whitespace/parens', 4, + 'Extra space before ( in function call') + # If the ) is followed only by a newline or a { + newline, assume it's + # part of a control statement (if/while/etc), and don't complain + if Search(r'[^)]\s+\)\s*[^{\s]', fncall): + # If the closing parenthesis is preceded by only whitespaces, + # try to give a more descriptive error message. + if Search(r'^\s+\)', fncall): + error(filename, linenum, 'whitespace/parens', 2, + 'Closing ) should be moved to the previous line') + else: + error(filename, linenum, 'whitespace/parens', 2, + 'Extra space before )') + + +def IsBlankLine(line): + """Returns true if the given line is blank. + + We consider a line to be blank if the line is empty or consists of + only white spaces. + + Args: + line: A line of a string. + + Returns: + True, if the given line is blank. + """ + return not line or line.isspace() + + +def CheckForNamespaceIndentation(filename, nesting_state, clean_lines, line, + error): + is_namespace_indent_item = ( + len(nesting_state.stack) > 1 and + nesting_state.stack[-1].check_namespace_indentation and + isinstance(nesting_state.previous_stack_top, _NamespaceInfo) and + nesting_state.previous_stack_top == nesting_state.stack[-2]) + + if ShouldCheckNamespaceIndentation(nesting_state, is_namespace_indent_item, + clean_lines.elided, line): + CheckItemIndentationInNamespace(filename, clean_lines.elided, + line, error) + + +def CheckForFunctionLengths(filename, clean_lines, linenum, + function_state, error): + """Reports for long function bodies. + + For an overview why this is done, see: + http://google-styleguide.googlecode.com/svn/trunk/cppguide.xml#Write_Short_Functions + + Uses a simplistic algorithm assuming other style guidelines + (especially spacing) are followed. + Only checks unindented functions, so class members are unchecked. + Trivial bodies are unchecked, so constructors with huge initializer lists + may be missed. + Blank/comment lines are not counted so as to avoid encouraging the removal + of vertical space and comments just to get through a lint check. + NOLINT *on the last line of a function* disables this check. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + function_state: Current function name and lines in body so far. + error: The function to call with any errors found. + """ + lines = clean_lines.lines + line = lines[linenum] + joined_line = '' + + starting_func = False + regexp = r'(\w(\w|::|\*|\&|\s)*)\(' # decls * & space::name( ... + match_result = Match(regexp, line) + if match_result: + # If the name is all caps and underscores, figure it's a macro and + # ignore it, unless it's TEST or TEST_F. + function_name = match_result.group(1).split()[-1] + if function_name == 'TEST' or function_name == 'TEST_F' or ( + not Match(r'[A-Z_]+$', function_name)): + starting_func = True + + if starting_func: + body_found = False + for start_linenum in xrange(linenum, clean_lines.NumLines()): + start_line = lines[start_linenum] + joined_line += ' ' + start_line.lstrip() + if Search(r'(;|})', start_line): # Declarations and trivial functions + body_found = True + break # ... ignore + elif Search(r'{', start_line): + body_found = True + function = Search(r'((\w|:)*)\(', line).group(1) + if Match(r'TEST', function): # Handle TEST... macros + parameter_regexp = Search(r'(\(.*\))', joined_line) + if parameter_regexp: # Ignore bad syntax + function += parameter_regexp.group(1) + else: + function += '()' + function_state.Begin(function) + break + if not body_found: + # No body for the function (or evidence of a non-function) was found. + error(filename, linenum, 'readability/fn_size', 5, + 'Lint failed to find start of function body.') + elif Match(r'^\}\s*$', line): # function end + function_state.Check(error, filename, linenum) + function_state.End() + elif not Match(r'^\s*$', line): + function_state.Count() # Count non-blank/non-comment lines. + + +_RE_PATTERN_TODO = re.compile(r'^//(\s*)TODO(\(.+?\))?:?(\s|$)?') + + +def CheckComment(line, filename, linenum, next_line_start, error): + """Checks for common mistakes in comments. + + Args: + line: The line in question. + filename: The name of the current file. + linenum: The number of the line to check. + next_line_start: The first non-whitespace column of the next line. + error: The function to call with any errors found. + """ + commentpos = line.find('//') + if commentpos != -1: + # Check if the // may be in quotes. If so, ignore it + # Comparisons made explicit for clarity -- pylint: disable=g-explicit-bool-comparison + if (line.count('"', 0, commentpos) - + line.count('\\"', 0, commentpos)) % 2 == 0: # not in quotes + # Allow one space for new scopes, two spaces otherwise: + if (not (Match(r'^.*{ *//', line) and next_line_start == commentpos) and + ((commentpos >= 1 and + line[commentpos-1] not in string.whitespace) or + (commentpos >= 2 and + line[commentpos-2] not in string.whitespace))): + error(filename, linenum, 'whitespace/comments', 2, + 'At least two spaces is best between code and comments') + + # Checks for common mistakes in TODO comments. + comment = line[commentpos:] + match = _RE_PATTERN_TODO.match(comment) + if match: + # One whitespace is correct; zero whitespace is handled elsewhere. + leading_whitespace = match.group(1) + if len(leading_whitespace) > 1: + error(filename, linenum, 'whitespace/todo', 2, + 'Too many spaces before TODO') + + username = match.group(2) + if not username: + error(filename, linenum, 'readability/todo', 2, + 'Missing username in TODO; it should look like ' + '"// TODO(my_username): Stuff."') + + middle_whitespace = match.group(3) + # Comparisons made explicit for correctness -- pylint: disable=g-explicit-bool-comparison + if middle_whitespace != ' ' and middle_whitespace != '': + error(filename, linenum, 'whitespace/todo', 2, + 'TODO(my_username) should be followed by a space') + + # If the comment contains an alphanumeric character, there + # should be a space somewhere between it and the // unless + # it's a /// or //! Doxygen comment. + if (Match(r'//[^ ]*\w', comment) and + not Match(r'(///|//\!)(\s+|$)', comment)): + error(filename, linenum, 'whitespace/comments', 4, + 'Should have a space between // and comment') + + +def CheckAccess(filename, clean_lines, linenum, nesting_state, error): + """Checks for improper use of DISALLOW* macros. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + nesting_state: A NestingState instance which maintains information about + the current stack of nested blocks being parsed. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] # get rid of comments and strings + + matched = Match((r'\s*(DISALLOW_COPY_AND_ASSIGN|' + r'DISALLOW_IMPLICIT_CONSTRUCTORS)'), line) + if not matched: + return + if nesting_state.stack and isinstance(nesting_state.stack[-1], _ClassInfo): + if nesting_state.stack[-1].access != 'private': + error(filename, linenum, 'readability/constructors', 3, + '%s must be in the private: section' % matched.group(1)) + + else: + # Found DISALLOW* macro outside a class declaration, or perhaps it + # was used inside a function when it should have been part of the + # class declaration. We could issue a warning here, but it + # probably resulted in a compiler error already. + pass + + +def CheckSpacing(filename, clean_lines, linenum, nesting_state, error): + """Checks for the correctness of various spacing issues in the code. + + Things we check for: spaces around operators, spaces after + if/for/while/switch, no spaces around parens in function calls, two + spaces between code and comment, don't start a block with a blank + line, don't end a function with a blank line, don't add a blank line + after public/protected/private, don't have too many blank lines in a row. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + nesting_state: A NestingState instance which maintains information about + the current stack of nested blocks being parsed. + error: The function to call with any errors found. + """ + + # Don't use "elided" lines here, otherwise we can't check commented lines. + # Don't want to use "raw" either, because we don't want to check inside C++11 + # raw strings, + raw = clean_lines.lines_without_raw_strings + line = raw[linenum] + + # Before nixing comments, check if the line is blank for no good + # reason. This includes the first line after a block is opened, and + # blank lines at the end of a function (ie, right before a line like '}' + # + # Skip all the blank line checks if we are immediately inside a + # namespace body. In other words, don't issue blank line warnings + # for this block: + # namespace { + # + # } + # + # A warning about missing end of namespace comments will be issued instead. + # + # Also skip blank line checks for 'extern "C"' blocks, which are formatted + # like namespaces. + if (IsBlankLine(line) and + not nesting_state.InNamespaceBody() and + not nesting_state.InExternC()): + elided = clean_lines.elided + prev_line = elided[linenum - 1] + prevbrace = prev_line.rfind('{') + # TODO(unknown): Don't complain if line before blank line, and line after, + # both start with alnums and are indented the same amount. + # This ignores whitespace at the start of a namespace block + # because those are not usually indented. + if prevbrace != -1 and prev_line[prevbrace:].find('}') == -1: + # OK, we have a blank line at the start of a code block. Before we + # complain, we check if it is an exception to the rule: The previous + # non-empty line has the parameters of a function header that are indented + # 4 spaces (because they did not fit in a 80 column line when placed on + # the same line as the function name). We also check for the case where + # the previous line is indented 6 spaces, which may happen when the + # initializers of a constructor do not fit into a 80 column line. + exception = False + if Match(r' {6}\w', prev_line): # Initializer list? + # We are looking for the opening column of initializer list, which + # should be indented 4 spaces to cause 6 space indentation afterwards. + search_position = linenum-2 + while (search_position >= 0 + and Match(r' {6}\w', elided[search_position])): + search_position -= 1 + exception = (search_position >= 0 + and elided[search_position][:5] == ' :') + else: + # Search for the function arguments or an initializer list. We use a + # simple heuristic here: If the line is indented 4 spaces; and we have a + # closing paren, without the opening paren, followed by an opening brace + # or colon (for initializer lists) we assume that it is the last line of + # a function header. If we have a colon indented 4 spaces, it is an + # initializer list. + exception = (Match(r' {4}\w[^\(]*\)\s*(const\s*)?(\{\s*$|:)', + prev_line) + or Match(r' {4}:', prev_line)) + + if not exception: + error(filename, linenum, 'whitespace/blank_line', 2, + 'Redundant blank line at the start of a code block ' + 'should be deleted.') + # Ignore blank lines at the end of a block in a long if-else + # chain, like this: + # if (condition1) { + # // Something followed by a blank line + # + # } else if (condition2) { + # // Something else + # } + if linenum + 1 < clean_lines.NumLines(): + next_line = raw[linenum + 1] + if (next_line + and Match(r'\s*}', next_line) + and next_line.find('} else ') == -1): + error(filename, linenum, 'whitespace/blank_line', 3, + 'Redundant blank line at the end of a code block ' + 'should be deleted.') + + matched = Match(r'\s*(public|protected|private):', prev_line) + if matched: + error(filename, linenum, 'whitespace/blank_line', 3, + 'Do not leave a blank line after "%s:"' % matched.group(1)) + + # Next, check comments + next_line_start = 0 + if linenum + 1 < clean_lines.NumLines(): + next_line = raw[linenum + 1] + next_line_start = len(next_line) - len(next_line.lstrip()) + CheckComment(line, filename, linenum, next_line_start, error) + + # get rid of comments and strings + line = clean_lines.elided[linenum] + + # You shouldn't have spaces before your brackets, except maybe after + # 'delete []' or 'return []() {};' + if Search(r'\w\s+\[', line) and not Search(r'(?:delete|return)\s+\[', line): + error(filename, linenum, 'whitespace/braces', 5, + 'Extra space before [') + + # In range-based for, we wanted spaces before and after the colon, but + # not around "::" tokens that might appear. + if (Search(r'for *\(.*[^:]:[^: ]', line) or + Search(r'for *\(.*[^: ]:[^:]', line)): + error(filename, linenum, 'whitespace/forcolon', 2, + 'Missing space around colon in range-based for loop') + + +def CheckOperatorSpacing(filename, clean_lines, linenum, error): + """Checks for horizontal spacing around operators. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Don't try to do spacing checks for operator methods. Do this by + # replacing the troublesome characters with something else, + # preserving column position for all other characters. + # + # The replacement is done repeatedly to avoid false positives from + # operators that call operators. + while True: + match = Match(r'^(.*\boperator\b)(\S+)(\s*\(.*)$', line) + if match: + line = match.group(1) + ('_' * len(match.group(2))) + match.group(3) + else: + break + + # We allow no-spaces around = within an if: "if ( (a=Foo()) == 0 )". + # Otherwise not. Note we only check for non-spaces on *both* sides; + # sometimes people put non-spaces on one side when aligning ='s among + # many lines (not that this is behavior that I approve of...) + if ((Search(r'[\w.]=', line) or + Search(r'=[\w.]', line)) + and not Search(r'\b(if|while|for) ', line) + # Operators taken from [lex.operators] in C++11 standard. + and not Search(r'(>=|<=|==|!=|&=|\^=|\|=|\+=|\*=|\/=|\%=)', line) + and not Search(r'operator=', line)): + error(filename, linenum, 'whitespace/operators', 4, + 'Missing spaces around =') + + # It's ok not to have spaces around binary operators like + - * /, but if + # there's too little whitespace, we get concerned. It's hard to tell, + # though, so we punt on this one for now. TODO. + + # You should always have whitespace around binary operators. + # + # Check <= and >= first to avoid false positives with < and >, then + # check non-include lines for spacing around < and >. + # + # If the operator is followed by a comma, assume it's be used in a + # macro context and don't do any checks. This avoids false + # positives. + # + # Note that && is not included here. Those are checked separately + # in CheckRValueReference + match = Search(r'[^<>=!\s](==|!=|<=|>=|\|\|)[^<>=!\s,;\)]', line) + if match: + error(filename, linenum, 'whitespace/operators', 3, + 'Missing spaces around %s' % match.group(1)) + elif not Match(r'#.*include', line): + # Look for < that is not surrounded by spaces. This is only + # triggered if both sides are missing spaces, even though + # technically should should flag if at least one side is missing a + # space. This is done to avoid some false positives with shifts. + match = Match(r'^(.*[^\s<])<[^\s=<,]', line) + if match: + (_, _, end_pos) = CloseExpression( + clean_lines, linenum, len(match.group(1))) + if end_pos <= -1: + error(filename, linenum, 'whitespace/operators', 3, + 'Missing spaces around <') + + # Look for > that is not surrounded by spaces. Similar to the + # above, we only trigger if both sides are missing spaces to avoid + # false positives with shifts. + match = Match(r'^(.*[^-\s>])>[^\s=>,]', line) + if match: + (_, _, start_pos) = ReverseCloseExpression( + clean_lines, linenum, len(match.group(1))) + if start_pos <= -1: + error(filename, linenum, 'whitespace/operators', 3, + 'Missing spaces around >') + + # We allow no-spaces around << when used like this: 10<<20, but + # not otherwise (particularly, not when used as streams) + # + # We also allow operators following an opening parenthesis, since + # those tend to be macros that deal with operators. + match = Search(r'(operator|[^\s(<])(?:L|UL|ULL|l|ul|ull)?<<([^\s,=<])', line) + if (match and not (match.group(1).isdigit() and match.group(2).isdigit()) and + not (match.group(1) == 'operator' and match.group(2) == ';')): + error(filename, linenum, 'whitespace/operators', 3, + 'Missing spaces around <<') + + # We allow no-spaces around >> for almost anything. This is because + # C++11 allows ">>" to close nested templates, which accounts for + # most cases when ">>" is not followed by a space. + # + # We still warn on ">>" followed by alpha character, because that is + # likely due to ">>" being used for right shifts, e.g.: + # value >> alpha + # + # When ">>" is used to close templates, the alphanumeric letter that + # follows would be part of an identifier, and there should still be + # a space separating the template type and the identifier. + # type> alpha + match = Search(r'>>[a-zA-Z_]', line) + if match: + error(filename, linenum, 'whitespace/operators', 3, + 'Missing spaces around >>') + + # There shouldn't be space around unary operators + match = Search(r'(!\s|~\s|[\s]--[\s;]|[\s]\+\+[\s;])', line) + if match: + error(filename, linenum, 'whitespace/operators', 4, + 'Extra space for operator %s' % match.group(1)) + + +def CheckParenthesisSpacing(filename, clean_lines, linenum, error): + """Checks for horizontal spacing around parentheses. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # No spaces after an if, while, switch, or for + match = Search(r' (if\(|for\(|while\(|switch\()', line) + if match: + error(filename, linenum, 'whitespace/parens', 5, + 'Missing space before ( in %s' % match.group(1)) + + # For if/for/while/switch, the left and right parens should be + # consistent about how many spaces are inside the parens, and + # there should either be zero or one spaces inside the parens. + # We don't want: "if ( foo)" or "if ( foo )". + # Exception: "for ( ; foo; bar)" and "for (foo; bar; )" are allowed. + match = Search(r'\b(if|for|while|switch)\s*' + r'\(([ ]*)(.).*[^ ]+([ ]*)\)\s*{\s*$', + line) + if match: + if len(match.group(2)) != len(match.group(4)): + if not (match.group(3) == ';' and + len(match.group(2)) == 1 + len(match.group(4)) or + not match.group(2) and Search(r'\bfor\s*\(.*; \)', line)): + error(filename, linenum, 'whitespace/parens', 5, + 'Mismatching spaces inside () in %s' % match.group(1)) + if len(match.group(2)) not in [0, 1]: + error(filename, linenum, 'whitespace/parens', 5, + 'Should have zero or one spaces inside ( and ) in %s' % + match.group(1)) + + +def CheckCommaSpacing(filename, clean_lines, linenum, error): + """Checks for horizontal spacing near commas and semicolons. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + raw = clean_lines.lines_without_raw_strings + line = clean_lines.elided[linenum] + + # You should always have a space after a comma (either as fn arg or operator) + # + # This does not apply when the non-space character following the + # comma is another comma, since the only time when that happens is + # for empty macro arguments. + # + # We run this check in two passes: first pass on elided lines to + # verify that lines contain missing whitespaces, second pass on raw + # lines to confirm that those missing whitespaces are not due to + # elided comments. + if (Search(r',[^,\s]', ReplaceAll(r'\boperator\s*,\s*\(', 'F(', line)) and + Search(r',[^,\s]', raw[linenum])): + error(filename, linenum, 'whitespace/comma', 3, + 'Missing space after ,') + + # You should always have a space after a semicolon + # except for few corner cases + # TODO(unknown): clarify if 'if (1) { return 1;}' is requires one more + # space after ; + if Search(r';[^\s};\\)/]', line): + error(filename, linenum, 'whitespace/semicolon', 3, + 'Missing space after ;') + + +def CheckBracesSpacing(filename, clean_lines, linenum, error): + """Checks for horizontal spacing near commas. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Except after an opening paren, or after another opening brace (in case of + # an initializer list, for instance), you should have spaces before your + # braces. And since you should never have braces at the beginning of a line, + # this is an easy test. + match = Match(r'^(.*[^ ({>]){', line) + if match: + # Try a bit harder to check for brace initialization. This + # happens in one of the following forms: + # Constructor() : initializer_list_{} { ... } + # Constructor{}.MemberFunction() + # Type variable{}; + # FunctionCall(type{}, ...); + # LastArgument(..., type{}); + # LOG(INFO) << type{} << " ..."; + # map_of_type[{...}] = ...; + # ternary = expr ? new type{} : nullptr; + # OuterTemplate{}> + # + # We check for the character following the closing brace, and + # silence the warning if it's one of those listed above, i.e. + # "{.;,)<>]:". + # + # To account for nested initializer list, we allow any number of + # closing braces up to "{;,)<". We can't simply silence the + # warning on first sight of closing brace, because that would + # cause false negatives for things that are not initializer lists. + # Silence this: But not this: + # Outer{ if (...) { + # Inner{...} if (...){ // Missing space before { + # }; } + # + # There is a false negative with this approach if people inserted + # spurious semicolons, e.g. "if (cond){};", but we will catch the + # spurious semicolon with a separate check. + (endline, endlinenum, endpos) = CloseExpression( + clean_lines, linenum, len(match.group(1))) + trailing_text = '' + if endpos > -1: + trailing_text = endline[endpos:] + for offset in xrange(endlinenum + 1, + min(endlinenum + 3, clean_lines.NumLines() - 1)): + trailing_text += clean_lines.elided[offset] + if not Match(r'^[\s}]*[{.;,)<>\]:]', trailing_text): + error(filename, linenum, 'whitespace/braces', 5, + 'Missing space before {') + + # Make sure '} else {' has spaces. + if Search(r'}else', line): + error(filename, linenum, 'whitespace/braces', 5, + 'Missing space before else') + + # You shouldn't have a space before a semicolon at the end of the line. + # There's a special case for "for" since the style guide allows space before + # the semicolon there. + if Search(r':\s*;\s*$', line): + error(filename, linenum, 'whitespace/semicolon', 5, + 'Semicolon defining empty statement. Use {} instead.') + elif Search(r'^\s*;\s*$', line): + error(filename, linenum, 'whitespace/semicolon', 5, + 'Line contains only semicolon. If this should be an empty statement, ' + 'use {} instead.') + elif (Search(r'\s+;\s*$', line) and + not Search(r'\bfor\b', line)): + error(filename, linenum, 'whitespace/semicolon', 5, + 'Extra space before last semicolon. If this should be an empty ' + 'statement, use {} instead.') + + +def IsDecltype(clean_lines, linenum, column): + """Check if the token ending on (linenum, column) is decltype(). + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: the number of the line to check. + column: end column of the token to check. + Returns: + True if this token is decltype() expression, False otherwise. + """ + (text, _, start_col) = ReverseCloseExpression(clean_lines, linenum, column) + if start_col < 0: + return False + if Search(r'\bdecltype\s*$', text[0:start_col]): + return True + return False + + +def IsTemplateParameterList(clean_lines, linenum, column): + """Check if the token ending on (linenum, column) is the end of template<>. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: the number of the line to check. + column: end column of the token to check. + Returns: + True if this token is end of a template parameter list, False otherwise. + """ + (_, startline, startpos) = ReverseCloseExpression( + clean_lines, linenum, column) + if (startpos > -1 and + Search(r'\btemplate\s*$', clean_lines.elided[startline][0:startpos])): + return True + return False + + +def IsRValueType(typenames, clean_lines, nesting_state, linenum, column): + """Check if the token ending on (linenum, column) is a type. + + Assumes that text to the right of the column is "&&" or a function + name. + + Args: + typenames: set of type names from template-argument-list. + clean_lines: A CleansedLines instance containing the file. + nesting_state: A NestingState instance which maintains information about + the current stack of nested blocks being parsed. + linenum: the number of the line to check. + column: end column of the token to check. + Returns: + True if this token is a type, False if we are not sure. + """ + prefix = clean_lines.elided[linenum][0:column] + + # Get one word to the left. If we failed to do so, this is most + # likely not a type, since it's unlikely that the type name and "&&" + # would be split across multiple lines. + match = Match(r'^(.*)(\b\w+|[>*)&])\s*$', prefix) + if not match: + return False + + # Check text following the token. If it's "&&>" or "&&," or "&&...", it's + # most likely a rvalue reference used inside a template. + suffix = clean_lines.elided[linenum][column:] + if Match(r'&&\s*(?:[>,]|\.\.\.)', suffix): + return True + + # Check for known types and end of templates: + # int&& variable + # vector&& variable + # + # Because this function is called recursively, we also need to + # recognize pointer and reference types: + # int* Function() + # int& Function() + if (match.group(2) in typenames or + match.group(2) in ['char', 'char16_t', 'char32_t', 'wchar_t', 'bool', + 'short', 'int', 'long', 'signed', 'unsigned', + 'float', 'double', 'void', 'auto', '>', '*', '&']): + return True + + # If we see a close parenthesis, look for decltype on the other side. + # decltype would unambiguously identify a type, anything else is + # probably a parenthesized expression and not a type. + if match.group(2) == ')': + return IsDecltype( + clean_lines, linenum, len(match.group(1)) + len(match.group(2)) - 1) + + # Check for casts and cv-qualifiers. + # match.group(1) remainder + # -------------- --------- + # const_cast< type&& + # const type&& + # type const&& + if Search(r'\b(?:const_cast\s*<|static_cast\s*<|dynamic_cast\s*<|' + r'reinterpret_cast\s*<|\w+\s)\s*$', + match.group(1)): + return True + + # Look for a preceding symbol that might help differentiate the context. + # These are the cases that would be ambiguous: + # match.group(1) remainder + # -------------- --------- + # Call ( expression && + # Declaration ( type&& + # sizeof ( type&& + # if ( expression && + # while ( expression && + # for ( type&& + # for( ; expression && + # statement ; type&& + # block { type&& + # constructor { expression && + start = linenum + line = match.group(1) + match_symbol = None + while start >= 0: + # We want to skip over identifiers and commas to get to a symbol. + # Commas are skipped so that we can find the opening parenthesis + # for function parameter lists. + match_symbol = Match(r'^(.*)([^\w\s,])[\w\s,]*$', line) + if match_symbol: + break + start -= 1 + line = clean_lines.elided[start] + + if not match_symbol: + # Probably the first statement in the file is an rvalue reference + return True + + if match_symbol.group(2) == '}': + # Found closing brace, probably an indicate of this: + # block{} type&& + return True + + if match_symbol.group(2) == ';': + # Found semicolon, probably one of these: + # for(; expression && + # statement; type&& + + # Look for the previous 'for(' in the previous lines. + before_text = match_symbol.group(1) + for i in xrange(start - 1, max(start - 6, 0), -1): + before_text = clean_lines.elided[i] + before_text + if Search(r'for\s*\([^{};]*$', before_text): + # This is the condition inside a for-loop + return False + + # Did not find a for-init-statement before this semicolon, so this + # is probably a new statement and not a condition. + return True + + if match_symbol.group(2) == '{': + # Found opening brace, probably one of these: + # block{ type&& = ... ; } + # constructor{ expression && expression } + + # Look for a closing brace or a semicolon. If we see a semicolon + # first, this is probably a rvalue reference. + line = clean_lines.elided[start][0:len(match_symbol.group(1)) + 1] + end = start + depth = 1 + while True: + for ch in line: + if ch == ';': + return True + elif ch == '{': + depth += 1 + elif ch == '}': + depth -= 1 + if depth == 0: + return False + end += 1 + if end >= clean_lines.NumLines(): + break + line = clean_lines.elided[end] + # Incomplete program? + return False + + if match_symbol.group(2) == '(': + # Opening parenthesis. Need to check what's to the left of the + # parenthesis. Look back one extra line for additional context. + before_text = match_symbol.group(1) + if linenum > 1: + before_text = clean_lines.elided[linenum - 1] + before_text + before_text = match_symbol.group(1) + + # Patterns that are likely to be types: + # [](type&& + # for (type&& + # sizeof(type&& + # operator=(type&& + # + if Search(r'(?:\]|\bfor|\bsizeof|\boperator\s*\S+\s*)\s*$', before_text): + return True + + # Patterns that are likely to be expressions: + # if (expression && + # while (expression && + # : initializer(expression && + # , initializer(expression && + # ( FunctionCall(expression && + # + FunctionCall(expression && + # + (expression && + # + # The last '+' represents operators such as '+' and '-'. + if Search(r'(?:\bif|\bwhile|[-+=%^(]*>)?\s*$', + match_symbol.group(1)) + if match_func: + # Check for constructors, which don't have return types. + if Search(r'\b(?:explicit|inline)$', match_func.group(1)): + return True + implicit_constructor = Match(r'\s*(\w+)\((?:const\s+)?(\w+)', prefix) + if (implicit_constructor and + implicit_constructor.group(1) == implicit_constructor.group(2)): + return True + return IsRValueType(typenames, clean_lines, nesting_state, linenum, + len(match_func.group(1))) + + # Nothing before the function name. If this is inside a block scope, + # this is probably a function call. + return not (nesting_state.previous_stack_top and + nesting_state.previous_stack_top.IsBlockInfo()) + + if match_symbol.group(2) == '>': + # Possibly a closing bracket, check that what's on the other side + # looks like the start of a template. + return IsTemplateParameterList( + clean_lines, start, len(match_symbol.group(1))) + + # Some other symbol, usually something like "a=b&&c". This is most + # likely not a type. + return False + + +def IsDeletedOrDefault(clean_lines, linenum): + """Check if current constructor or operator is deleted or default. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + Returns: + True if this is a deleted or default constructor. + """ + open_paren = clean_lines.elided[linenum].find('(') + if open_paren < 0: + return False + (close_line, _, close_paren) = CloseExpression( + clean_lines, linenum, open_paren) + if close_paren < 0: + return False + return Match(r'\s*=\s*(?:delete|default)\b', close_line[close_paren:]) + + +def IsRValueAllowed(clean_lines, linenum, typenames): + """Check if RValue reference is allowed on a particular line. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + typenames: set of type names from template-argument-list. + Returns: + True if line is within the region where RValue references are allowed. + """ + # Allow region marked by PUSH/POP macros + for i in xrange(linenum, 0, -1): + line = clean_lines.elided[i] + if Match(r'GOOGLE_ALLOW_RVALUE_REFERENCES_(?:PUSH|POP)', line): + if not line.endswith('PUSH'): + return False + for j in xrange(linenum, clean_lines.NumLines(), 1): + line = clean_lines.elided[j] + if Match(r'GOOGLE_ALLOW_RVALUE_REFERENCES_(?:PUSH|POP)', line): + return line.endswith('POP') + + # Allow operator= + line = clean_lines.elided[linenum] + if Search(r'\boperator\s*=\s*\(', line): + return IsDeletedOrDefault(clean_lines, linenum) + + # Allow constructors + match = Match(r'\s*(?:[\w<>]+::)*([\w<>]+)\s*::\s*([\w<>]+)\s*\(', line) + if match and match.group(1) == match.group(2): + return IsDeletedOrDefault(clean_lines, linenum) + if Search(r'\b(?:explicit|inline)\s+[\w<>]+\s*\(', line): + return IsDeletedOrDefault(clean_lines, linenum) + + if Match(r'\s*[\w<>]+\s*\(', line): + previous_line = 'ReturnType' + if linenum > 0: + previous_line = clean_lines.elided[linenum - 1] + if Match(r'^\s*$', previous_line) or Search(r'[{}:;]\s*$', previous_line): + return IsDeletedOrDefault(clean_lines, linenum) + + # Reject types not mentioned in template-argument-list + while line: + match = Match(r'^.*?(\w+)\s*&&(.*)$', line) + if not match: + break + if match.group(1) not in typenames: + return False + line = match.group(2) + + # All RValue types that were in template-argument-list should have + # been removed by now. Those were allowed, assuming that they will + # be forwarded. + # + # If there are no remaining RValue types left (i.e. types that were + # not found in template-argument-list), flag those as not allowed. + return line.find('&&') < 0 + + +def GetTemplateArgs(clean_lines, linenum): + """Find list of template arguments associated with this function declaration. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: Line number containing the start of the function declaration, + usually one line after the end of the template-argument-list. + Returns: + Set of type names, or empty set if this does not appear to have + any template parameters. + """ + # Find start of function + func_line = linenum + while func_line > 0: + line = clean_lines.elided[func_line] + if Match(r'^\s*$', line): + return set() + if line.find('(') >= 0: + break + func_line -= 1 + if func_line == 0: + return set() + + # Collapse template-argument-list into a single string + argument_list = '' + match = Match(r'^(\s*template\s*)<', clean_lines.elided[func_line]) + if match: + # template-argument-list on the same line as function name + start_col = len(match.group(1)) + _, end_line, end_col = CloseExpression(clean_lines, func_line, start_col) + if end_col > -1 and end_line == func_line: + start_col += 1 # Skip the opening bracket + argument_list = clean_lines.elided[func_line][start_col:end_col] + + elif func_line > 1: + # template-argument-list one line before function name + match = Match(r'^(.*)>\s*$', clean_lines.elided[func_line - 1]) + if match: + end_col = len(match.group(1)) + _, start_line, start_col = ReverseCloseExpression( + clean_lines, func_line - 1, end_col) + if start_col > -1: + start_col += 1 # Skip the opening bracket + while start_line < func_line - 1: + argument_list += clean_lines.elided[start_line][start_col:] + start_col = 0 + start_line += 1 + argument_list += clean_lines.elided[func_line - 1][start_col:end_col] + + if not argument_list: + return set() + + # Extract type names + typenames = set() + while True: + match = Match(r'^[,\s]*(?:typename|class)(?:\.\.\.)?\s+(\w+)(.*)$', + argument_list) + if not match: + break + typenames.add(match.group(1)) + argument_list = match.group(2) + return typenames + + +def CheckRValueReference(filename, clean_lines, linenum, nesting_state, error): + """Check for rvalue references. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + nesting_state: A NestingState instance which maintains information about + the current stack of nested blocks being parsed. + error: The function to call with any errors found. + """ + # Find lines missing spaces around &&. + # TODO(unknown): currently we don't check for rvalue references + # with spaces surrounding the && to avoid false positives with + # boolean expressions. + line = clean_lines.elided[linenum] + match = Match(r'^(.*\S)&&', line) + if not match: + match = Match(r'(.*)&&\S', line) + if (not match) or '(&&)' in line or Search(r'\boperator\s*$', match.group(1)): + return + + # Either poorly formed && or an rvalue reference, check the context + # to get a more accurate error message. Mostly we want to determine + # if what's to the left of "&&" is a type or not. + typenames = GetTemplateArgs(clean_lines, linenum) + and_pos = len(match.group(1)) + if IsRValueType(typenames, clean_lines, nesting_state, linenum, and_pos): + if not IsRValueAllowed(clean_lines, linenum, typenames): + error(filename, linenum, 'build/c++11', 3, + 'RValue references are an unapproved C++ feature.') + else: + error(filename, linenum, 'whitespace/operators', 3, + 'Missing spaces around &&') + + +def CheckSectionSpacing(filename, clean_lines, class_info, linenum, error): + """Checks for additional blank line issues related to sections. + + Currently the only thing checked here is blank line before protected/private. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + class_info: A _ClassInfo objects. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + # Skip checks if the class is small, where small means 25 lines or less. + # 25 lines seems like a good cutoff since that's the usual height of + # terminals, and any class that can't fit in one screen can't really + # be considered "small". + # + # Also skip checks if we are on the first line. This accounts for + # classes that look like + # class Foo { public: ... }; + # + # If we didn't find the end of the class, last_line would be zero, + # and the check will be skipped by the first condition. + if (class_info.last_line - class_info.starting_linenum <= 24 or + linenum <= class_info.starting_linenum): + return + + matched = Match(r'\s*(public|protected|private):', clean_lines.lines[linenum]) + if matched: + # Issue warning if the line before public/protected/private was + # not a blank line, but don't do this if the previous line contains + # "class" or "struct". This can happen two ways: + # - We are at the beginning of the class. + # - We are forward-declaring an inner class that is semantically + # private, but needed to be public for implementation reasons. + # Also ignores cases where the previous line ends with a backslash as can be + # common when defining classes in C macros. + prev_line = clean_lines.lines[linenum - 1] + if (not IsBlankLine(prev_line) and + not Search(r'\b(class|struct)\b', prev_line) and + not Search(r'\\$', prev_line)): + # Try a bit harder to find the beginning of the class. This is to + # account for multi-line base-specifier lists, e.g.: + # class Derived + # : public Base { + end_class_head = class_info.starting_linenum + for i in range(class_info.starting_linenum, linenum): + if Search(r'\{\s*$', clean_lines.lines[i]): + end_class_head = i + break + if end_class_head < linenum - 1: + error(filename, linenum, 'whitespace/blank_line', 3, + '"%s:" should be preceded by a blank line' % matched.group(1)) + + +def GetPreviousNonBlankLine(clean_lines, linenum): + """Return the most recent non-blank line and its line number. + + Args: + clean_lines: A CleansedLines instance containing the file contents. + linenum: The number of the line to check. + + Returns: + A tuple with two elements. The first element is the contents of the last + non-blank line before the current line, or the empty string if this is the + first non-blank line. The second is the line number of that line, or -1 + if this is the first non-blank line. + """ + + prevlinenum = linenum - 1 + while prevlinenum >= 0: + prevline = clean_lines.elided[prevlinenum] + if not IsBlankLine(prevline): # if not a blank line... + return (prevline, prevlinenum) + prevlinenum -= 1 + return ('', -1) + + +def CheckBraces(filename, clean_lines, linenum, error): + """Looks for misplaced braces (e.g. at the end of line). + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + + line = clean_lines.elided[linenum] # get rid of comments and strings + + if Match(r'\s*{\s*$', line): + # We allow an open brace to start a line in the case where someone is using + # braces in a block to explicitly create a new scope, which is commonly used + # to control the lifetime of stack-allocated variables. Braces are also + # used for brace initializers inside function calls. We don't detect this + # perfectly: we just don't complain if the last non-whitespace character on + # the previous non-blank line is ',', ';', ':', '(', '{', or '}', or if the + # previous line starts a preprocessor block. + prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0] + if (not Search(r'[,;:}{(]\s*$', prevline) and + not Match(r'\s*#', prevline)): + error(filename, linenum, 'whitespace/braces', 4, + '{ should almost always be at the end of the previous line') + + # An else clause should be on the same line as the preceding closing brace. + if Match(r'\s*else\b\s*(?:if\b|\{|$)', line): + prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0] + if Match(r'\s*}\s*$', prevline): + error(filename, linenum, 'whitespace/newline', 4, + 'An else should appear on the same line as the preceding }') + + # If braces come on one side of an else, they should be on both. + # However, we have to worry about "else if" that spans multiple lines! + if Search(r'else if\s*\(', line): # could be multi-line if + brace_on_left = bool(Search(r'}\s*else if\s*\(', line)) + # find the ( after the if + pos = line.find('else if') + pos = line.find('(', pos) + if pos > 0: + (endline, _, endpos) = CloseExpression(clean_lines, linenum, pos) + brace_on_right = endline[endpos:].find('{') != -1 + if brace_on_left != brace_on_right: # must be brace after if + error(filename, linenum, 'readability/braces', 5, + 'If an else has a brace on one side, it should have it on both') + elif Search(r'}\s*else[^{]*$', line) or Match(r'[^}]*else\s*{', line): + error(filename, linenum, 'readability/braces', 5, + 'If an else has a brace on one side, it should have it on both') + + # Likewise, an else should never have the else clause on the same line + if Search(r'\belse [^\s{]', line) and not Search(r'\belse if\b', line): + error(filename, linenum, 'whitespace/newline', 4, + 'Else clause should never be on same line as else (use 2 lines)') + + # In the same way, a do/while should never be on one line + if Match(r'\s*do [^\s{]', line): + error(filename, linenum, 'whitespace/newline', 4, + 'do/while clauses should not be on a single line') + + # Check single-line if/else bodies. The style guide says 'curly braces are not + # required for single-line statements'. We additionally allow multi-line, + # single statements, but we reject anything with more than one semicolon in + # it. This means that the first semicolon after the if should be at the end of + # its line, and the line after that should have an indent level equal to or + # lower than the if. We also check for ambiguous if/else nesting without + # braces. + if_else_match = Search(r'\b(if\s*\(|else\b)', line) + if if_else_match and not Match(r'\s*#', line): + if_indent = GetIndentLevel(line) + endline, endlinenum, endpos = line, linenum, if_else_match.end() + if_match = Search(r'\bif\s*\(', line) + if if_match: + # This could be a multiline if condition, so find the end first. + pos = if_match.end() - 1 + (endline, endlinenum, endpos) = CloseExpression(clean_lines, linenum, pos) + # Check for an opening brace, either directly after the if or on the next + # line. If found, this isn't a single-statement conditional. + if (not Match(r'\s*{', endline[endpos:]) + and not (Match(r'\s*$', endline[endpos:]) + and endlinenum < (len(clean_lines.elided) - 1) + and Match(r'\s*{', clean_lines.elided[endlinenum + 1]))): + while (endlinenum < len(clean_lines.elided) + and ';' not in clean_lines.elided[endlinenum][endpos:]): + endlinenum += 1 + endpos = 0 + if endlinenum < len(clean_lines.elided): + endline = clean_lines.elided[endlinenum] + # We allow a mix of whitespace and closing braces (e.g. for one-liner + # methods) and a single \ after the semicolon (for macros) + endpos = endline.find(';') + if not Match(r';[\s}]*(\\?)$', endline[endpos:]): + # Semicolon isn't the last character, there's something trailing. + # Output a warning if the semicolon is not contained inside + # a lambda expression. + if not Match(r'^[^{};]*\[[^\[\]]*\][^{}]*\{[^{}]*\}\s*\)*[;,]\s*$', + endline): + error(filename, linenum, 'readability/braces', 4, + 'If/else bodies with multiple statements require braces') + elif endlinenum < len(clean_lines.elided) - 1: + # Make sure the next line is dedented + next_line = clean_lines.elided[endlinenum + 1] + next_indent = GetIndentLevel(next_line) + # With ambiguous nested if statements, this will error out on the + # if that *doesn't* match the else, regardless of whether it's the + # inner one or outer one. + if (if_match and Match(r'\s*else\b', next_line) + and next_indent != if_indent): + error(filename, linenum, 'readability/braces', 4, + 'Else clause should be indented at the same level as if. ' + 'Ambiguous nested if/else chains require braces.') + elif next_indent > if_indent: + error(filename, linenum, 'readability/braces', 4, + 'If/else bodies with multiple statements require braces') + + +def CheckTrailingSemicolon(filename, clean_lines, linenum, error): + """Looks for redundant trailing semicolon. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + + line = clean_lines.elided[linenum] + + # Block bodies should not be followed by a semicolon. Due to C++11 + # brace initialization, there are more places where semicolons are + # required than not, so we use a whitelist approach to check these + # rather than a blacklist. These are the places where "};" should + # be replaced by just "}": + # 1. Some flavor of block following closing parenthesis: + # for (;;) {}; + # while (...) {}; + # switch (...) {}; + # Function(...) {}; + # if (...) {}; + # if (...) else if (...) {}; + # + # 2. else block: + # if (...) else {}; + # + # 3. const member function: + # Function(...) const {}; + # + # 4. Block following some statement: + # x = 42; + # {}; + # + # 5. Block at the beginning of a function: + # Function(...) { + # {}; + # } + # + # Note that naively checking for the preceding "{" will also match + # braces inside multi-dimensional arrays, but this is fine since + # that expression will not contain semicolons. + # + # 6. Block following another block: + # while (true) {} + # {}; + # + # 7. End of namespaces: + # namespace {}; + # + # These semicolons seems far more common than other kinds of + # redundant semicolons, possibly due to people converting classes + # to namespaces. For now we do not warn for this case. + # + # Try matching case 1 first. + match = Match(r'^(.*\)\s*)\{', line) + if match: + # Matched closing parenthesis (case 1). Check the token before the + # matching opening parenthesis, and don't warn if it looks like a + # macro. This avoids these false positives: + # - macro that defines a base class + # - multi-line macro that defines a base class + # - macro that defines the whole class-head + # + # But we still issue warnings for macros that we know are safe to + # warn, specifically: + # - TEST, TEST_F, TEST_P, MATCHER, MATCHER_P + # - TYPED_TEST + # - INTERFACE_DEF + # - EXCLUSIVE_LOCKS_REQUIRED, SHARED_LOCKS_REQUIRED, LOCKS_EXCLUDED: + # + # We implement a whitelist of safe macros instead of a blacklist of + # unsafe macros, even though the latter appears less frequently in + # google code and would have been easier to implement. This is because + # the downside for getting the whitelist wrong means some extra + # semicolons, while the downside for getting the blacklist wrong + # would result in compile errors. + # + # In addition to macros, we also don't want to warn on + # - Compound literals + # - Lambdas + # - alignas specifier with anonymous structs: + closing_brace_pos = match.group(1).rfind(')') + opening_parenthesis = ReverseCloseExpression( + clean_lines, linenum, closing_brace_pos) + if opening_parenthesis[2] > -1: + line_prefix = opening_parenthesis[0][0:opening_parenthesis[2]] + macro = Search(r'\b([A-Z_]+)\s*$', line_prefix) + func = Match(r'^(.*\])\s*$', line_prefix) + if ((macro and + macro.group(1) not in ( + 'TEST', 'TEST_F', 'MATCHER', 'MATCHER_P', 'TYPED_TEST', + 'EXCLUSIVE_LOCKS_REQUIRED', 'SHARED_LOCKS_REQUIRED', + 'LOCKS_EXCLUDED', 'INTERFACE_DEF')) or + (func and not Search(r'\boperator\s*\[\s*\]', func.group(1))) or + Search(r'\b(?:struct|union)\s+alignas\s*$', line_prefix) or + Search(r'\s+=\s*$', line_prefix)): + match = None + if (match and + opening_parenthesis[1] > 1 and + Search(r'\]\s*$', clean_lines.elided[opening_parenthesis[1] - 1])): + # Multi-line lambda-expression + match = None + + else: + # Try matching cases 2-3. + match = Match(r'^(.*(?:else|\)\s*const)\s*)\{', line) + if not match: + # Try matching cases 4-6. These are always matched on separate lines. + # + # Note that we can't simply concatenate the previous line to the + # current line and do a single match, otherwise we may output + # duplicate warnings for the blank line case: + # if (cond) { + # // blank line + # } + prevline = GetPreviousNonBlankLine(clean_lines, linenum)[0] + if prevline and Search(r'[;{}]\s*$', prevline): + match = Match(r'^(\s*)\{', line) + + # Check matching closing brace + if match: + (endline, endlinenum, endpos) = CloseExpression( + clean_lines, linenum, len(match.group(1))) + if endpos > -1 and Match(r'^\s*;', endline[endpos:]): + # Current {} pair is eligible for semicolon check, and we have found + # the redundant semicolon, output warning here. + # + # Note: because we are scanning forward for opening braces, and + # outputting warnings for the matching closing brace, if there are + # nested blocks with trailing semicolons, we will get the error + # messages in reversed order. + error(filename, endlinenum, 'readability/braces', 4, + "You don't need a ; after a }") + + +def CheckEmptyBlockBody(filename, clean_lines, linenum, error): + """Look for empty loop/conditional body with only a single semicolon. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + + # Search for loop keywords at the beginning of the line. Because only + # whitespaces are allowed before the keywords, this will also ignore most + # do-while-loops, since those lines should start with closing brace. + # + # We also check "if" blocks here, since an empty conditional block + # is likely an error. + line = clean_lines.elided[linenum] + matched = Match(r'\s*(for|while|if)\s*\(', line) + if matched: + # Find the end of the conditional expression + (end_line, end_linenum, end_pos) = CloseExpression( + clean_lines, linenum, line.find('(')) + + # Output warning if what follows the condition expression is a semicolon. + # No warning for all other cases, including whitespace or newline, since we + # have a separate check for semicolons preceded by whitespace. + if end_pos >= 0 and Match(r';', end_line[end_pos:]): + if matched.group(1) == 'if': + error(filename, end_linenum, 'whitespace/empty_conditional_body', 5, + 'Empty conditional bodies should use {}') + else: + error(filename, end_linenum, 'whitespace/empty_loop_body', 5, + 'Empty loop bodies should use {} or continue') + + +def FindCheckMacro(line): + """Find a replaceable CHECK-like macro. + + Args: + line: line to search on. + Returns: + (macro name, start position), or (None, -1) if no replaceable + macro is found. + """ + for macro in _CHECK_MACROS: + i = line.find(macro) + if i >= 0: + # Find opening parenthesis. Do a regular expression match here + # to make sure that we are matching the expected CHECK macro, as + # opposed to some other macro that happens to contain the CHECK + # substring. + matched = Match(r'^(.*\b' + macro + r'\s*)\(', line) + if not matched: + continue + return (macro, len(matched.group(1))) + return (None, -1) + + +def CheckCheck(filename, clean_lines, linenum, error): + """Checks the use of CHECK and EXPECT macros. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + + # Decide the set of replacement macros that should be suggested + lines = clean_lines.elided + (check_macro, start_pos) = FindCheckMacro(lines[linenum]) + if not check_macro: + return + + # Find end of the boolean expression by matching parentheses + (last_line, end_line, end_pos) = CloseExpression( + clean_lines, linenum, start_pos) + if end_pos < 0: + return + + # If the check macro is followed by something other than a + # semicolon, assume users will log their own custom error messages + # and don't suggest any replacements. + if not Match(r'\s*;', last_line[end_pos:]): + return + + if linenum == end_line: + expression = lines[linenum][start_pos + 1:end_pos - 1] + else: + expression = lines[linenum][start_pos + 1:] + for i in xrange(linenum + 1, end_line): + expression += lines[i] + expression += last_line[0:end_pos - 1] + + # Parse expression so that we can take parentheses into account. + # This avoids false positives for inputs like "CHECK((a < 4) == b)", + # which is not replaceable by CHECK_LE. + lhs = '' + rhs = '' + operator = None + while expression: + matched = Match(r'^\s*(<<|<<=|>>|>>=|->\*|->|&&|\|\||' + r'==|!=|>=|>|<=|<|\()(.*)$', expression) + if matched: + token = matched.group(1) + if token == '(': + # Parenthesized operand + expression = matched.group(2) + (end, _) = FindEndOfExpressionInLine(expression, 0, ['(']) + if end < 0: + return # Unmatched parenthesis + lhs += '(' + expression[0:end] + expression = expression[end:] + elif token in ('&&', '||'): + # Logical and/or operators. This means the expression + # contains more than one term, for example: + # CHECK(42 < a && a < b); + # + # These are not replaceable with CHECK_LE, so bail out early. + return + elif token in ('<<', '<<=', '>>', '>>=', '->*', '->'): + # Non-relational operator + lhs += token + expression = matched.group(2) + else: + # Relational operator + operator = token + rhs = matched.group(2) + break + else: + # Unparenthesized operand. Instead of appending to lhs one character + # at a time, we do another regular expression match to consume several + # characters at once if possible. Trivial benchmark shows that this + # is more efficient when the operands are longer than a single + # character, which is generally the case. + matched = Match(r'^([^-=!<>()&|]+)(.*)$', expression) + if not matched: + matched = Match(r'^(\s*\S)(.*)$', expression) + if not matched: + break + lhs += matched.group(1) + expression = matched.group(2) + + # Only apply checks if we got all parts of the boolean expression + if not (lhs and operator and rhs): + return + + # Check that rhs do not contain logical operators. We already know + # that lhs is fine since the loop above parses out && and ||. + if rhs.find('&&') > -1 or rhs.find('||') > -1: + return + + # At least one of the operands must be a constant literal. This is + # to avoid suggesting replacements for unprintable things like + # CHECK(variable != iterator) + # + # The following pattern matches decimal, hex integers, strings, and + # characters (in that order). + lhs = lhs.strip() + rhs = rhs.strip() + match_constant = r'^([-+]?(\d+|0[xX][0-9a-fA-F]+)[lLuU]{0,3}|".*"|\'.*\')$' + if Match(match_constant, lhs) or Match(match_constant, rhs): + # Note: since we know both lhs and rhs, we can provide a more + # descriptive error message like: + # Consider using CHECK_EQ(x, 42) instead of CHECK(x == 42) + # Instead of: + # Consider using CHECK_EQ instead of CHECK(a == b) + # + # We are still keeping the less descriptive message because if lhs + # or rhs gets long, the error message might become unreadable. + error(filename, linenum, 'readability/check', 2, + 'Consider using %s instead of %s(a %s b)' % ( + _CHECK_REPLACEMENT[check_macro][operator], + check_macro, operator)) + + +def CheckAltTokens(filename, clean_lines, linenum, error): + """Check alternative keywords being used in boolean expressions. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Avoid preprocessor lines + if Match(r'^\s*#', line): + return + + # Last ditch effort to avoid multi-line comments. This will not help + # if the comment started before the current line or ended after the + # current line, but it catches most of the false positives. At least, + # it provides a way to workaround this warning for people who use + # multi-line comments in preprocessor macros. + # + # TODO(unknown): remove this once cpplint has better support for + # multi-line comments. + if line.find('/*') >= 0 or line.find('*/') >= 0: + return + + for match in _ALT_TOKEN_REPLACEMENT_PATTERN.finditer(line): + error(filename, linenum, 'readability/alt_tokens', 2, + 'Use operator %s instead of %s' % ( + _ALT_TOKEN_REPLACEMENT[match.group(1)], match.group(1))) + + +def GetLineWidth(line): + """Determines the width of the line in column positions. + + Args: + line: A string, which may be a Unicode string. + + Returns: + The width of the line in column positions, accounting for Unicode + combining characters and wide characters. + """ + if isinstance(line, unicode): + width = 0 + for uc in unicodedata.normalize('NFC', line): + if unicodedata.east_asian_width(uc) in ('W', 'F'): + width += 2 + elif not unicodedata.combining(uc): + width += 1 + return width + else: + return len(line) + + +def CheckStyle(filename, clean_lines, linenum, file_extension, nesting_state, + error): + """Checks rules from the 'C++ style rules' section of cppguide.html. + + Most of these rules are hard to test (naming, comment style), but we + do what we can. In particular we check for 2-space indents, line lengths, + tab usage, spaces inside code, etc. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + file_extension: The extension (without the dot) of the filename. + nesting_state: A NestingState instance which maintains information about + the current stack of nested blocks being parsed. + error: The function to call with any errors found. + """ + + # Don't use "elided" lines here, otherwise we can't check commented lines. + # Don't want to use "raw" either, because we don't want to check inside C++11 + # raw strings, + raw_lines = clean_lines.lines_without_raw_strings + line = raw_lines[linenum] + + if line.find('\t') != -1: + error(filename, linenum, 'whitespace/tab', 1, + 'Tab found; better to use spaces') + + # One or three blank spaces at the beginning of the line is weird; it's + # hard to reconcile that with 2-space indents. + # NOTE: here are the conditions rob pike used for his tests. Mine aren't + # as sophisticated, but it may be worth becoming so: RLENGTH==initial_spaces + # if(RLENGTH > 20) complain = 0; + # if(match($0, " +(error|private|public|protected):")) complain = 0; + # if(match(prev, "&& *$")) complain = 0; + # if(match(prev, "\\|\\| *$")) complain = 0; + # if(match(prev, "[\",=><] *$")) complain = 0; + # if(match($0, " <<")) complain = 0; + # if(match(prev, " +for \\(")) complain = 0; + # if(prevodd && match(prevprev, " +for \\(")) complain = 0; + scope_or_label_pattern = r'\s*\w+\s*:\s*\\?$' + classinfo = nesting_state.InnermostClass() + initial_spaces = 0 + cleansed_line = clean_lines.elided[linenum] + while initial_spaces < len(line) and line[initial_spaces] == ' ': + initial_spaces += 1 + if line and line[-1].isspace(): + error(filename, linenum, 'whitespace/end_of_line', 4, + 'Line ends in whitespace. Consider deleting these extra spaces.') + # There are certain situations we allow one space, notably for + # section labels, and also lines containing multi-line raw strings. + elif ((initial_spaces == 1 or initial_spaces == 3) and + not Match(scope_or_label_pattern, cleansed_line) and + not (clean_lines.raw_lines[linenum] != line and + Match(r'^\s*""', line))): + error(filename, linenum, 'whitespace/indent', 3, + 'Weird number of spaces at line-start. ' + 'Are you using a 2-space indent?') + + # Check if the line is a header guard. + is_header_guard = False + if file_extension == 'h': + cppvar = GetHeaderGuardCPPVariable(filename) + if (line.startswith('#ifndef %s' % cppvar) or + line.startswith('#define %s' % cppvar) or + line.startswith('#endif // %s' % cppvar)): + is_header_guard = True + # #include lines and header guards can be long, since there's no clean way to + # split them. + # + # URLs can be long too. It's possible to split these, but it makes them + # harder to cut&paste. + # + # The "$Id:...$" comment may also get very long without it being the + # developers fault. + if (not line.startswith('#include') and not is_header_guard and + not Match(r'^\s*//.*http(s?)://\S*$', line) and + not Match(r'^// \$Id:.*#[0-9]+ \$$', line)): + line_width = GetLineWidth(line) + extended_length = int((_line_length * 1.25)) + if line_width > extended_length: + error(filename, linenum, 'whitespace/line_length', 4, + 'Lines should very rarely be longer than %i characters' % + extended_length) + elif line_width > _line_length: + error(filename, linenum, 'whitespace/line_length', 2, + 'Lines should be <= %i characters long' % _line_length) + + if (cleansed_line.count(';') > 1 and + # for loops are allowed two ;'s (and may run over two lines). + cleansed_line.find('for') == -1 and + (GetPreviousNonBlankLine(clean_lines, linenum)[0].find('for') == -1 or + GetPreviousNonBlankLine(clean_lines, linenum)[0].find(';') != -1) and + # It's ok to have many commands in a switch case that fits in 1 line + not ((cleansed_line.find('case ') != -1 or + cleansed_line.find('default:') != -1) and + cleansed_line.find('break;') != -1)): + error(filename, linenum, 'whitespace/newline', 0, + 'More than one command on the same line') + + # Some more style checks + CheckBraces(filename, clean_lines, linenum, error) + CheckTrailingSemicolon(filename, clean_lines, linenum, error) + CheckEmptyBlockBody(filename, clean_lines, linenum, error) + CheckAccess(filename, clean_lines, linenum, nesting_state, error) + CheckSpacing(filename, clean_lines, linenum, nesting_state, error) + CheckOperatorSpacing(filename, clean_lines, linenum, error) + CheckParenthesisSpacing(filename, clean_lines, linenum, error) + CheckCommaSpacing(filename, clean_lines, linenum, error) + CheckBracesSpacing(filename, clean_lines, linenum, error) + CheckSpacingForFunctionCall(filename, clean_lines, linenum, error) + CheckRValueReference(filename, clean_lines, linenum, nesting_state, error) + CheckCheck(filename, clean_lines, linenum, error) + CheckAltTokens(filename, clean_lines, linenum, error) + classinfo = nesting_state.InnermostClass() + if classinfo: + CheckSectionSpacing(filename, clean_lines, classinfo, linenum, error) + + +_RE_PATTERN_INCLUDE = re.compile(r'^\s*#\s*include\s*([<"])([^>"]*)[>"].*$') +# Matches the first component of a filename delimited by -s and _s. That is: +# _RE_FIRST_COMPONENT.match('foo').group(0) == 'foo' +# _RE_FIRST_COMPONENT.match('foo.cc').group(0) == 'foo' +# _RE_FIRST_COMPONENT.match('foo-bar_baz.cc').group(0) == 'foo' +# _RE_FIRST_COMPONENT.match('foo_bar-baz.cc').group(0) == 'foo' +_RE_FIRST_COMPONENT = re.compile(r'^[^-_.]+') + + +def _DropCommonSuffixes(filename): + """Drops common suffixes like _test.cc or -inl.h from filename. + + For example: + >>> _DropCommonSuffixes('foo/foo-inl.h') + 'foo/foo' + >>> _DropCommonSuffixes('foo/bar/foo.cc') + 'foo/bar/foo' + >>> _DropCommonSuffixes('foo/foo_internal.h') + 'foo/foo' + >>> _DropCommonSuffixes('foo/foo_unusualinternal.h') + 'foo/foo_unusualinternal' + + Args: + filename: The input filename. + + Returns: + The filename with the common suffix removed. + """ + for suffix in ('test.cc', 'regtest.cc', 'unittest.cc', + 'inl.h', 'impl.h', 'internal.h'): + if (filename.endswith(suffix) and len(filename) > len(suffix) and + filename[-len(suffix) - 1] in ('-', '_')): + return filename[:-len(suffix) - 1] + return os.path.splitext(filename)[0] + + +def _IsTestFilename(filename): + """Determines if the given filename has a suffix that identifies it as a test. + + Args: + filename: The input filename. + + Returns: + True if 'filename' looks like a test, False otherwise. + """ + if (filename.endswith('_test.cc') or + filename.endswith('_unittest.cc') or + filename.endswith('_regtest.cc')): + return True + else: + return False + + +def _ClassifyInclude(fileinfo, include, is_system): + """Figures out what kind of header 'include' is. + + Args: + fileinfo: The current file cpplint is running over. A FileInfo instance. + include: The path to a #included file. + is_system: True if the #include used <> rather than "". + + Returns: + One of the _XXX_HEADER constants. + + For example: + >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'stdio.h', True) + _C_SYS_HEADER + >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'string', True) + _CPP_SYS_HEADER + >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'foo/foo.h', False) + _LIKELY_MY_HEADER + >>> _ClassifyInclude(FileInfo('foo/foo_unknown_extension.cc'), + ... 'bar/foo_other_ext.h', False) + _POSSIBLE_MY_HEADER + >>> _ClassifyInclude(FileInfo('foo/foo.cc'), 'foo/bar.h', False) + _OTHER_HEADER + """ + # This is a list of all standard c++ header files, except + # those already checked for above. + is_cpp_h = include in _CPP_HEADERS + + if is_system: + if is_cpp_h: + return _CPP_SYS_HEADER + else: + return _C_SYS_HEADER + + # If the target file and the include we're checking share a + # basename when we drop common extensions, and the include + # lives in . , then it's likely to be owned by the target file. + target_dir, target_base = ( + os.path.split(_DropCommonSuffixes(fileinfo.RepositoryName()))) + include_dir, include_base = os.path.split(_DropCommonSuffixes(include)) + if target_base == include_base and ( + include_dir == target_dir or + include_dir == os.path.normpath(target_dir + '/../public')): + return _LIKELY_MY_HEADER + + # If the target and include share some initial basename + # component, it's possible the target is implementing the + # include, so it's allowed to be first, but we'll never + # complain if it's not there. + target_first_component = _RE_FIRST_COMPONENT.match(target_base) + include_first_component = _RE_FIRST_COMPONENT.match(include_base) + if (target_first_component and include_first_component and + target_first_component.group(0) == + include_first_component.group(0)): + return _POSSIBLE_MY_HEADER + + return _OTHER_HEADER + + + +def CheckIncludeLine(filename, clean_lines, linenum, include_state, error): + """Check rules that are applicable to #include lines. + + Strings on #include lines are NOT removed from elided line, to make + certain tasks easier. However, to prevent false positives, checks + applicable to #include lines in CheckLanguage must be put here. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + include_state: An _IncludeState instance in which the headers are inserted. + error: The function to call with any errors found. + """ + fileinfo = FileInfo(filename) + line = clean_lines.lines[linenum] + + # "include" should use the new style "foo/bar.h" instead of just "bar.h" + # Only do this check if the included header follows google naming + # conventions. If not, assume that it's a 3rd party API that + # requires special include conventions. + # + # We also make an exception for Lua headers, which follow google + # naming convention but not the include convention. + match = Match(r'#include\s*"([^/]+\.h)"', line) + if match and not _THIRD_PARTY_HEADERS_PATTERN.match(match.group(1)): + error(filename, linenum, 'build/include', 4, + 'Include the directory when naming .h files') + + # we shouldn't include a file more than once. actually, there are a + # handful of instances where doing so is okay, but in general it's + # not. + match = _RE_PATTERN_INCLUDE.search(line) + if match: + include = match.group(2) + is_system = (match.group(1) == '<') + duplicate_line = include_state.FindHeader(include) + if duplicate_line >= 0: + error(filename, linenum, 'build/include', 4, + '"%s" already included at %s:%s' % + (include, filename, duplicate_line)) + elif (include.endswith('.cc') and + os.path.dirname(fileinfo.RepositoryName()) != os.path.dirname(include)): + error(filename, linenum, 'build/include', 4, + 'Do not include .cc files from other packages') + elif not _THIRD_PARTY_HEADERS_PATTERN.match(include): + include_state.include_list[-1].append((include, linenum)) + + # We want to ensure that headers appear in the right order: + # 1) for foo.cc, foo.h (preferred location) + # 2) c system files + # 3) cpp system files + # 4) for foo.cc, foo.h (deprecated location) + # 5) other google headers + # + # We classify each include statement as one of those 5 types + # using a number of techniques. The include_state object keeps + # track of the highest type seen, and complains if we see a + # lower type after that. + error_message = include_state.CheckNextIncludeOrder( + _ClassifyInclude(fileinfo, include, is_system)) + if error_message: + error(filename, linenum, 'build/include_order', 4, + '%s. Should be: %s.h, c system, c++ system, other.' % + (error_message, fileinfo.BaseName())) + canonical_include = include_state.CanonicalizeAlphabeticalOrder(include) + if not include_state.IsInAlphabeticalOrder( + clean_lines, linenum, canonical_include): + error(filename, linenum, 'build/include_alpha', 4, + 'Include "%s" not in alphabetical order' % include) + include_state.SetLastHeader(canonical_include) + + + +def _GetTextInside(text, start_pattern): + r"""Retrieves all the text between matching open and close parentheses. + + Given a string of lines and a regular expression string, retrieve all the text + following the expression and between opening punctuation symbols like + (, [, or {, and the matching close-punctuation symbol. This properly nested + occurrences of the punctuations, so for the text like + printf(a(), b(c())); + a call to _GetTextInside(text, r'printf\(') will return 'a(), b(c())'. + start_pattern must match string having an open punctuation symbol at the end. + + Args: + text: The lines to extract text. Its comments and strings must be elided. + It can be single line and can span multiple lines. + start_pattern: The regexp string indicating where to start extracting + the text. + Returns: + The extracted text. + None if either the opening string or ending punctuation could not be found. + """ + # TODO(unknown): Audit cpplint.py to see what places could be profitably + # rewritten to use _GetTextInside (and use inferior regexp matching today). + + # Give opening punctuations to get the matching close-punctuations. + matching_punctuation = {'(': ')', '{': '}', '[': ']'} + closing_punctuation = set(matching_punctuation.itervalues()) + + # Find the position to start extracting text. + match = re.search(start_pattern, text, re.M) + if not match: # start_pattern not found in text. + return None + start_position = match.end(0) + + assert start_position > 0, ( + 'start_pattern must ends with an opening punctuation.') + assert text[start_position - 1] in matching_punctuation, ( + 'start_pattern must ends with an opening punctuation.') + # Stack of closing punctuations we expect to have in text after position. + punctuation_stack = [matching_punctuation[text[start_position - 1]]] + position = start_position + while punctuation_stack and position < len(text): + if text[position] == punctuation_stack[-1]: + punctuation_stack.pop() + elif text[position] in closing_punctuation: + # A closing punctuation without matching opening punctuations. + return None + elif text[position] in matching_punctuation: + punctuation_stack.append(matching_punctuation[text[position]]) + position += 1 + if punctuation_stack: + # Opening punctuations left without matching close-punctuations. + return None + # punctuations match. + return text[start_position:position - 1] + + +# Patterns for matching call-by-reference parameters. +# +# Supports nested templates up to 2 levels deep using this messy pattern: +# < (?: < (?: < [^<>]* +# > +# | [^<>] )* +# > +# | [^<>] )* +# > +_RE_PATTERN_IDENT = r'[_a-zA-Z]\w*' # =~ [[:alpha:]][[:alnum:]]* +_RE_PATTERN_TYPE = ( + r'(?:const\s+)?(?:typename\s+|class\s+|struct\s+|union\s+|enum\s+)?' + r'(?:\w|' + r'\s*<(?:<(?:<[^<>]*>|[^<>])*>|[^<>])*>|' + r'::)+') +# A call-by-reference parameter ends with '& identifier'. +_RE_PATTERN_REF_PARAM = re.compile( + r'(' + _RE_PATTERN_TYPE + r'(?:\s*(?:\bconst\b|[*]))*\s*' + r'&\s*' + _RE_PATTERN_IDENT + r')\s*(?:=[^,()]+)?[,)]') +# A call-by-const-reference parameter either ends with 'const& identifier' +# or looks like 'const type& identifier' when 'type' is atomic. +_RE_PATTERN_CONST_REF_PARAM = ( + r'(?:.*\s*\bconst\s*&\s*' + _RE_PATTERN_IDENT + + r'|const\s+' + _RE_PATTERN_TYPE + r'\s*&\s*' + _RE_PATTERN_IDENT + r')') + + +def CheckLanguage(filename, clean_lines, linenum, file_extension, + include_state, nesting_state, error): + """Checks rules from the 'C++ language rules' section of cppguide.html. + + Some of these rules are hard to test (function overloading, using + uint32 inappropriately), but we do the best we can. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + file_extension: The extension (without the dot) of the filename. + include_state: An _IncludeState instance in which the headers are inserted. + nesting_state: A NestingState instance which maintains information about + the current stack of nested blocks being parsed. + error: The function to call with any errors found. + """ + # If the line is empty or consists of entirely a comment, no need to + # check it. + line = clean_lines.elided[linenum] + if not line: + return + + match = _RE_PATTERN_INCLUDE.search(line) + if match: + CheckIncludeLine(filename, clean_lines, linenum, include_state, error) + return + + # Reset include state across preprocessor directives. This is meant + # to silence warnings for conditional includes. + match = Match(r'^\s*#\s*(if|ifdef|ifndef|elif|else|endif)\b', line) + if match: + include_state.ResetSection(match.group(1)) + + # Make Windows paths like Unix. + fullname = os.path.abspath(filename).replace('\\', '/') + + # Perform other checks now that we are sure that this is not an include line + CheckCasts(filename, clean_lines, linenum, error) + CheckGlobalStatic(filename, clean_lines, linenum, error) + CheckPrintf(filename, clean_lines, linenum, error) + + if file_extension == 'h': + # TODO(unknown): check that 1-arg constructors are explicit. + # How to tell it's a constructor? + # (handled in CheckForNonStandardConstructs for now) + # TODO(unknown): check that classes declare or disable copy/assign + # (level 1 error) + pass + + # Check if people are using the verboten C basic types. The only exception + # we regularly allow is "unsigned short port" for port. + if Search(r'\bshort port\b', line): + if not Search(r'\bunsigned short port\b', line): + error(filename, linenum, 'runtime/int', 4, + 'Use "unsigned short" for ports, not "short"') + else: + match = Search(r'\b(short|long(?! +double)|long long)\b', line) + if match: + error(filename, linenum, 'runtime/int', 4, + 'Use int16/int64/etc, rather than the C type %s' % match.group(1)) + + # Check if some verboten operator overloading is going on + # TODO(unknown): catch out-of-line unary operator&: + # class X {}; + # int operator&(const X& x) { return 42; } // unary operator& + # The trick is it's hard to tell apart from binary operator&: + # class Y { int operator&(const Y& x) { return 23; } }; // binary operator& + if Search(r'\boperator\s*&\s*\(\s*\)', line): + error(filename, linenum, 'runtime/operator', 4, + 'Unary operator& is dangerous. Do not use it.') + + # Check for suspicious usage of "if" like + # } if (a == b) { + if Search(r'\}\s*if\s*\(', line): + error(filename, linenum, 'readability/braces', 4, + 'Did you mean "else if"? If not, start a new line for "if".') + + # Check for potential format string bugs like printf(foo). + # We constrain the pattern not to pick things like DocidForPrintf(foo). + # Not perfect but it can catch printf(foo.c_str()) and printf(foo->c_str()) + # TODO(unknown): Catch the following case. Need to change the calling + # convention of the whole function to process multiple line to handle it. + # printf( + # boy_this_is_a_really_long_variable_that_cannot_fit_on_the_prev_line); + printf_args = _GetTextInside(line, r'(?i)\b(string)?printf\s*\(') + if printf_args: + match = Match(r'([\w.\->()]+)$', printf_args) + if match and match.group(1) != '__VA_ARGS__': + function_name = re.search(r'\b((?:string)?printf)\s*\(', + line, re.I).group(1) + error(filename, linenum, 'runtime/printf', 4, + 'Potential format string bug. Do %s("%%s", %s) instead.' + % (function_name, match.group(1))) + + # Check for potential memset bugs like memset(buf, sizeof(buf), 0). + match = Search(r'memset\s*\(([^,]*),\s*([^,]*),\s*0\s*\)', line) + if match and not Match(r"^''|-?[0-9]+|0x[0-9A-Fa-f]$", match.group(2)): + error(filename, linenum, 'runtime/memset', 4, + 'Did you mean "memset(%s, 0, %s)"?' + % (match.group(1), match.group(2))) + + if Search(r'\busing namespace\b', line): + error(filename, linenum, 'build/namespaces', 5, + 'Do not use namespace using-directives. ' + 'Use using-declarations instead.') + + # Detect variable-length arrays. + match = Match(r'\s*(.+::)?(\w+) [a-z]\w*\[(.+)];', line) + if (match and match.group(2) != 'return' and match.group(2) != 'delete' and + match.group(3).find(']') == -1): + # Split the size using space and arithmetic operators as delimiters. + # If any of the resulting tokens are not compile time constants then + # report the error. + tokens = re.split(r'\s|\+|\-|\*|\/|<<|>>]', match.group(3)) + is_const = True + skip_next = False + for tok in tokens: + if skip_next: + skip_next = False + continue + + if Search(r'sizeof\(.+\)', tok): continue + if Search(r'arraysize\(\w+\)', tok): continue + + tok = tok.lstrip('(') + tok = tok.rstrip(')') + if not tok: continue + if Match(r'\d+', tok): continue + if Match(r'0[xX][0-9a-fA-F]+', tok): continue + if Match(r'k[A-Z0-9]\w*', tok): continue + if Match(r'(.+::)?k[A-Z0-9]\w*', tok): continue + if Match(r'(.+::)?[A-Z][A-Z0-9_]*', tok): continue + # A catch all for tricky sizeof cases, including 'sizeof expression', + # 'sizeof(*type)', 'sizeof(const type)', 'sizeof(struct StructName)' + # requires skipping the next token because we split on ' ' and '*'. + if tok.startswith('sizeof'): + skip_next = True + continue + is_const = False + break + if not is_const: + error(filename, linenum, 'runtime/arrays', 1, + 'Do not use variable-length arrays. Use an appropriately named ' + "('k' followed by CamelCase) compile-time constant for the size.") + + # Check for use of unnamed namespaces in header files. Registration + # macros are typically OK, so we allow use of "namespace {" on lines + # that end with backslashes. + if (file_extension == 'h' + and Search(r'\bnamespace\s*{', line) + and line[-1] != '\\'): + error(filename, linenum, 'build/namespaces', 4, + 'Do not use unnamed namespaces in header files. See ' + 'http://google-styleguide.googlecode.com/svn/trunk/cppguide.xml#Namespaces' + ' for more information.') + + +def CheckGlobalStatic(filename, clean_lines, linenum, error): + """Check for unsafe global or static objects. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Match two lines at a time to support multiline declarations + if linenum + 1 < clean_lines.NumLines() and not Search(r'[;({]', line): + line += clean_lines.elided[linenum + 1].strip() + + # Check for people declaring static/global STL strings at the top level. + # This is dangerous because the C++ language does not guarantee that + # globals with constructors are initialized before the first access. + match = Match( + r'((?:|static +)(?:|const +))string +([a-zA-Z0-9_:]+)\b(.*)', + line) + + # Remove false positives: + # - String pointers (as opposed to values). + # string *pointer + # const string *pointer + # string const *pointer + # string *const pointer + # + # - Functions and template specializations. + # string Function(... + # string Class::Method(... + # + # - Operators. These are matched separately because operator names + # cross non-word boundaries, and trying to match both operators + # and functions at the same time would decrease accuracy of + # matching identifiers. + # string Class::operator*() + if (match and + not Search(r'\bstring\b(\s+const)?\s*\*\s*(const\s+)?\w', line) and + not Search(r'\boperator\W', line) and + not Match(r'\s*(<.*>)?(::[a-zA-Z0-9_]+)*\s*\(([^"]|$)', match.group(3))): + error(filename, linenum, 'runtime/string', 4, + 'For a static/global string constant, use a C style string instead: ' + '"%schar %s[]".' % + (match.group(1), match.group(2))) + + if Search(r'\b([A-Za-z0-9_]*_)\(\1\)', line): + error(filename, linenum, 'runtime/init', 4, + 'You seem to be initializing a member variable with itself.') + + +def CheckPrintf(filename, clean_lines, linenum, error): + """Check for printf related issues. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # When snprintf is used, the second argument shouldn't be a literal. + match = Search(r'snprintf\s*\(([^,]*),\s*([0-9]*)\s*,', line) + if match and match.group(2) != '0': + # If 2nd arg is zero, snprintf is used to calculate size. + error(filename, linenum, 'runtime/printf', 3, + 'If you can, use sizeof(%s) instead of %s as the 2nd arg ' + 'to snprintf.' % (match.group(1), match.group(2))) + + # Check if some verboten C functions are being used. + if Search(r'\bsprintf\s*\(', line): + error(filename, linenum, 'runtime/printf', 5, + 'Never use sprintf. Use snprintf instead.') + match = Search(r'\b(strcpy|strcat)\s*\(', line) + if match: + error(filename, linenum, 'runtime/printf', 4, + 'Almost always, snprintf is better than %s' % match.group(1)) + + +def IsDerivedFunction(clean_lines, linenum): + """Check if current line contains an inherited function. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + Returns: + True if current line contains a function with "override" + virt-specifier. + """ + # Scan back a few lines for start of current function + for i in xrange(linenum, max(-1, linenum - 10), -1): + match = Match(r'^([^()]*\w+)\(', clean_lines.elided[i]) + if match: + # Look for "override" after the matching closing parenthesis + line, _, closing_paren = CloseExpression( + clean_lines, i, len(match.group(1))) + return (closing_paren >= 0 and + Search(r'\boverride\b', line[closing_paren:])) + return False + + +def IsOutOfLineMethodDefinition(clean_lines, linenum): + """Check if current line contains an out-of-line method definition. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + Returns: + True if current line contains an out-of-line method definition. + """ + # Scan back a few lines for start of current function + for i in xrange(linenum, max(-1, linenum - 10), -1): + if Match(r'^([^()]*\w+)\(', clean_lines.elided[i]): + return Match(r'^[^()]*\w+::\w+\(', clean_lines.elided[i]) is not None + return False + + +def IsInitializerList(clean_lines, linenum): + """Check if current line is inside constructor initializer list. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + Returns: + True if current line appears to be inside constructor initializer + list, False otherwise. + """ + for i in xrange(linenum, 1, -1): + line = clean_lines.elided[i] + if i == linenum: + remove_function_body = Match(r'^(.*)\{\s*$', line) + if remove_function_body: + line = remove_function_body.group(1) + + if Search(r'\s:\s*\w+[({]', line): + # A lone colon tend to indicate the start of a constructor + # initializer list. It could also be a ternary operator, which + # also tend to appear in constructor initializer lists as + # opposed to parameter lists. + return True + if Search(r'\}\s*,\s*$', line): + # A closing brace followed by a comma is probably the end of a + # brace-initialized member in constructor initializer list. + return True + if Search(r'[{};]\s*$', line): + # Found one of the following: + # - A closing brace or semicolon, probably the end of the previous + # function. + # - An opening brace, probably the start of current class or namespace. + # + # Current line is probably not inside an initializer list since + # we saw one of those things without seeing the starting colon. + return False + + # Got to the beginning of the file without seeing the start of + # constructor initializer list. + return False + + +def CheckForNonConstReference(filename, clean_lines, linenum, + nesting_state, error): + """Check for non-const references. + + Separate from CheckLanguage since it scans backwards from current + line, instead of scanning forward. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + nesting_state: A NestingState instance which maintains information about + the current stack of nested blocks being parsed. + error: The function to call with any errors found. + """ + # Do nothing if there is no '&' on current line. + line = clean_lines.elided[linenum] + if '&' not in line: + return + + # If a function is inherited, current function doesn't have much of + # a choice, so any non-const references should not be blamed on + # derived function. + if IsDerivedFunction(clean_lines, linenum): + return + + # Don't warn on out-of-line method definitions, as we would warn on the + # in-line declaration, if it isn't marked with 'override'. + if IsOutOfLineMethodDefinition(clean_lines, linenum): + return + + # Long type names may be broken across multiple lines, usually in one + # of these forms: + # LongType + # ::LongTypeContinued &identifier + # LongType:: + # LongTypeContinued &identifier + # LongType< + # ...>::LongTypeContinued &identifier + # + # If we detected a type split across two lines, join the previous + # line to current line so that we can match const references + # accordingly. + # + # Note that this only scans back one line, since scanning back + # arbitrary number of lines would be expensive. If you have a type + # that spans more than 2 lines, please use a typedef. + if linenum > 1: + previous = None + if Match(r'\s*::(?:[\w<>]|::)+\s*&\s*\S', line): + # previous_line\n + ::current_line + previous = Search(r'\b((?:const\s*)?(?:[\w<>]|::)+[\w<>])\s*$', + clean_lines.elided[linenum - 1]) + elif Match(r'\s*[a-zA-Z_]([\w<>]|::)+\s*&\s*\S', line): + # previous_line::\n + current_line + previous = Search(r'\b((?:const\s*)?(?:[\w<>]|::)+::)\s*$', + clean_lines.elided[linenum - 1]) + if previous: + line = previous.group(1) + line.lstrip() + else: + # Check for templated parameter that is split across multiple lines + endpos = line.rfind('>') + if endpos > -1: + (_, startline, startpos) = ReverseCloseExpression( + clean_lines, linenum, endpos) + if startpos > -1 and startline < linenum: + # Found the matching < on an earlier line, collect all + # pieces up to current line. + line = '' + for i in xrange(startline, linenum + 1): + line += clean_lines.elided[i].strip() + + # Check for non-const references in function parameters. A single '&' may + # found in the following places: + # inside expression: binary & for bitwise AND + # inside expression: unary & for taking the address of something + # inside declarators: reference parameter + # We will exclude the first two cases by checking that we are not inside a + # function body, including one that was just introduced by a trailing '{'. + # TODO(unknown): Doesn't account for 'catch(Exception& e)' [rare]. + if (nesting_state.previous_stack_top and + not (isinstance(nesting_state.previous_stack_top, _ClassInfo) or + isinstance(nesting_state.previous_stack_top, _NamespaceInfo))): + # Not at toplevel, not within a class, and not within a namespace + return + + # Avoid initializer lists. We only need to scan back from the + # current line for something that starts with ':'. + # + # We don't need to check the current line, since the '&' would + # appear inside the second set of parentheses on the current line as + # opposed to the first set. + if linenum > 0: + for i in xrange(linenum - 1, max(0, linenum - 10), -1): + previous_line = clean_lines.elided[i] + if not Search(r'[),]\s*$', previous_line): + break + if Match(r'^\s*:\s+\S', previous_line): + return + + # Avoid preprocessors + if Search(r'\\\s*$', line): + return + + # Avoid constructor initializer lists + if IsInitializerList(clean_lines, linenum): + return + + # We allow non-const references in a few standard places, like functions + # called "swap()" or iostream operators like "<<" or ">>". Do not check + # those function parameters. + # + # We also accept & in static_assert, which looks like a function but + # it's actually a declaration expression. + whitelisted_functions = (r'(?:[sS]wap(?:<\w:+>)?|' + r'operator\s*[<>][<>]|' + r'static_assert|COMPILE_ASSERT' + r')\s*\(') + if Search(whitelisted_functions, line): + return + elif not Search(r'\S+\([^)]*$', line): + # Don't see a whitelisted function on this line. Actually we + # didn't see any function name on this line, so this is likely a + # multi-line parameter list. Try a bit harder to catch this case. + for i in xrange(2): + if (linenum > i and + Search(whitelisted_functions, clean_lines.elided[linenum - i - 1])): + return + + decls = ReplaceAll(r'{[^}]*}', ' ', line) # exclude function body + for parameter in re.findall(_RE_PATTERN_REF_PARAM, decls): + if not Match(_RE_PATTERN_CONST_REF_PARAM, parameter): + error(filename, linenum, 'runtime/references', 2, + 'Is this a non-const reference? ' + 'If so, make const or use a pointer: ' + + ReplaceAll(' *<', '<', parameter)) + + +def CheckCasts(filename, clean_lines, linenum, error): + """Various cast related checks. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Check to see if they're using an conversion function cast. + # I just try to capture the most common basic types, though there are more. + # Parameterless conversion functions, such as bool(), are allowed as they are + # probably a member operator declaration or default constructor. + match = Search( + r'(\bnew\s+|\S<\s*(?:const\s+)?)?\b' + r'(int|float|double|bool|char|int32|uint32|int64|uint64)' + r'(\([^)].*)', line) + expecting_function = ExpectingFunctionArgs(clean_lines, linenum) + if match and not expecting_function: + matched_type = match.group(2) + + # matched_new_or_template is used to silence two false positives: + # - New operators + # - Template arguments with function types + # + # For template arguments, we match on types immediately following + # an opening bracket without any spaces. This is a fast way to + # silence the common case where the function type is the first + # template argument. False negative with less-than comparison is + # avoided because those operators are usually followed by a space. + # + # function // bracket + no space = false positive + # value < double(42) // bracket + space = true positive + matched_new_or_template = match.group(1) + + # Avoid arrays by looking for brackets that come after the closing + # parenthesis. + if Match(r'\([^()]+\)\s*\[', match.group(3)): + return + + # Other things to ignore: + # - Function pointers + # - Casts to pointer types + # - Placement new + # - Alias declarations + matched_funcptr = match.group(3) + if (matched_new_or_template is None and + not (matched_funcptr and + (Match(r'\((?:[^() ]+::\s*\*\s*)?[^() ]+\)\s*\(', + matched_funcptr) or + matched_funcptr.startswith('(*)'))) and + not Match(r'\s*using\s+\S+\s*=\s*' + matched_type, line) and + not Search(r'new\(\S+\)\s*' + matched_type, line)): + error(filename, linenum, 'readability/casting', 4, + 'Using deprecated casting style. ' + 'Use static_cast<%s>(...) instead' % + matched_type) + + if not expecting_function: + CheckCStyleCast(filename, clean_lines, linenum, 'static_cast', + r'\((int|float|double|bool|char|u?int(16|32|64))\)', error) + + # This doesn't catch all cases. Consider (const char * const)"hello". + # + # (char *) "foo" should always be a const_cast (reinterpret_cast won't + # compile). + if CheckCStyleCast(filename, clean_lines, linenum, 'const_cast', + r'\((char\s?\*+\s?)\)\s*"', error): + pass + else: + # Check pointer casts for other than string constants + CheckCStyleCast(filename, clean_lines, linenum, 'reinterpret_cast', + r'\((\w+\s?\*+\s?)\)', error) + + # In addition, we look for people taking the address of a cast. This + # is dangerous -- casts can assign to temporaries, so the pointer doesn't + # point where you think. + # + # Some non-identifier character is required before the '&' for the + # expression to be recognized as a cast. These are casts: + # expression = &static_cast(temporary()); + # function(&(int*)(temporary())); + # + # This is not a cast: + # reference_type&(int* function_param); + match = Search( + r'(?:[^\w]&\(([^)*][^)]*)\)[\w(])|' + r'(?:[^\w]&(static|dynamic|down|reinterpret)_cast\b)', line) + if match: + # Try a better error message when the & is bound to something + # dereferenced by the casted pointer, as opposed to the casted + # pointer itself. + parenthesis_error = False + match = Match(r'^(.*&(?:static|dynamic|down|reinterpret)_cast\b)<', line) + if match: + _, y1, x1 = CloseExpression(clean_lines, linenum, len(match.group(1))) + if x1 >= 0 and clean_lines.elided[y1][x1] == '(': + _, y2, x2 = CloseExpression(clean_lines, y1, x1) + if x2 >= 0: + extended_line = clean_lines.elided[y2][x2:] + if y2 < clean_lines.NumLines() - 1: + extended_line += clean_lines.elided[y2 + 1] + if Match(r'\s*(?:->|\[)', extended_line): + parenthesis_error = True + + if parenthesis_error: + error(filename, linenum, 'readability/casting', 4, + ('Are you taking an address of something dereferenced ' + 'from a cast? Wrapping the dereferenced expression in ' + 'parentheses will make the binding more obvious')) + else: + error(filename, linenum, 'runtime/casting', 4, + ('Are you taking an address of a cast? ' + 'This is dangerous: could be a temp var. ' + 'Take the address before doing the cast, rather than after')) + + +def CheckCStyleCast(filename, clean_lines, linenum, cast_type, pattern, error): + """Checks for a C-style cast by looking for the pattern. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + cast_type: The string for the C++ cast to recommend. This is either + reinterpret_cast, static_cast, or const_cast, depending. + pattern: The regular expression used to find C-style casts. + error: The function to call with any errors found. + + Returns: + True if an error was emitted. + False otherwise. + """ + line = clean_lines.elided[linenum] + match = Search(pattern, line) + if not match: + return False + + # Exclude lines with keywords that tend to look like casts + context = line[0:match.start(1) - 1] + if Match(r'.*\b(?:sizeof|alignof|alignas|[_A-Z][_A-Z0-9]*)\s*$', context): + return False + + # Try expanding current context to see if we one level of + # parentheses inside a macro. + if linenum > 0: + for i in xrange(linenum - 1, max(0, linenum - 5), -1): + context = clean_lines.elided[i] + context + if Match(r'.*\b[_A-Z][_A-Z0-9]*\s*\((?:\([^()]*\)|[^()])*$', context): + return False + + # operator++(int) and operator--(int) + if context.endswith(' operator++') or context.endswith(' operator--'): + return False + + # A single unnamed argument for a function tends to look like old + # style cast. If we see those, don't issue warnings for deprecated + # casts, instead issue warnings for unnamed arguments where + # appropriate. + # + # These are things that we want warnings for, since the style guide + # explicitly require all parameters to be named: + # Function(int); + # Function(int) { + # ConstMember(int) const; + # ConstMember(int) const { + # ExceptionMember(int) throw (...); + # ExceptionMember(int) throw (...) { + # PureVirtual(int) = 0; + # [](int) -> bool { + # + # These are functions of some sort, where the compiler would be fine + # if they had named parameters, but people often omit those + # identifiers to reduce clutter: + # (FunctionPointer)(int); + # (FunctionPointer)(int) = value; + # Function((function_pointer_arg)(int)) + # Function((function_pointer_arg)(int), int param) + # ; + # <(FunctionPointerTemplateArgument)(int)>; + remainder = line[match.end(0):] + if Match(r'^\s*(?:;|const\b|throw\b|final\b|override\b|[=>{),]|->)', + remainder): + # Looks like an unnamed parameter. + + # Don't warn on any kind of template arguments. + if Match(r'^\s*>', remainder): + return False + + # Don't warn on assignments to function pointers, but keep warnings for + # unnamed parameters to pure virtual functions. Note that this pattern + # will also pass on assignments of "0" to function pointers, but the + # preferred values for those would be "nullptr" or "NULL". + matched_zero = Match(r'^\s=\s*(\S+)\s*;', remainder) + if matched_zero and matched_zero.group(1) != '0': + return False + + # Don't warn on function pointer declarations. For this we need + # to check what came before the "(type)" string. + if Match(r'.*\)\s*$', line[0:match.start(0)]): + return False + + # Don't warn if the parameter is named with block comments, e.g.: + # Function(int /*unused_param*/); + raw_line = clean_lines.raw_lines[linenum] + if '/*' in raw_line: + return False + + # Passed all filters, issue warning here. + error(filename, linenum, 'readability/function', 3, + 'All parameters should be named in a function') + return True + + # At this point, all that should be left is actual casts. + error(filename, linenum, 'readability/casting', 4, + 'Using C-style cast. Use %s<%s>(...) instead' % + (cast_type, match.group(1))) + + return True + + +def ExpectingFunctionArgs(clean_lines, linenum): + """Checks whether where function type arguments are expected. + + Args: + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + + Returns: + True if the line at 'linenum' is inside something that expects arguments + of function types. + """ + line = clean_lines.elided[linenum] + return (Match(r'^\s*MOCK_(CONST_)?METHOD\d+(_T)?\(', line) or + (linenum >= 2 and + (Match(r'^\s*MOCK_(?:CONST_)?METHOD\d+(?:_T)?\((?:\S+,)?\s*$', + clean_lines.elided[linenum - 1]) or + Match(r'^\s*MOCK_(?:CONST_)?METHOD\d+(?:_T)?\(\s*$', + clean_lines.elided[linenum - 2]) or + Search(r'\bstd::m?function\s*\<\s*$', + clean_lines.elided[linenum - 1])))) + + +_HEADERS_CONTAINING_TEMPLATES = ( + ('', ('deque',)), + ('', ('unary_function', 'binary_function', + 'plus', 'minus', 'multiplies', 'divides', 'modulus', + 'negate', + 'equal_to', 'not_equal_to', 'greater', 'less', + 'greater_equal', 'less_equal', + 'logical_and', 'logical_or', 'logical_not', + 'unary_negate', 'not1', 'binary_negate', 'not2', + 'bind1st', 'bind2nd', + 'pointer_to_unary_function', + 'pointer_to_binary_function', + 'ptr_fun', + 'mem_fun_t', 'mem_fun', 'mem_fun1_t', 'mem_fun1_ref_t', + 'mem_fun_ref_t', + 'const_mem_fun_t', 'const_mem_fun1_t', + 'const_mem_fun_ref_t', 'const_mem_fun1_ref_t', + 'mem_fun_ref', + )), + ('', ('numeric_limits',)), + ('', ('list',)), + ('', ('map', 'multimap',)), + ('', ('allocator',)), + ('', ('queue', 'priority_queue',)), + ('', ('set', 'multiset',)), + ('', ('stack',)), + ('', ('char_traits', 'basic_string',)), + ('', ('tuple',)), + ('', ('pair',)), + ('', ('vector',)), + + # gcc extensions. + # Note: std::hash is their hash, ::hash is our hash + ('', ('hash_map', 'hash_multimap',)), + ('', ('hash_set', 'hash_multiset',)), + ('', ('slist',)), + ) + +_RE_PATTERN_STRING = re.compile(r'\bstring\b') + +_re_pattern_algorithm_header = [] +for _template in ('copy', 'max', 'min', 'min_element', 'sort', 'swap', + 'transform'): + # Match max(..., ...), max(..., ...), but not foo->max, foo.max or + # type::max(). + _re_pattern_algorithm_header.append( + (re.compile(r'[^>.]\b' + _template + r'(<.*?>)?\([^\)]'), + _template, + '')) + +_re_pattern_templates = [] +for _header, _templates in _HEADERS_CONTAINING_TEMPLATES: + for _template in _templates: + _re_pattern_templates.append( + (re.compile(r'(\<|\b)' + _template + r'\s*\<'), + _template + '<>', + _header)) + + +def FilesBelongToSameModule(filename_cc, filename_h): + """Check if these two filenames belong to the same module. + + The concept of a 'module' here is a as follows: + foo.h, foo-inl.h, foo.cc, foo_test.cc and foo_unittest.cc belong to the + same 'module' if they are in the same directory. + some/path/public/xyzzy and some/path/internal/xyzzy are also considered + to belong to the same module here. + + If the filename_cc contains a longer path than the filename_h, for example, + '/absolute/path/to/base/sysinfo.cc', and this file would include + 'base/sysinfo.h', this function also produces the prefix needed to open the + header. This is used by the caller of this function to more robustly open the + header file. We don't have access to the real include paths in this context, + so we need this guesswork here. + + Known bugs: tools/base/bar.cc and base/bar.h belong to the same module + according to this implementation. Because of this, this function gives + some false positives. This should be sufficiently rare in practice. + + Args: + filename_cc: is the path for the .cc file + filename_h: is the path for the header path + + Returns: + Tuple with a bool and a string: + bool: True if filename_cc and filename_h belong to the same module. + string: the additional prefix needed to open the header file. + """ + + if not filename_cc.endswith('.cc'): + return (False, '') + filename_cc = filename_cc[:-len('.cc')] + if filename_cc.endswith('_unittest'): + filename_cc = filename_cc[:-len('_unittest')] + elif filename_cc.endswith('_test'): + filename_cc = filename_cc[:-len('_test')] + filename_cc = filename_cc.replace('/public/', '/') + filename_cc = filename_cc.replace('/internal/', '/') + + if not filename_h.endswith('.h'): + return (False, '') + filename_h = filename_h[:-len('.h')] + if filename_h.endswith('-inl'): + filename_h = filename_h[:-len('-inl')] + filename_h = filename_h.replace('/public/', '/') + filename_h = filename_h.replace('/internal/', '/') + + files_belong_to_same_module = filename_cc.endswith(filename_h) + common_path = '' + if files_belong_to_same_module: + common_path = filename_cc[:-len(filename_h)] + return files_belong_to_same_module, common_path + + +def UpdateIncludeState(filename, include_dict, io=codecs): + """Fill up the include_dict with new includes found from the file. + + Args: + filename: the name of the header to read. + include_dict: a dictionary in which the headers are inserted. + io: The io factory to use to read the file. Provided for testability. + + Returns: + True if a header was successfully added. False otherwise. + """ + headerfile = None + try: + headerfile = io.open(filename, 'r', 'utf8', 'replace') + except IOError: + return False + linenum = 0 + for line in headerfile: + linenum += 1 + clean_line = CleanseComments(line) + match = _RE_PATTERN_INCLUDE.search(clean_line) + if match: + include = match.group(2) + include_dict.setdefault(include, linenum) + return True + + +def CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error, + io=codecs): + """Reports for missing stl includes. + + This function will output warnings to make sure you are including the headers + necessary for the stl containers and functions that you use. We only give one + reason to include a header. For example, if you use both equal_to<> and + less<> in a .h file, only one (the latter in the file) of these will be + reported as a reason to include the . + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + include_state: An _IncludeState instance. + error: The function to call with any errors found. + io: The IO factory to use to read the header file. Provided for unittest + injection. + """ + required = {} # A map of header name to linenumber and the template entity. + # Example of required: { '': (1219, 'less<>') } + + for linenum in xrange(clean_lines.NumLines()): + line = clean_lines.elided[linenum] + if not line or line[0] == '#': + continue + + # String is special -- it is a non-templatized type in STL. + matched = _RE_PATTERN_STRING.search(line) + if matched: + # Don't warn about strings in non-STL namespaces: + # (We check only the first match per line; good enough.) + prefix = line[:matched.start()] + if prefix.endswith('std::') or not prefix.endswith('::'): + required[''] = (linenum, 'string') + + for pattern, template, header in _re_pattern_algorithm_header: + if pattern.search(line): + required[header] = (linenum, template) + + # The following function is just a speed up, no semantics are changed. + if not '<' in line: # Reduces the cpu time usage by skipping lines. + continue + + for pattern, template, header in _re_pattern_templates: + if pattern.search(line): + required[header] = (linenum, template) + + # The policy is that if you #include something in foo.h you don't need to + # include it again in foo.cc. Here, we will look at possible includes. + # Let's flatten the include_state include_list and copy it into a dictionary. + include_dict = dict([item for sublist in include_state.include_list + for item in sublist]) + + # Did we find the header for this file (if any) and successfully load it? + header_found = False + + # Use the absolute path so that matching works properly. + abs_filename = FileInfo(filename).FullName() + + # For Emacs's flymake. + # If cpplint is invoked from Emacs's flymake, a temporary file is generated + # by flymake and that file name might end with '_flymake.cc'. In that case, + # restore original file name here so that the corresponding header file can be + # found. + # e.g. If the file name is 'foo_flymake.cc', we should search for 'foo.h' + # instead of 'foo_flymake.h' + abs_filename = re.sub(r'_flymake\.cc$', '.cc', abs_filename) + + # include_dict is modified during iteration, so we iterate over a copy of + # the keys. + header_keys = include_dict.keys() + for header in header_keys: + (same_module, common_path) = FilesBelongToSameModule(abs_filename, header) + fullpath = common_path + header + if same_module and UpdateIncludeState(fullpath, include_dict, io): + header_found = True + + # If we can't find the header file for a .cc, assume it's because we don't + # know where to look. In that case we'll give up as we're not sure they + # didn't include it in the .h file. + # TODO(unknown): Do a better job of finding .h files so we are confident that + # not having the .h file means there isn't one. + if filename.endswith('.cc') and not header_found: + return + + # All the lines have been processed, report the errors found. + for required_header_unstripped in required: + template = required[required_header_unstripped][1] + if required_header_unstripped.strip('<>"') not in include_dict: + error(filename, required[required_header_unstripped][0], + 'build/include_what_you_use', 4, + 'Add #include ' + required_header_unstripped + ' for ' + template) + + +_RE_PATTERN_EXPLICIT_MAKEPAIR = re.compile(r'\bmake_pair\s*<') + + +def CheckMakePairUsesDeduction(filename, clean_lines, linenum, error): + """Check that make_pair's template arguments are deduced. + + G++ 4.6 in C++11 mode fails badly if make_pair's template arguments are + specified explicitly, and such use isn't intended in any case. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + match = _RE_PATTERN_EXPLICIT_MAKEPAIR.search(line) + if match: + error(filename, linenum, 'build/explicit_make_pair', + 4, # 4 = high confidence + 'For C++11-compatibility, omit template arguments from make_pair' + ' OR use pair directly OR if appropriate, construct a pair directly') + + +def CheckDefaultLambdaCaptures(filename, clean_lines, linenum, error): + """Check that default lambda captures are not used. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # A lambda introducer specifies a default capture if it starts with "[=" + # or if it starts with "[&" _not_ followed by an identifier. + match = Match(r'^(.*)\[\s*(?:=|&[^\w])', line) + if match: + # Found a potential error, check what comes after the lambda-introducer. + # If it's not open parenthesis (for lambda-declarator) or open brace + # (for compound-statement), it's not a lambda. + line, _, pos = CloseExpression(clean_lines, linenum, len(match.group(1))) + if pos >= 0 and Match(r'^\s*[{(]', line[pos:]): + error(filename, linenum, 'build/c++11', + 4, # 4 = high confidence + 'Default lambda captures are an unapproved C++ feature.') + + +def CheckRedundantVirtual(filename, clean_lines, linenum, error): + """Check if line contains a redundant "virtual" function-specifier. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + # Look for "virtual" on current line. + line = clean_lines.elided[linenum] + virtual = Match(r'^(.*)(\bvirtual\b)(.*)$', line) + if not virtual: return + + # Ignore "virtual" keywords that are near access-specifiers. These + # are only used in class base-specifier and do not apply to member + # functions. + if (Search(r'\b(public|protected|private)\s+$', virtual.group(1)) or + Match(r'^\s+(public|protected|private)\b', virtual.group(3))): + return + + # Ignore the "virtual" keyword from virtual base classes. Usually + # there is a column on the same line in these cases (virtual base + # classes are rare in google3 because multiple inheritance is rare). + if Match(r'^.*[^:]:[^:].*$', line): return + + # Look for the next opening parenthesis. This is the start of the + # parameter list (possibly on the next line shortly after virtual). + # TODO(unknown): doesn't work if there are virtual functions with + # decltype() or other things that use parentheses, but csearch suggests + # that this is rare. + end_col = -1 + end_line = -1 + start_col = len(virtual.group(2)) + for start_line in xrange(linenum, min(linenum + 3, clean_lines.NumLines())): + line = clean_lines.elided[start_line][start_col:] + parameter_list = Match(r'^([^(]*)\(', line) + if parameter_list: + # Match parentheses to find the end of the parameter list + (_, end_line, end_col) = CloseExpression( + clean_lines, start_line, start_col + len(parameter_list.group(1))) + break + start_col = 0 + + if end_col < 0: + return # Couldn't find end of parameter list, give up + + # Look for "override" or "final" after the parameter list + # (possibly on the next few lines). + for i in xrange(end_line, min(end_line + 3, clean_lines.NumLines())): + line = clean_lines.elided[i][end_col:] + match = Search(r'\b(override|final)\b', line) + if match: + error(filename, linenum, 'readability/inheritance', 4, + ('"virtual" is redundant since function is ' + 'already declared as "%s"' % match.group(1))) + + # Set end_col to check whole lines after we are done with the + # first line. + end_col = 0 + if Search(r'[^\w]\s*$', line): + break + + +def CheckRedundantOverrideOrFinal(filename, clean_lines, linenum, error): + """Check if line contains a redundant "override" or "final" virt-specifier. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + # Look for closing parenthesis nearby. We need one to confirm where + # the declarator ends and where the virt-specifier starts to avoid + # false positives. + line = clean_lines.elided[linenum] + declarator_end = line.rfind(')') + if declarator_end >= 0: + fragment = line[declarator_end:] + else: + if linenum > 1 and clean_lines.elided[linenum - 1].rfind(')') >= 0: + fragment = line + else: + return + + # Check that at most one of "override" or "final" is present, not both + if Search(r'\boverride\b', fragment) and Search(r'\bfinal\b', fragment): + error(filename, linenum, 'readability/inheritance', 4, + ('"override" is redundant since function is ' + 'already declared as "final"')) + + + + +# Returns true if we are at a new block, and it is directly +# inside of a namespace. +def IsBlockInNameSpace(nesting_state, is_forward_declaration): + """Checks that the new block is directly in a namespace. + + Args: + nesting_state: The _NestingState object that contains info about our state. + is_forward_declaration: If the class is a forward declared class. + Returns: + Whether or not the new block is directly in a namespace. + """ + if is_forward_declaration: + if len(nesting_state.stack) >= 1 and ( + isinstance(nesting_state.stack[-1], _NamespaceInfo)): + return True + else: + return False + + return (len(nesting_state.stack) > 1 and + nesting_state.stack[-1].check_namespace_indentation and + isinstance(nesting_state.stack[-2], _NamespaceInfo)) + + +def ShouldCheckNamespaceIndentation(nesting_state, is_namespace_indent_item, + raw_lines_no_comments, linenum): + """This method determines if we should apply our namespace indentation check. + + Args: + nesting_state: The current nesting state. + is_namespace_indent_item: If we just put a new class on the stack, True. + If the top of the stack is not a class, or we did not recently + add the class, False. + raw_lines_no_comments: The lines without the comments. + linenum: The current line number we are processing. + + Returns: + True if we should apply our namespace indentation check. Currently, it + only works for classes and namespaces inside of a namespace. + """ + + is_forward_declaration = IsForwardClassDeclaration(raw_lines_no_comments, + linenum) + + if not (is_namespace_indent_item or is_forward_declaration): + return False + + # If we are in a macro, we do not want to check the namespace indentation. + if IsMacroDefinition(raw_lines_no_comments, linenum): + return False + + return IsBlockInNameSpace(nesting_state, is_forward_declaration) + + +# Call this method if the line is directly inside of a namespace. +# If the line above is blank (excluding comments) or the start of +# an inner namespace, it cannot be indented. +def CheckItemIndentationInNamespace(filename, raw_lines_no_comments, linenum, + error): + line = raw_lines_no_comments[linenum] + if Match(r'^\s+', line): + error(filename, linenum, 'runtime/indentation_namespace', 4, + 'Do not indent within a namespace') + + +def ProcessLine(filename, file_extension, clean_lines, line, + include_state, function_state, nesting_state, error, + extra_check_functions=[]): + """Processes a single line in the file. + + Args: + filename: Filename of the file that is being processed. + file_extension: The extension (dot not included) of the file. + clean_lines: An array of strings, each representing a line of the file, + with comments stripped. + line: Number of line being processed. + include_state: An _IncludeState instance in which the headers are inserted. + function_state: A _FunctionState instance which counts function lines, etc. + nesting_state: A NestingState instance which maintains information about + the current stack of nested blocks being parsed. + error: A callable to which errors are reported, which takes 4 arguments: + filename, line number, error level, and message + extra_check_functions: An array of additional check functions that will be + run on each source line. Each function takes 4 + arguments: filename, clean_lines, line, error + """ + raw_lines = clean_lines.raw_lines + ParseNolintSuppressions(filename, raw_lines[line], line, error) + nesting_state.Update(filename, clean_lines, line, error) + CheckForNamespaceIndentation(filename, nesting_state, clean_lines, line, + error) + if nesting_state.InAsmBlock(): return + CheckForFunctionLengths(filename, clean_lines, line, function_state, error) + CheckForMultilineCommentsAndStrings(filename, clean_lines, line, error) + CheckStyle(filename, clean_lines, line, file_extension, nesting_state, error) + CheckLanguage(filename, clean_lines, line, file_extension, include_state, + nesting_state, error) + CheckForNonConstReference(filename, clean_lines, line, nesting_state, error) + CheckForNonStandardConstructs(filename, clean_lines, line, + nesting_state, error) + CheckVlogArguments(filename, clean_lines, line, error) + CheckPosixThreading(filename, clean_lines, line, error) + CheckInvalidIncrement(filename, clean_lines, line, error) + CheckMakePairUsesDeduction(filename, clean_lines, line, error) + CheckDefaultLambdaCaptures(filename, clean_lines, line, error) + CheckRedundantVirtual(filename, clean_lines, line, error) + CheckRedundantOverrideOrFinal(filename, clean_lines, line, error) + for check_fn in extra_check_functions: + check_fn(filename, clean_lines, line, error) + +def FlagCxx11Features(filename, clean_lines, linenum, error): + """Flag those c++11 features that we only allow in certain places. + + Args: + filename: The name of the current file. + clean_lines: A CleansedLines instance containing the file. + linenum: The number of the line to check. + error: The function to call with any errors found. + """ + line = clean_lines.elided[linenum] + + # Flag unapproved C++11 headers. + include = Match(r'\s*#\s*include\s+[<"]([^<"]+)[">]', line) + if include and include.group(1) in ('cfenv', + 'condition_variable', + 'fenv.h', + 'future', + 'mutex', + 'thread', + 'chrono', + 'ratio', + 'regex', + 'system_error', + ): + error(filename, linenum, 'build/c++11', 5, + ('<%s> is an unapproved C++11 header.') % include.group(1)) + + # The only place where we need to worry about C++11 keywords and library + # features in preprocessor directives is in macro definitions. + if Match(r'\s*#', line) and not Match(r'\s*#\s*define\b', line): return + + # These are classes and free functions. The classes are always + # mentioned as std::*, but we only catch the free functions if + # they're not found by ADL. They're alphabetical by header. + for top_name in ( + # type_traits + 'alignment_of', + 'aligned_union', + ): + if Search(r'\bstd::%s\b' % top_name, line): + error(filename, linenum, 'build/c++11', 5, + ('std::%s is an unapproved C++11 class or function. Send c-style ' + 'an example of where it would make your code more readable, and ' + 'they may let you use it.') % top_name) + + +def ProcessFileData(filename, file_extension, lines, error, + extra_check_functions=[]): + """Performs lint checks and reports any errors to the given error function. + + Args: + filename: Filename of the file that is being processed. + file_extension: The extension (dot not included) of the file. + lines: An array of strings, each representing a line of the file, with the + last element being empty if the file is terminated with a newline. + error: A callable to which errors are reported, which takes 4 arguments: + filename, line number, error level, and message + extra_check_functions: An array of additional check functions that will be + run on each source line. Each function takes 4 + arguments: filename, clean_lines, line, error + """ + lines = (['// marker so line numbers and indices both start at 1'] + lines + + ['// marker so line numbers end in a known way']) + + include_state = _IncludeState() + function_state = _FunctionState() + nesting_state = NestingState() + + ResetNolintSuppressions() + + RemoveMultiLineComments(filename, lines, error) + clean_lines = CleansedLines(lines) + + if file_extension == 'h': + CheckForHeaderGuard(filename, clean_lines, error) + + for line in xrange(clean_lines.NumLines()): + ProcessLine(filename, file_extension, clean_lines, line, + include_state, function_state, nesting_state, error, + extra_check_functions) + FlagCxx11Features(filename, clean_lines, line, error) + nesting_state.CheckCompletedBlocks(filename, error) + + # Yangqing: disabled since Caffe2 puts std containers in common.h. + # CheckForIncludeWhatYouUse(filename, clean_lines, include_state, error) + + # Check that the .cc file has included its header if it exists. + if file_extension == 'cc': + CheckHeaderFileIncluded(filename, include_state, error) + + # We check here rather than inside ProcessLine so that we see raw + # lines rather than "cleaned" lines. + CheckForBadCharacters(filename, lines, error) + + CheckForNewlineAtEOF(filename, lines, error) + +def ProcessConfigOverrides(filename): + """ Loads the configuration files and processes the config overrides. + + Args: + filename: The name of the file being processed by the linter. + + Returns: + False if the current |filename| should not be processed further. + """ + + abs_filename = os.path.abspath(filename) + cfg_filters = [] + keep_looking = True + while keep_looking: + abs_path, base_name = os.path.split(abs_filename) + if not base_name: + break # Reached the root directory. + + cfg_file = os.path.join(abs_path, "CPPLINT.cfg") + abs_filename = abs_path + if not os.path.isfile(cfg_file): + continue + + try: + with open(cfg_file) as file_handle: + for line in file_handle: + line, _, _ = line.partition('#') # Remove comments. + if not line.strip(): + continue + + name, _, val = line.partition('=') + name = name.strip() + val = val.strip() + if name == 'set noparent': + keep_looking = False + elif name == 'filter': + cfg_filters.append(val) + elif name == 'exclude_files': + # When matching exclude_files pattern, use the base_name of + # the current file name or the directory name we are processing. + # For example, if we are checking for lint errors in /foo/bar/baz.cc + # and we found the .cfg file at /foo/CPPLINT.cfg, then the config + # file's "exclude_files" filter is meant to be checked against "bar" + # and not "baz" nor "bar/baz.cc". + if base_name: + pattern = re.compile(val) + if pattern.match(base_name): + sys.stderr.write('Ignoring "%s": file excluded by "%s". ' + 'File path component "%s" matches ' + 'pattern "%s"\n' % + (filename, cfg_file, base_name, val)) + return False + elif name == 'linelength': + global _line_length + try: + _line_length = int(val) + except ValueError: + sys.stderr.write('Line length must be numeric.') + else: + sys.stderr.write( + 'Invalid configuration option (%s) in file %s\n' % + (name, cfg_file)) + + except IOError: + sys.stderr.write( + "Skipping config file '%s': Can't open for reading\n" % cfg_file) + keep_looking = False + + # Apply all the accumulated filters in reverse order (top-level directory + # config options having the least priority). + for filter in reversed(cfg_filters): + _AddFilters(filter) + + return True + + +def ProcessFile(filename, vlevel, extra_check_functions=[]): + """Does google-lint on a single file. + + Args: + filename: The name of the file to parse. + + vlevel: The level of errors to report. Every error of confidence + >= verbose_level will be reported. 0 is a good default. + + extra_check_functions: An array of additional check functions that will be + run on each source line. Each function takes 4 + arguments: filename, clean_lines, line, error + """ + + _SetVerboseLevel(vlevel) + _BackupFilters() + + if not ProcessConfigOverrides(filename): + _RestoreFilters() + return + + lf_lines = [] + crlf_lines = [] + try: + # Support the UNIX convention of using "-" for stdin. Note that + # we are not opening the file with universal newline support + # (which codecs doesn't support anyway), so the resulting lines do + # contain trailing '\r' characters if we are reading a file that + # has CRLF endings. + # If after the split a trailing '\r' is present, it is removed + # below. + if filename == '-': + lines = codecs.StreamReaderWriter(sys.stdin, + codecs.getreader('utf8'), + codecs.getwriter('utf8'), + 'replace').read().split('\n') + else: + lines = codecs.open(filename, 'r', 'utf8', 'replace').read().split('\n') + + # Remove trailing '\r'. + # The -1 accounts for the extra trailing blank line we get from split() + for linenum in range(len(lines) - 1): + if lines[linenum].endswith('\r'): + lines[linenum] = lines[linenum].rstrip('\r') + crlf_lines.append(linenum + 1) + else: + lf_lines.append(linenum + 1) + + except IOError: + sys.stderr.write( + "Skipping input '%s': Can't open for reading\n" % filename) + _RestoreFilters() + return + + # Note, if no dot is found, this will give the entire filename as the ext. + file_extension = filename[filename.rfind('.') + 1:] + + # When reading from stdin, the extension is unknown, so no cpplint tests + # should rely on the extension. + if filename != '-' and file_extension not in _valid_extensions: + sys.stderr.write('Ignoring %s; not a valid file name ' + '(%s)\n' % (filename, ', '.join(_valid_extensions))) + else: + ProcessFileData(filename, file_extension, lines, Error, + extra_check_functions) + + # If end-of-line sequences are a mix of LF and CR-LF, issue + # warnings on the lines with CR. + # + # Don't issue any warnings if all lines are uniformly LF or CR-LF, + # since critique can handle these just fine, and the style guide + # doesn't dictate a particular end of line sequence. + # + # We can't depend on os.linesep to determine what the desired + # end-of-line sequence should be, since that will return the + # server-side end-of-line sequence. + if lf_lines and crlf_lines: + # Warn on every line with CR. An alternative approach might be to + # check whether the file is mostly CRLF or just LF, and warn on the + # minority, we bias toward LF here since most tools prefer LF. + for linenum in crlf_lines: + Error(filename, linenum, 'whitespace/newline', 1, + 'Unexpected \\r (^M) found; better to use only \\n') + + sys.stderr.write('Done processing %s\n' % filename) + _RestoreFilters() + + +def PrintUsage(message): + """Prints a brief usage string and exits, optionally with an error message. + + Args: + message: The optional error message. + """ + sys.stderr.write(_USAGE) + if message: + sys.exit('\nFATAL ERROR: ' + message) + else: + sys.exit(1) + + +def PrintCategories(): + """Prints a list of all the error-categories used by error messages. + + These are the categories used to filter messages via --filter. + """ + sys.stderr.write(''.join(' %s\n' % cat for cat in _ERROR_CATEGORIES)) + sys.exit(0) + + +def ParseArguments(args): + """Parses the command line arguments. + + This may set the output format and verbosity level as side-effects. + + Args: + args: The command line arguments: + + Returns: + The list of filenames to lint. + """ + try: + (opts, filenames) = getopt.getopt(args, '', ['help', 'output=', 'verbose=', + 'counting=', + 'filter=', + 'root=', + 'linelength=', + 'extensions=']) + except getopt.GetoptError: + PrintUsage('Invalid arguments.') + + verbosity = _VerboseLevel() + output_format = _OutputFormat() + filters = '' + counting_style = '' + + for (opt, val) in opts: + if opt == '--help': + PrintUsage(None) + elif opt == '--output': + if val not in ('emacs', 'vs7', 'eclipse'): + PrintUsage('The only allowed output formats are emacs, vs7 and eclipse.') + output_format = val + elif opt == '--verbose': + verbosity = int(val) + elif opt == '--filter': + filters = val + if not filters: + PrintCategories() + elif opt == '--counting': + if val not in ('total', 'toplevel', 'detailed'): + PrintUsage('Valid counting options are total, toplevel, and detailed') + counting_style = val + elif opt == '--root': + global _root + _root = val + elif opt == '--linelength': + global _line_length + try: + _line_length = int(val) + except ValueError: + PrintUsage('Line length must be digits.') + elif opt == '--extensions': + global _valid_extensions + try: + _valid_extensions = set(val.split(',')) + except ValueError: + PrintUsage('Extensions must be comma seperated list.') + + if not filenames: + PrintUsage('No files were specified.') + + _SetOutputFormat(output_format) + _SetVerboseLevel(verbosity) + _SetFilters(filters) + _SetCountingStyle(counting_style) + + return filenames + + +def main(): + filenames = ParseArguments(sys.argv[1:]) + + # Change stderr to write with replacement characters so we don't die + # if we try to print something containing non-ASCII characters. + sys.stderr = codecs.StreamReaderWriter(sys.stderr, + codecs.getreader('utf8'), + codecs.getwriter('utf8'), + 'replace') + + _cpplint_state.ResetErrorCounts() + for filename in filenames: + ProcessFile(filename, _cpplint_state.verbose_level) + _cpplint_state.PrintErrorCounts() + + sys.exit(_cpplint_state.error_count > 0) + + +if __name__ == '__main__': + main() diff --git a/gtest/BREW b/gtest/BREW new file mode 100644 index 00000000000..ce38f5cd5c1 --- /dev/null +++ b/gtest/BREW @@ -0,0 +1,28 @@ +cc_library( + name = "gtest", + srcs = ["gtest-all.cpp"], + hdrs = ["gtest.h"], + cflags = ["-DGTEST_USE_OWN_TR1_TUPLE=1"], +) + +cc_library( + name = "gtest_main", + srcs = ["gtest_main.cc"], + deps = [ + ":gtest", + "//third_party/gflags:gflags", + "//third_party/glog:glog" + ], + cflags = ["-DGTEST_USE_OWN_TR1_TUPLE=1"], +) + +cc_test( + name = "gtest_main_binary", + srcs = ["gtest_main.cc"], + deps = [ + ":gtest", + "//third_party/gflags:gflags", + "//third_party/glog:glog", + ], + cflags = ["-DGTEST_USE_OWN_TR1_TUPLE=1"], +) \ No newline at end of file diff --git a/gtest/LICENSE b/gtest/LICENSE new file mode 100644 index 00000000000..649083b57d3 --- /dev/null +++ b/gtest/LICENSE @@ -0,0 +1,3 @@ +New BSD License. See + +https://code.google.com/p/googletest/ diff --git a/gtest/gtest-all.cpp b/gtest/gtest-all.cpp new file mode 100644 index 00000000000..926197419fc --- /dev/null +++ b/gtest/gtest-all.cpp @@ -0,0 +1,9117 @@ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: mheule@google.com (Markus Heule) +// +// Google C++ Testing Framework (Google Test) +// +// Sometimes it's desirable to build Google Test by compiling a single file. +// This file serves this purpose. + +// This line ensures that gtest.h can be compiled on its own, even +// when it's fused. +#include "gtest/gtest.h" + +// The following lines pull in the real gtest *.cc files. +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) +// +// The Google C++ Testing Framework (Google Test) + +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) +// +// Utilities for testing Google Test itself and code that uses Google Test +// (e.g. frameworks built on top of Google Test). + +#ifndef GTEST_INCLUDE_GTEST_GTEST_SPI_H_ +#define GTEST_INCLUDE_GTEST_GTEST_SPI_H_ + + +namespace testing { + +// This helper class can be used to mock out Google Test failure reporting +// so that we can test Google Test or code that builds on Google Test. +// +// An object of this class appends a TestPartResult object to the +// TestPartResultArray object given in the constructor whenever a Google Test +// failure is reported. It can either intercept only failures that are +// generated in the same thread that created this object or it can intercept +// all generated failures. The scope of this mock object can be controlled with +// the second argument to the two arguments constructor. +class GTEST_API_ ScopedFakeTestPartResultReporter + : public TestPartResultReporterInterface { + public: + // The two possible mocking modes of this object. + enum InterceptMode { + INTERCEPT_ONLY_CURRENT_THREAD, // Intercepts only thread local failures. + INTERCEPT_ALL_THREADS // Intercepts all failures. + }; + + // The c'tor sets this object as the test part result reporter used + // by Google Test. The 'result' parameter specifies where to report the + // results. This reporter will only catch failures generated in the current + // thread. DEPRECATED + explicit ScopedFakeTestPartResultReporter(TestPartResultArray* result); + + // Same as above, but you can choose the interception scope of this object. + ScopedFakeTestPartResultReporter(InterceptMode intercept_mode, + TestPartResultArray* result); + + // The d'tor restores the previous test part result reporter. + virtual ~ScopedFakeTestPartResultReporter(); + + // Appends the TestPartResult object to the TestPartResultArray + // received in the constructor. + // + // This method is from the TestPartResultReporterInterface + // interface. + virtual void ReportTestPartResult(const TestPartResult& result); + private: + void Init(); + + const InterceptMode intercept_mode_; + TestPartResultReporterInterface* old_reporter_; + TestPartResultArray* const result_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ScopedFakeTestPartResultReporter); +}; + +namespace internal { + +// A helper class for implementing EXPECT_FATAL_FAILURE() and +// EXPECT_NONFATAL_FAILURE(). Its destructor verifies that the given +// TestPartResultArray contains exactly one failure that has the given +// type and contains the given substring. If that's not the case, a +// non-fatal failure will be generated. +class GTEST_API_ SingleFailureChecker { + public: + // The constructor remembers the arguments. + SingleFailureChecker(const TestPartResultArray* results, + TestPartResult::Type type, + const string& substr); + ~SingleFailureChecker(); + private: + const TestPartResultArray* const results_; + const TestPartResult::Type type_; + const string substr_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(SingleFailureChecker); +}; + +} // namespace internal + +} // namespace testing + +// A set of macros for testing Google Test assertions or code that's expected +// to generate Google Test fatal failures. It verifies that the given +// statement will cause exactly one fatal Google Test failure with 'substr' +// being part of the failure message. +// +// There are two different versions of this macro. EXPECT_FATAL_FAILURE only +// affects and considers failures generated in the current thread and +// EXPECT_FATAL_FAILURE_ON_ALL_THREADS does the same but for all threads. +// +// The verification of the assertion is done correctly even when the statement +// throws an exception or aborts the current function. +// +// Known restrictions: +// - 'statement' cannot reference local non-static variables or +// non-static members of the current object. +// - 'statement' cannot return a value. +// - You cannot stream a failure message to this macro. +// +// Note that even though the implementations of the following two +// macros are much alike, we cannot refactor them to use a common +// helper macro, due to some peculiarity in how the preprocessor +// works. The AcceptsMacroThatExpandsToUnprotectedComma test in +// gtest_unittest.cc will fail to compile if we do that. +#define EXPECT_FATAL_FAILURE(statement, substr) \ + do { \ + class GTestExpectFatalFailureHelper {\ + public:\ + static void Execute() { statement; }\ + };\ + ::testing::TestPartResultArray gtest_failures;\ + ::testing::internal::SingleFailureChecker gtest_checker(\ + >est_failures, ::testing::TestPartResult::kFatalFailure, (substr));\ + {\ + ::testing::ScopedFakeTestPartResultReporter gtest_reporter(\ + ::testing::ScopedFakeTestPartResultReporter:: \ + INTERCEPT_ONLY_CURRENT_THREAD, >est_failures);\ + GTestExpectFatalFailureHelper::Execute();\ + }\ + } while (::testing::internal::AlwaysFalse()) + +#define EXPECT_FATAL_FAILURE_ON_ALL_THREADS(statement, substr) \ + do { \ + class GTestExpectFatalFailureHelper {\ + public:\ + static void Execute() { statement; }\ + };\ + ::testing::TestPartResultArray gtest_failures;\ + ::testing::internal::SingleFailureChecker gtest_checker(\ + >est_failures, ::testing::TestPartResult::kFatalFailure, (substr));\ + {\ + ::testing::ScopedFakeTestPartResultReporter gtest_reporter(\ + ::testing::ScopedFakeTestPartResultReporter:: \ + INTERCEPT_ALL_THREADS, >est_failures);\ + GTestExpectFatalFailureHelper::Execute();\ + }\ + } while (::testing::internal::AlwaysFalse()) + +// A macro for testing Google Test assertions or code that's expected to +// generate Google Test non-fatal failures. It asserts that the given +// statement will cause exactly one non-fatal Google Test failure with 'substr' +// being part of the failure message. +// +// There are two different versions of this macro. EXPECT_NONFATAL_FAILURE only +// affects and considers failures generated in the current thread and +// EXPECT_NONFATAL_FAILURE_ON_ALL_THREADS does the same but for all threads. +// +// 'statement' is allowed to reference local variables and members of +// the current object. +// +// The verification of the assertion is done correctly even when the statement +// throws an exception or aborts the current function. +// +// Known restrictions: +// - You cannot stream a failure message to this macro. +// +// Note that even though the implementations of the following two +// macros are much alike, we cannot refactor them to use a common +// helper macro, due to some peculiarity in how the preprocessor +// works. If we do that, the code won't compile when the user gives +// EXPECT_NONFATAL_FAILURE() a statement that contains a macro that +// expands to code containing an unprotected comma. The +// AcceptsMacroThatExpandsToUnprotectedComma test in gtest_unittest.cc +// catches that. +// +// For the same reason, we have to write +// if (::testing::internal::AlwaysTrue()) { statement; } +// instead of +// GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement) +// to avoid an MSVC warning on unreachable code. +#define EXPECT_NONFATAL_FAILURE(statement, substr) \ + do {\ + ::testing::TestPartResultArray gtest_failures;\ + ::testing::internal::SingleFailureChecker gtest_checker(\ + >est_failures, ::testing::TestPartResult::kNonFatalFailure, \ + (substr));\ + {\ + ::testing::ScopedFakeTestPartResultReporter gtest_reporter(\ + ::testing::ScopedFakeTestPartResultReporter:: \ + INTERCEPT_ONLY_CURRENT_THREAD, >est_failures);\ + if (::testing::internal::AlwaysTrue()) { statement; }\ + }\ + } while (::testing::internal::AlwaysFalse()) + +#define EXPECT_NONFATAL_FAILURE_ON_ALL_THREADS(statement, substr) \ + do {\ + ::testing::TestPartResultArray gtest_failures;\ + ::testing::internal::SingleFailureChecker gtest_checker(\ + >est_failures, ::testing::TestPartResult::kNonFatalFailure, \ + (substr));\ + {\ + ::testing::ScopedFakeTestPartResultReporter gtest_reporter(\ + ::testing::ScopedFakeTestPartResultReporter::INTERCEPT_ALL_THREADS,\ + >est_failures);\ + if (::testing::internal::AlwaysTrue()) { statement; }\ + }\ + } while (::testing::internal::AlwaysFalse()) + +#endif // GTEST_INCLUDE_GTEST_GTEST_SPI_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include // NOLINT +#include +#include + +#if GTEST_OS_LINUX + +// TODO(kenton@google.com): Use autoconf to detect availability of +// gettimeofday(). +# define GTEST_HAS_GETTIMEOFDAY_ 1 + +# include // NOLINT +# include // NOLINT +# include // NOLINT +// Declares vsnprintf(). This header is not available on Windows. +# include // NOLINT +# include // NOLINT +# include // NOLINT +# include // NOLINT +# include + +#elif GTEST_OS_SYMBIAN +# define GTEST_HAS_GETTIMEOFDAY_ 1 +# include // NOLINT + +#elif GTEST_OS_ZOS +# define GTEST_HAS_GETTIMEOFDAY_ 1 +# include // NOLINT + +// On z/OS we additionally need strings.h for strcasecmp. +# include // NOLINT + +#elif GTEST_OS_WINDOWS_MOBILE // We are on Windows CE. + +# include // NOLINT + +#elif GTEST_OS_WINDOWS // We are on Windows proper. + +# include // NOLINT +# include // NOLINT +# include // NOLINT +# include // NOLINT + +# if GTEST_OS_WINDOWS_MINGW +// MinGW has gettimeofday() but not _ftime64(). +// TODO(kenton@google.com): Use autoconf to detect availability of +// gettimeofday(). +// TODO(kenton@google.com): There are other ways to get the time on +// Windows, like GetTickCount() or GetSystemTimeAsFileTime(). MinGW +// supports these. consider using them instead. +# define GTEST_HAS_GETTIMEOFDAY_ 1 +# include // NOLINT +# endif // GTEST_OS_WINDOWS_MINGW + +// cpplint thinks that the header is already included, so we want to +// silence it. +# include // NOLINT + +#else + +// Assume other platforms have gettimeofday(). +// TODO(kenton@google.com): Use autoconf to detect availability of +// gettimeofday(). +# define GTEST_HAS_GETTIMEOFDAY_ 1 + +// cpplint thinks that the header is already included, so we want to +// silence it. +# include // NOLINT +# include // NOLINT + +#endif // GTEST_OS_LINUX + +#if GTEST_HAS_EXCEPTIONS +# include +#endif + +#if GTEST_CAN_STREAM_RESULTS_ +# include // NOLINT +# include // NOLINT +#endif + +// Indicates that this translation unit is part of Google Test's +// implementation. It must come before gtest-internal-inl.h is +// included, or there will be a compiler error. This trick is to +// prevent a user from accidentally including gtest-internal-inl.h in +// his code. +#define GTEST_IMPLEMENTATION_ 1 +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Utility functions and classes used by the Google C++ testing framework. +// +// Author: wan@google.com (Zhanyong Wan) +// +// This file contains purely Google Test's internal implementation. Please +// DO NOT #INCLUDE IT IN A USER PROGRAM. + +#ifndef GTEST_SRC_GTEST_INTERNAL_INL_H_ +#define GTEST_SRC_GTEST_INTERNAL_INL_H_ + +// GTEST_IMPLEMENTATION_ is defined to 1 iff the current translation unit is +// part of Google Test's implementation; otherwise it's undefined. +#if !GTEST_IMPLEMENTATION_ +// A user is trying to include this from his code - just say no. +# error "gtest-internal-inl.h is part of Google Test's internal implementation." +# error "It must not be included except by Google Test itself." +#endif // GTEST_IMPLEMENTATION_ + +#ifndef _WIN32_WCE +# include +#endif // !_WIN32_WCE +#include +#include // For strtoll/_strtoul64/malloc/free. +#include // For memmove. + +#include +#include +#include + + +#if GTEST_OS_WINDOWS +# include // NOLINT +#endif // GTEST_OS_WINDOWS + + +namespace testing { + +// Declares the flags. +// +// We don't want the users to modify this flag in the code, but want +// Google Test's own unit tests to be able to access it. Therefore we +// declare it here as opposed to in gtest.h. +GTEST_DECLARE_bool_(death_test_use_fork); + +namespace internal { + +// The value of GetTestTypeId() as seen from within the Google Test +// library. This is solely for testing GetTestTypeId(). +GTEST_API_ extern const TypeId kTestTypeIdInGoogleTest; + +// Names of the flags (needed for parsing Google Test flags). +const char kAlsoRunDisabledTestsFlag[] = "also_run_disabled_tests"; +const char kBreakOnFailureFlag[] = "break_on_failure"; +const char kCatchExceptionsFlag[] = "catch_exceptions"; +const char kColorFlag[] = "color"; +const char kFilterFlag[] = "filter"; +const char kListTestsFlag[] = "list_tests"; +const char kOutputFlag[] = "output"; +const char kPrintTimeFlag[] = "print_time"; +const char kRandomSeedFlag[] = "random_seed"; +const char kRepeatFlag[] = "repeat"; +const char kShuffleFlag[] = "shuffle"; +const char kStackTraceDepthFlag[] = "stack_trace_depth"; +const char kStreamResultToFlag[] = "stream_result_to"; +const char kThrowOnFailureFlag[] = "throw_on_failure"; + +// A valid random seed must be in [1, kMaxRandomSeed]. +const int kMaxRandomSeed = 99999; + +// g_help_flag is true iff the --help flag or an equivalent form is +// specified on the command line. +GTEST_API_ extern bool g_help_flag; + +// Returns the current time in milliseconds. +GTEST_API_ TimeInMillis GetTimeInMillis(); + +// Returns true iff Google Test should use colors in the output. +GTEST_API_ bool ShouldUseColor(bool stdout_is_tty); + +// Formats the given time in milliseconds as seconds. +GTEST_API_ std::string FormatTimeInMillisAsSeconds(TimeInMillis ms); + +// Parses a string for an Int32 flag, in the form of "--flag=value". +// +// On success, stores the value of the flag in *value, and returns +// true. On failure, returns false without changing *value. +GTEST_API_ bool ParseInt32Flag( + const char* str, const char* flag, Int32* value); + +// Returns a random seed in range [1, kMaxRandomSeed] based on the +// given --gtest_random_seed flag value. +inline int GetRandomSeedFromFlag(Int32 random_seed_flag) { + const unsigned int raw_seed = (random_seed_flag == 0) ? + static_cast(GetTimeInMillis()) : + static_cast(random_seed_flag); + + // Normalizes the actual seed to range [1, kMaxRandomSeed] such that + // it's easy to type. + const int normalized_seed = + static_cast((raw_seed - 1U) % + static_cast(kMaxRandomSeed)) + 1; + return normalized_seed; +} + +// Returns the first valid random seed after 'seed'. The behavior is +// undefined if 'seed' is invalid. The seed after kMaxRandomSeed is +// considered to be 1. +inline int GetNextRandomSeed(int seed) { + GTEST_CHECK_(1 <= seed && seed <= kMaxRandomSeed) + << "Invalid random seed " << seed << " - must be in [1, " + << kMaxRandomSeed << "]."; + const int next_seed = seed + 1; + return (next_seed > kMaxRandomSeed) ? 1 : next_seed; +} + +// This class saves the values of all Google Test flags in its c'tor, and +// restores them in its d'tor. +class GTestFlagSaver { + public: + // The c'tor. + GTestFlagSaver() { + also_run_disabled_tests_ = GTEST_FLAG(also_run_disabled_tests); + break_on_failure_ = GTEST_FLAG(break_on_failure); + catch_exceptions_ = GTEST_FLAG(catch_exceptions); + color_ = GTEST_FLAG(color); + death_test_style_ = GTEST_FLAG(death_test_style); + death_test_use_fork_ = GTEST_FLAG(death_test_use_fork); + filter_ = GTEST_FLAG(filter); + internal_run_death_test_ = GTEST_FLAG(internal_run_death_test); + list_tests_ = GTEST_FLAG(list_tests); + output_ = GTEST_FLAG(output); + print_time_ = GTEST_FLAG(print_time); + random_seed_ = GTEST_FLAG(random_seed); + repeat_ = GTEST_FLAG(repeat); + shuffle_ = GTEST_FLAG(shuffle); + stack_trace_depth_ = GTEST_FLAG(stack_trace_depth); + stream_result_to_ = GTEST_FLAG(stream_result_to); + throw_on_failure_ = GTEST_FLAG(throw_on_failure); + } + + // The d'tor is not virtual. DO NOT INHERIT FROM THIS CLASS. + ~GTestFlagSaver() { + GTEST_FLAG(also_run_disabled_tests) = also_run_disabled_tests_; + GTEST_FLAG(break_on_failure) = break_on_failure_; + GTEST_FLAG(catch_exceptions) = catch_exceptions_; + GTEST_FLAG(color) = color_; + GTEST_FLAG(death_test_style) = death_test_style_; + GTEST_FLAG(death_test_use_fork) = death_test_use_fork_; + GTEST_FLAG(filter) = filter_; + GTEST_FLAG(internal_run_death_test) = internal_run_death_test_; + GTEST_FLAG(list_tests) = list_tests_; + GTEST_FLAG(output) = output_; + GTEST_FLAG(print_time) = print_time_; + GTEST_FLAG(random_seed) = random_seed_; + GTEST_FLAG(repeat) = repeat_; + GTEST_FLAG(shuffle) = shuffle_; + GTEST_FLAG(stack_trace_depth) = stack_trace_depth_; + GTEST_FLAG(stream_result_to) = stream_result_to_; + GTEST_FLAG(throw_on_failure) = throw_on_failure_; + } + private: + // Fields for saving the original values of flags. + bool also_run_disabled_tests_; + bool break_on_failure_; + bool catch_exceptions_; + String color_; + String death_test_style_; + bool death_test_use_fork_; + String filter_; + String internal_run_death_test_; + bool list_tests_; + String output_; + bool print_time_; + internal::Int32 random_seed_; + internal::Int32 repeat_; + bool shuffle_; + internal::Int32 stack_trace_depth_; + String stream_result_to_; + bool throw_on_failure_; +} GTEST_ATTRIBUTE_UNUSED_; + +// Converts a Unicode code point to a narrow string in UTF-8 encoding. +// code_point parameter is of type UInt32 because wchar_t may not be +// wide enough to contain a code point. +// The output buffer str must containt at least 32 characters. +// The function returns the address of the output buffer. +// If the code_point is not a valid Unicode code point +// (i.e. outside of Unicode range U+0 to U+10FFFF) it will be output +// as '(Invalid Unicode 0xXXXXXXXX)'. +GTEST_API_ char* CodePointToUtf8(UInt32 code_point, char* str); + +// Converts a wide string to a narrow string in UTF-8 encoding. +// The wide string is assumed to have the following encoding: +// UTF-16 if sizeof(wchar_t) == 2 (on Windows, Cygwin, Symbian OS) +// UTF-32 if sizeof(wchar_t) == 4 (on Linux) +// Parameter str points to a null-terminated wide string. +// Parameter num_chars may additionally limit the number +// of wchar_t characters processed. -1 is used when the entire string +// should be processed. +// If the string contains code points that are not valid Unicode code points +// (i.e. outside of Unicode range U+0 to U+10FFFF) they will be output +// as '(Invalid Unicode 0xXXXXXXXX)'. If the string is in UTF16 encoding +// and contains invalid UTF-16 surrogate pairs, values in those pairs +// will be encoded as individual Unicode characters from Basic Normal Plane. +GTEST_API_ String WideStringToUtf8(const wchar_t* str, int num_chars); + +// Reads the GTEST_SHARD_STATUS_FILE environment variable, and creates the file +// if the variable is present. If a file already exists at this location, this +// function will write over it. If the variable is present, but the file cannot +// be created, prints an error and exits. +void WriteToShardStatusFileIfNeeded(); + +// Checks whether sharding is enabled by examining the relevant +// environment variable values. If the variables are present, +// but inconsistent (e.g., shard_index >= total_shards), prints +// an error and exits. If in_subprocess_for_death_test, sharding is +// disabled because it must only be applied to the original test +// process. Otherwise, we could filter out death tests we intended to execute. +GTEST_API_ bool ShouldShard(const char* total_shards_str, + const char* shard_index_str, + bool in_subprocess_for_death_test); + +// Parses the environment variable var as an Int32. If it is unset, +// returns default_val. If it is not an Int32, prints an error and +// and aborts. +GTEST_API_ Int32 Int32FromEnvOrDie(const char* env_var, Int32 default_val); + +// Given the total number of shards, the shard index, and the test id, +// returns true iff the test should be run on this shard. The test id is +// some arbitrary but unique non-negative integer assigned to each test +// method. Assumes that 0 <= shard_index < total_shards. +GTEST_API_ bool ShouldRunTestOnShard( + int total_shards, int shard_index, int test_id); + +// STL container utilities. + +// Returns the number of elements in the given container that satisfy +// the given predicate. +template +inline int CountIf(const Container& c, Predicate predicate) { + // Implemented as an explicit loop since std::count_if() in libCstd on + // Solaris has a non-standard signature. + int count = 0; + for (typename Container::const_iterator it = c.begin(); it != c.end(); ++it) { + if (predicate(*it)) + ++count; + } + return count; +} + +// Applies a function/functor to each element in the container. +template +void ForEach(const Container& c, Functor functor) { + std::for_each(c.begin(), c.end(), functor); +} + +// Returns the i-th element of the vector, or default_value if i is not +// in range [0, v.size()). +template +inline E GetElementOr(const std::vector& v, int i, E default_value) { + return (i < 0 || i >= static_cast(v.size())) ? default_value : v[i]; +} + +// Performs an in-place shuffle of a range of the vector's elements. +// 'begin' and 'end' are element indices as an STL-style range; +// i.e. [begin, end) are shuffled, where 'end' == size() means to +// shuffle to the end of the vector. +template +void ShuffleRange(internal::Random* random, int begin, int end, + std::vector* v) { + const int size = static_cast(v->size()); + GTEST_CHECK_(0 <= begin && begin <= size) + << "Invalid shuffle range start " << begin << ": must be in range [0, " + << size << "]."; + GTEST_CHECK_(begin <= end && end <= size) + << "Invalid shuffle range finish " << end << ": must be in range [" + << begin << ", " << size << "]."; + + // Fisher-Yates shuffle, from + // http://en.wikipedia.org/wiki/Fisher-Yates_shuffle + for (int range_width = end - begin; range_width >= 2; range_width--) { + const int last_in_range = begin + range_width - 1; + const int selected = begin + random->Generate(range_width); + std::swap((*v)[selected], (*v)[last_in_range]); + } +} + +// Performs an in-place shuffle of the vector's elements. +template +inline void Shuffle(internal::Random* random, std::vector* v) { + ShuffleRange(random, 0, static_cast(v->size()), v); +} + +// A function for deleting an object. Handy for being used as a +// functor. +template +static void Delete(T* x) { + delete x; +} + +// A predicate that checks the key of a TestProperty against a known key. +// +// TestPropertyKeyIs is copyable. +class TestPropertyKeyIs { + public: + // Constructor. + // + // TestPropertyKeyIs has NO default constructor. + explicit TestPropertyKeyIs(const char* key) + : key_(key) {} + + // Returns true iff the test name of test property matches on key_. + bool operator()(const TestProperty& test_property) const { + return String(test_property.key()).Compare(key_) == 0; + } + + private: + String key_; +}; + +// Class UnitTestOptions. +// +// This class contains functions for processing options the user +// specifies when running the tests. It has only static members. +// +// In most cases, the user can specify an option using either an +// environment variable or a command line flag. E.g. you can set the +// test filter using either GTEST_FILTER or --gtest_filter. If both +// the variable and the flag are present, the latter overrides the +// former. +class GTEST_API_ UnitTestOptions { + public: + // Functions for processing the gtest_output flag. + + // Returns the output format, or "" for normal printed output. + static String GetOutputFormat(); + + // Returns the absolute path of the requested output file, or the + // default (test_detail.xml in the original working directory) if + // none was explicitly specified. + static String GetAbsolutePathToOutputFile(); + + // Functions for processing the gtest_filter flag. + + // Returns true iff the wildcard pattern matches the string. The + // first ':' or '\0' character in pattern marks the end of it. + // + // This recursive algorithm isn't very efficient, but is clear and + // works well enough for matching test names, which are short. + static bool PatternMatchesString(const char *pattern, const char *str); + + // Returns true iff the user-specified filter matches the test case + // name and the test name. + static bool FilterMatchesTest(const String &test_case_name, + const String &test_name); + +#if GTEST_OS_WINDOWS + // Function for supporting the gtest_catch_exception flag. + + // Returns EXCEPTION_EXECUTE_HANDLER if Google Test should handle the + // given SEH exception, or EXCEPTION_CONTINUE_SEARCH otherwise. + // This function is useful as an __except condition. + static int GTestShouldProcessSEH(DWORD exception_code); +#endif // GTEST_OS_WINDOWS + + // Returns true if "name" matches the ':' separated list of glob-style + // filters in "filter". + static bool MatchesFilter(const String& name, const char* filter); +}; + +// Returns the current application's name, removing directory path if that +// is present. Used by UnitTestOptions::GetOutputFile. +GTEST_API_ FilePath GetCurrentExecutableName(); + +// The role interface for getting the OS stack trace as a string. +class OsStackTraceGetterInterface { + public: + OsStackTraceGetterInterface() {} + virtual ~OsStackTraceGetterInterface() {} + + // Returns the current OS stack trace as a String. Parameters: + // + // max_depth - the maximum number of stack frames to be included + // in the trace. + // skip_count - the number of top frames to be skipped; doesn't count + // against max_depth. + virtual String CurrentStackTrace(int max_depth, int skip_count) = 0; + + // UponLeavingGTest() should be called immediately before Google Test calls + // user code. It saves some information about the current stack that + // CurrentStackTrace() will use to find and hide Google Test stack frames. + virtual void UponLeavingGTest() = 0; + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(OsStackTraceGetterInterface); +}; + +// A working implementation of the OsStackTraceGetterInterface interface. +class OsStackTraceGetter : public OsStackTraceGetterInterface { + public: + OsStackTraceGetter() : caller_frame_(NULL) {} + virtual String CurrentStackTrace(int max_depth, int skip_count); + virtual void UponLeavingGTest(); + + // This string is inserted in place of stack frames that are part of + // Google Test's implementation. + static const char* const kElidedFramesMarker; + + private: + Mutex mutex_; // protects all internal state + + // We save the stack frame below the frame that calls user code. + // We do this because the address of the frame immediately below + // the user code changes between the call to UponLeavingGTest() + // and any calls to CurrentStackTrace() from within the user code. + void* caller_frame_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(OsStackTraceGetter); +}; + +// Information about a Google Test trace point. +struct TraceInfo { + const char* file; + int line; + String message; +}; + +// This is the default global test part result reporter used in UnitTestImpl. +// This class should only be used by UnitTestImpl. +class DefaultGlobalTestPartResultReporter + : public TestPartResultReporterInterface { + public: + explicit DefaultGlobalTestPartResultReporter(UnitTestImpl* unit_test); + // Implements the TestPartResultReporterInterface. Reports the test part + // result in the current test. + virtual void ReportTestPartResult(const TestPartResult& result); + + private: + UnitTestImpl* const unit_test_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(DefaultGlobalTestPartResultReporter); +}; + +// This is the default per thread test part result reporter used in +// UnitTestImpl. This class should only be used by UnitTestImpl. +class DefaultPerThreadTestPartResultReporter + : public TestPartResultReporterInterface { + public: + explicit DefaultPerThreadTestPartResultReporter(UnitTestImpl* unit_test); + // Implements the TestPartResultReporterInterface. The implementation just + // delegates to the current global test part result reporter of *unit_test_. + virtual void ReportTestPartResult(const TestPartResult& result); + + private: + UnitTestImpl* const unit_test_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(DefaultPerThreadTestPartResultReporter); +}; + +// The private implementation of the UnitTest class. We don't protect +// the methods under a mutex, as this class is not accessible by a +// user and the UnitTest class that delegates work to this class does +// proper locking. +class GTEST_API_ UnitTestImpl { + public: + explicit UnitTestImpl(UnitTest* parent); + virtual ~UnitTestImpl(); + + // There are two different ways to register your own TestPartResultReporter. + // You can register your own repoter to listen either only for test results + // from the current thread or for results from all threads. + // By default, each per-thread test result repoter just passes a new + // TestPartResult to the global test result reporter, which registers the + // test part result for the currently running test. + + // Returns the global test part result reporter. + TestPartResultReporterInterface* GetGlobalTestPartResultReporter(); + + // Sets the global test part result reporter. + void SetGlobalTestPartResultReporter( + TestPartResultReporterInterface* reporter); + + // Returns the test part result reporter for the current thread. + TestPartResultReporterInterface* GetTestPartResultReporterForCurrentThread(); + + // Sets the test part result reporter for the current thread. + void SetTestPartResultReporterForCurrentThread( + TestPartResultReporterInterface* reporter); + + // Gets the number of successful test cases. + int successful_test_case_count() const; + + // Gets the number of failed test cases. + int failed_test_case_count() const; + + // Gets the number of all test cases. + int total_test_case_count() const; + + // Gets the number of all test cases that contain at least one test + // that should run. + int test_case_to_run_count() const; + + // Gets the number of successful tests. + int successful_test_count() const; + + // Gets the number of failed tests. + int failed_test_count() const; + + // Gets the number of disabled tests. + int disabled_test_count() const; + + // Gets the number of all tests. + int total_test_count() const; + + // Gets the number of tests that should run. + int test_to_run_count() const; + + // Gets the elapsed time, in milliseconds. + TimeInMillis elapsed_time() const { return elapsed_time_; } + + // Returns true iff the unit test passed (i.e. all test cases passed). + bool Passed() const { return !Failed(); } + + // Returns true iff the unit test failed (i.e. some test case failed + // or something outside of all tests failed). + bool Failed() const { + return failed_test_case_count() > 0 || ad_hoc_test_result()->Failed(); + } + + // Gets the i-th test case among all the test cases. i can range from 0 to + // total_test_case_count() - 1. If i is not in that range, returns NULL. + const TestCase* GetTestCase(int i) const { + const int index = GetElementOr(test_case_indices_, i, -1); + return index < 0 ? NULL : test_cases_[i]; + } + + // Gets the i-th test case among all the test cases. i can range from 0 to + // total_test_case_count() - 1. If i is not in that range, returns NULL. + TestCase* GetMutableTestCase(int i) { + const int index = GetElementOr(test_case_indices_, i, -1); + return index < 0 ? NULL : test_cases_[index]; + } + + // Provides access to the event listener list. + TestEventListeners* listeners() { return &listeners_; } + + // Returns the TestResult for the test that's currently running, or + // the TestResult for the ad hoc test if no test is running. + TestResult* current_test_result(); + + // Returns the TestResult for the ad hoc test. + const TestResult* ad_hoc_test_result() const { return &ad_hoc_test_result_; } + + // Sets the OS stack trace getter. + // + // Does nothing if the input and the current OS stack trace getter + // are the same; otherwise, deletes the old getter and makes the + // input the current getter. + void set_os_stack_trace_getter(OsStackTraceGetterInterface* getter); + + // Returns the current OS stack trace getter if it is not NULL; + // otherwise, creates an OsStackTraceGetter, makes it the current + // getter, and returns it. + OsStackTraceGetterInterface* os_stack_trace_getter(); + + // Returns the current OS stack trace as a String. + // + // The maximum number of stack frames to be included is specified by + // the gtest_stack_trace_depth flag. The skip_count parameter + // specifies the number of top frames to be skipped, which doesn't + // count against the number of frames to be included. + // + // For example, if Foo() calls Bar(), which in turn calls + // CurrentOsStackTraceExceptTop(1), Foo() will be included in the + // trace but Bar() and CurrentOsStackTraceExceptTop() won't. + String CurrentOsStackTraceExceptTop(int skip_count); + + // Finds and returns a TestCase with the given name. If one doesn't + // exist, creates one and returns it. + // + // Arguments: + // + // test_case_name: name of the test case + // type_param: the name of the test's type parameter, or NULL if + // this is not a typed or a type-parameterized test. + // set_up_tc: pointer to the function that sets up the test case + // tear_down_tc: pointer to the function that tears down the test case + TestCase* GetTestCase(const char* test_case_name, + const char* type_param, + Test::SetUpTestCaseFunc set_up_tc, + Test::TearDownTestCaseFunc tear_down_tc); + + // Adds a TestInfo to the unit test. + // + // Arguments: + // + // set_up_tc: pointer to the function that sets up the test case + // tear_down_tc: pointer to the function that tears down the test case + // test_info: the TestInfo object + void AddTestInfo(Test::SetUpTestCaseFunc set_up_tc, + Test::TearDownTestCaseFunc tear_down_tc, + TestInfo* test_info) { + // In order to support thread-safe death tests, we need to + // remember the original working directory when the test program + // was first invoked. We cannot do this in RUN_ALL_TESTS(), as + // the user may have changed the current directory before calling + // RUN_ALL_TESTS(). Therefore we capture the current directory in + // AddTestInfo(), which is called to register a TEST or TEST_F + // before main() is reached. + if (original_working_dir_.IsEmpty()) { + original_working_dir_.Set(FilePath::GetCurrentDir()); + GTEST_CHECK_(!original_working_dir_.IsEmpty()) + << "Failed to get the current working directory."; + } + + GetTestCase(test_info->test_case_name(), + test_info->type_param(), + set_up_tc, + tear_down_tc)->AddTestInfo(test_info); + } + +#if GTEST_HAS_PARAM_TEST + // Returns ParameterizedTestCaseRegistry object used to keep track of + // value-parameterized tests and instantiate and register them. + internal::ParameterizedTestCaseRegistry& parameterized_test_registry() { + return parameterized_test_registry_; + } +#endif // GTEST_HAS_PARAM_TEST + + // Sets the TestCase object for the test that's currently running. + void set_current_test_case(TestCase* a_current_test_case) { + current_test_case_ = a_current_test_case; + } + + // Sets the TestInfo object for the test that's currently running. If + // current_test_info is NULL, the assertion results will be stored in + // ad_hoc_test_result_. + void set_current_test_info(TestInfo* a_current_test_info) { + current_test_info_ = a_current_test_info; + } + + // Registers all parameterized tests defined using TEST_P and + // INSTANTIATE_TEST_CASE_P, creating regular tests for each test/parameter + // combination. This method can be called more then once; it has guards + // protecting from registering the tests more then once. If + // value-parameterized tests are disabled, RegisterParameterizedTests is + // present but does nothing. + void RegisterParameterizedTests(); + + // Runs all tests in this UnitTest object, prints the result, and + // returns true if all tests are successful. If any exception is + // thrown during a test, this test is considered to be failed, but + // the rest of the tests will still be run. + bool RunAllTests(); + + // Clears the results of all tests, except the ad hoc tests. + void ClearNonAdHocTestResult() { + ForEach(test_cases_, TestCase::ClearTestCaseResult); + } + + // Clears the results of ad-hoc test assertions. + void ClearAdHocTestResult() { + ad_hoc_test_result_.Clear(); + } + + enum ReactionToSharding { + HONOR_SHARDING_PROTOCOL, + IGNORE_SHARDING_PROTOCOL + }; + + // Matches the full name of each test against the user-specified + // filter to decide whether the test should run, then records the + // result in each TestCase and TestInfo object. + // If shard_tests == HONOR_SHARDING_PROTOCOL, further filters tests + // based on sharding variables in the environment. + // Returns the number of tests that should run. + int FilterTests(ReactionToSharding shard_tests); + + // Prints the names of the tests matching the user-specified filter flag. + void ListTestsMatchingFilter(); + + const TestCase* current_test_case() const { return current_test_case_; } + TestInfo* current_test_info() { return current_test_info_; } + const TestInfo* current_test_info() const { return current_test_info_; } + + // Returns the vector of environments that need to be set-up/torn-down + // before/after the tests are run. + std::vector& environments() { return environments_; } + + // Getters for the per-thread Google Test trace stack. + std::vector& gtest_trace_stack() { + return *(gtest_trace_stack_.pointer()); + } + const std::vector& gtest_trace_stack() const { + return gtest_trace_stack_.get(); + } + +#if GTEST_HAS_DEATH_TEST + void InitDeathTestSubprocessControlInfo() { + internal_run_death_test_flag_.reset(ParseInternalRunDeathTestFlag()); + } + // Returns a pointer to the parsed --gtest_internal_run_death_test + // flag, or NULL if that flag was not specified. + // This information is useful only in a death test child process. + // Must not be called before a call to InitGoogleTest. + const InternalRunDeathTestFlag* internal_run_death_test_flag() const { + return internal_run_death_test_flag_.get(); + } + + // Returns a pointer to the current death test factory. + internal::DeathTestFactory* death_test_factory() { + return death_test_factory_.get(); + } + + void SuppressTestEventsIfInSubprocess(); + + friend class ReplaceDeathTestFactory; +#endif // GTEST_HAS_DEATH_TEST + + // Initializes the event listener performing XML output as specified by + // UnitTestOptions. Must not be called before InitGoogleTest. + void ConfigureXmlOutput(); + +#if GTEST_CAN_STREAM_RESULTS_ + // Initializes the event listener for streaming test results to a socket. + // Must not be called before InitGoogleTest. + void ConfigureStreamingOutput(); +#endif + + // Performs initialization dependent upon flag values obtained in + // ParseGoogleTestFlagsOnly. Is called from InitGoogleTest after the call to + // ParseGoogleTestFlagsOnly. In case a user neglects to call InitGoogleTest + // this function is also called from RunAllTests. Since this function can be + // called more than once, it has to be idempotent. + void PostFlagParsingInit(); + + // Gets the random seed used at the start of the current test iteration. + int random_seed() const { return random_seed_; } + + // Gets the random number generator. + internal::Random* random() { return &random_; } + + // Shuffles all test cases, and the tests within each test case, + // making sure that death tests are still run first. + void ShuffleTests(); + + // Restores the test cases and tests to their order before the first shuffle. + void UnshuffleTests(); + + // Returns the value of GTEST_FLAG(catch_exceptions) at the moment + // UnitTest::Run() starts. + bool catch_exceptions() const { return catch_exceptions_; } + + private: + friend class ::testing::UnitTest; + + // Used by UnitTest::Run() to capture the state of + // GTEST_FLAG(catch_exceptions) at the moment it starts. + void set_catch_exceptions(bool value) { catch_exceptions_ = value; } + + // The UnitTest object that owns this implementation object. + UnitTest* const parent_; + + // The working directory when the first TEST() or TEST_F() was + // executed. + internal::FilePath original_working_dir_; + + // The default test part result reporters. + DefaultGlobalTestPartResultReporter default_global_test_part_result_reporter_; + DefaultPerThreadTestPartResultReporter + default_per_thread_test_part_result_reporter_; + + // Points to (but doesn't own) the global test part result reporter. + TestPartResultReporterInterface* global_test_part_result_repoter_; + + // Protects read and write access to global_test_part_result_reporter_. + internal::Mutex global_test_part_result_reporter_mutex_; + + // Points to (but doesn't own) the per-thread test part result reporter. + internal::ThreadLocal + per_thread_test_part_result_reporter_; + + // The vector of environments that need to be set-up/torn-down + // before/after the tests are run. + std::vector environments_; + + // The vector of TestCases in their original order. It owns the + // elements in the vector. + std::vector test_cases_; + + // Provides a level of indirection for the test case list to allow + // easy shuffling and restoring the test case order. The i-th + // element of this vector is the index of the i-th test case in the + // shuffled order. + std::vector test_case_indices_; + +#if GTEST_HAS_PARAM_TEST + // ParameterizedTestRegistry object used to register value-parameterized + // tests. + internal::ParameterizedTestCaseRegistry parameterized_test_registry_; + + // Indicates whether RegisterParameterizedTests() has been called already. + bool parameterized_tests_registered_; +#endif // GTEST_HAS_PARAM_TEST + + // Index of the last death test case registered. Initially -1. + int last_death_test_case_; + + // This points to the TestCase for the currently running test. It + // changes as Google Test goes through one test case after another. + // When no test is running, this is set to NULL and Google Test + // stores assertion results in ad_hoc_test_result_. Initially NULL. + TestCase* current_test_case_; + + // This points to the TestInfo for the currently running test. It + // changes as Google Test goes through one test after another. When + // no test is running, this is set to NULL and Google Test stores + // assertion results in ad_hoc_test_result_. Initially NULL. + TestInfo* current_test_info_; + + // Normally, a user only writes assertions inside a TEST or TEST_F, + // or inside a function called by a TEST or TEST_F. Since Google + // Test keeps track of which test is current running, it can + // associate such an assertion with the test it belongs to. + // + // If an assertion is encountered when no TEST or TEST_F is running, + // Google Test attributes the assertion result to an imaginary "ad hoc" + // test, and records the result in ad_hoc_test_result_. + TestResult ad_hoc_test_result_; + + // The list of event listeners that can be used to track events inside + // Google Test. + TestEventListeners listeners_; + + // The OS stack trace getter. Will be deleted when the UnitTest + // object is destructed. By default, an OsStackTraceGetter is used, + // but the user can set this field to use a custom getter if that is + // desired. + OsStackTraceGetterInterface* os_stack_trace_getter_; + + // True iff PostFlagParsingInit() has been called. + bool post_flag_parse_init_performed_; + + // The random number seed used at the beginning of the test run. + int random_seed_; + + // Our random number generator. + internal::Random random_; + + // How long the test took to run, in milliseconds. + TimeInMillis elapsed_time_; + +#if GTEST_HAS_DEATH_TEST + // The decomposed components of the gtest_internal_run_death_test flag, + // parsed when RUN_ALL_TESTS is called. + internal::scoped_ptr internal_run_death_test_flag_; + internal::scoped_ptr death_test_factory_; +#endif // GTEST_HAS_DEATH_TEST + + // A per-thread stack of traces created by the SCOPED_TRACE() macro. + internal::ThreadLocal > gtest_trace_stack_; + + // The value of GTEST_FLAG(catch_exceptions) at the moment RunAllTests() + // starts. + bool catch_exceptions_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(UnitTestImpl); +}; // class UnitTestImpl + +// Convenience function for accessing the global UnitTest +// implementation object. +inline UnitTestImpl* GetUnitTestImpl() { + return UnitTest::GetInstance()->impl(); +} + +#if GTEST_USES_SIMPLE_RE + +// Internal helper functions for implementing the simple regular +// expression matcher. +GTEST_API_ bool IsInSet(char ch, const char* str); +GTEST_API_ bool IsAsciiDigit(char ch); +GTEST_API_ bool IsAsciiPunct(char ch); +GTEST_API_ bool IsRepeat(char ch); +GTEST_API_ bool IsAsciiWhiteSpace(char ch); +GTEST_API_ bool IsAsciiWordChar(char ch); +GTEST_API_ bool IsValidEscape(char ch); +GTEST_API_ bool AtomMatchesChar(bool escaped, char pattern, char ch); +GTEST_API_ bool ValidateRegex(const char* regex); +GTEST_API_ bool MatchRegexAtHead(const char* regex, const char* str); +GTEST_API_ bool MatchRepetitionAndRegexAtHead( + bool escaped, char ch, char repeat, const char* regex, const char* str); +GTEST_API_ bool MatchRegexAnywhere(const char* regex, const char* str); + +#endif // GTEST_USES_SIMPLE_RE + +// Parses the command line for Google Test flags, without initializing +// other parts of Google Test. +GTEST_API_ void ParseGoogleTestFlagsOnly(int* argc, char** argv); +GTEST_API_ void ParseGoogleTestFlagsOnly(int* argc, wchar_t** argv); + +#if GTEST_HAS_DEATH_TEST + +// Returns the message describing the last system error, regardless of the +// platform. +GTEST_API_ String GetLastErrnoDescription(); + +# if GTEST_OS_WINDOWS +// Provides leak-safe Windows kernel handle ownership. +class AutoHandle { + public: + AutoHandle() : handle_(INVALID_HANDLE_VALUE) {} + explicit AutoHandle(HANDLE handle) : handle_(handle) {} + + ~AutoHandle() { Reset(); } + + HANDLE Get() const { return handle_; } + void Reset() { Reset(INVALID_HANDLE_VALUE); } + void Reset(HANDLE handle) { + if (handle != handle_) { + if (handle_ != INVALID_HANDLE_VALUE) + ::CloseHandle(handle_); + handle_ = handle; + } + } + + private: + HANDLE handle_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(AutoHandle); +}; +# endif // GTEST_OS_WINDOWS + +// Attempts to parse a string into a positive integer pointed to by the +// number parameter. Returns true if that is possible. +// GTEST_HAS_DEATH_TEST implies that we have ::std::string, so we can use +// it here. +template +bool ParseNaturalNumber(const ::std::string& str, Integer* number) { + // Fail fast if the given string does not begin with a digit; + // this bypasses strtoXXX's "optional leading whitespace and plus + // or minus sign" semantics, which are undesirable here. + if (str.empty() || !IsDigit(str[0])) { + return false; + } + errno = 0; + + char* end; + // BiggestConvertible is the largest integer type that system-provided + // string-to-number conversion routines can return. + +# if GTEST_OS_WINDOWS && !defined(__GNUC__) + + // MSVC and C++ Builder define __int64 instead of the standard long long. + typedef unsigned __int64 BiggestConvertible; + const BiggestConvertible parsed = _strtoui64(str.c_str(), &end, 10); + +# else + + typedef unsigned long long BiggestConvertible; // NOLINT + const BiggestConvertible parsed = strtoull(str.c_str(), &end, 10); + +# endif // GTEST_OS_WINDOWS && !defined(__GNUC__) + + const bool parse_success = *end == '\0' && errno == 0; + + // TODO(vladl@google.com): Convert this to compile time assertion when it is + // available. + GTEST_CHECK_(sizeof(Integer) <= sizeof(parsed)); + + const Integer result = static_cast(parsed); + if (parse_success && static_cast(result) == parsed) { + *number = result; + return true; + } + return false; +} +#endif // GTEST_HAS_DEATH_TEST + +// TestResult contains some private methods that should be hidden from +// Google Test user but are required for testing. This class allow our tests +// to access them. +// +// This class is supplied only for the purpose of testing Google Test's own +// constructs. Do not use it in user tests, either directly or indirectly. +class TestResultAccessor { + public: + static void RecordProperty(TestResult* test_result, + const TestProperty& property) { + test_result->RecordProperty(property); + } + + static void ClearTestPartResults(TestResult* test_result) { + test_result->ClearTestPartResults(); + } + + static const std::vector& test_part_results( + const TestResult& test_result) { + return test_result.test_part_results(); + } +}; + +} // namespace internal +} // namespace testing + +#endif // GTEST_SRC_GTEST_INTERNAL_INL_H_ +#undef GTEST_IMPLEMENTATION_ + +#if GTEST_OS_WINDOWS +# define vsnprintf _vsnprintf +#endif // GTEST_OS_WINDOWS + +namespace testing { + +using internal::CountIf; +using internal::ForEach; +using internal::GetElementOr; +using internal::Shuffle; + +// Constants. + +// A test whose test case name or test name matches this filter is +// disabled and not run. +static const char kDisableTestFilter[] = "DISABLED_*:*/DISABLED_*"; + +// A test case whose name matches this filter is considered a death +// test case and will be run before test cases whose name doesn't +// match this filter. +static const char kDeathTestCaseFilter[] = "*DeathTest:*DeathTest/*"; + +// A test filter that matches everything. +static const char kUniversalFilter[] = "*"; + +// The default output file for XML output. +static const char kDefaultOutputFile[] = "test_detail.xml"; + +// The environment variable name for the test shard index. +static const char kTestShardIndex[] = "GTEST_SHARD_INDEX"; +// The environment variable name for the total number of test shards. +static const char kTestTotalShards[] = "GTEST_TOTAL_SHARDS"; +// The environment variable name for the test shard status file. +static const char kTestShardStatusFile[] = "GTEST_SHARD_STATUS_FILE"; + +namespace internal { + +// The text used in failure messages to indicate the start of the +// stack trace. +const char kStackTraceMarker[] = "\nStack trace:\n"; + +// g_help_flag is true iff the --help flag or an equivalent form is +// specified on the command line. +bool g_help_flag = false; + +} // namespace internal + +GTEST_DEFINE_bool_( + also_run_disabled_tests, + internal::BoolFromGTestEnv("also_run_disabled_tests", false), + "Run disabled tests too, in addition to the tests normally being run."); + +GTEST_DEFINE_bool_( + break_on_failure, + internal::BoolFromGTestEnv("break_on_failure", false), + "True iff a failed assertion should be a debugger break-point."); + +GTEST_DEFINE_bool_( + catch_exceptions, + internal::BoolFromGTestEnv("catch_exceptions", true), + "True iff " GTEST_NAME_ + " should catch exceptions and treat them as test failures."); + +GTEST_DEFINE_string_( + color, + internal::StringFromGTestEnv("color", "auto"), + "Whether to use colors in the output. Valid values: yes, no, " + "and auto. 'auto' means to use colors if the output is " + "being sent to a terminal and the TERM environment variable " + "is set to xterm, xterm-color, xterm-256color, linux or cygwin."); + +GTEST_DEFINE_string_( + filter, + internal::StringFromGTestEnv("filter", kUniversalFilter), + "A colon-separated list of glob (not regex) patterns " + "for filtering the tests to run, optionally followed by a " + "'-' and a : separated list of negative patterns (tests to " + "exclude). A test is run if it matches one of the positive " + "patterns and does not match any of the negative patterns."); + +GTEST_DEFINE_bool_(list_tests, false, + "List all tests without running them."); + +GTEST_DEFINE_string_( + output, + internal::StringFromGTestEnv("output", ""), + "A format (currently must be \"xml\"), optionally followed " + "by a colon and an output file name or directory. A directory " + "is indicated by a trailing pathname separator. " + "Examples: \"xml:filename.xml\", \"xml::directoryname/\". " + "If a directory is specified, output files will be created " + "within that directory, with file-names based on the test " + "executable's name and, if necessary, made unique by adding " + "digits."); + +GTEST_DEFINE_bool_( + print_time, + internal::BoolFromGTestEnv("print_time", true), + "True iff " GTEST_NAME_ + " should display elapsed time in text output."); + +GTEST_DEFINE_int32_( + random_seed, + internal::Int32FromGTestEnv("random_seed", 0), + "Random number seed to use when shuffling test orders. Must be in range " + "[1, 99999], or 0 to use a seed based on the current time."); + +GTEST_DEFINE_int32_( + repeat, + internal::Int32FromGTestEnv("repeat", 1), + "How many times to repeat each test. Specify a negative number " + "for repeating forever. Useful for shaking out flaky tests."); + +GTEST_DEFINE_bool_( + show_internal_stack_frames, false, + "True iff " GTEST_NAME_ " should include internal stack frames when " + "printing test failure stack traces."); + +GTEST_DEFINE_bool_( + shuffle, + internal::BoolFromGTestEnv("shuffle", false), + "True iff " GTEST_NAME_ + " should randomize tests' order on every run."); + +GTEST_DEFINE_int32_( + stack_trace_depth, + internal::Int32FromGTestEnv("stack_trace_depth", kMaxStackTraceDepth), + "The maximum number of stack frames to print when an " + "assertion fails. The valid range is 0 through 100, inclusive."); + +GTEST_DEFINE_string_( + stream_result_to, + internal::StringFromGTestEnv("stream_result_to", ""), + "This flag specifies the host name and the port number on which to stream " + "test results. Example: \"localhost:555\". The flag is effective only on " + "Linux."); + +GTEST_DEFINE_bool_( + throw_on_failure, + internal::BoolFromGTestEnv("throw_on_failure", false), + "When this flag is specified, a failed assertion will throw an exception " + "if exceptions are enabled or exit the program with a non-zero code " + "otherwise."); + +namespace internal { + +// Generates a random number from [0, range), using a Linear +// Congruential Generator (LCG). Crashes if 'range' is 0 or greater +// than kMaxRange. +UInt32 Random::Generate(UInt32 range) { + // These constants are the same as are used in glibc's rand(3). + state_ = (1103515245U*state_ + 12345U) % kMaxRange; + + GTEST_CHECK_(range > 0) + << "Cannot generate a number in the range [0, 0)."; + GTEST_CHECK_(range <= kMaxRange) + << "Generation of a number in [0, " << range << ") was requested, " + << "but this can only generate numbers in [0, " << kMaxRange << ")."; + + // Converting via modulus introduces a bit of downward bias, but + // it's simple, and a linear congruential generator isn't too good + // to begin with. + return state_ % range; +} + +// GTestIsInitialized() returns true iff the user has initialized +// Google Test. Useful for catching the user mistake of not initializing +// Google Test before calling RUN_ALL_TESTS(). +// +// A user must call testing::InitGoogleTest() to initialize Google +// Test. g_init_gtest_count is set to the number of times +// InitGoogleTest() has been called. We don't protect this variable +// under a mutex as it is only accessed in the main thread. +int g_init_gtest_count = 0; +static bool GTestIsInitialized() { return g_init_gtest_count != 0; } + +// Iterates over a vector of TestCases, keeping a running sum of the +// results of calling a given int-returning method on each. +// Returns the sum. +static int SumOverTestCaseList(const std::vector& case_list, + int (TestCase::*method)() const) { + int sum = 0; + for (size_t i = 0; i < case_list.size(); i++) { + sum += (case_list[i]->*method)(); + } + return sum; +} + +// Returns true iff the test case passed. +static bool TestCasePassed(const TestCase* test_case) { + return test_case->should_run() && test_case->Passed(); +} + +// Returns true iff the test case failed. +static bool TestCaseFailed(const TestCase* test_case) { + return test_case->should_run() && test_case->Failed(); +} + +// Returns true iff test_case contains at least one test that should +// run. +static bool ShouldRunTestCase(const TestCase* test_case) { + return test_case->should_run(); +} + +// AssertHelper constructor. +AssertHelper::AssertHelper(TestPartResult::Type type, + const char* file, + int line, + const char* message) + : data_(new AssertHelperData(type, file, line, message)) { +} + +AssertHelper::~AssertHelper() { + delete data_; +} + +// Message assignment, for assertion streaming support. +void AssertHelper::operator=(const Message& message) const { + UnitTest::GetInstance()-> + AddTestPartResult(data_->type, data_->file, data_->line, + AppendUserMessage(data_->message, message), + UnitTest::GetInstance()->impl() + ->CurrentOsStackTraceExceptTop(1) + // Skips the stack frame for this function itself. + ); // NOLINT +} + +// Mutex for linked pointers. +GTEST_DEFINE_STATIC_MUTEX_(g_linked_ptr_mutex); + +// Application pathname gotten in InitGoogleTest. +String g_executable_path; + +// Returns the current application's name, removing directory path if that +// is present. +FilePath GetCurrentExecutableName() { + FilePath result; + +#if GTEST_OS_WINDOWS + result.Set(FilePath(g_executable_path).RemoveExtension("exe")); +#else + result.Set(FilePath(g_executable_path)); +#endif // GTEST_OS_WINDOWS + + return result.RemoveDirectoryName(); +} + +// Functions for processing the gtest_output flag. + +// Returns the output format, or "" for normal printed output. +String UnitTestOptions::GetOutputFormat() { + const char* const gtest_output_flag = GTEST_FLAG(output).c_str(); + if (gtest_output_flag == NULL) return String(""); + + const char* const colon = strchr(gtest_output_flag, ':'); + return (colon == NULL) ? + String(gtest_output_flag) : + String(gtest_output_flag, colon - gtest_output_flag); +} + +// Returns the name of the requested output file, or the default if none +// was explicitly specified. +String UnitTestOptions::GetAbsolutePathToOutputFile() { + const char* const gtest_output_flag = GTEST_FLAG(output).c_str(); + if (gtest_output_flag == NULL) + return String(""); + + const char* const colon = strchr(gtest_output_flag, ':'); + if (colon == NULL) + return String(internal::FilePath::ConcatPaths( + internal::FilePath( + UnitTest::GetInstance()->original_working_dir()), + internal::FilePath(kDefaultOutputFile)).ToString() ); + + internal::FilePath output_name(colon + 1); + if (!output_name.IsAbsolutePath()) + // TODO(wan@google.com): on Windows \some\path is not an absolute + // path (as its meaning depends on the current drive), yet the + // following logic for turning it into an absolute path is wrong. + // Fix it. + output_name = internal::FilePath::ConcatPaths( + internal::FilePath(UnitTest::GetInstance()->original_working_dir()), + internal::FilePath(colon + 1)); + + if (!output_name.IsDirectory()) + return output_name.ToString(); + + internal::FilePath result(internal::FilePath::GenerateUniqueFileName( + output_name, internal::GetCurrentExecutableName(), + GetOutputFormat().c_str())); + return result.ToString(); +} + +// Returns true iff the wildcard pattern matches the string. The +// first ':' or '\0' character in pattern marks the end of it. +// +// This recursive algorithm isn't very efficient, but is clear and +// works well enough for matching test names, which are short. +bool UnitTestOptions::PatternMatchesString(const char *pattern, + const char *str) { + switch (*pattern) { + case '\0': + case ':': // Either ':' or '\0' marks the end of the pattern. + return *str == '\0'; + case '?': // Matches any single character. + return *str != '\0' && PatternMatchesString(pattern + 1, str + 1); + case '*': // Matches any string (possibly empty) of characters. + return (*str != '\0' && PatternMatchesString(pattern, str + 1)) || + PatternMatchesString(pattern + 1, str); + default: // Non-special character. Matches itself. + return *pattern == *str && + PatternMatchesString(pattern + 1, str + 1); + } +} + +bool UnitTestOptions::MatchesFilter(const String& name, const char* filter) { + const char *cur_pattern = filter; + for (;;) { + if (PatternMatchesString(cur_pattern, name.c_str())) { + return true; + } + + // Finds the next pattern in the filter. + cur_pattern = strchr(cur_pattern, ':'); + + // Returns if no more pattern can be found. + if (cur_pattern == NULL) { + return false; + } + + // Skips the pattern separater (the ':' character). + cur_pattern++; + } +} + +// TODO(keithray): move String function implementations to gtest-string.cc. + +// Returns true iff the user-specified filter matches the test case +// name and the test name. +bool UnitTestOptions::FilterMatchesTest(const String &test_case_name, + const String &test_name) { + const String& full_name = String::Format("%s.%s", + test_case_name.c_str(), + test_name.c_str()); + + // Split --gtest_filter at '-', if there is one, to separate into + // positive filter and negative filter portions + const char* const p = GTEST_FLAG(filter).c_str(); + const char* const dash = strchr(p, '-'); + String positive; + String negative; + if (dash == NULL) { + positive = GTEST_FLAG(filter).c_str(); // Whole string is a positive filter + negative = String(""); + } else { + positive = String(p, dash - p); // Everything up to the dash + negative = String(dash+1); // Everything after the dash + if (positive.empty()) { + // Treat '-test1' as the same as '*-test1' + positive = kUniversalFilter; + } + } + + // A filter is a colon-separated list of patterns. It matches a + // test if any pattern in it matches the test. + return (MatchesFilter(full_name, positive.c_str()) && + !MatchesFilter(full_name, negative.c_str())); +} + +#if GTEST_HAS_SEH +// Returns EXCEPTION_EXECUTE_HANDLER if Google Test should handle the +// given SEH exception, or EXCEPTION_CONTINUE_SEARCH otherwise. +// This function is useful as an __except condition. +int UnitTestOptions::GTestShouldProcessSEH(DWORD exception_code) { + // Google Test should handle a SEH exception if: + // 1. the user wants it to, AND + // 2. this is not a breakpoint exception, AND + // 3. this is not a C++ exception (VC++ implements them via SEH, + // apparently). + // + // SEH exception code for C++ exceptions. + // (see http://support.microsoft.com/kb/185294 for more information). + const DWORD kCxxExceptionCode = 0xe06d7363; + + bool should_handle = true; + + if (!GTEST_FLAG(catch_exceptions)) + should_handle = false; + else if (exception_code == EXCEPTION_BREAKPOINT) + should_handle = false; + else if (exception_code == kCxxExceptionCode) + should_handle = false; + + return should_handle ? EXCEPTION_EXECUTE_HANDLER : EXCEPTION_CONTINUE_SEARCH; +} +#endif // GTEST_HAS_SEH + +} // namespace internal + +// The c'tor sets this object as the test part result reporter used by +// Google Test. The 'result' parameter specifies where to report the +// results. Intercepts only failures from the current thread. +ScopedFakeTestPartResultReporter::ScopedFakeTestPartResultReporter( + TestPartResultArray* result) + : intercept_mode_(INTERCEPT_ONLY_CURRENT_THREAD), + result_(result) { + Init(); +} + +// The c'tor sets this object as the test part result reporter used by +// Google Test. The 'result' parameter specifies where to report the +// results. +ScopedFakeTestPartResultReporter::ScopedFakeTestPartResultReporter( + InterceptMode intercept_mode, TestPartResultArray* result) + : intercept_mode_(intercept_mode), + result_(result) { + Init(); +} + +void ScopedFakeTestPartResultReporter::Init() { + internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); + if (intercept_mode_ == INTERCEPT_ALL_THREADS) { + old_reporter_ = impl->GetGlobalTestPartResultReporter(); + impl->SetGlobalTestPartResultReporter(this); + } else { + old_reporter_ = impl->GetTestPartResultReporterForCurrentThread(); + impl->SetTestPartResultReporterForCurrentThread(this); + } +} + +// The d'tor restores the test part result reporter used by Google Test +// before. +ScopedFakeTestPartResultReporter::~ScopedFakeTestPartResultReporter() { + internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); + if (intercept_mode_ == INTERCEPT_ALL_THREADS) { + impl->SetGlobalTestPartResultReporter(old_reporter_); + } else { + impl->SetTestPartResultReporterForCurrentThread(old_reporter_); + } +} + +// Increments the test part result count and remembers the result. +// This method is from the TestPartResultReporterInterface interface. +void ScopedFakeTestPartResultReporter::ReportTestPartResult( + const TestPartResult& result) { + result_->Append(result); +} + +namespace internal { + +// Returns the type ID of ::testing::Test. We should always call this +// instead of GetTypeId< ::testing::Test>() to get the type ID of +// testing::Test. This is to work around a suspected linker bug when +// using Google Test as a framework on Mac OS X. The bug causes +// GetTypeId< ::testing::Test>() to return different values depending +// on whether the call is from the Google Test framework itself or +// from user test code. GetTestTypeId() is guaranteed to always +// return the same value, as it always calls GetTypeId<>() from the +// gtest.cc, which is within the Google Test framework. +TypeId GetTestTypeId() { + return GetTypeId(); +} + +// The value of GetTestTypeId() as seen from within the Google Test +// library. This is solely for testing GetTestTypeId(). +extern const TypeId kTestTypeIdInGoogleTest = GetTestTypeId(); + +// This predicate-formatter checks that 'results' contains a test part +// failure of the given type and that the failure message contains the +// given substring. +AssertionResult HasOneFailure(const char* /* results_expr */, + const char* /* type_expr */, + const char* /* substr_expr */, + const TestPartResultArray& results, + TestPartResult::Type type, + const string& substr) { + const String expected(type == TestPartResult::kFatalFailure ? + "1 fatal failure" : + "1 non-fatal failure"); + Message msg; + if (results.size() != 1) { + msg << "Expected: " << expected << "\n" + << " Actual: " << results.size() << " failures"; + for (int i = 0; i < results.size(); i++) { + msg << "\n" << results.GetTestPartResult(i); + } + return AssertionFailure() << msg; + } + + const TestPartResult& r = results.GetTestPartResult(0); + if (r.type() != type) { + return AssertionFailure() << "Expected: " << expected << "\n" + << " Actual:\n" + << r; + } + + if (strstr(r.message(), substr.c_str()) == NULL) { + return AssertionFailure() << "Expected: " << expected << " containing \"" + << substr << "\"\n" + << " Actual:\n" + << r; + } + + return AssertionSuccess(); +} + +// The constructor of SingleFailureChecker remembers where to look up +// test part results, what type of failure we expect, and what +// substring the failure message should contain. +SingleFailureChecker:: SingleFailureChecker( + const TestPartResultArray* results, + TestPartResult::Type type, + const string& substr) + : results_(results), + type_(type), + substr_(substr) {} + +// The destructor of SingleFailureChecker verifies that the given +// TestPartResultArray contains exactly one failure that has the given +// type and contains the given substring. If that's not the case, a +// non-fatal failure will be generated. +SingleFailureChecker::~SingleFailureChecker() { + EXPECT_PRED_FORMAT3(HasOneFailure, *results_, type_, substr_); +} + +DefaultGlobalTestPartResultReporter::DefaultGlobalTestPartResultReporter( + UnitTestImpl* unit_test) : unit_test_(unit_test) {} + +void DefaultGlobalTestPartResultReporter::ReportTestPartResult( + const TestPartResult& result) { + unit_test_->current_test_result()->AddTestPartResult(result); + unit_test_->listeners()->repeater()->OnTestPartResult(result); +} + +DefaultPerThreadTestPartResultReporter::DefaultPerThreadTestPartResultReporter( + UnitTestImpl* unit_test) : unit_test_(unit_test) {} + +void DefaultPerThreadTestPartResultReporter::ReportTestPartResult( + const TestPartResult& result) { + unit_test_->GetGlobalTestPartResultReporter()->ReportTestPartResult(result); +} + +// Returns the global test part result reporter. +TestPartResultReporterInterface* +UnitTestImpl::GetGlobalTestPartResultReporter() { + internal::MutexLock lock(&global_test_part_result_reporter_mutex_); + return global_test_part_result_repoter_; +} + +// Sets the global test part result reporter. +void UnitTestImpl::SetGlobalTestPartResultReporter( + TestPartResultReporterInterface* reporter) { + internal::MutexLock lock(&global_test_part_result_reporter_mutex_); + global_test_part_result_repoter_ = reporter; +} + +// Returns the test part result reporter for the current thread. +TestPartResultReporterInterface* +UnitTestImpl::GetTestPartResultReporterForCurrentThread() { + return per_thread_test_part_result_reporter_.get(); +} + +// Sets the test part result reporter for the current thread. +void UnitTestImpl::SetTestPartResultReporterForCurrentThread( + TestPartResultReporterInterface* reporter) { + per_thread_test_part_result_reporter_.set(reporter); +} + +// Gets the number of successful test cases. +int UnitTestImpl::successful_test_case_count() const { + return CountIf(test_cases_, TestCasePassed); +} + +// Gets the number of failed test cases. +int UnitTestImpl::failed_test_case_count() const { + return CountIf(test_cases_, TestCaseFailed); +} + +// Gets the number of all test cases. +int UnitTestImpl::total_test_case_count() const { + return static_cast(test_cases_.size()); +} + +// Gets the number of all test cases that contain at least one test +// that should run. +int UnitTestImpl::test_case_to_run_count() const { + return CountIf(test_cases_, ShouldRunTestCase); +} + +// Gets the number of successful tests. +int UnitTestImpl::successful_test_count() const { + return SumOverTestCaseList(test_cases_, &TestCase::successful_test_count); +} + +// Gets the number of failed tests. +int UnitTestImpl::failed_test_count() const { + return SumOverTestCaseList(test_cases_, &TestCase::failed_test_count); +} + +// Gets the number of disabled tests. +int UnitTestImpl::disabled_test_count() const { + return SumOverTestCaseList(test_cases_, &TestCase::disabled_test_count); +} + +// Gets the number of all tests. +int UnitTestImpl::total_test_count() const { + return SumOverTestCaseList(test_cases_, &TestCase::total_test_count); +} + +// Gets the number of tests that should run. +int UnitTestImpl::test_to_run_count() const { + return SumOverTestCaseList(test_cases_, &TestCase::test_to_run_count); +} + +// Returns the current OS stack trace as a String. +// +// The maximum number of stack frames to be included is specified by +// the gtest_stack_trace_depth flag. The skip_count parameter +// specifies the number of top frames to be skipped, which doesn't +// count against the number of frames to be included. +// +// For example, if Foo() calls Bar(), which in turn calls +// CurrentOsStackTraceExceptTop(1), Foo() will be included in the +// trace but Bar() and CurrentOsStackTraceExceptTop() won't. +String UnitTestImpl::CurrentOsStackTraceExceptTop(int skip_count) { + (void)skip_count; + return String(""); +} + +// Returns the current time in milliseconds. +TimeInMillis GetTimeInMillis() { +#if GTEST_OS_WINDOWS_MOBILE || defined(__BORLANDC__) + // Difference between 1970-01-01 and 1601-01-01 in milliseconds. + // http://analogous.blogspot.com/2005/04/epoch.html + const TimeInMillis kJavaEpochToWinFileTimeDelta = + static_cast(116444736UL) * 100000UL; + const DWORD kTenthMicrosInMilliSecond = 10000; + + SYSTEMTIME now_systime; + FILETIME now_filetime; + ULARGE_INTEGER now_int64; + // TODO(kenton@google.com): Shouldn't this just use + // GetSystemTimeAsFileTime()? + GetSystemTime(&now_systime); + if (SystemTimeToFileTime(&now_systime, &now_filetime)) { + now_int64.LowPart = now_filetime.dwLowDateTime; + now_int64.HighPart = now_filetime.dwHighDateTime; + now_int64.QuadPart = (now_int64.QuadPart / kTenthMicrosInMilliSecond) - + kJavaEpochToWinFileTimeDelta; + return now_int64.QuadPart; + } + return 0; +#elif GTEST_OS_WINDOWS && !GTEST_HAS_GETTIMEOFDAY_ + __timeb64 now; + +# ifdef _MSC_VER + + // MSVC 8 deprecates _ftime64(), so we want to suppress warning 4996 + // (deprecated function) there. + // TODO(kenton@google.com): Use GetTickCount()? Or use + // SystemTimeToFileTime() +# pragma warning(push) // Saves the current warning state. +# pragma warning(disable:4996) // Temporarily disables warning 4996. + _ftime64(&now); +# pragma warning(pop) // Restores the warning state. +# else + + _ftime64(&now); + +# endif // _MSC_VER + + return static_cast(now.time) * 1000 + now.millitm; +#elif GTEST_HAS_GETTIMEOFDAY_ + struct timeval now; + gettimeofday(&now, NULL); + return static_cast(now.tv_sec) * 1000 + now.tv_usec / 1000; +#else +# error "Don't know how to get the current time on your system." +#endif +} + +// Utilities + +// class String + +// Returns the input enclosed in double quotes if it's not NULL; +// otherwise returns "(null)". For example, "\"Hello\"" is returned +// for input "Hello". +// +// This is useful for printing a C string in the syntax of a literal. +// +// Known issue: escape sequences are not handled yet. +String String::ShowCStringQuoted(const char* c_str) { + return c_str ? String::Format("\"%s\"", c_str) : String("(null)"); +} + +// Copies at most length characters from str into a newly-allocated +// piece of memory of size length+1. The memory is allocated with new[]. +// A terminating null byte is written to the memory, and a pointer to it +// is returned. If str is NULL, NULL is returned. +static char* CloneString(const char* str, size_t length) { + if (str == NULL) { + return NULL; + } else { + char* const clone = new char[length + 1]; + posix::StrNCpy(clone, str, length); + clone[length] = '\0'; + return clone; + } +} + +// Clones a 0-terminated C string, allocating memory using new. The +// caller is responsible for deleting[] the return value. Returns the +// cloned string, or NULL if the input is NULL. +const char * String::CloneCString(const char* c_str) { + return (c_str == NULL) ? + NULL : CloneString(c_str, strlen(c_str)); +} + +#if GTEST_OS_WINDOWS_MOBILE +// Creates a UTF-16 wide string from the given ANSI string, allocating +// memory using new. The caller is responsible for deleting the return +// value using delete[]. Returns the wide string, or NULL if the +// input is NULL. +LPCWSTR String::AnsiToUtf16(const char* ansi) { + if (!ansi) return NULL; + const int length = strlen(ansi); + const int unicode_length = + MultiByteToWideChar(CP_ACP, 0, ansi, length, + NULL, 0); + WCHAR* unicode = new WCHAR[unicode_length + 1]; + MultiByteToWideChar(CP_ACP, 0, ansi, length, + unicode, unicode_length); + unicode[unicode_length] = 0; + return unicode; +} + +// Creates an ANSI string from the given wide string, allocating +// memory using new. The caller is responsible for deleting the return +// value using delete[]. Returns the ANSI string, or NULL if the +// input is NULL. +const char* String::Utf16ToAnsi(LPCWSTR utf16_str) { + if (!utf16_str) return NULL; + const int ansi_length = + WideCharToMultiByte(CP_ACP, 0, utf16_str, -1, + NULL, 0, NULL, NULL); + char* ansi = new char[ansi_length + 1]; + WideCharToMultiByte(CP_ACP, 0, utf16_str, -1, + ansi, ansi_length, NULL, NULL); + ansi[ansi_length] = 0; + return ansi; +} + +#endif // GTEST_OS_WINDOWS_MOBILE + +// Compares two C strings. Returns true iff they have the same content. +// +// Unlike strcmp(), this function can handle NULL argument(s). A NULL +// C string is considered different to any non-NULL C string, +// including the empty string. +bool String::CStringEquals(const char * lhs, const char * rhs) { + if ( lhs == NULL ) return rhs == NULL; + + if ( rhs == NULL ) return false; + + return strcmp(lhs, rhs) == 0; +} + +#if GTEST_HAS_STD_WSTRING || GTEST_HAS_GLOBAL_WSTRING + +// Converts an array of wide chars to a narrow string using the UTF-8 +// encoding, and streams the result to the given Message object. +static void StreamWideCharsToMessage(const wchar_t* wstr, size_t length, + Message* msg) { + // TODO(wan): consider allowing a testing::String object to + // contain '\0'. This will make it behave more like std::string, + // and will allow ToUtf8String() to return the correct encoding + // for '\0' s.t. we can get rid of the conditional here (and in + // several other places). + for (size_t i = 0; i != length; ) { // NOLINT + if (wstr[i] != L'\0') { + *msg << WideStringToUtf8(wstr + i, static_cast(length - i)); + while (i != length && wstr[i] != L'\0') + i++; + } else { + *msg << '\0'; + i++; + } + } +} + +#endif // GTEST_HAS_STD_WSTRING || GTEST_HAS_GLOBAL_WSTRING + +} // namespace internal + +#if GTEST_HAS_STD_WSTRING +// Converts the given wide string to a narrow string using the UTF-8 +// encoding, and streams the result to this Message object. +Message& Message::operator <<(const ::std::wstring& wstr) { + internal::StreamWideCharsToMessage(wstr.c_str(), wstr.length(), this); + return *this; +} +#endif // GTEST_HAS_STD_WSTRING + +#if GTEST_HAS_GLOBAL_WSTRING +// Converts the given wide string to a narrow string using the UTF-8 +// encoding, and streams the result to this Message object. +Message& Message::operator <<(const ::wstring& wstr) { + internal::StreamWideCharsToMessage(wstr.c_str(), wstr.length(), this); + return *this; +} +#endif // GTEST_HAS_GLOBAL_WSTRING + +// AssertionResult constructors. +// Used in EXPECT_TRUE/FALSE(assertion_result). +AssertionResult::AssertionResult(const AssertionResult& other) + : success_(other.success_), + message_(other.message_.get() != NULL ? + new ::std::string(*other.message_) : + static_cast< ::std::string*>(NULL)) { +} + +// Returns the assertion's negation. Used with EXPECT/ASSERT_FALSE. +AssertionResult AssertionResult::operator!() const { + AssertionResult negation(!success_); + if (message_.get() != NULL) + negation << *message_; + return negation; +} + +// Makes a successful assertion result. +AssertionResult AssertionSuccess() { + return AssertionResult(true); +} + +// Makes a failed assertion result. +AssertionResult AssertionFailure() { + return AssertionResult(false); +} + +// Makes a failed assertion result with the given failure message. +// Deprecated; use AssertionFailure() << message. +AssertionResult AssertionFailure(const Message& message) { + return AssertionFailure() << message; +} + +namespace internal { + +// Constructs and returns the message for an equality assertion +// (e.g. ASSERT_EQ, EXPECT_STREQ, etc) failure. +// +// The first four parameters are the expressions used in the assertion +// and their values, as strings. For example, for ASSERT_EQ(foo, bar) +// where foo is 5 and bar is 6, we have: +// +// expected_expression: "foo" +// actual_expression: "bar" +// expected_value: "5" +// actual_value: "6" +// +// The ignoring_case parameter is true iff the assertion is a +// *_STRCASEEQ*. When it's true, the string " (ignoring case)" will +// be inserted into the message. +AssertionResult EqFailure(const char* expected_expression, + const char* actual_expression, + const String& expected_value, + const String& actual_value, + bool ignoring_case) { + Message msg; + msg << "Value of: " << actual_expression; + if (actual_value != actual_expression) { + msg << "\n Actual: " << actual_value; + } + + msg << "\nExpected: " << expected_expression; + if (ignoring_case) { + msg << " (ignoring case)"; + } + if (expected_value != expected_expression) { + msg << "\nWhich is: " << expected_value; + } + + return AssertionFailure() << msg; +} + +// Constructs a failure message for Boolean assertions such as EXPECT_TRUE. +String GetBoolAssertionFailureMessage(const AssertionResult& assertion_result, + const char* expression_text, + const char* actual_predicate_value, + const char* expected_predicate_value) { + const char* actual_message = assertion_result.message(); + Message msg; + msg << "Value of: " << expression_text + << "\n Actual: " << actual_predicate_value; + if (actual_message[0] != '\0') + msg << " (" << actual_message << ")"; + msg << "\nExpected: " << expected_predicate_value; + return msg.GetString(); +} + +// Helper function for implementing ASSERT_NEAR. +AssertionResult DoubleNearPredFormat(const char* expr1, + const char* expr2, + const char* abs_error_expr, + double val1, + double val2, + double abs_error) { + const double diff = fabs(val1 - val2); + if (diff <= abs_error) return AssertionSuccess(); + + // TODO(wan): do not print the value of an expression if it's + // already a literal. + return AssertionFailure() + << "The difference between " << expr1 << " and " << expr2 + << " is " << diff << ", which exceeds " << abs_error_expr << ", where\n" + << expr1 << " evaluates to " << val1 << ",\n" + << expr2 << " evaluates to " << val2 << ", and\n" + << abs_error_expr << " evaluates to " << abs_error << "."; +} + + +// Helper template for implementing FloatLE() and DoubleLE(). +template +AssertionResult FloatingPointLE(const char* expr1, + const char* expr2, + RawType val1, + RawType val2) { + // Returns success if val1 is less than val2, + if (val1 < val2) { + return AssertionSuccess(); + } + + // or if val1 is almost equal to val2. + const FloatingPoint lhs(val1), rhs(val2); + if (lhs.AlmostEquals(rhs)) { + return AssertionSuccess(); + } + + // Note that the above two checks will both fail if either val1 or + // val2 is NaN, as the IEEE floating-point standard requires that + // any predicate involving a NaN must return false. + + ::std::stringstream val1_ss; + val1_ss << std::setprecision(std::numeric_limits::digits10 + 2) + << val1; + + ::std::stringstream val2_ss; + val2_ss << std::setprecision(std::numeric_limits::digits10 + 2) + << val2; + + return AssertionFailure() + << "Expected: (" << expr1 << ") <= (" << expr2 << ")\n" + << " Actual: " << StringStreamToString(&val1_ss) << " vs " + << StringStreamToString(&val2_ss); +} + +} // namespace internal + +// Asserts that val1 is less than, or almost equal to, val2. Fails +// otherwise. In particular, it fails if either val1 or val2 is NaN. +AssertionResult FloatLE(const char* expr1, const char* expr2, + float val1, float val2) { + return internal::FloatingPointLE(expr1, expr2, val1, val2); +} + +// Asserts that val1 is less than, or almost equal to, val2. Fails +// otherwise. In particular, it fails if either val1 or val2 is NaN. +AssertionResult DoubleLE(const char* expr1, const char* expr2, + double val1, double val2) { + return internal::FloatingPointLE(expr1, expr2, val1, val2); +} + +namespace internal { + +// The helper function for {ASSERT|EXPECT}_EQ with int or enum +// arguments. +AssertionResult CmpHelperEQ(const char* expected_expression, + const char* actual_expression, + BiggestInt expected, + BiggestInt actual) { + if (expected == actual) { + return AssertionSuccess(); + } + + return EqFailure(expected_expression, + actual_expression, + FormatForComparisonFailureMessage(expected, actual), + FormatForComparisonFailureMessage(actual, expected), + false); +} + +// A macro for implementing the helper functions needed to implement +// ASSERT_?? and EXPECT_?? with integer or enum arguments. It is here +// just to avoid copy-and-paste of similar code. +#define GTEST_IMPL_CMP_HELPER_(op_name, op)\ +AssertionResult CmpHelper##op_name(const char* expr1, const char* expr2, \ + BiggestInt val1, BiggestInt val2) {\ + if (val1 op val2) {\ + return AssertionSuccess();\ + } else {\ + return AssertionFailure() \ + << "Expected: (" << expr1 << ") " #op " (" << expr2\ + << "), actual: " << FormatForComparisonFailureMessage(val1, val2)\ + << " vs " << FormatForComparisonFailureMessage(val2, val1);\ + }\ +} + +// Implements the helper function for {ASSERT|EXPECT}_NE with int or +// enum arguments. +GTEST_IMPL_CMP_HELPER_(NE, !=) +// Implements the helper function for {ASSERT|EXPECT}_LE with int or +// enum arguments. +GTEST_IMPL_CMP_HELPER_(LE, <=) +// Implements the helper function for {ASSERT|EXPECT}_LT with int or +// enum arguments. +GTEST_IMPL_CMP_HELPER_(LT, < ) +// Implements the helper function for {ASSERT|EXPECT}_GE with int or +// enum arguments. +GTEST_IMPL_CMP_HELPER_(GE, >=) +// Implements the helper function for {ASSERT|EXPECT}_GT with int or +// enum arguments. +GTEST_IMPL_CMP_HELPER_(GT, > ) + +#undef GTEST_IMPL_CMP_HELPER_ + +// The helper function for {ASSERT|EXPECT}_STREQ. +AssertionResult CmpHelperSTREQ(const char* expected_expression, + const char* actual_expression, + const char* expected, + const char* actual) { + if (String::CStringEquals(expected, actual)) { + return AssertionSuccess(); + } + + return EqFailure(expected_expression, + actual_expression, + String::ShowCStringQuoted(expected), + String::ShowCStringQuoted(actual), + false); +} + +// The helper function for {ASSERT|EXPECT}_STRCASEEQ. +AssertionResult CmpHelperSTRCASEEQ(const char* expected_expression, + const char* actual_expression, + const char* expected, + const char* actual) { + if (String::CaseInsensitiveCStringEquals(expected, actual)) { + return AssertionSuccess(); + } + + return EqFailure(expected_expression, + actual_expression, + String::ShowCStringQuoted(expected), + String::ShowCStringQuoted(actual), + true); +} + +// The helper function for {ASSERT|EXPECT}_STRNE. +AssertionResult CmpHelperSTRNE(const char* s1_expression, + const char* s2_expression, + const char* s1, + const char* s2) { + if (!String::CStringEquals(s1, s2)) { + return AssertionSuccess(); + } else { + return AssertionFailure() << "Expected: (" << s1_expression << ") != (" + << s2_expression << "), actual: \"" + << s1 << "\" vs \"" << s2 << "\""; + } +} + +// The helper function for {ASSERT|EXPECT}_STRCASENE. +AssertionResult CmpHelperSTRCASENE(const char* s1_expression, + const char* s2_expression, + const char* s1, + const char* s2) { + if (!String::CaseInsensitiveCStringEquals(s1, s2)) { + return AssertionSuccess(); + } else { + return AssertionFailure() + << "Expected: (" << s1_expression << ") != (" + << s2_expression << ") (ignoring case), actual: \"" + << s1 << "\" vs \"" << s2 << "\""; + } +} + +} // namespace internal + +namespace { + +// Helper functions for implementing IsSubString() and IsNotSubstring(). + +// This group of overloaded functions return true iff needle is a +// substring of haystack. NULL is considered a substring of itself +// only. + +bool IsSubstringPred(const char* needle, const char* haystack) { + if (needle == NULL || haystack == NULL) + return needle == haystack; + + return strstr(haystack, needle) != NULL; +} + +bool IsSubstringPred(const wchar_t* needle, const wchar_t* haystack) { + if (needle == NULL || haystack == NULL) + return needle == haystack; + + return wcsstr(haystack, needle) != NULL; +} + +// StringType here can be either ::std::string or ::std::wstring. +template +bool IsSubstringPred(const StringType& needle, + const StringType& haystack) { + return haystack.find(needle) != StringType::npos; +} + +// This function implements either IsSubstring() or IsNotSubstring(), +// depending on the value of the expected_to_be_substring parameter. +// StringType here can be const char*, const wchar_t*, ::std::string, +// or ::std::wstring. +template +AssertionResult IsSubstringImpl( + bool expected_to_be_substring, + const char* needle_expr, const char* haystack_expr, + const StringType& needle, const StringType& haystack) { + if (IsSubstringPred(needle, haystack) == expected_to_be_substring) + return AssertionSuccess(); + + const bool is_wide_string = sizeof(needle[0]) > 1; + const char* const begin_string_quote = is_wide_string ? "L\"" : "\""; + return AssertionFailure() + << "Value of: " << needle_expr << "\n" + << " Actual: " << begin_string_quote << needle << "\"\n" + << "Expected: " << (expected_to_be_substring ? "" : "not ") + << "a substring of " << haystack_expr << "\n" + << "Which is: " << begin_string_quote << haystack << "\""; +} + +} // namespace + +// IsSubstring() and IsNotSubstring() check whether needle is a +// substring of haystack (NULL is considered a substring of itself +// only), and return an appropriate error message when they fail. + +AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const char* needle, const char* haystack) { + return IsSubstringImpl(true, needle_expr, haystack_expr, needle, haystack); +} + +AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const wchar_t* needle, const wchar_t* haystack) { + return IsSubstringImpl(true, needle_expr, haystack_expr, needle, haystack); +} + +AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const char* needle, const char* haystack) { + return IsSubstringImpl(false, needle_expr, haystack_expr, needle, haystack); +} + +AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const wchar_t* needle, const wchar_t* haystack) { + return IsSubstringImpl(false, needle_expr, haystack_expr, needle, haystack); +} + +AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::string& needle, const ::std::string& haystack) { + return IsSubstringImpl(true, needle_expr, haystack_expr, needle, haystack); +} + +AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::string& needle, const ::std::string& haystack) { + return IsSubstringImpl(false, needle_expr, haystack_expr, needle, haystack); +} + +#if GTEST_HAS_STD_WSTRING +AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::wstring& needle, const ::std::wstring& haystack) { + return IsSubstringImpl(true, needle_expr, haystack_expr, needle, haystack); +} + +AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::wstring& needle, const ::std::wstring& haystack) { + return IsSubstringImpl(false, needle_expr, haystack_expr, needle, haystack); +} +#endif // GTEST_HAS_STD_WSTRING + +namespace internal { + +#if GTEST_OS_WINDOWS + +namespace { + +// Helper function for IsHRESULT{SuccessFailure} predicates +AssertionResult HRESULTFailureHelper(const char* expr, + const char* expected, + long hr) { // NOLINT +# if GTEST_OS_WINDOWS_MOBILE + + // Windows CE doesn't support FormatMessage. + const char error_text[] = ""; + +# else + + // Looks up the human-readable system message for the HRESULT code + // and since we're not passing any params to FormatMessage, we don't + // want inserts expanded. + const DWORD kFlags = FORMAT_MESSAGE_FROM_SYSTEM | + FORMAT_MESSAGE_IGNORE_INSERTS; + const DWORD kBufSize = 4096; // String::Format can't exceed this length. + // Gets the system's human readable message string for this HRESULT. + char error_text[kBufSize] = { '\0' }; + DWORD message_length = ::FormatMessageA(kFlags, + 0, // no source, we're asking system + hr, // the error + 0, // no line width restrictions + error_text, // output buffer + kBufSize, // buf size + NULL); // no arguments for inserts + // Trims tailing white space (FormatMessage leaves a trailing cr-lf) + for (; message_length && IsSpace(error_text[message_length - 1]); + --message_length) { + error_text[message_length - 1] = '\0'; + } + +# endif // GTEST_OS_WINDOWS_MOBILE + + const String error_hex(String::Format("0x%08X ", hr)); + return ::testing::AssertionFailure() + << "Expected: " << expr << " " << expected << ".\n" + << " Actual: " << error_hex << error_text << "\n"; +} + +} // namespace + +AssertionResult IsHRESULTSuccess(const char* expr, long hr) { // NOLINT + if (SUCCEEDED(hr)) { + return AssertionSuccess(); + } + return HRESULTFailureHelper(expr, "succeeds", hr); +} + +AssertionResult IsHRESULTFailure(const char* expr, long hr) { // NOLINT + if (FAILED(hr)) { + return AssertionSuccess(); + } + return HRESULTFailureHelper(expr, "fails", hr); +} + +#endif // GTEST_OS_WINDOWS + +// Utility functions for encoding Unicode text (wide strings) in +// UTF-8. + +// A Unicode code-point can have upto 21 bits, and is encoded in UTF-8 +// like this: +// +// Code-point length Encoding +// 0 - 7 bits 0xxxxxxx +// 8 - 11 bits 110xxxxx 10xxxxxx +// 12 - 16 bits 1110xxxx 10xxxxxx 10xxxxxx +// 17 - 21 bits 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + +// The maximum code-point a one-byte UTF-8 sequence can represent. +const UInt32 kMaxCodePoint1 = (static_cast(1) << 7) - 1; + +// The maximum code-point a two-byte UTF-8 sequence can represent. +const UInt32 kMaxCodePoint2 = (static_cast(1) << (5 + 6)) - 1; + +// The maximum code-point a three-byte UTF-8 sequence can represent. +const UInt32 kMaxCodePoint3 = (static_cast(1) << (4 + 2*6)) - 1; + +// The maximum code-point a four-byte UTF-8 sequence can represent. +const UInt32 kMaxCodePoint4 = (static_cast(1) << (3 + 3*6)) - 1; + +// Chops off the n lowest bits from a bit pattern. Returns the n +// lowest bits. As a side effect, the original bit pattern will be +// shifted to the right by n bits. +inline UInt32 ChopLowBits(UInt32* bits, int n) { + const UInt32 low_bits = *bits & ((static_cast(1) << n) - 1); + *bits >>= n; + return low_bits; +} + +// Converts a Unicode code point to a narrow string in UTF-8 encoding. +// code_point parameter is of type UInt32 because wchar_t may not be +// wide enough to contain a code point. +// The output buffer str must containt at least 32 characters. +// The function returns the address of the output buffer. +// If the code_point is not a valid Unicode code point +// (i.e. outside of Unicode range U+0 to U+10FFFF) it will be output +// as '(Invalid Unicode 0xXXXXXXXX)'. +char* CodePointToUtf8(UInt32 code_point, char* str) { + if (code_point <= kMaxCodePoint1) { + str[1] = '\0'; + str[0] = static_cast(code_point); // 0xxxxxxx + } else if (code_point <= kMaxCodePoint2) { + str[2] = '\0'; + str[1] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx + str[0] = static_cast(0xC0 | code_point); // 110xxxxx + } else if (code_point <= kMaxCodePoint3) { + str[3] = '\0'; + str[2] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx + str[1] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx + str[0] = static_cast(0xE0 | code_point); // 1110xxxx + } else if (code_point <= kMaxCodePoint4) { + str[4] = '\0'; + str[3] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx + str[2] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx + str[1] = static_cast(0x80 | ChopLowBits(&code_point, 6)); // 10xxxxxx + str[0] = static_cast(0xF0 | code_point); // 11110xxx + } else { + // The longest string String::Format can produce when invoked + // with these parameters is 28 character long (not including + // the terminating nul character). We are asking for 32 character + // buffer just in case. This is also enough for strncpy to + // null-terminate the destination string. + posix::StrNCpy( + str, String::Format("(Invalid Unicode 0x%X)", code_point).c_str(), 32); + str[31] = '\0'; // Makes sure no change in the format to strncpy leaves + // the result unterminated. + } + return str; +} + +// The following two functions only make sense if the the system +// uses UTF-16 for wide string encoding. All supported systems +// with 16 bit wchar_t (Windows, Cygwin, Symbian OS) do use UTF-16. + +// Determines if the arguments constitute UTF-16 surrogate pair +// and thus should be combined into a single Unicode code point +// using CreateCodePointFromUtf16SurrogatePair. +inline bool IsUtf16SurrogatePair(wchar_t first, wchar_t second) { + return sizeof(wchar_t) == 2 && + (first & 0xFC00) == 0xD800 && (second & 0xFC00) == 0xDC00; +} + +// Creates a Unicode code point from UTF16 surrogate pair. +inline UInt32 CreateCodePointFromUtf16SurrogatePair(wchar_t first, + wchar_t second) { + const UInt32 mask = (1 << 10) - 1; + return (sizeof(wchar_t) == 2) ? + (((first & mask) << 10) | (second & mask)) + 0x10000 : + // This function should not be called when the condition is + // false, but we provide a sensible default in case it is. + static_cast(first); +} + +// Converts a wide string to a narrow string in UTF-8 encoding. +// The wide string is assumed to have the following encoding: +// UTF-16 if sizeof(wchar_t) == 2 (on Windows, Cygwin, Symbian OS) +// UTF-32 if sizeof(wchar_t) == 4 (on Linux) +// Parameter str points to a null-terminated wide string. +// Parameter num_chars may additionally limit the number +// of wchar_t characters processed. -1 is used when the entire string +// should be processed. +// If the string contains code points that are not valid Unicode code points +// (i.e. outside of Unicode range U+0 to U+10FFFF) they will be output +// as '(Invalid Unicode 0xXXXXXXXX)'. If the string is in UTF16 encoding +// and contains invalid UTF-16 surrogate pairs, values in those pairs +// will be encoded as individual Unicode characters from Basic Normal Plane. +String WideStringToUtf8(const wchar_t* str, int num_chars) { + if (num_chars == -1) + num_chars = static_cast(wcslen(str)); + + ::std::stringstream stream; + for (int i = 0; i < num_chars; ++i) { + UInt32 unicode_code_point; + + if (str[i] == L'\0') { + break; + } else if (i + 1 < num_chars && IsUtf16SurrogatePair(str[i], str[i + 1])) { + unicode_code_point = CreateCodePointFromUtf16SurrogatePair(str[i], + str[i + 1]); + i++; + } else { + unicode_code_point = static_cast(str[i]); + } + + char buffer[32]; // CodePointToUtf8 requires a buffer this big. + stream << CodePointToUtf8(unicode_code_point, buffer); + } + return StringStreamToString(&stream); +} + +// Converts a wide C string to a String using the UTF-8 encoding. +// NULL will be converted to "(null)". +String String::ShowWideCString(const wchar_t * wide_c_str) { + if (wide_c_str == NULL) return String("(null)"); + + return String(internal::WideStringToUtf8(wide_c_str, -1).c_str()); +} + +// Similar to ShowWideCString(), except that this function encloses +// the converted string in double quotes. +String String::ShowWideCStringQuoted(const wchar_t* wide_c_str) { + if (wide_c_str == NULL) return String("(null)"); + + return String::Format("L\"%s\"", + String::ShowWideCString(wide_c_str).c_str()); +} + +// Compares two wide C strings. Returns true iff they have the same +// content. +// +// Unlike wcscmp(), this function can handle NULL argument(s). A NULL +// C string is considered different to any non-NULL C string, +// including the empty string. +bool String::WideCStringEquals(const wchar_t * lhs, const wchar_t * rhs) { + if (lhs == NULL) return rhs == NULL; + + if (rhs == NULL) return false; + + return wcscmp(lhs, rhs) == 0; +} + +// Helper function for *_STREQ on wide strings. +AssertionResult CmpHelperSTREQ(const char* expected_expression, + const char* actual_expression, + const wchar_t* expected, + const wchar_t* actual) { + if (String::WideCStringEquals(expected, actual)) { + return AssertionSuccess(); + } + + return EqFailure(expected_expression, + actual_expression, + String::ShowWideCStringQuoted(expected), + String::ShowWideCStringQuoted(actual), + false); +} + +// Helper function for *_STRNE on wide strings. +AssertionResult CmpHelperSTRNE(const char* s1_expression, + const char* s2_expression, + const wchar_t* s1, + const wchar_t* s2) { + if (!String::WideCStringEquals(s1, s2)) { + return AssertionSuccess(); + } + + return AssertionFailure() << "Expected: (" << s1_expression << ") != (" + << s2_expression << "), actual: " + << String::ShowWideCStringQuoted(s1) + << " vs " << String::ShowWideCStringQuoted(s2); +} + +// Compares two C strings, ignoring case. Returns true iff they have +// the same content. +// +// Unlike strcasecmp(), this function can handle NULL argument(s). A +// NULL C string is considered different to any non-NULL C string, +// including the empty string. +bool String::CaseInsensitiveCStringEquals(const char * lhs, const char * rhs) { + if (lhs == NULL) + return rhs == NULL; + if (rhs == NULL) + return false; + return posix::StrCaseCmp(lhs, rhs) == 0; +} + + // Compares two wide C strings, ignoring case. Returns true iff they + // have the same content. + // + // Unlike wcscasecmp(), this function can handle NULL argument(s). + // A NULL C string is considered different to any non-NULL wide C string, + // including the empty string. + // NB: The implementations on different platforms slightly differ. + // On windows, this method uses _wcsicmp which compares according to LC_CTYPE + // environment variable. On GNU platform this method uses wcscasecmp + // which compares according to LC_CTYPE category of the current locale. + // On MacOS X, it uses towlower, which also uses LC_CTYPE category of the + // current locale. +bool String::CaseInsensitiveWideCStringEquals(const wchar_t* lhs, + const wchar_t* rhs) { + if (lhs == NULL) return rhs == NULL; + + if (rhs == NULL) return false; + +#if GTEST_OS_WINDOWS + return _wcsicmp(lhs, rhs) == 0; +#elif GTEST_OS_LINUX && !GTEST_OS_LINUX_ANDROID + return wcscasecmp(lhs, rhs) == 0; +#else + // Android, Mac OS X and Cygwin don't define wcscasecmp. + // Other unknown OSes may not define it either. + wint_t left, right; + do { + left = towlower(*lhs++); + right = towlower(*rhs++); + } while (left && left == right); + return left == right; +#endif // OS selector +} + +// Compares this with another String. +// Returns < 0 if this is less than rhs, 0 if this is equal to rhs, or > 0 +// if this is greater than rhs. +int String::Compare(const String & rhs) const { + const char* const lhs_c_str = c_str(); + const char* const rhs_c_str = rhs.c_str(); + + if (lhs_c_str == NULL) { + return rhs_c_str == NULL ? 0 : -1; // NULL < anything except NULL + } else if (rhs_c_str == NULL) { + return 1; + } + + const size_t shorter_str_len = + length() <= rhs.length() ? length() : rhs.length(); + for (size_t i = 0; i != shorter_str_len; i++) { + if (lhs_c_str[i] < rhs_c_str[i]) { + return -1; + } else if (lhs_c_str[i] > rhs_c_str[i]) { + return 1; + } + } + return (length() < rhs.length()) ? -1 : + (length() > rhs.length()) ? 1 : 0; +} + +// Returns true iff this String ends with the given suffix. *Any* +// String is considered to end with a NULL or empty suffix. +bool String::EndsWith(const char* suffix) const { + if (suffix == NULL || CStringEquals(suffix, "")) return true; + + if (c_str() == NULL) return false; + + const size_t this_len = strlen(c_str()); + const size_t suffix_len = strlen(suffix); + return (this_len >= suffix_len) && + CStringEquals(c_str() + this_len - suffix_len, suffix); +} + +// Returns true iff this String ends with the given suffix, ignoring case. +// Any String is considered to end with a NULL or empty suffix. +bool String::EndsWithCaseInsensitive(const char* suffix) const { + if (suffix == NULL || CStringEquals(suffix, "")) return true; + + if (c_str() == NULL) return false; + + const size_t this_len = strlen(c_str()); + const size_t suffix_len = strlen(suffix); + return (this_len >= suffix_len) && + CaseInsensitiveCStringEquals(c_str() + this_len - suffix_len, suffix); +} + +// Formats a list of arguments to a String, using the same format +// spec string as for printf. +// +// We do not use the StringPrintf class as it is not universally +// available. +// +// The result is limited to 4096 characters (including the tailing 0). +// If 4096 characters are not enough to format the input, or if +// there's an error, "" is +// returned. +String String::Format(const char * format, ...) { + va_list args; + va_start(args, format); + + char buffer[4096]; + const int kBufferSize = sizeof(buffer)/sizeof(buffer[0]); + + // MSVC 8 deprecates vsnprintf(), so we want to suppress warning + // 4996 (deprecated function) there. +#ifdef _MSC_VER // We are using MSVC. +# pragma warning(push) // Saves the current warning state. +# pragma warning(disable:4996) // Temporarily disables warning 4996. + + const int size = vsnprintf(buffer, kBufferSize, format, args); + +# pragma warning(pop) // Restores the warning state. +#else // We are not using MSVC. + const int size = vsnprintf(buffer, kBufferSize, format, args); +#endif // _MSC_VER + va_end(args); + + // vsnprintf()'s behavior is not portable. When the buffer is not + // big enough, it returns a negative value in MSVC, and returns the + // needed buffer size on Linux. When there is an output error, it + // always returns a negative value. For simplicity, we lump the two + // error cases together. + if (size < 0 || size >= kBufferSize) { + return String(""); + } else { + return String(buffer, size); + } +} + +// Converts the buffer in a stringstream to a String, converting NUL +// bytes to "\\0" along the way. +String StringStreamToString(::std::stringstream* ss) { + const ::std::string& str = ss->str(); + const char* const start = str.c_str(); + const char* const end = start + str.length(); + + // We need to use a helper stringstream to do this transformation + // because String doesn't support push_back(). + ::std::stringstream helper; + for (const char* ch = start; ch != end; ++ch) { + if (*ch == '\0') { + helper << "\\0"; // Replaces NUL with "\\0"; + } else { + helper.put(*ch); + } + } + + return String(helper.str().c_str()); +} + +// Appends the user-supplied message to the Google-Test-generated message. +String AppendUserMessage(const String& gtest_msg, + const Message& user_msg) { + // Appends the user message if it's non-empty. + const String user_msg_string = user_msg.GetString(); + if (user_msg_string.empty()) { + return gtest_msg; + } + + Message msg; + msg << gtest_msg << "\n" << user_msg_string; + + return msg.GetString(); +} + +} // namespace internal + +// class TestResult + +// Creates an empty TestResult. +TestResult::TestResult() + : death_test_count_(0), + elapsed_time_(0) { +} + +// D'tor. +TestResult::~TestResult() { +} + +// Returns the i-th test part result among all the results. i can +// range from 0 to total_part_count() - 1. If i is not in that range, +// aborts the program. +const TestPartResult& TestResult::GetTestPartResult(int i) const { + if (i < 0 || i >= total_part_count()) + internal::posix::Abort(); + return test_part_results_.at(i); +} + +// Returns the i-th test property. i can range from 0 to +// test_property_count() - 1. If i is not in that range, aborts the +// program. +const TestProperty& TestResult::GetTestProperty(int i) const { + if (i < 0 || i >= test_property_count()) + internal::posix::Abort(); + return test_properties_.at(i); +} + +// Clears the test part results. +void TestResult::ClearTestPartResults() { + test_part_results_.clear(); +} + +// Adds a test part result to the list. +void TestResult::AddTestPartResult(const TestPartResult& test_part_result) { + test_part_results_.push_back(test_part_result); +} + +// Adds a test property to the list. If a property with the same key as the +// supplied property is already represented, the value of this test_property +// replaces the old value for that key. +void TestResult::RecordProperty(const TestProperty& test_property) { + if (!ValidateTestProperty(test_property)) { + return; + } + internal::MutexLock lock(&test_properites_mutex_); + const std::vector::iterator property_with_matching_key = + std::find_if(test_properties_.begin(), test_properties_.end(), + internal::TestPropertyKeyIs(test_property.key())); + if (property_with_matching_key == test_properties_.end()) { + test_properties_.push_back(test_property); + return; + } + property_with_matching_key->SetValue(test_property.value()); +} + +// Adds a failure if the key is a reserved attribute of Google Test +// testcase tags. Returns true if the property is valid. +bool TestResult::ValidateTestProperty(const TestProperty& test_property) { + internal::String key(test_property.key()); + if (key == "name" || key == "status" || key == "time" || key == "classname") { + ADD_FAILURE() + << "Reserved key used in RecordProperty(): " + << key + << " ('name', 'status', 'time', and 'classname' are reserved by " + << GTEST_NAME_ << ")"; + return false; + } + return true; +} + +// Clears the object. +void TestResult::Clear() { + test_part_results_.clear(); + test_properties_.clear(); + death_test_count_ = 0; + elapsed_time_ = 0; +} + +// Returns true iff the test failed. +bool TestResult::Failed() const { + for (int i = 0; i < total_part_count(); ++i) { + if (GetTestPartResult(i).failed()) + return true; + } + return false; +} + +// Returns true iff the test part fatally failed. +static bool TestPartFatallyFailed(const TestPartResult& result) { + return result.fatally_failed(); +} + +// Returns true iff the test fatally failed. +bool TestResult::HasFatalFailure() const { + return CountIf(test_part_results_, TestPartFatallyFailed) > 0; +} + +// Returns true iff the test part non-fatally failed. +static bool TestPartNonfatallyFailed(const TestPartResult& result) { + return result.nonfatally_failed(); +} + +// Returns true iff the test has a non-fatal failure. +bool TestResult::HasNonfatalFailure() const { + return CountIf(test_part_results_, TestPartNonfatallyFailed) > 0; +} + +// Gets the number of all test parts. This is the sum of the number +// of successful test parts and the number of failed test parts. +int TestResult::total_part_count() const { + return static_cast(test_part_results_.size()); +} + +// Returns the number of the test properties. +int TestResult::test_property_count() const { + return static_cast(test_properties_.size()); +} + +// class Test + +// Creates a Test object. + +// The c'tor saves the values of all Google Test flags. +Test::Test() + : gtest_flag_saver_(new internal::GTestFlagSaver) { +} + +// The d'tor restores the values of all Google Test flags. +Test::~Test() { + delete gtest_flag_saver_; +} + +// Sets up the test fixture. +// +// A sub-class may override this. +void Test::SetUp() { +} + +// Tears down the test fixture. +// +// A sub-class may override this. +void Test::TearDown() { +} + +// Allows user supplied key value pairs to be recorded for later output. +void Test::RecordProperty(const char* key, const char* value) { + UnitTest::GetInstance()->RecordPropertyForCurrentTest(key, value); +} + +// Allows user supplied key value pairs to be recorded for later output. +void Test::RecordProperty(const char* key, int value) { + Message value_message; + value_message << value; + RecordProperty(key, value_message.GetString().c_str()); +} + +namespace internal { + +void ReportFailureInUnknownLocation(TestPartResult::Type result_type, + const String& message) { + // This function is a friend of UnitTest and as such has access to + // AddTestPartResult. + UnitTest::GetInstance()->AddTestPartResult( + result_type, + NULL, // No info about the source file where the exception occurred. + -1, // We have no info on which line caused the exception. + message, + String()); // No stack trace, either. +} + +} // namespace internal + +// Google Test requires all tests in the same test case to use the same test +// fixture class. This function checks if the current test has the +// same fixture class as the first test in the current test case. If +// yes, it returns true; otherwise it generates a Google Test failure and +// returns false. +bool Test::HasSameFixtureClass() { + internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); + const TestCase* const test_case = impl->current_test_case(); + + // Info about the first test in the current test case. + const TestInfo* const first_test_info = test_case->test_info_list()[0]; + const internal::TypeId first_fixture_id = first_test_info->fixture_class_id_; + const char* const first_test_name = first_test_info->name(); + + // Info about the current test. + const TestInfo* const this_test_info = impl->current_test_info(); + const internal::TypeId this_fixture_id = this_test_info->fixture_class_id_; + const char* const this_test_name = this_test_info->name(); + + if (this_fixture_id != first_fixture_id) { + // Is the first test defined using TEST? + const bool first_is_TEST = first_fixture_id == internal::GetTestTypeId(); + // Is this test defined using TEST? + const bool this_is_TEST = this_fixture_id == internal::GetTestTypeId(); + + if (first_is_TEST || this_is_TEST) { + // The user mixed TEST and TEST_F in this test case - we'll tell + // him/her how to fix it. + + // Gets the name of the TEST and the name of the TEST_F. Note + // that first_is_TEST and this_is_TEST cannot both be true, as + // the fixture IDs are different for the two tests. + const char* const TEST_name = + first_is_TEST ? first_test_name : this_test_name; + const char* const TEST_F_name = + first_is_TEST ? this_test_name : first_test_name; + + ADD_FAILURE() + << "All tests in the same test case must use the same test fixture\n" + << "class, so mixing TEST_F and TEST in the same test case is\n" + << "illegal. In test case " << this_test_info->test_case_name() + << ",\n" + << "test " << TEST_F_name << " is defined using TEST_F but\n" + << "test " << TEST_name << " is defined using TEST. You probably\n" + << "want to change the TEST to TEST_F or move it to another test\n" + << "case."; + } else { + // The user defined two fixture classes with the same name in + // two namespaces - we'll tell him/her how to fix it. + ADD_FAILURE() + << "All tests in the same test case must use the same test fixture\n" + << "class. However, in test case " + << this_test_info->test_case_name() << ",\n" + << "you defined test " << first_test_name + << " and test " << this_test_name << "\n" + << "using two different test fixture classes. This can happen if\n" + << "the two classes are from different namespaces or translation\n" + << "units and have the same name. You should probably rename one\n" + << "of the classes to put the tests into different test cases."; + } + return false; + } + + return true; +} + +#if GTEST_HAS_SEH + +// Adds an "exception thrown" fatal failure to the current test. This +// function returns its result via an output parameter pointer because VC++ +// prohibits creation of objects with destructors on stack in functions +// using __try (see error C2712). +static internal::String* FormatSehExceptionMessage(DWORD exception_code, + const char* location) { + Message message; + message << "SEH exception with code 0x" << std::setbase(16) << + exception_code << std::setbase(10) << " thrown in " << location << "."; + + return new internal::String(message.GetString()); +} + +#endif // GTEST_HAS_SEH + +#if GTEST_HAS_EXCEPTIONS + +// Adds an "exception thrown" fatal failure to the current test. +static internal::String FormatCxxExceptionMessage(const char* description, + const char* location) { + Message message; + if (description != NULL) { + message << "C++ exception with description \"" << description << "\""; + } else { + message << "Unknown C++ exception"; + } + message << " thrown in " << location << "."; + + return message.GetString(); +} + +static internal::String PrintTestPartResultToString( + const TestPartResult& test_part_result); + +// A failed Google Test assertion will throw an exception of this type when +// GTEST_FLAG(throw_on_failure) is true (if exceptions are enabled). We +// derive it from std::runtime_error, which is for errors presumably +// detectable only at run time. Since std::runtime_error inherits from +// std::exception, many testing frameworks know how to extract and print the +// message inside it. +class GoogleTestFailureException : public ::std::runtime_error { + public: + explicit GoogleTestFailureException(const TestPartResult& failure) + : ::std::runtime_error(PrintTestPartResultToString(failure).c_str()) {} +}; +#endif // GTEST_HAS_EXCEPTIONS + +namespace internal { +// We put these helper functions in the internal namespace as IBM's xlC +// compiler rejects the code if they were declared static. + +// Runs the given method and handles SEH exceptions it throws, when +// SEH is supported; returns the 0-value for type Result in case of an +// SEH exception. (Microsoft compilers cannot handle SEH and C++ +// exceptions in the same function. Therefore, we provide a separate +// wrapper function for handling SEH exceptions.) +template +Result HandleSehExceptionsInMethodIfSupported( + T* object, Result (T::*method)(), const char* location) { +#if GTEST_HAS_SEH + __try { + return (object->*method)(); + } __except (internal::UnitTestOptions::GTestShouldProcessSEH( // NOLINT + GetExceptionCode())) { + // We create the exception message on the heap because VC++ prohibits + // creation of objects with destructors on stack in functions using __try + // (see error C2712). + internal::String* exception_message = FormatSehExceptionMessage( + GetExceptionCode(), location); + internal::ReportFailureInUnknownLocation(TestPartResult::kFatalFailure, + *exception_message); + delete exception_message; + return static_cast(0); + } +#else + (void)location; + return (object->*method)(); +#endif // GTEST_HAS_SEH +} + +// Runs the given method and catches and reports C++ and/or SEH-style +// exceptions, if they are supported; returns the 0-value for type +// Result in case of an SEH exception. +template +Result HandleExceptionsInMethodIfSupported( + T* object, Result (T::*method)(), const char* location) { + // NOTE: The user code can affect the way in which Google Test handles + // exceptions by setting GTEST_FLAG(catch_exceptions), but only before + // RUN_ALL_TESTS() starts. It is technically possible to check the flag + // after the exception is caught and either report or re-throw the + // exception based on the flag's value: + // + // try { + // // Perform the test method. + // } catch (...) { + // if (GTEST_FLAG(catch_exceptions)) + // // Report the exception as failure. + // else + // throw; // Re-throws the original exception. + // } + // + // However, the purpose of this flag is to allow the program to drop into + // the debugger when the exception is thrown. On most platforms, once the + // control enters the catch block, the exception origin information is + // lost and the debugger will stop the program at the point of the + // re-throw in this function -- instead of at the point of the original + // throw statement in the code under test. For this reason, we perform + // the check early, sacrificing the ability to affect Google Test's + // exception handling in the method where the exception is thrown. + if (internal::GetUnitTestImpl()->catch_exceptions()) { +#if GTEST_HAS_EXCEPTIONS + try { + return HandleSehExceptionsInMethodIfSupported(object, method, location); + } catch (const GoogleTestFailureException&) { // NOLINT + // This exception doesn't originate in code under test. It makes no + // sense to report it as a test failure. + throw; + } catch (const std::exception& e) { // NOLINT + internal::ReportFailureInUnknownLocation( + TestPartResult::kFatalFailure, + FormatCxxExceptionMessage(e.what(), location)); + } catch (...) { // NOLINT + internal::ReportFailureInUnknownLocation( + TestPartResult::kFatalFailure, + FormatCxxExceptionMessage(NULL, location)); + } + return static_cast(0); +#else + return HandleSehExceptionsInMethodIfSupported(object, method, location); +#endif // GTEST_HAS_EXCEPTIONS + } else { + return (object->*method)(); + } +} + +} // namespace internal + +// Runs the test and updates the test result. +void Test::Run() { + if (!HasSameFixtureClass()) return; + + internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); + impl->os_stack_trace_getter()->UponLeavingGTest(); + internal::HandleExceptionsInMethodIfSupported(this, &Test::SetUp, "SetUp()"); + // We will run the test only if SetUp() was successful. + if (!HasFatalFailure()) { + impl->os_stack_trace_getter()->UponLeavingGTest(); + internal::HandleExceptionsInMethodIfSupported( + this, &Test::TestBody, "the test body"); + } + + // However, we want to clean up as much as possible. Hence we will + // always call TearDown(), even if SetUp() or the test body has + // failed. + impl->os_stack_trace_getter()->UponLeavingGTest(); + internal::HandleExceptionsInMethodIfSupported( + this, &Test::TearDown, "TearDown()"); +} + +// Returns true iff the current test has a fatal failure. +bool Test::HasFatalFailure() { + return internal::GetUnitTestImpl()->current_test_result()->HasFatalFailure(); +} + +// Returns true iff the current test has a non-fatal failure. +bool Test::HasNonfatalFailure() { + return internal::GetUnitTestImpl()->current_test_result()-> + HasNonfatalFailure(); +} + +// class TestInfo + +// Constructs a TestInfo object. It assumes ownership of the test factory +// object. +// TODO(vladl@google.com): Make a_test_case_name and a_name const string&'s +// to signify they cannot be NULLs. +TestInfo::TestInfo(const char* a_test_case_name, + const char* a_name, + const char* a_type_param, + const char* a_value_param, + internal::TypeId fixture_class_id, + internal::TestFactoryBase* factory) + : test_case_name_(a_test_case_name), + name_(a_name), + type_param_(a_type_param ? new std::string(a_type_param) : NULL), + value_param_(a_value_param ? new std::string(a_value_param) : NULL), + fixture_class_id_(fixture_class_id), + should_run_(false), + is_disabled_(false), + matches_filter_(false), + factory_(factory), + result_() {} + +// Destructs a TestInfo object. +TestInfo::~TestInfo() { delete factory_; } + +namespace internal { + +// Creates a new TestInfo object and registers it with Google Test; +// returns the created object. +// +// Arguments: +// +// test_case_name: name of the test case +// name: name of the test +// type_param: the name of the test's type parameter, or NULL if +// this is not a typed or a type-parameterized test. +// value_param: text representation of the test's value parameter, +// or NULL if this is not a value-parameterized test. +// fixture_class_id: ID of the test fixture class +// set_up_tc: pointer to the function that sets up the test case +// tear_down_tc: pointer to the function that tears down the test case +// factory: pointer to the factory that creates a test object. +// The newly created TestInfo instance will assume +// ownership of the factory object. +TestInfo* MakeAndRegisterTestInfo( + const char* test_case_name, const char* name, + const char* type_param, + const char* value_param, + TypeId fixture_class_id, + SetUpTestCaseFunc set_up_tc, + TearDownTestCaseFunc tear_down_tc, + TestFactoryBase* factory) { + TestInfo* const test_info = + new TestInfo(test_case_name, name, type_param, value_param, + fixture_class_id, factory); + GetUnitTestImpl()->AddTestInfo(set_up_tc, tear_down_tc, test_info); + return test_info; +} + +#if GTEST_HAS_PARAM_TEST +void ReportInvalidTestCaseType(const char* test_case_name, + const char* file, int line) { + Message errors; + errors + << "Attempted redefinition of test case " << test_case_name << ".\n" + << "All tests in the same test case must use the same test fixture\n" + << "class. However, in test case " << test_case_name << ", you tried\n" + << "to define a test using a fixture class different from the one\n" + << "used earlier. This can happen if the two fixture classes are\n" + << "from different namespaces and have the same name. You should\n" + << "probably rename one of the classes to put the tests into different\n" + << "test cases."; + + fprintf(stderr, "%s %s", FormatFileLocation(file, line).c_str(), + errors.GetString().c_str()); +} +#endif // GTEST_HAS_PARAM_TEST + +} // namespace internal + +namespace { + +// A predicate that checks the test name of a TestInfo against a known +// value. +// +// This is used for implementation of the TestCase class only. We put +// it in the anonymous namespace to prevent polluting the outer +// namespace. +// +// TestNameIs is copyable. +class TestNameIs { + public: + // Constructor. + // + // TestNameIs has NO default constructor. + explicit TestNameIs(const char* name) + : name_(name) {} + + // Returns true iff the test name of test_info matches name_. + bool operator()(const TestInfo * test_info) const { + return test_info && internal::String(test_info->name()).Compare(name_) == 0; + } + + private: + internal::String name_; +}; + +} // namespace + +namespace internal { + +// This method expands all parameterized tests registered with macros TEST_P +// and INSTANTIATE_TEST_CASE_P into regular tests and registers those. +// This will be done just once during the program runtime. +void UnitTestImpl::RegisterParameterizedTests() { +#if GTEST_HAS_PARAM_TEST + if (!parameterized_tests_registered_) { + parameterized_test_registry_.RegisterTests(); + parameterized_tests_registered_ = true; + } +#endif +} + +} // namespace internal + +// Creates the test object, runs it, records its result, and then +// deletes it. +void TestInfo::Run() { + if (!should_run_) return; + + // Tells UnitTest where to store test result. + internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); + impl->set_current_test_info(this); + + TestEventListener* repeater = UnitTest::GetInstance()->listeners().repeater(); + + // Notifies the unit test event listeners that a test is about to start. + repeater->OnTestStart(*this); + + const TimeInMillis start = internal::GetTimeInMillis(); + + impl->os_stack_trace_getter()->UponLeavingGTest(); + + // Creates the test object. + Test* const test = internal::HandleExceptionsInMethodIfSupported( + factory_, &internal::TestFactoryBase::CreateTest, + "the test fixture's constructor"); + + // Runs the test only if the test object was created and its + // constructor didn't generate a fatal failure. + if ((test != NULL) && !Test::HasFatalFailure()) { + // This doesn't throw as all user code that can throw are wrapped into + // exception handling code. + test->Run(); + } + + // Deletes the test object. + impl->os_stack_trace_getter()->UponLeavingGTest(); + internal::HandleExceptionsInMethodIfSupported( + test, &Test::DeleteSelf_, "the test fixture's destructor"); + + result_.set_elapsed_time(internal::GetTimeInMillis() - start); + + // Notifies the unit test event listener that a test has just finished. + repeater->OnTestEnd(*this); + + // Tells UnitTest to stop associating assertion results to this + // test. + impl->set_current_test_info(NULL); +} + +// class TestCase + +// Gets the number of successful tests in this test case. +int TestCase::successful_test_count() const { + return CountIf(test_info_list_, TestPassed); +} + +// Gets the number of failed tests in this test case. +int TestCase::failed_test_count() const { + return CountIf(test_info_list_, TestFailed); +} + +int TestCase::disabled_test_count() const { + return CountIf(test_info_list_, TestDisabled); +} + +// Get the number of tests in this test case that should run. +int TestCase::test_to_run_count() const { + return CountIf(test_info_list_, ShouldRunTest); +} + +// Gets the number of all tests. +int TestCase::total_test_count() const { + return static_cast(test_info_list_.size()); +} + +// Creates a TestCase with the given name. +// +// Arguments: +// +// name: name of the test case +// a_type_param: the name of the test case's type parameter, or NULL if +// this is not a typed or a type-parameterized test case. +// set_up_tc: pointer to the function that sets up the test case +// tear_down_tc: pointer to the function that tears down the test case +TestCase::TestCase(const char* a_name, const char* a_type_param, + Test::SetUpTestCaseFunc set_up_tc, + Test::TearDownTestCaseFunc tear_down_tc) + : name_(a_name), + type_param_(a_type_param ? new std::string(a_type_param) : NULL), + set_up_tc_(set_up_tc), + tear_down_tc_(tear_down_tc), + should_run_(false), + elapsed_time_(0) { +} + +// Destructor of TestCase. +TestCase::~TestCase() { + // Deletes every Test in the collection. + ForEach(test_info_list_, internal::Delete); +} + +// Returns the i-th test among all the tests. i can range from 0 to +// total_test_count() - 1. If i is not in that range, returns NULL. +const TestInfo* TestCase::GetTestInfo(int i) const { + const int index = GetElementOr(test_indices_, i, -1); + return index < 0 ? NULL : test_info_list_[index]; +} + +// Returns the i-th test among all the tests. i can range from 0 to +// total_test_count() - 1. If i is not in that range, returns NULL. +TestInfo* TestCase::GetMutableTestInfo(int i) { + const int index = GetElementOr(test_indices_, i, -1); + return index < 0 ? NULL : test_info_list_[index]; +} + +// Adds a test to this test case. Will delete the test upon +// destruction of the TestCase object. +void TestCase::AddTestInfo(TestInfo * test_info) { + test_info_list_.push_back(test_info); + test_indices_.push_back(static_cast(test_indices_.size())); +} + +// Runs every test in this TestCase. +void TestCase::Run() { + if (!should_run_) return; + + internal::UnitTestImpl* const impl = internal::GetUnitTestImpl(); + impl->set_current_test_case(this); + + TestEventListener* repeater = UnitTest::GetInstance()->listeners().repeater(); + + repeater->OnTestCaseStart(*this); + impl->os_stack_trace_getter()->UponLeavingGTest(); + internal::HandleExceptionsInMethodIfSupported( + this, &TestCase::RunSetUpTestCase, "SetUpTestCase()"); + + const internal::TimeInMillis start = internal::GetTimeInMillis(); + for (int i = 0; i < total_test_count(); i++) { + GetMutableTestInfo(i)->Run(); + } + elapsed_time_ = internal::GetTimeInMillis() - start; + + impl->os_stack_trace_getter()->UponLeavingGTest(); + internal::HandleExceptionsInMethodIfSupported( + this, &TestCase::RunTearDownTestCase, "TearDownTestCase()"); + + repeater->OnTestCaseEnd(*this); + impl->set_current_test_case(NULL); +} + +// Clears the results of all tests in this test case. +void TestCase::ClearResult() { + ForEach(test_info_list_, TestInfo::ClearTestResult); +} + +// Shuffles the tests in this test case. +void TestCase::ShuffleTests(internal::Random* random) { + Shuffle(random, &test_indices_); +} + +// Restores the test order to before the first shuffle. +void TestCase::UnshuffleTests() { + for (size_t i = 0; i < test_indices_.size(); i++) { + test_indices_[i] = static_cast(i); + } +} + +// Formats a countable noun. Depending on its quantity, either the +// singular form or the plural form is used. e.g. +// +// FormatCountableNoun(1, "formula", "formuli") returns "1 formula". +// FormatCountableNoun(5, "book", "books") returns "5 books". +static internal::String FormatCountableNoun(int count, + const char * singular_form, + const char * plural_form) { + return internal::String::Format("%d %s", count, + count == 1 ? singular_form : plural_form); +} + +// Formats the count of tests. +static internal::String FormatTestCount(int test_count) { + return FormatCountableNoun(test_count, "test", "tests"); +} + +// Formats the count of test cases. +static internal::String FormatTestCaseCount(int test_case_count) { + return FormatCountableNoun(test_case_count, "test case", "test cases"); +} + +// Converts a TestPartResult::Type enum to human-friendly string +// representation. Both kNonFatalFailure and kFatalFailure are translated +// to "Failure", as the user usually doesn't care about the difference +// between the two when viewing the test result. +static const char * TestPartResultTypeToString(TestPartResult::Type type) { + switch (type) { + case TestPartResult::kSuccess: + return "Success"; + + case TestPartResult::kNonFatalFailure: + case TestPartResult::kFatalFailure: +#ifdef _MSC_VER + return "error: "; +#else + return "Failure\n"; +#endif + default: + return "Unknown result type"; + } +} + +// Prints a TestPartResult to a String. +static internal::String PrintTestPartResultToString( + const TestPartResult& test_part_result) { + return (Message() + << internal::FormatFileLocation(test_part_result.file_name(), + test_part_result.line_number()) + << " " << TestPartResultTypeToString(test_part_result.type()) + << test_part_result.message()).GetString(); +} + +// Prints a TestPartResult. +static void PrintTestPartResult(const TestPartResult& test_part_result) { + const internal::String& result = + PrintTestPartResultToString(test_part_result); + printf("%s\n", result.c_str()); + fflush(stdout); + // If the test program runs in Visual Studio or a debugger, the + // following statements add the test part result message to the Output + // window such that the user can double-click on it to jump to the + // corresponding source code location; otherwise they do nothing. +#if GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MOBILE + // We don't call OutputDebugString*() on Windows Mobile, as printing + // to stdout is done by OutputDebugString() there already - we don't + // want the same message printed twice. + ::OutputDebugStringA(result.c_str()); + ::OutputDebugStringA("\n"); +#endif +} + +// class PrettyUnitTestResultPrinter + +namespace internal { + +enum GTestColor { + COLOR_DEFAULT, + COLOR_RED, + COLOR_GREEN, + COLOR_YELLOW +}; + +#if GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MOBILE + +// Returns the character attribute for the given color. +WORD GetColorAttribute(GTestColor color) { + switch (color) { + case COLOR_RED: return FOREGROUND_RED; + case COLOR_GREEN: return FOREGROUND_GREEN; + case COLOR_YELLOW: return FOREGROUND_RED | FOREGROUND_GREEN; + default: return 0; + } +} + +#else + +// Returns the ANSI color code for the given color. COLOR_DEFAULT is +// an invalid input. +const char* GetAnsiColorCode(GTestColor color) { + switch (color) { + case COLOR_RED: return "1"; + case COLOR_GREEN: return "2"; + case COLOR_YELLOW: return "3"; + default: return NULL; + }; +} + +#endif // GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MOBILE + +// Returns true iff Google Test should use colors in the output. +bool ShouldUseColor(bool stdout_is_tty) { + const char* const gtest_color = GTEST_FLAG(color).c_str(); + + if (String::CaseInsensitiveCStringEquals(gtest_color, "auto")) { +#if GTEST_OS_WINDOWS + // On Windows the TERM variable is usually not set, but the + // console there does support colors. + return stdout_is_tty; +#else + // On non-Windows platforms, we rely on the TERM variable. + const char* const term = posix::GetEnv("TERM"); + const bool term_supports_color = + String::CStringEquals(term, "xterm") || + String::CStringEquals(term, "xterm-color") || + String::CStringEquals(term, "xterm-256color") || + String::CStringEquals(term, "screen") || + String::CStringEquals(term, "linux") || + String::CStringEquals(term, "cygwin"); + return stdout_is_tty && term_supports_color; +#endif // GTEST_OS_WINDOWS + } + + return String::CaseInsensitiveCStringEquals(gtest_color, "yes") || + String::CaseInsensitiveCStringEquals(gtest_color, "true") || + String::CaseInsensitiveCStringEquals(gtest_color, "t") || + String::CStringEquals(gtest_color, "1"); + // We take "yes", "true", "t", and "1" as meaning "yes". If the + // value is neither one of these nor "auto", we treat it as "no" to + // be conservative. +} + +// Helpers for printing colored strings to stdout. Note that on Windows, we +// cannot simply emit special characters and have the terminal change colors. +// This routine must actually emit the characters rather than return a string +// that would be colored when printed, as can be done on Linux. +void ColoredPrintf(GTestColor color, const char* fmt, ...) { + va_list args; + va_start(args, fmt); + +#if GTEST_OS_WINDOWS_MOBILE || GTEST_OS_SYMBIAN || GTEST_OS_ZOS + const bool use_color = false; +#else + static const bool in_color_mode = + ShouldUseColor(posix::IsATTY(posix::FileNo(stdout)) != 0); + const bool use_color = in_color_mode && (color != COLOR_DEFAULT); +#endif // GTEST_OS_WINDOWS_MOBILE || GTEST_OS_SYMBIAN || GTEST_OS_ZOS + // The '!= 0' comparison is necessary to satisfy MSVC 7.1. + + if (!use_color) { + vprintf(fmt, args); + va_end(args); + return; + } + +#if GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MOBILE + const HANDLE stdout_handle = GetStdHandle(STD_OUTPUT_HANDLE); + + // Gets the current text color. + CONSOLE_SCREEN_BUFFER_INFO buffer_info; + GetConsoleScreenBufferInfo(stdout_handle, &buffer_info); + const WORD old_color_attrs = buffer_info.wAttributes; + + // We need to flush the stream buffers into the console before each + // SetConsoleTextAttribute call lest it affect the text that is already + // printed but has not yet reached the console. + fflush(stdout); + SetConsoleTextAttribute(stdout_handle, + GetColorAttribute(color) | FOREGROUND_INTENSITY); + vprintf(fmt, args); + + fflush(stdout); + // Restores the text color. + SetConsoleTextAttribute(stdout_handle, old_color_attrs); +#else + printf("\033[0;3%sm", GetAnsiColorCode(color)); + vprintf(fmt, args); + printf("\033[m"); // Resets the terminal to default. +#endif // GTEST_OS_WINDOWS && !GTEST_OS_WINDOWS_MOBILE + va_end(args); +} + +void PrintFullTestCommentIfPresent(const TestInfo& test_info) { + const char* const type_param = test_info.type_param(); + const char* const value_param = test_info.value_param(); + + if (type_param != NULL || value_param != NULL) { + printf(", where "); + if (type_param != NULL) { + printf("TypeParam = %s", type_param); + if (value_param != NULL) + printf(" and "); + } + if (value_param != NULL) { + printf("GetParam() = %s", value_param); + } + } +} + +// This class implements the TestEventListener interface. +// +// Class PrettyUnitTestResultPrinter is copyable. +class PrettyUnitTestResultPrinter : public TestEventListener { + public: + PrettyUnitTestResultPrinter() {} + static void PrintTestName(const char * test_case, const char * test) { + printf("%s.%s", test_case, test); + } + + // The following methods override what's in the TestEventListener class. + virtual void OnTestProgramStart(const UnitTest& /*unit_test*/) {} + virtual void OnTestIterationStart(const UnitTest& unit_test, int iteration); + virtual void OnEnvironmentsSetUpStart(const UnitTest& unit_test); + virtual void OnEnvironmentsSetUpEnd(const UnitTest& /*unit_test*/) {} + virtual void OnTestCaseStart(const TestCase& test_case); + virtual void OnTestStart(const TestInfo& test_info); + virtual void OnTestPartResult(const TestPartResult& result); + virtual void OnTestEnd(const TestInfo& test_info); + virtual void OnTestCaseEnd(const TestCase& test_case); + virtual void OnEnvironmentsTearDownStart(const UnitTest& unit_test); + virtual void OnEnvironmentsTearDownEnd(const UnitTest& /*unit_test*/) {} + virtual void OnTestIterationEnd(const UnitTest& unit_test, int iteration); + virtual void OnTestProgramEnd(const UnitTest& /*unit_test*/) {} + + private: + static void PrintFailedTests(const UnitTest& unit_test); + + internal::String test_case_name_; +}; + + // Fired before each iteration of tests starts. +void PrettyUnitTestResultPrinter::OnTestIterationStart( + const UnitTest& unit_test, int iteration) { + if (GTEST_FLAG(repeat) != 1) + printf("\nRepeating all tests (iteration %d) . . .\n\n", iteration + 1); + + const char* const filter = GTEST_FLAG(filter).c_str(); + + // Prints the filter if it's not *. This reminds the user that some + // tests may be skipped. + if (!internal::String::CStringEquals(filter, kUniversalFilter)) { + ColoredPrintf(COLOR_YELLOW, + "Note: %s filter = %s\n", GTEST_NAME_, filter); + } + + if (internal::ShouldShard(kTestTotalShards, kTestShardIndex, false)) { + const Int32 shard_index = Int32FromEnvOrDie(kTestShardIndex, -1); + ColoredPrintf(COLOR_YELLOW, + "Note: This is test shard %d of %s.\n", + static_cast(shard_index) + 1, + internal::posix::GetEnv(kTestTotalShards)); + } + + if (GTEST_FLAG(shuffle)) { + ColoredPrintf(COLOR_YELLOW, + "Note: Randomizing tests' orders with a seed of %d .\n", + unit_test.random_seed()); + } + + ColoredPrintf(COLOR_GREEN, "[==========] "); + printf("Running %s from %s.\n", + FormatTestCount(unit_test.test_to_run_count()).c_str(), + FormatTestCaseCount(unit_test.test_case_to_run_count()).c_str()); + fflush(stdout); +} + +void PrettyUnitTestResultPrinter::OnEnvironmentsSetUpStart( + const UnitTest& /*unit_test*/) { + ColoredPrintf(COLOR_GREEN, "[----------] "); + printf("Global test environment set-up.\n"); + fflush(stdout); +} + +void PrettyUnitTestResultPrinter::OnTestCaseStart(const TestCase& test_case) { + test_case_name_ = test_case.name(); + const internal::String counts = + FormatCountableNoun(test_case.test_to_run_count(), "test", "tests"); + ColoredPrintf(COLOR_GREEN, "[----------] "); + printf("%s from %s", counts.c_str(), test_case_name_.c_str()); + if (test_case.type_param() == NULL) { + printf("\n"); + } else { + printf(", where TypeParam = %s\n", test_case.type_param()); + } + fflush(stdout); +} + +void PrettyUnitTestResultPrinter::OnTestStart(const TestInfo& test_info) { + ColoredPrintf(COLOR_GREEN, "[ RUN ] "); + PrintTestName(test_case_name_.c_str(), test_info.name()); + printf("\n"); + fflush(stdout); +} + +// Called after an assertion failure. +void PrettyUnitTestResultPrinter::OnTestPartResult( + const TestPartResult& result) { + // If the test part succeeded, we don't need to do anything. + if (result.type() == TestPartResult::kSuccess) + return; + + // Print failure message from the assertion (e.g. expected this and got that). + PrintTestPartResult(result); + fflush(stdout); +} + +void PrettyUnitTestResultPrinter::OnTestEnd(const TestInfo& test_info) { + if (test_info.result()->Passed()) { + ColoredPrintf(COLOR_GREEN, "[ OK ] "); + } else { + ColoredPrintf(COLOR_RED, "[ FAILED ] "); + } + PrintTestName(test_case_name_.c_str(), test_info.name()); + if (test_info.result()->Failed()) + PrintFullTestCommentIfPresent(test_info); + + if (GTEST_FLAG(print_time)) { + printf(" (%s ms)\n", internal::StreamableToString( + test_info.result()->elapsed_time()).c_str()); + } else { + printf("\n"); + } + fflush(stdout); +} + +void PrettyUnitTestResultPrinter::OnTestCaseEnd(const TestCase& test_case) { + if (!GTEST_FLAG(print_time)) return; + + test_case_name_ = test_case.name(); + const internal::String counts = + FormatCountableNoun(test_case.test_to_run_count(), "test", "tests"); + ColoredPrintf(COLOR_GREEN, "[----------] "); + printf("%s from %s (%s ms total)\n\n", + counts.c_str(), test_case_name_.c_str(), + internal::StreamableToString(test_case.elapsed_time()).c_str()); + fflush(stdout); +} + +void PrettyUnitTestResultPrinter::OnEnvironmentsTearDownStart( + const UnitTest& /*unit_test*/) { + ColoredPrintf(COLOR_GREEN, "[----------] "); + printf("Global test environment tear-down\n"); + fflush(stdout); +} + +// Internal helper for printing the list of failed tests. +void PrettyUnitTestResultPrinter::PrintFailedTests(const UnitTest& unit_test) { + const int failed_test_count = unit_test.failed_test_count(); + if (failed_test_count == 0) { + return; + } + + for (int i = 0; i < unit_test.total_test_case_count(); ++i) { + const TestCase& test_case = *unit_test.GetTestCase(i); + if (!test_case.should_run() || (test_case.failed_test_count() == 0)) { + continue; + } + for (int j = 0; j < test_case.total_test_count(); ++j) { + const TestInfo& test_info = *test_case.GetTestInfo(j); + if (!test_info.should_run() || test_info.result()->Passed()) { + continue; + } + ColoredPrintf(COLOR_RED, "[ FAILED ] "); + printf("%s.%s", test_case.name(), test_info.name()); + PrintFullTestCommentIfPresent(test_info); + printf("\n"); + } + } +} + +void PrettyUnitTestResultPrinter::OnTestIterationEnd(const UnitTest& unit_test, + int /*iteration*/) { + ColoredPrintf(COLOR_GREEN, "[==========] "); + printf("%s from %s ran.", + FormatTestCount(unit_test.test_to_run_count()).c_str(), + FormatTestCaseCount(unit_test.test_case_to_run_count()).c_str()); + if (GTEST_FLAG(print_time)) { + printf(" (%s ms total)", + internal::StreamableToString(unit_test.elapsed_time()).c_str()); + } + printf("\n"); + ColoredPrintf(COLOR_GREEN, "[ PASSED ] "); + printf("%s.\n", FormatTestCount(unit_test.successful_test_count()).c_str()); + + int num_failures = unit_test.failed_test_count(); + if (!unit_test.Passed()) { + const int failed_test_count = unit_test.failed_test_count(); + ColoredPrintf(COLOR_RED, "[ FAILED ] "); + printf("%s, listed below:\n", FormatTestCount(failed_test_count).c_str()); + PrintFailedTests(unit_test); + printf("\n%2d FAILED %s\n", num_failures, + num_failures == 1 ? "TEST" : "TESTS"); + } + + int num_disabled = unit_test.disabled_test_count(); + if (num_disabled && !GTEST_FLAG(also_run_disabled_tests)) { + if (!num_failures) { + printf("\n"); // Add a spacer if no FAILURE banner is displayed. + } + ColoredPrintf(COLOR_YELLOW, + " YOU HAVE %d DISABLED %s\n\n", + num_disabled, + num_disabled == 1 ? "TEST" : "TESTS"); + } + // Ensure that Google Test output is printed before, e.g., heapchecker output. + fflush(stdout); +} + +// End PrettyUnitTestResultPrinter + +// class TestEventRepeater +// +// This class forwards events to other event listeners. +class TestEventRepeater : public TestEventListener { + public: + TestEventRepeater() : forwarding_enabled_(true) {} + virtual ~TestEventRepeater(); + void Append(TestEventListener *listener); + TestEventListener* Release(TestEventListener* listener); + + // Controls whether events will be forwarded to listeners_. Set to false + // in death test child processes. + bool forwarding_enabled() const { return forwarding_enabled_; } + void set_forwarding_enabled(bool enable) { forwarding_enabled_ = enable; } + + virtual void OnTestProgramStart(const UnitTest& unit_test); + virtual void OnTestIterationStart(const UnitTest& unit_test, int iteration); + virtual void OnEnvironmentsSetUpStart(const UnitTest& unit_test); + virtual void OnEnvironmentsSetUpEnd(const UnitTest& unit_test); + virtual void OnTestCaseStart(const TestCase& test_case); + virtual void OnTestStart(const TestInfo& test_info); + virtual void OnTestPartResult(const TestPartResult& result); + virtual void OnTestEnd(const TestInfo& test_info); + virtual void OnTestCaseEnd(const TestCase& test_case); + virtual void OnEnvironmentsTearDownStart(const UnitTest& unit_test); + virtual void OnEnvironmentsTearDownEnd(const UnitTest& unit_test); + virtual void OnTestIterationEnd(const UnitTest& unit_test, int iteration); + virtual void OnTestProgramEnd(const UnitTest& unit_test); + + private: + // Controls whether events will be forwarded to listeners_. Set to false + // in death test child processes. + bool forwarding_enabled_; + // The list of listeners that receive events. + std::vector listeners_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestEventRepeater); +}; + +TestEventRepeater::~TestEventRepeater() { + ForEach(listeners_, Delete); +} + +void TestEventRepeater::Append(TestEventListener *listener) { + listeners_.push_back(listener); +} + +// TODO(vladl@google.com): Factor the search functionality into Vector::Find. +TestEventListener* TestEventRepeater::Release(TestEventListener *listener) { + for (size_t i = 0; i < listeners_.size(); ++i) { + if (listeners_[i] == listener) { + listeners_.erase(listeners_.begin() + i); + return listener; + } + } + + return NULL; +} + +// Since most methods are very similar, use macros to reduce boilerplate. +// This defines a member that forwards the call to all listeners. +#define GTEST_REPEATER_METHOD_(Name, Type) \ +void TestEventRepeater::Name(const Type& parameter) { \ + if (forwarding_enabled_) { \ + for (size_t i = 0; i < listeners_.size(); i++) { \ + listeners_[i]->Name(parameter); \ + } \ + } \ +} +// This defines a member that forwards the call to all listeners in reverse +// order. +#define GTEST_REVERSE_REPEATER_METHOD_(Name, Type) \ +void TestEventRepeater::Name(const Type& parameter) { \ + if (forwarding_enabled_) { \ + for (int i = static_cast(listeners_.size()) - 1; i >= 0; i--) { \ + listeners_[i]->Name(parameter); \ + } \ + } \ +} + +GTEST_REPEATER_METHOD_(OnTestProgramStart, UnitTest) +GTEST_REPEATER_METHOD_(OnEnvironmentsSetUpStart, UnitTest) +GTEST_REPEATER_METHOD_(OnTestCaseStart, TestCase) +GTEST_REPEATER_METHOD_(OnTestStart, TestInfo) +GTEST_REPEATER_METHOD_(OnTestPartResult, TestPartResult) +GTEST_REPEATER_METHOD_(OnEnvironmentsTearDownStart, UnitTest) +GTEST_REVERSE_REPEATER_METHOD_(OnEnvironmentsSetUpEnd, UnitTest) +GTEST_REVERSE_REPEATER_METHOD_(OnEnvironmentsTearDownEnd, UnitTest) +GTEST_REVERSE_REPEATER_METHOD_(OnTestEnd, TestInfo) +GTEST_REVERSE_REPEATER_METHOD_(OnTestCaseEnd, TestCase) +GTEST_REVERSE_REPEATER_METHOD_(OnTestProgramEnd, UnitTest) + +#undef GTEST_REPEATER_METHOD_ +#undef GTEST_REVERSE_REPEATER_METHOD_ + +void TestEventRepeater::OnTestIterationStart(const UnitTest& unit_test, + int iteration) { + if (forwarding_enabled_) { + for (size_t i = 0; i < listeners_.size(); i++) { + listeners_[i]->OnTestIterationStart(unit_test, iteration); + } + } +} + +void TestEventRepeater::OnTestIterationEnd(const UnitTest& unit_test, + int iteration) { + if (forwarding_enabled_) { + for (int i = static_cast(listeners_.size()) - 1; i >= 0; i--) { + listeners_[i]->OnTestIterationEnd(unit_test, iteration); + } + } +} + +// End TestEventRepeater + +// This class generates an XML output file. +class XmlUnitTestResultPrinter : public EmptyTestEventListener { + public: + explicit XmlUnitTestResultPrinter(const char* output_file); + + virtual void OnTestIterationEnd(const UnitTest& unit_test, int iteration); + + private: + // Is c a whitespace character that is normalized to a space character + // when it appears in an XML attribute value? + static bool IsNormalizableWhitespace(char c) { + return c == 0x9 || c == 0xA || c == 0xD; + } + + // May c appear in a well-formed XML document? + static bool IsValidXmlCharacter(char c) { + return IsNormalizableWhitespace(c) || c >= 0x20; + } + + // Returns an XML-escaped copy of the input string str. If + // is_attribute is true, the text is meant to appear as an attribute + // value, and normalizable whitespace is preserved by replacing it + // with character references. + static String EscapeXml(const char* str, bool is_attribute); + + // Returns the given string with all characters invalid in XML removed. + static string RemoveInvalidXmlCharacters(const string& str); + + // Convenience wrapper around EscapeXml when str is an attribute value. + static String EscapeXmlAttribute(const char* str) { + return EscapeXml(str, true); + } + + // Convenience wrapper around EscapeXml when str is not an attribute value. + static String EscapeXmlText(const char* str) { return EscapeXml(str, false); } + + // Streams an XML CDATA section, escaping invalid CDATA sequences as needed. + static void OutputXmlCDataSection(::std::ostream* stream, const char* data); + + // Streams an XML representation of a TestInfo object. + static void OutputXmlTestInfo(::std::ostream* stream, + const char* test_case_name, + const TestInfo& test_info); + + // Prints an XML representation of a TestCase object + static void PrintXmlTestCase(FILE* out, const TestCase& test_case); + + // Prints an XML summary of unit_test to output stream out. + static void PrintXmlUnitTest(FILE* out, const UnitTest& unit_test); + + // Produces a string representing the test properties in a result as space + // delimited XML attributes based on the property key="value" pairs. + // When the String is not empty, it includes a space at the beginning, + // to delimit this attribute from prior attributes. + static String TestPropertiesAsXmlAttributes(const TestResult& result); + + // The output file. + const String output_file_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(XmlUnitTestResultPrinter); +}; + +// Creates a new XmlUnitTestResultPrinter. +XmlUnitTestResultPrinter::XmlUnitTestResultPrinter(const char* output_file) + : output_file_(output_file) { + if (output_file_.c_str() == NULL || output_file_.empty()) { + fprintf(stderr, "XML output file may not be null\n"); + fflush(stderr); + exit(EXIT_FAILURE); + } +} + +// Called after the unit test ends. +void XmlUnitTestResultPrinter::OnTestIterationEnd(const UnitTest& unit_test, + int /*iteration*/) { + FILE* xmlout = NULL; + FilePath output_file(output_file_); + FilePath output_dir(output_file.RemoveFileName()); + + if (output_dir.CreateDirectoriesRecursively()) { + xmlout = posix::FOpen(output_file_.c_str(), "w"); + } + if (xmlout == NULL) { + // TODO(wan): report the reason of the failure. + // + // We don't do it for now as: + // + // 1. There is no urgent need for it. + // 2. It's a bit involved to make the errno variable thread-safe on + // all three operating systems (Linux, Windows, and Mac OS). + // 3. To interpret the meaning of errno in a thread-safe way, + // we need the strerror_r() function, which is not available on + // Windows. + fprintf(stderr, + "Unable to open file \"%s\"\n", + output_file_.c_str()); + fflush(stderr); + exit(EXIT_FAILURE); + } + PrintXmlUnitTest(xmlout, unit_test); + fclose(xmlout); +} + +// Returns an XML-escaped copy of the input string str. If is_attribute +// is true, the text is meant to appear as an attribute value, and +// normalizable whitespace is preserved by replacing it with character +// references. +// +// Invalid XML characters in str, if any, are stripped from the output. +// It is expected that most, if not all, of the text processed by this +// module will consist of ordinary English text. +// If this module is ever modified to produce version 1.1 XML output, +// most invalid characters can be retained using character references. +// TODO(wan): It might be nice to have a minimally invasive, human-readable +// escaping scheme for invalid characters, rather than dropping them. +String XmlUnitTestResultPrinter::EscapeXml(const char* str, bool is_attribute) { + Message m; + + if (str != NULL) { + for (const char* src = str; *src; ++src) { + switch (*src) { + case '<': + m << "<"; + break; + case '>': + m << ">"; + break; + case '&': + m << "&"; + break; + case '\'': + if (is_attribute) + m << "'"; + else + m << '\''; + break; + case '"': + if (is_attribute) + m << """; + else + m << '"'; + break; + default: + if (IsValidXmlCharacter(*src)) { + if (is_attribute && IsNormalizableWhitespace(*src)) + m << String::Format("&#x%02X;", unsigned(*src)); + else + m << *src; + } + break; + } + } + } + + return m.GetString(); +} + +// Returns the given string with all characters invalid in XML removed. +// Currently invalid characters are dropped from the string. An +// alternative is to replace them with certain characters such as . or ?. +string XmlUnitTestResultPrinter::RemoveInvalidXmlCharacters(const string& str) { + string output; + output.reserve(str.size()); + for (string::const_iterator it = str.begin(); it != str.end(); ++it) + if (IsValidXmlCharacter(*it)) + output.push_back(*it); + + return output; +} + +// The following routines generate an XML representation of a UnitTest +// object. +// +// This is how Google Test concepts map to the DTD: +// +// <-- corresponds to a UnitTest object +// <-- corresponds to a TestCase object +// <-- corresponds to a TestInfo object +// ... +// ... +// ... +// <-- individual assertion failures +// +// +// + +// Formats the given time in milliseconds as seconds. +std::string FormatTimeInMillisAsSeconds(TimeInMillis ms) { + ::std::stringstream ss; + ss << ms/1000.0; + return ss.str(); +} + +// Streams an XML CDATA section, escaping invalid CDATA sequences as needed. +void XmlUnitTestResultPrinter::OutputXmlCDataSection(::std::ostream* stream, + const char* data) { + const char* segment = data; + *stream << ""); + if (next_segment != NULL) { + stream->write( + segment, static_cast(next_segment - segment)); + *stream << "]]>]]>"); + } else { + *stream << segment; + break; + } + } + *stream << "]]>"; +} + +// Prints an XML representation of a TestInfo object. +// TODO(wan): There is also value in printing properties with the plain printer. +void XmlUnitTestResultPrinter::OutputXmlTestInfo(::std::ostream* stream, + const char* test_case_name, + const TestInfo& test_info) { + const TestResult& result = *test_info.result(); + *stream << " \n"; + *stream << " "; + const string location = internal::FormatCompilerIndependentFileLocation( + part.file_name(), part.line_number()); + const string message = location + "\n" + part.message(); + OutputXmlCDataSection(stream, + RemoveInvalidXmlCharacters(message).c_str()); + *stream << "\n"; + } + } + + if (failures == 0) + *stream << " />\n"; + else + *stream << " \n"; +} + +// Prints an XML representation of a TestCase object +void XmlUnitTestResultPrinter::PrintXmlTestCase(FILE* out, + const TestCase& test_case) { + fprintf(out, + " \n", + FormatTimeInMillisAsSeconds(test_case.elapsed_time()).c_str()); + for (int i = 0; i < test_case.total_test_count(); ++i) { + ::std::stringstream stream; + OutputXmlTestInfo(&stream, test_case.name(), *test_case.GetTestInfo(i)); + fprintf(out, "%s", StringStreamToString(&stream).c_str()); + } + fprintf(out, " \n"); +} + +// Prints an XML summary of unit_test to output stream out. +void XmlUnitTestResultPrinter::PrintXmlUnitTest(FILE* out, + const UnitTest& unit_test) { + fprintf(out, "\n"); + fprintf(out, + "\n"); + for (int i = 0; i < unit_test.total_test_case_count(); ++i) + PrintXmlTestCase(out, *unit_test.GetTestCase(i)); + fprintf(out, "\n"); +} + +// Produces a string representing the test properties in a result as space +// delimited XML attributes based on the property key="value" pairs. +String XmlUnitTestResultPrinter::TestPropertiesAsXmlAttributes( + const TestResult& result) { + Message attributes; + for (int i = 0; i < result.test_property_count(); ++i) { + const TestProperty& property = result.GetTestProperty(i); + attributes << " " << property.key() << "=" + << "\"" << EscapeXmlAttribute(property.value()) << "\""; + } + return attributes.GetString(); +} + +// End XmlUnitTestResultPrinter + +#if GTEST_CAN_STREAM_RESULTS_ + +// Streams test results to the given port on the given host machine. +class StreamingListener : public EmptyTestEventListener { + public: + // Escapes '=', '&', '%', and '\n' characters in str as "%xx". + static string UrlEncode(const char* str); + + StreamingListener(const string& host, const string& port) + : sockfd_(-1), host_name_(host), port_num_(port) { + MakeConnection(); + Send("gtest_streaming_protocol_version=1.0\n"); + } + + virtual ~StreamingListener() { + if (sockfd_ != -1) + CloseConnection(); + } + + void OnTestProgramStart(const UnitTest& /* unit_test */) { + Send("event=TestProgramStart\n"); + } + + void OnTestProgramEnd(const UnitTest& unit_test) { + // Note that Google Test current only report elapsed time for each + // test iteration, not for the entire test program. + Send(String::Format("event=TestProgramEnd&passed=%d\n", + unit_test.Passed())); + + // Notify the streaming server to stop. + CloseConnection(); + } + + void OnTestIterationStart(const UnitTest& /* unit_test */, int iteration) { + Send(String::Format("event=TestIterationStart&iteration=%d\n", + iteration)); + } + + void OnTestIterationEnd(const UnitTest& unit_test, int /* iteration */) { + Send(String::Format("event=TestIterationEnd&passed=%d&elapsed_time=%sms\n", + unit_test.Passed(), + StreamableToString(unit_test.elapsed_time()).c_str())); + } + + void OnTestCaseStart(const TestCase& test_case) { + Send(String::Format("event=TestCaseStart&name=%s\n", test_case.name())); + } + + void OnTestCaseEnd(const TestCase& test_case) { + Send(String::Format("event=TestCaseEnd&passed=%d&elapsed_time=%sms\n", + test_case.Passed(), + StreamableToString(test_case.elapsed_time()).c_str())); + } + + void OnTestStart(const TestInfo& test_info) { + Send(String::Format("event=TestStart&name=%s\n", test_info.name())); + } + + void OnTestEnd(const TestInfo& test_info) { + Send(String::Format( + "event=TestEnd&passed=%d&elapsed_time=%sms\n", + (test_info.result())->Passed(), + StreamableToString((test_info.result())->elapsed_time()).c_str())); + } + + void OnTestPartResult(const TestPartResult& test_part_result) { + const char* file_name = test_part_result.file_name(); + if (file_name == NULL) + file_name = ""; + Send(String::Format("event=TestPartResult&file=%s&line=%d&message=", + UrlEncode(file_name).c_str(), + test_part_result.line_number())); + Send(UrlEncode(test_part_result.message()) + "\n"); + } + + private: + // Creates a client socket and connects to the server. + void MakeConnection(); + + // Closes the socket. + void CloseConnection() { + GTEST_CHECK_(sockfd_ != -1) + << "CloseConnection() can be called only when there is a connection."; + + close(sockfd_); + sockfd_ = -1; + } + + // Sends a string to the socket. + void Send(const string& message) { + GTEST_CHECK_(sockfd_ != -1) + << "Send() can be called only when there is a connection."; + + const int len = static_cast(message.length()); + if (write(sockfd_, message.c_str(), len) != len) { + GTEST_LOG_(WARNING) + << "stream_result_to: failed to stream to " + << host_name_ << ":" << port_num_; + } + } + + int sockfd_; // socket file descriptor + const string host_name_; + const string port_num_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(StreamingListener); +}; // class StreamingListener + +// Checks if str contains '=', '&', '%' or '\n' characters. If yes, +// replaces them by "%xx" where xx is their hexadecimal value. For +// example, replaces "=" with "%3D". This algorithm is O(strlen(str)) +// in both time and space -- important as the input str may contain an +// arbitrarily long test failure message and stack trace. +string StreamingListener::UrlEncode(const char* str) { + string result; + result.reserve(strlen(str) + 1); + for (char ch = *str; ch != '\0'; ch = *++str) { + switch (ch) { + case '%': + case '=': + case '&': + case '\n': + result.append(String::Format("%%%02x", static_cast(ch))); + break; + default: + result.push_back(ch); + break; + } + } + return result; +} + +void StreamingListener::MakeConnection() { + GTEST_CHECK_(sockfd_ == -1) + << "MakeConnection() can't be called when there is already a connection."; + + addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; // To allow both IPv4 and IPv6 addresses. + hints.ai_socktype = SOCK_STREAM; + addrinfo* servinfo = NULL; + + // Use the getaddrinfo() to get a linked list of IP addresses for + // the given host name. + const int error_num = getaddrinfo( + host_name_.c_str(), port_num_.c_str(), &hints, &servinfo); + if (error_num != 0) { + GTEST_LOG_(WARNING) << "stream_result_to: getaddrinfo() failed: " + << gai_strerror(error_num); + } + + // Loop through all the results and connect to the first we can. + for (addrinfo* cur_addr = servinfo; sockfd_ == -1 && cur_addr != NULL; + cur_addr = cur_addr->ai_next) { + sockfd_ = socket( + cur_addr->ai_family, cur_addr->ai_socktype, cur_addr->ai_protocol); + if (sockfd_ != -1) { + // Connect the client socket to the server socket. + if (connect(sockfd_, cur_addr->ai_addr, cur_addr->ai_addrlen) == -1) { + close(sockfd_); + sockfd_ = -1; + } + } + } + + freeaddrinfo(servinfo); // all done with this structure + + if (sockfd_ == -1) { + GTEST_LOG_(WARNING) << "stream_result_to: failed to connect to " + << host_name_ << ":" << port_num_; + } +} + +// End of class Streaming Listener +#endif // GTEST_CAN_STREAM_RESULTS__ + +// Class ScopedTrace + +// Pushes the given source file location and message onto a per-thread +// trace stack maintained by Google Test. +// L < UnitTest::mutex_ +ScopedTrace::ScopedTrace(const char* file, int line, const Message& message) { + TraceInfo trace; + trace.file = file; + trace.line = line; + trace.message = message.GetString(); + + UnitTest::GetInstance()->PushGTestTrace(trace); +} + +// Pops the info pushed by the c'tor. +// L < UnitTest::mutex_ +ScopedTrace::~ScopedTrace() { + UnitTest::GetInstance()->PopGTestTrace(); +} + + +// class OsStackTraceGetter + +// Returns the current OS stack trace as a String. Parameters: +// +// max_depth - the maximum number of stack frames to be included +// in the trace. +// skip_count - the number of top frames to be skipped; doesn't count +// against max_depth. +// +// L < mutex_ +// We use "L < mutex_" to denote that the function may acquire mutex_. +String OsStackTraceGetter::CurrentStackTrace(int, int) { + return String(""); +} + +// L < mutex_ +void OsStackTraceGetter::UponLeavingGTest() { +} + +const char* const +OsStackTraceGetter::kElidedFramesMarker = + "... " GTEST_NAME_ " internal frames ..."; + +} // namespace internal + +// class TestEventListeners + +TestEventListeners::TestEventListeners() + : repeater_(new internal::TestEventRepeater()), + default_result_printer_(NULL), + default_xml_generator_(NULL) { +} + +TestEventListeners::~TestEventListeners() { delete repeater_; } + +// Returns the standard listener responsible for the default console +// output. Can be removed from the listeners list to shut down default +// console output. Note that removing this object from the listener list +// with Release transfers its ownership to the user. +void TestEventListeners::Append(TestEventListener* listener) { + repeater_->Append(listener); +} + +// Removes the given event listener from the list and returns it. It then +// becomes the caller's responsibility to delete the listener. Returns +// NULL if the listener is not found in the list. +TestEventListener* TestEventListeners::Release(TestEventListener* listener) { + if (listener == default_result_printer_) + default_result_printer_ = NULL; + else if (listener == default_xml_generator_) + default_xml_generator_ = NULL; + return repeater_->Release(listener); +} + +// Returns repeater that broadcasts the TestEventListener events to all +// subscribers. +TestEventListener* TestEventListeners::repeater() { return repeater_; } + +// Sets the default_result_printer attribute to the provided listener. +// The listener is also added to the listener list and previous +// default_result_printer is removed from it and deleted. The listener can +// also be NULL in which case it will not be added to the list. Does +// nothing if the previous and the current listener objects are the same. +void TestEventListeners::SetDefaultResultPrinter(TestEventListener* listener) { + if (default_result_printer_ != listener) { + // It is an error to pass this method a listener that is already in the + // list. + delete Release(default_result_printer_); + default_result_printer_ = listener; + if (listener != NULL) + Append(listener); + } +} + +// Sets the default_xml_generator attribute to the provided listener. The +// listener is also added to the listener list and previous +// default_xml_generator is removed from it and deleted. The listener can +// also be NULL in which case it will not be added to the list. Does +// nothing if the previous and the current listener objects are the same. +void TestEventListeners::SetDefaultXmlGenerator(TestEventListener* listener) { + if (default_xml_generator_ != listener) { + // It is an error to pass this method a listener that is already in the + // list. + delete Release(default_xml_generator_); + default_xml_generator_ = listener; + if (listener != NULL) + Append(listener); + } +} + +// Controls whether events will be forwarded by the repeater to the +// listeners in the list. +bool TestEventListeners::EventForwardingEnabled() const { + return repeater_->forwarding_enabled(); +} + +void TestEventListeners::SuppressEventForwarding() { + repeater_->set_forwarding_enabled(false); +} + +// class UnitTest + +// Gets the singleton UnitTest object. The first time this method is +// called, a UnitTest object is constructed and returned. Consecutive +// calls will return the same object. +// +// We don't protect this under mutex_ as a user is not supposed to +// call this before main() starts, from which point on the return +// value will never change. +UnitTest * UnitTest::GetInstance() { + // When compiled with MSVC 7.1 in optimized mode, destroying the + // UnitTest object upon exiting the program messes up the exit code, + // causing successful tests to appear failed. We have to use a + // different implementation in this case to bypass the compiler bug. + // This implementation makes the compiler happy, at the cost of + // leaking the UnitTest object. + + // CodeGear C++Builder insists on a public destructor for the + // default implementation. Use this implementation to keep good OO + // design with private destructor. + +#if (_MSC_VER == 1310 && !defined(_DEBUG)) || defined(__BORLANDC__) + static UnitTest* const instance = new UnitTest; + return instance; +#else + static UnitTest instance; + return &instance; +#endif // (_MSC_VER == 1310 && !defined(_DEBUG)) || defined(__BORLANDC__) +} + +// Gets the number of successful test cases. +int UnitTest::successful_test_case_count() const { + return impl()->successful_test_case_count(); +} + +// Gets the number of failed test cases. +int UnitTest::failed_test_case_count() const { + return impl()->failed_test_case_count(); +} + +// Gets the number of all test cases. +int UnitTest::total_test_case_count() const { + return impl()->total_test_case_count(); +} + +// Gets the number of all test cases that contain at least one test +// that should run. +int UnitTest::test_case_to_run_count() const { + return impl()->test_case_to_run_count(); +} + +// Gets the number of successful tests. +int UnitTest::successful_test_count() const { + return impl()->successful_test_count(); +} + +// Gets the number of failed tests. +int UnitTest::failed_test_count() const { return impl()->failed_test_count(); } + +// Gets the number of disabled tests. +int UnitTest::disabled_test_count() const { + return impl()->disabled_test_count(); +} + +// Gets the number of all tests. +int UnitTest::total_test_count() const { return impl()->total_test_count(); } + +// Gets the number of tests that should run. +int UnitTest::test_to_run_count() const { return impl()->test_to_run_count(); } + +// Gets the elapsed time, in milliseconds. +internal::TimeInMillis UnitTest::elapsed_time() const { + return impl()->elapsed_time(); +} + +// Returns true iff the unit test passed (i.e. all test cases passed). +bool UnitTest::Passed() const { return impl()->Passed(); } + +// Returns true iff the unit test failed (i.e. some test case failed +// or something outside of all tests failed). +bool UnitTest::Failed() const { return impl()->Failed(); } + +// Gets the i-th test case among all the test cases. i can range from 0 to +// total_test_case_count() - 1. If i is not in that range, returns NULL. +const TestCase* UnitTest::GetTestCase(int i) const { + return impl()->GetTestCase(i); +} + +// Gets the i-th test case among all the test cases. i can range from 0 to +// total_test_case_count() - 1. If i is not in that range, returns NULL. +TestCase* UnitTest::GetMutableTestCase(int i) { + return impl()->GetMutableTestCase(i); +} + +// Returns the list of event listeners that can be used to track events +// inside Google Test. +TestEventListeners& UnitTest::listeners() { + return *impl()->listeners(); +} + +// Registers and returns a global test environment. When a test +// program is run, all global test environments will be set-up in the +// order they were registered. After all tests in the program have +// finished, all global test environments will be torn-down in the +// *reverse* order they were registered. +// +// The UnitTest object takes ownership of the given environment. +// +// We don't protect this under mutex_, as we only support calling it +// from the main thread. +Environment* UnitTest::AddEnvironment(Environment* env) { + if (env == NULL) { + return NULL; + } + + impl_->environments().push_back(env); + return env; +} + +// Adds a TestPartResult to the current TestResult object. All Google Test +// assertion macros (e.g. ASSERT_TRUE, EXPECT_EQ, etc) eventually call +// this to report their results. The user code should use the +// assertion macros instead of calling this directly. +// L < mutex_ +void UnitTest::AddTestPartResult(TestPartResult::Type result_type, + const char* file_name, + int line_number, + const internal::String& message, + const internal::String& os_stack_trace) { + Message msg; + msg << message; + + internal::MutexLock lock(&mutex_); + if (impl_->gtest_trace_stack().size() > 0) { + msg << "\n" << GTEST_NAME_ << " trace:"; + + for (int i = static_cast(impl_->gtest_trace_stack().size()); + i > 0; --i) { + const internal::TraceInfo& trace = impl_->gtest_trace_stack()[i - 1]; + msg << "\n" << internal::FormatFileLocation(trace.file, trace.line) + << " " << trace.message; + } + } + + if (os_stack_trace.c_str() != NULL && !os_stack_trace.empty()) { + msg << internal::kStackTraceMarker << os_stack_trace; + } + + const TestPartResult result = + TestPartResult(result_type, file_name, line_number, + msg.GetString().c_str()); + impl_->GetTestPartResultReporterForCurrentThread()-> + ReportTestPartResult(result); + + if (result_type != TestPartResult::kSuccess) { + // gtest_break_on_failure takes precedence over + // gtest_throw_on_failure. This allows a user to set the latter + // in the code (perhaps in order to use Google Test assertions + // with another testing framework) and specify the former on the + // command line for debugging. + if (GTEST_FLAG(break_on_failure)) { +#if GTEST_OS_WINDOWS + // Using DebugBreak on Windows allows gtest to still break into a debugger + // when a failure happens and both the --gtest_break_on_failure and + // the --gtest_catch_exceptions flags are specified. + DebugBreak(); +#else + // Dereference NULL through a volatile pointer to prevent the compiler + // from removing. We use this rather than abort() or __builtin_trap() for + // portability: Symbian doesn't implement abort() well, and some debuggers + // don't correctly trap abort(). + *static_cast(NULL) = 1; +#endif // GTEST_OS_WINDOWS + } else if (GTEST_FLAG(throw_on_failure)) { +#if GTEST_HAS_EXCEPTIONS + throw GoogleTestFailureException(result); +#else + // We cannot call abort() as it generates a pop-up in debug mode + // that cannot be suppressed in VC 7.1 or below. + exit(1); +#endif + } + } +} + +// Creates and adds a property to the current TestResult. If a property matching +// the supplied value already exists, updates its value instead. +void UnitTest::RecordPropertyForCurrentTest(const char* key, + const char* value) { + const TestProperty test_property(key, value); + impl_->current_test_result()->RecordProperty(test_property); +} + +// Runs all tests in this UnitTest object and prints the result. +// Returns 0 if successful, or 1 otherwise. +// +// We don't protect this under mutex_, as we only support calling it +// from the main thread. +int UnitTest::Run() { + // Captures the value of GTEST_FLAG(catch_exceptions). This value will be + // used for the duration of the program. + impl()->set_catch_exceptions(GTEST_FLAG(catch_exceptions)); + +#if GTEST_HAS_SEH + const bool in_death_test_child_process = + internal::GTEST_FLAG(internal_run_death_test).length() > 0; + + // Either the user wants Google Test to catch exceptions thrown by the + // tests or this is executing in the context of death test child + // process. In either case the user does not want to see pop-up dialogs + // about crashes - they are expected. + if (impl()->catch_exceptions() || in_death_test_child_process) { + +# if !GTEST_OS_WINDOWS_MOBILE + // SetErrorMode doesn't exist on CE. + SetErrorMode(SEM_FAILCRITICALERRORS | SEM_NOALIGNMENTFAULTEXCEPT | + SEM_NOGPFAULTERRORBOX | SEM_NOOPENFILEERRORBOX); +# endif // !GTEST_OS_WINDOWS_MOBILE + +# if (defined(_MSC_VER) || GTEST_OS_WINDOWS_MINGW) && !GTEST_OS_WINDOWS_MOBILE + // Death test children can be terminated with _abort(). On Windows, + // _abort() can show a dialog with a warning message. This forces the + // abort message to go to stderr instead. + _set_error_mode(_OUT_TO_STDERR); +# endif + +# if _MSC_VER >= 1400 && !GTEST_OS_WINDOWS_MOBILE + // In the debug version, Visual Studio pops up a separate dialog + // offering a choice to debug the aborted program. We need to suppress + // this dialog or it will pop up for every EXPECT/ASSERT_DEATH statement + // executed. Google Test will notify the user of any unexpected + // failure via stderr. + // + // VC++ doesn't define _set_abort_behavior() prior to the version 8.0. + // Users of prior VC versions shall suffer the agony and pain of + // clicking through the countless debug dialogs. + // TODO(vladl@google.com): find a way to suppress the abort dialog() in the + // debug mode when compiled with VC 7.1 or lower. + if (!GTEST_FLAG(break_on_failure)) + _set_abort_behavior( + 0x0, // Clear the following flags: + _WRITE_ABORT_MSG | _CALL_REPORTFAULT); // pop-up window, core dump. +# endif + + } +#endif // GTEST_HAS_SEH + + return internal::HandleExceptionsInMethodIfSupported( + impl(), + &internal::UnitTestImpl::RunAllTests, + "auxiliary test code (environments or event listeners)") ? 0 : 1; +} + +// Returns the working directory when the first TEST() or TEST_F() was +// executed. +const char* UnitTest::original_working_dir() const { + return impl_->original_working_dir_.c_str(); +} + +// Returns the TestCase object for the test that's currently running, +// or NULL if no test is running. +// L < mutex_ +const TestCase* UnitTest::current_test_case() const { + internal::MutexLock lock(&mutex_); + return impl_->current_test_case(); +} + +// Returns the TestInfo object for the test that's currently running, +// or NULL if no test is running. +// L < mutex_ +const TestInfo* UnitTest::current_test_info() const { + internal::MutexLock lock(&mutex_); + return impl_->current_test_info(); +} + +// Returns the random seed used at the start of the current test run. +int UnitTest::random_seed() const { return impl_->random_seed(); } + +#if GTEST_HAS_PARAM_TEST +// Returns ParameterizedTestCaseRegistry object used to keep track of +// value-parameterized tests and instantiate and register them. +// L < mutex_ +internal::ParameterizedTestCaseRegistry& + UnitTest::parameterized_test_registry() { + return impl_->parameterized_test_registry(); +} +#endif // GTEST_HAS_PARAM_TEST + +// Creates an empty UnitTest. +UnitTest::UnitTest() { + impl_ = new internal::UnitTestImpl(this); +} + +// Destructor of UnitTest. +UnitTest::~UnitTest() { + delete impl_; +} + +// Pushes a trace defined by SCOPED_TRACE() on to the per-thread +// Google Test trace stack. +// L < mutex_ +void UnitTest::PushGTestTrace(const internal::TraceInfo& trace) { + internal::MutexLock lock(&mutex_); + impl_->gtest_trace_stack().push_back(trace); +} + +// Pops a trace from the per-thread Google Test trace stack. +// L < mutex_ +void UnitTest::PopGTestTrace() { + internal::MutexLock lock(&mutex_); + impl_->gtest_trace_stack().pop_back(); +} + +namespace internal { + +UnitTestImpl::UnitTestImpl(UnitTest* parent) + : parent_(parent), +#ifdef _MSC_VER +# pragma warning(push) // Saves the current warning state. +# pragma warning(disable:4355) // Temporarily disables warning 4355 + // (using this in initializer). + default_global_test_part_result_reporter_(this), + default_per_thread_test_part_result_reporter_(this), +# pragma warning(pop) // Restores the warning state again. +#else + default_global_test_part_result_reporter_(this), + default_per_thread_test_part_result_reporter_(this), +#endif // _MSC_VER + global_test_part_result_repoter_( + &default_global_test_part_result_reporter_), + per_thread_test_part_result_reporter_( + &default_per_thread_test_part_result_reporter_), +#if GTEST_HAS_PARAM_TEST + parameterized_test_registry_(), + parameterized_tests_registered_(false), +#endif // GTEST_HAS_PARAM_TEST + last_death_test_case_(-1), + current_test_case_(NULL), + current_test_info_(NULL), + ad_hoc_test_result_(), + os_stack_trace_getter_(NULL), + post_flag_parse_init_performed_(false), + random_seed_(0), // Will be overridden by the flag before first use. + random_(0), // Will be reseeded before first use. + elapsed_time_(0), +#if GTEST_HAS_DEATH_TEST + internal_run_death_test_flag_(NULL), + death_test_factory_(new DefaultDeathTestFactory), +#endif + // Will be overridden by the flag before first use. + catch_exceptions_(false) { + listeners()->SetDefaultResultPrinter(new PrettyUnitTestResultPrinter); +} + +UnitTestImpl::~UnitTestImpl() { + // Deletes every TestCase. + ForEach(test_cases_, internal::Delete); + + // Deletes every Environment. + ForEach(environments_, internal::Delete); + + delete os_stack_trace_getter_; +} + +#if GTEST_HAS_DEATH_TEST +// Disables event forwarding if the control is currently in a death test +// subprocess. Must not be called before InitGoogleTest. +void UnitTestImpl::SuppressTestEventsIfInSubprocess() { + if (internal_run_death_test_flag_.get() != NULL) + listeners()->SuppressEventForwarding(); +} +#endif // GTEST_HAS_DEATH_TEST + +// Initializes event listeners performing XML output as specified by +// UnitTestOptions. Must not be called before InitGoogleTest. +void UnitTestImpl::ConfigureXmlOutput() { + const String& output_format = UnitTestOptions::GetOutputFormat(); + if (output_format == "xml") { + listeners()->SetDefaultXmlGenerator(new XmlUnitTestResultPrinter( + UnitTestOptions::GetAbsolutePathToOutputFile().c_str())); + } else if (output_format != "") { + printf("WARNING: unrecognized output format \"%s\" ignored.\n", + output_format.c_str()); + fflush(stdout); + } +} + +#if GTEST_CAN_STREAM_RESULTS_ +// Initializes event listeners for streaming test results in String form. +// Must not be called before InitGoogleTest. +void UnitTestImpl::ConfigureStreamingOutput() { + const string& target = GTEST_FLAG(stream_result_to); + if (!target.empty()) { + const size_t pos = target.find(':'); + if (pos != string::npos) { + listeners()->Append(new StreamingListener(target.substr(0, pos), + target.substr(pos+1))); + } else { + printf("WARNING: unrecognized streaming target \"%s\" ignored.\n", + target.c_str()); + fflush(stdout); + } + } +} +#endif // GTEST_CAN_STREAM_RESULTS_ + +// Performs initialization dependent upon flag values obtained in +// ParseGoogleTestFlagsOnly. Is called from InitGoogleTest after the call to +// ParseGoogleTestFlagsOnly. In case a user neglects to call InitGoogleTest +// this function is also called from RunAllTests. Since this function can be +// called more than once, it has to be idempotent. +void UnitTestImpl::PostFlagParsingInit() { + // Ensures that this function does not execute more than once. + if (!post_flag_parse_init_performed_) { + post_flag_parse_init_performed_ = true; + +#if GTEST_HAS_DEATH_TEST + InitDeathTestSubprocessControlInfo(); + SuppressTestEventsIfInSubprocess(); +#endif // GTEST_HAS_DEATH_TEST + + // Registers parameterized tests. This makes parameterized tests + // available to the UnitTest reflection API without running + // RUN_ALL_TESTS. + RegisterParameterizedTests(); + + // Configures listeners for XML output. This makes it possible for users + // to shut down the default XML output before invoking RUN_ALL_TESTS. + ConfigureXmlOutput(); + +#if GTEST_CAN_STREAM_RESULTS_ + // Configures listeners for streaming test results to the specified server. + ConfigureStreamingOutput(); +#endif // GTEST_CAN_STREAM_RESULTS_ + } +} + +// A predicate that checks the name of a TestCase against a known +// value. +// +// This is used for implementation of the UnitTest class only. We put +// it in the anonymous namespace to prevent polluting the outer +// namespace. +// +// TestCaseNameIs is copyable. +class TestCaseNameIs { + public: + // Constructor. + explicit TestCaseNameIs(const String& name) + : name_(name) {} + + // Returns true iff the name of test_case matches name_. + bool operator()(const TestCase* test_case) const { + return test_case != NULL && strcmp(test_case->name(), name_.c_str()) == 0; + } + + private: + String name_; +}; + +// Finds and returns a TestCase with the given name. If one doesn't +// exist, creates one and returns it. It's the CALLER'S +// RESPONSIBILITY to ensure that this function is only called WHEN THE +// TESTS ARE NOT SHUFFLED. +// +// Arguments: +// +// test_case_name: name of the test case +// type_param: the name of the test case's type parameter, or NULL if +// this is not a typed or a type-parameterized test case. +// set_up_tc: pointer to the function that sets up the test case +// tear_down_tc: pointer to the function that tears down the test case +TestCase* UnitTestImpl::GetTestCase(const char* test_case_name, + const char* type_param, + Test::SetUpTestCaseFunc set_up_tc, + Test::TearDownTestCaseFunc tear_down_tc) { + // Can we find a TestCase with the given name? + const std::vector::const_iterator test_case = + std::find_if(test_cases_.begin(), test_cases_.end(), + TestCaseNameIs(test_case_name)); + + if (test_case != test_cases_.end()) + return *test_case; + + // No. Let's create one. + TestCase* const new_test_case = + new TestCase(test_case_name, type_param, set_up_tc, tear_down_tc); + + // Is this a death test case? + if (internal::UnitTestOptions::MatchesFilter(String(test_case_name), + kDeathTestCaseFilter)) { + // Yes. Inserts the test case after the last death test case + // defined so far. This only works when the test cases haven't + // been shuffled. Otherwise we may end up running a death test + // after a non-death test. + ++last_death_test_case_; + test_cases_.insert(test_cases_.begin() + last_death_test_case_, + new_test_case); + } else { + // No. Appends to the end of the list. + test_cases_.push_back(new_test_case); + } + + test_case_indices_.push_back(static_cast(test_case_indices_.size())); + return new_test_case; +} + +// Helpers for setting up / tearing down the given environment. They +// are for use in the ForEach() function. +static void SetUpEnvironment(Environment* env) { env->SetUp(); } +static void TearDownEnvironment(Environment* env) { env->TearDown(); } + +// Runs all tests in this UnitTest object, prints the result, and +// returns true if all tests are successful. If any exception is +// thrown during a test, the test is considered to be failed, but the +// rest of the tests will still be run. +// +// When parameterized tests are enabled, it expands and registers +// parameterized tests first in RegisterParameterizedTests(). +// All other functions called from RunAllTests() may safely assume that +// parameterized tests are ready to be counted and run. +bool UnitTestImpl::RunAllTests() { + // Makes sure InitGoogleTest() was called. + if (!GTestIsInitialized()) { + printf("%s", + "\nThis test program did NOT call ::testing::InitGoogleTest " + "before calling RUN_ALL_TESTS(). Please fix it.\n"); + return false; + } + + // Do not run any test if the --help flag was specified. + if (g_help_flag) + return true; + + // Repeats the call to the post-flag parsing initialization in case the + // user didn't call InitGoogleTest. + PostFlagParsingInit(); + + // Even if sharding is not on, test runners may want to use the + // GTEST_SHARD_STATUS_FILE to query whether the test supports the sharding + // protocol. + internal::WriteToShardStatusFileIfNeeded(); + + // True iff we are in a subprocess for running a thread-safe-style + // death test. + bool in_subprocess_for_death_test = false; + +#if GTEST_HAS_DEATH_TEST + in_subprocess_for_death_test = (internal_run_death_test_flag_.get() != NULL); +#endif // GTEST_HAS_DEATH_TEST + + const bool should_shard = ShouldShard(kTestTotalShards, kTestShardIndex, + in_subprocess_for_death_test); + + // Compares the full test names with the filter to decide which + // tests to run. + const bool has_tests_to_run = FilterTests(should_shard + ? HONOR_SHARDING_PROTOCOL + : IGNORE_SHARDING_PROTOCOL) > 0; + + // Lists the tests and exits if the --gtest_list_tests flag was specified. + if (GTEST_FLAG(list_tests)) { + // This must be called *after* FilterTests() has been called. + ListTestsMatchingFilter(); + return true; + } + + random_seed_ = GTEST_FLAG(shuffle) ? + GetRandomSeedFromFlag(GTEST_FLAG(random_seed)) : 0; + + // True iff at least one test has failed. + bool failed = false; + + TestEventListener* repeater = listeners()->repeater(); + + repeater->OnTestProgramStart(*parent_); + + // How many times to repeat the tests? We don't want to repeat them + // when we are inside the subprocess of a death test. + const int repeat = in_subprocess_for_death_test ? 1 : GTEST_FLAG(repeat); + // Repeats forever if the repeat count is negative. + const bool forever = repeat < 0; + for (int i = 0; forever || i != repeat; i++) { + // We want to preserve failures generated by ad-hoc test + // assertions executed before RUN_ALL_TESTS(). + ClearNonAdHocTestResult(); + + const TimeInMillis start = GetTimeInMillis(); + + // Shuffles test cases and tests if requested. + if (has_tests_to_run && GTEST_FLAG(shuffle)) { + random()->Reseed(random_seed_); + // This should be done before calling OnTestIterationStart(), + // such that a test event listener can see the actual test order + // in the event. + ShuffleTests(); + } + + // Tells the unit test event listeners that the tests are about to start. + repeater->OnTestIterationStart(*parent_, i); + + // Runs each test case if there is at least one test to run. + if (has_tests_to_run) { + // Sets up all environments beforehand. + repeater->OnEnvironmentsSetUpStart(*parent_); + ForEach(environments_, SetUpEnvironment); + repeater->OnEnvironmentsSetUpEnd(*parent_); + + // Runs the tests only if there was no fatal failure during global + // set-up. + if (!Test::HasFatalFailure()) { + for (int test_index = 0; test_index < total_test_case_count(); + test_index++) { + GetMutableTestCase(test_index)->Run(); + } + } + + // Tears down all environments in reverse order afterwards. + repeater->OnEnvironmentsTearDownStart(*parent_); + std::for_each(environments_.rbegin(), environments_.rend(), + TearDownEnvironment); + repeater->OnEnvironmentsTearDownEnd(*parent_); + } + + elapsed_time_ = GetTimeInMillis() - start; + + // Tells the unit test event listener that the tests have just finished. + repeater->OnTestIterationEnd(*parent_, i); + + // Gets the result and clears it. + if (!Passed()) { + failed = true; + } + + // Restores the original test order after the iteration. This + // allows the user to quickly repro a failure that happens in the + // N-th iteration without repeating the first (N - 1) iterations. + // This is not enclosed in "if (GTEST_FLAG(shuffle)) { ... }", in + // case the user somehow changes the value of the flag somewhere + // (it's always safe to unshuffle the tests). + UnshuffleTests(); + + if (GTEST_FLAG(shuffle)) { + // Picks a new random seed for each iteration. + random_seed_ = GetNextRandomSeed(random_seed_); + } + } + + repeater->OnTestProgramEnd(*parent_); + + return !failed; +} + +// Reads the GTEST_SHARD_STATUS_FILE environment variable, and creates the file +// if the variable is present. If a file already exists at this location, this +// function will write over it. If the variable is present, but the file cannot +// be created, prints an error and exits. +void WriteToShardStatusFileIfNeeded() { + const char* const test_shard_file = posix::GetEnv(kTestShardStatusFile); + if (test_shard_file != NULL) { + FILE* const file = posix::FOpen(test_shard_file, "w"); + if (file == NULL) { + ColoredPrintf(COLOR_RED, + "Could not write to the test shard status file \"%s\" " + "specified by the %s environment variable.\n", + test_shard_file, kTestShardStatusFile); + fflush(stdout); + exit(EXIT_FAILURE); + } + fclose(file); + } +} + +// Checks whether sharding is enabled by examining the relevant +// environment variable values. If the variables are present, +// but inconsistent (i.e., shard_index >= total_shards), prints +// an error and exits. If in_subprocess_for_death_test, sharding is +// disabled because it must only be applied to the original test +// process. Otherwise, we could filter out death tests we intended to execute. +bool ShouldShard(const char* total_shards_env, + const char* shard_index_env, + bool in_subprocess_for_death_test) { + if (in_subprocess_for_death_test) { + return false; + } + + const Int32 total_shards = Int32FromEnvOrDie(total_shards_env, -1); + const Int32 shard_index = Int32FromEnvOrDie(shard_index_env, -1); + + if (total_shards == -1 && shard_index == -1) { + return false; + } else if (total_shards == -1 && shard_index != -1) { + const Message msg = Message() + << "Invalid environment variables: you have " + << kTestShardIndex << " = " << shard_index + << ", but have left " << kTestTotalShards << " unset.\n"; + ColoredPrintf(COLOR_RED, msg.GetString().c_str()); + fflush(stdout); + exit(EXIT_FAILURE); + } else if (total_shards != -1 && shard_index == -1) { + const Message msg = Message() + << "Invalid environment variables: you have " + << kTestTotalShards << " = " << total_shards + << ", but have left " << kTestShardIndex << " unset.\n"; + ColoredPrintf(COLOR_RED, msg.GetString().c_str()); + fflush(stdout); + exit(EXIT_FAILURE); + } else if (shard_index < 0 || shard_index >= total_shards) { + const Message msg = Message() + << "Invalid environment variables: we require 0 <= " + << kTestShardIndex << " < " << kTestTotalShards + << ", but you have " << kTestShardIndex << "=" << shard_index + << ", " << kTestTotalShards << "=" << total_shards << ".\n"; + ColoredPrintf(COLOR_RED, msg.GetString().c_str()); + fflush(stdout); + exit(EXIT_FAILURE); + } + + return total_shards > 1; +} + +// Parses the environment variable var as an Int32. If it is unset, +// returns default_val. If it is not an Int32, prints an error +// and aborts. +Int32 Int32FromEnvOrDie(const char* var, Int32 default_val) { + const char* str_val = posix::GetEnv(var); + if (str_val == NULL) { + return default_val; + } + + Int32 result; + if (!ParseInt32(Message() << "The value of environment variable " << var, + str_val, &result)) { + exit(EXIT_FAILURE); + } + return result; +} + +// Given the total number of shards, the shard index, and the test id, +// returns true iff the test should be run on this shard. The test id is +// some arbitrary but unique non-negative integer assigned to each test +// method. Assumes that 0 <= shard_index < total_shards. +bool ShouldRunTestOnShard(int total_shards, int shard_index, int test_id) { + return (test_id % total_shards) == shard_index; +} + +// Compares the name of each test with the user-specified filter to +// decide whether the test should be run, then records the result in +// each TestCase and TestInfo object. +// If shard_tests == true, further filters tests based on sharding +// variables in the environment - see +// http://code.google.com/p/googletest/wiki/GoogleTestAdvancedGuide. +// Returns the number of tests that should run. +int UnitTestImpl::FilterTests(ReactionToSharding shard_tests) { + const Int32 total_shards = shard_tests == HONOR_SHARDING_PROTOCOL ? + Int32FromEnvOrDie(kTestTotalShards, -1) : -1; + const Int32 shard_index = shard_tests == HONOR_SHARDING_PROTOCOL ? + Int32FromEnvOrDie(kTestShardIndex, -1) : -1; + + // num_runnable_tests are the number of tests that will + // run across all shards (i.e., match filter and are not disabled). + // num_selected_tests are the number of tests to be run on + // this shard. + int num_runnable_tests = 0; + int num_selected_tests = 0; + for (size_t i = 0; i < test_cases_.size(); i++) { + TestCase* const test_case = test_cases_[i]; + const String &test_case_name = test_case->name(); + test_case->set_should_run(false); + + for (size_t j = 0; j < test_case->test_info_list().size(); j++) { + TestInfo* const test_info = test_case->test_info_list()[j]; + const String test_name(test_info->name()); + // A test is disabled if test case name or test name matches + // kDisableTestFilter. + const bool is_disabled = + internal::UnitTestOptions::MatchesFilter(test_case_name, + kDisableTestFilter) || + internal::UnitTestOptions::MatchesFilter(test_name, + kDisableTestFilter); + test_info->is_disabled_ = is_disabled; + + const bool matches_filter = + internal::UnitTestOptions::FilterMatchesTest(test_case_name, + test_name); + test_info->matches_filter_ = matches_filter; + + const bool is_runnable = + (GTEST_FLAG(also_run_disabled_tests) || !is_disabled) && + matches_filter; + + const bool is_selected = is_runnable && + (shard_tests == IGNORE_SHARDING_PROTOCOL || + ShouldRunTestOnShard(total_shards, shard_index, + num_runnable_tests)); + + num_runnable_tests += is_runnable; + num_selected_tests += is_selected; + + test_info->should_run_ = is_selected; + test_case->set_should_run(test_case->should_run() || is_selected); + } + } + return num_selected_tests; +} + +// Prints the names of the tests matching the user-specified filter flag. +void UnitTestImpl::ListTestsMatchingFilter() { + for (size_t i = 0; i < test_cases_.size(); i++) { + const TestCase* const test_case = test_cases_[i]; + bool printed_test_case_name = false; + + for (size_t j = 0; j < test_case->test_info_list().size(); j++) { + const TestInfo* const test_info = + test_case->test_info_list()[j]; + if (test_info->matches_filter_) { + if (!printed_test_case_name) { + printed_test_case_name = true; + printf("%s.\n", test_case->name()); + } + printf(" %s\n", test_info->name()); + } + } + } + fflush(stdout); +} + +// Sets the OS stack trace getter. +// +// Does nothing if the input and the current OS stack trace getter are +// the same; otherwise, deletes the old getter and makes the input the +// current getter. +void UnitTestImpl::set_os_stack_trace_getter( + OsStackTraceGetterInterface* getter) { + if (os_stack_trace_getter_ != getter) { + delete os_stack_trace_getter_; + os_stack_trace_getter_ = getter; + } +} + +// Returns the current OS stack trace getter if it is not NULL; +// otherwise, creates an OsStackTraceGetter, makes it the current +// getter, and returns it. +OsStackTraceGetterInterface* UnitTestImpl::os_stack_trace_getter() { + if (os_stack_trace_getter_ == NULL) { + os_stack_trace_getter_ = new OsStackTraceGetter; + } + + return os_stack_trace_getter_; +} + +// Returns the TestResult for the test that's currently running, or +// the TestResult for the ad hoc test if no test is running. +TestResult* UnitTestImpl::current_test_result() { + return current_test_info_ ? + &(current_test_info_->result_) : &ad_hoc_test_result_; +} + +// Shuffles all test cases, and the tests within each test case, +// making sure that death tests are still run first. +void UnitTestImpl::ShuffleTests() { + // Shuffles the death test cases. + ShuffleRange(random(), 0, last_death_test_case_ + 1, &test_case_indices_); + + // Shuffles the non-death test cases. + ShuffleRange(random(), last_death_test_case_ + 1, + static_cast(test_cases_.size()), &test_case_indices_); + + // Shuffles the tests inside each test case. + for (size_t i = 0; i < test_cases_.size(); i++) { + test_cases_[i]->ShuffleTests(random()); + } +} + +// Restores the test cases and tests to their order before the first shuffle. +void UnitTestImpl::UnshuffleTests() { + for (size_t i = 0; i < test_cases_.size(); i++) { + // Unshuffles the tests in each test case. + test_cases_[i]->UnshuffleTests(); + // Resets the index of each test case. + test_case_indices_[i] = static_cast(i); + } +} + +// Returns the current OS stack trace as a String. +// +// The maximum number of stack frames to be included is specified by +// the gtest_stack_trace_depth flag. The skip_count parameter +// specifies the number of top frames to be skipped, which doesn't +// count against the number of frames to be included. +// +// For example, if Foo() calls Bar(), which in turn calls +// GetCurrentOsStackTraceExceptTop(..., 1), Foo() will be included in +// the trace but Bar() and GetCurrentOsStackTraceExceptTop() won't. +String GetCurrentOsStackTraceExceptTop(UnitTest* /*unit_test*/, + int skip_count) { + // We pass skip_count + 1 to skip this wrapper function in addition + // to what the user really wants to skip. + return GetUnitTestImpl()->CurrentOsStackTraceExceptTop(skip_count + 1); +} + +// Used by the GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_ macro to +// suppress unreachable code warnings. +namespace { +class ClassUniqueToAlwaysTrue {}; +} + +bool IsTrue(bool condition) { return condition; } + +bool AlwaysTrue() { +#if GTEST_HAS_EXCEPTIONS + // This condition is always false so AlwaysTrue() never actually throws, + // but it makes the compiler think that it may throw. + if (IsTrue(false)) + throw ClassUniqueToAlwaysTrue(); +#endif // GTEST_HAS_EXCEPTIONS + return true; +} + +// If *pstr starts with the given prefix, modifies *pstr to be right +// past the prefix and returns true; otherwise leaves *pstr unchanged +// and returns false. None of pstr, *pstr, and prefix can be NULL. +bool SkipPrefix(const char* prefix, const char** pstr) { + const size_t prefix_len = strlen(prefix); + if (strncmp(*pstr, prefix, prefix_len) == 0) { + *pstr += prefix_len; + return true; + } + return false; +} + +// Parses a string as a command line flag. The string should have +// the format "--flag=value". When def_optional is true, the "=value" +// part can be omitted. +// +// Returns the value of the flag, or NULL if the parsing failed. +const char* ParseFlagValue(const char* str, + const char* flag, + bool def_optional) { + // str and flag must not be NULL. + if (str == NULL || flag == NULL) return NULL; + + // The flag must start with "--" followed by GTEST_FLAG_PREFIX_. + const String flag_str = String::Format("--%s%s", GTEST_FLAG_PREFIX_, flag); + const size_t flag_len = flag_str.length(); + if (strncmp(str, flag_str.c_str(), flag_len) != 0) return NULL; + + // Skips the flag name. + const char* flag_end = str + flag_len; + + // When def_optional is true, it's OK to not have a "=value" part. + if (def_optional && (flag_end[0] == '\0')) { + return flag_end; + } + + // If def_optional is true and there are more characters after the + // flag name, or if def_optional is false, there must be a '=' after + // the flag name. + if (flag_end[0] != '=') return NULL; + + // Returns the string after "=". + return flag_end + 1; +} + +// Parses a string for a bool flag, in the form of either +// "--flag=value" or "--flag". +// +// In the former case, the value is taken as true as long as it does +// not start with '0', 'f', or 'F'. +// +// In the latter case, the value is taken as true. +// +// On success, stores the value of the flag in *value, and returns +// true. On failure, returns false without changing *value. +bool ParseBoolFlag(const char* str, const char* flag, bool* value) { + // Gets the value of the flag as a string. + const char* const value_str = ParseFlagValue(str, flag, true); + + // Aborts if the parsing failed. + if (value_str == NULL) return false; + + // Converts the string value to a bool. + *value = !(*value_str == '0' || *value_str == 'f' || *value_str == 'F'); + return true; +} + +// Parses a string for an Int32 flag, in the form of +// "--flag=value". +// +// On success, stores the value of the flag in *value, and returns +// true. On failure, returns false without changing *value. +bool ParseInt32Flag(const char* str, const char* flag, Int32* value) { + // Gets the value of the flag as a string. + const char* const value_str = ParseFlagValue(str, flag, false); + + // Aborts if the parsing failed. + if (value_str == NULL) return false; + + // Sets *value to the value of the flag. + return ParseInt32(Message() << "The value of flag --" << flag, + value_str, value); +} + +// Parses a string for a string flag, in the form of +// "--flag=value". +// +// On success, stores the value of the flag in *value, and returns +// true. On failure, returns false without changing *value. +bool ParseStringFlag(const char* str, const char* flag, String* value) { + // Gets the value of the flag as a string. + const char* const value_str = ParseFlagValue(str, flag, false); + + // Aborts if the parsing failed. + if (value_str == NULL) return false; + + // Sets *value to the value of the flag. + *value = value_str; + return true; +} + +// Determines whether a string has a prefix that Google Test uses for its +// flags, i.e., starts with GTEST_FLAG_PREFIX_ or GTEST_FLAG_PREFIX_DASH_. +// If Google Test detects that a command line flag has its prefix but is not +// recognized, it will print its help message. Flags starting with +// GTEST_INTERNAL_PREFIX_ followed by "internal_" are considered Google Test +// internal flags and do not trigger the help message. +static bool HasGoogleTestFlagPrefix(const char* str) { + return (SkipPrefix("--", &str) || + SkipPrefix("-", &str) || + SkipPrefix("/", &str)) && + !SkipPrefix(GTEST_FLAG_PREFIX_ "internal_", &str) && + (SkipPrefix(GTEST_FLAG_PREFIX_, &str) || + SkipPrefix(GTEST_FLAG_PREFIX_DASH_, &str)); +} + +// Prints a string containing code-encoded text. The following escape +// sequences can be used in the string to control the text color: +// +// @@ prints a single '@' character. +// @R changes the color to red. +// @G changes the color to green. +// @Y changes the color to yellow. +// @D changes to the default terminal text color. +// +// TODO(wan@google.com): Write tests for this once we add stdout +// capturing to Google Test. +static void PrintColorEncoded(const char* str) { + GTestColor color = COLOR_DEFAULT; // The current color. + + // Conceptually, we split the string into segments divided by escape + // sequences. Then we print one segment at a time. At the end of + // each iteration, the str pointer advances to the beginning of the + // next segment. + for (;;) { + const char* p = strchr(str, '@'); + if (p == NULL) { + ColoredPrintf(color, "%s", str); + return; + } + + ColoredPrintf(color, "%s", String(str, p - str).c_str()); + + const char ch = p[1]; + str = p + 2; + if (ch == '@') { + ColoredPrintf(color, "@"); + } else if (ch == 'D') { + color = COLOR_DEFAULT; + } else if (ch == 'R') { + color = COLOR_RED; + } else if (ch == 'G') { + color = COLOR_GREEN; + } else if (ch == 'Y') { + color = COLOR_YELLOW; + } else { + --str; + } + } +} + +static const char kColorEncodedHelpMessage[] = +"This program contains tests written using " GTEST_NAME_ ". You can use the\n" +"following command line flags to control its behavior:\n" +"\n" +"Test Selection:\n" +" @G--" GTEST_FLAG_PREFIX_ "list_tests@D\n" +" List the names of all tests instead of running them. The name of\n" +" TEST(Foo, Bar) is \"Foo.Bar\".\n" +" @G--" GTEST_FLAG_PREFIX_ "filter=@YPOSTIVE_PATTERNS" + "[@G-@YNEGATIVE_PATTERNS]@D\n" +" Run only the tests whose name matches one of the positive patterns but\n" +" none of the negative patterns. '?' matches any single character; '*'\n" +" matches any substring; ':' separates two patterns.\n" +" @G--" GTEST_FLAG_PREFIX_ "also_run_disabled_tests@D\n" +" Run all disabled tests too.\n" +"\n" +"Test Execution:\n" +" @G--" GTEST_FLAG_PREFIX_ "repeat=@Y[COUNT]@D\n" +" Run the tests repeatedly; use a negative count to repeat forever.\n" +" @G--" GTEST_FLAG_PREFIX_ "shuffle@D\n" +" Randomize tests' orders on every iteration.\n" +" @G--" GTEST_FLAG_PREFIX_ "random_seed=@Y[NUMBER]@D\n" +" Random number seed to use for shuffling test orders (between 1 and\n" +" 99999, or 0 to use a seed based on the current time).\n" +"\n" +"Test Output:\n" +" @G--" GTEST_FLAG_PREFIX_ "color=@Y(@Gyes@Y|@Gno@Y|@Gauto@Y)@D\n" +" Enable/disable colored output. The default is @Gauto@D.\n" +" -@G-" GTEST_FLAG_PREFIX_ "print_time=0@D\n" +" Don't print the elapsed time of each test.\n" +" @G--" GTEST_FLAG_PREFIX_ "output=xml@Y[@G:@YDIRECTORY_PATH@G" + GTEST_PATH_SEP_ "@Y|@G:@YFILE_PATH]@D\n" +" Generate an XML report in the given directory or with the given file\n" +" name. @YFILE_PATH@D defaults to @Gtest_details.xml@D.\n" +#if GTEST_CAN_STREAM_RESULTS_ +" @G--" GTEST_FLAG_PREFIX_ "stream_result_to=@YHOST@G:@YPORT@D\n" +" Stream test results to the given server.\n" +#endif // GTEST_CAN_STREAM_RESULTS_ +"\n" +"Assertion Behavior:\n" +#if GTEST_HAS_DEATH_TEST && !GTEST_OS_WINDOWS +" @G--" GTEST_FLAG_PREFIX_ "death_test_style=@Y(@Gfast@Y|@Gthreadsafe@Y)@D\n" +" Set the default death test style.\n" +#endif // GTEST_HAS_DEATH_TEST && !GTEST_OS_WINDOWS +" @G--" GTEST_FLAG_PREFIX_ "break_on_failure@D\n" +" Turn assertion failures into debugger break-points.\n" +" @G--" GTEST_FLAG_PREFIX_ "throw_on_failure@D\n" +" Turn assertion failures into C++ exceptions.\n" +" @G--" GTEST_FLAG_PREFIX_ "catch_exceptions=0@D\n" +" Do not report exceptions as test failures. Instead, allow them\n" +" to crash the program or throw a pop-up (on Windows).\n" +"\n" +"Except for @G--" GTEST_FLAG_PREFIX_ "list_tests@D, you can alternatively set " + "the corresponding\n" +"environment variable of a flag (all letters in upper-case). For example, to\n" +"disable colored text output, you can either specify @G--" GTEST_FLAG_PREFIX_ + "color=no@D or set\n" +"the @G" GTEST_FLAG_PREFIX_UPPER_ "COLOR@D environment variable to @Gno@D.\n" +"\n" +"For more information, please read the " GTEST_NAME_ " documentation at\n" +"@G" GTEST_PROJECT_URL_ "@D. If you find a bug in " GTEST_NAME_ "\n" +"(not one in your own code or tests), please report it to\n" +"@G<" GTEST_DEV_EMAIL_ ">@D.\n"; + +// Parses the command line for Google Test flags, without initializing +// other parts of Google Test. The type parameter CharType can be +// instantiated to either char or wchar_t. +template +void ParseGoogleTestFlagsOnlyImpl(int* argc, CharType** argv) { + for (int i = 1; i < *argc; i++) { + const String arg_string = StreamableToString(argv[i]); + const char* const arg = arg_string.c_str(); + + using internal::ParseBoolFlag; + using internal::ParseInt32Flag; + using internal::ParseStringFlag; + + // Do we see a Google Test flag? + if (ParseBoolFlag(arg, kAlsoRunDisabledTestsFlag, + >EST_FLAG(also_run_disabled_tests)) || + ParseBoolFlag(arg, kBreakOnFailureFlag, + >EST_FLAG(break_on_failure)) || + ParseBoolFlag(arg, kCatchExceptionsFlag, + >EST_FLAG(catch_exceptions)) || + ParseStringFlag(arg, kColorFlag, >EST_FLAG(color)) || + ParseStringFlag(arg, kDeathTestStyleFlag, + >EST_FLAG(death_test_style)) || + ParseBoolFlag(arg, kDeathTestUseFork, + >EST_FLAG(death_test_use_fork)) || + ParseStringFlag(arg, kFilterFlag, >EST_FLAG(filter)) || + ParseStringFlag(arg, kInternalRunDeathTestFlag, + >EST_FLAG(internal_run_death_test)) || + ParseBoolFlag(arg, kListTestsFlag, >EST_FLAG(list_tests)) || + ParseStringFlag(arg, kOutputFlag, >EST_FLAG(output)) || + ParseBoolFlag(arg, kPrintTimeFlag, >EST_FLAG(print_time)) || + ParseInt32Flag(arg, kRandomSeedFlag, >EST_FLAG(random_seed)) || + ParseInt32Flag(arg, kRepeatFlag, >EST_FLAG(repeat)) || + ParseBoolFlag(arg, kShuffleFlag, >EST_FLAG(shuffle)) || + ParseInt32Flag(arg, kStackTraceDepthFlag, + >EST_FLAG(stack_trace_depth)) || + ParseStringFlag(arg, kStreamResultToFlag, + >EST_FLAG(stream_result_to)) || + ParseBoolFlag(arg, kThrowOnFailureFlag, + >EST_FLAG(throw_on_failure)) + ) { + // Yes. Shift the remainder of the argv list left by one. Note + // that argv has (*argc + 1) elements, the last one always being + // NULL. The following loop moves the trailing NULL element as + // well. + for (int j = i; j != *argc; j++) { + argv[j] = argv[j + 1]; + } + + // Decrements the argument count. + (*argc)--; + + // We also need to decrement the iterator as we just removed + // an element. + i--; + } else if (arg_string == "--help" || arg_string == "-h" || + arg_string == "-?" || arg_string == "/?" || + HasGoogleTestFlagPrefix(arg)) { + // Both help flag and unrecognized Google Test flags (excluding + // internal ones) trigger help display. + g_help_flag = true; + } + } + + if (g_help_flag) { + // We print the help here instead of in RUN_ALL_TESTS(), as the + // latter may not be called at all if the user is using Google + // Test with another testing framework. + PrintColorEncoded(kColorEncodedHelpMessage); + } +} + +// Parses the command line for Google Test flags, without initializing +// other parts of Google Test. +void ParseGoogleTestFlagsOnly(int* argc, char** argv) { + ParseGoogleTestFlagsOnlyImpl(argc, argv); +} +void ParseGoogleTestFlagsOnly(int* argc, wchar_t** argv) { + ParseGoogleTestFlagsOnlyImpl(argc, argv); +} + +// The internal implementation of InitGoogleTest(). +// +// The type parameter CharType can be instantiated to either char or +// wchar_t. +template +void InitGoogleTestImpl(int* argc, CharType** argv) { + g_init_gtest_count++; + + // We don't want to run the initialization code twice. + if (g_init_gtest_count != 1) return; + + if (*argc <= 0) return; + + internal::g_executable_path = internal::StreamableToString(argv[0]); + +#if GTEST_HAS_DEATH_TEST + + g_argvs.clear(); + for (int i = 0; i != *argc; i++) { + g_argvs.push_back(StreamableToString(argv[i])); + } + +#endif // GTEST_HAS_DEATH_TEST + + ParseGoogleTestFlagsOnly(argc, argv); + GetUnitTestImpl()->PostFlagParsingInit(); +} + +} // namespace internal + +// Initializes Google Test. This must be called before calling +// RUN_ALL_TESTS(). In particular, it parses a command line for the +// flags that Google Test recognizes. Whenever a Google Test flag is +// seen, it is removed from argv, and *argc is decremented. +// +// No value is returned. Instead, the Google Test flag variables are +// updated. +// +// Calling the function for the second time has no user-visible effect. +void InitGoogleTest(int* argc, char** argv) { + internal::InitGoogleTestImpl(argc, argv); +} + +// This overloaded version can be used in Windows programs compiled in +// UNICODE mode. +void InitGoogleTest(int* argc, wchar_t** argv) { + internal::InitGoogleTestImpl(argc, argv); +} + +} // namespace testing +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan), vladl@google.com (Vlad Losev) +// +// This file implements death tests. + + +#if GTEST_HAS_DEATH_TEST + +# if GTEST_OS_MAC +# include +# endif // GTEST_OS_MAC + +# include +# include +# include +# include + +# if GTEST_OS_WINDOWS +# include +# else +# include +# include +# endif // GTEST_OS_WINDOWS + +#endif // GTEST_HAS_DEATH_TEST + + +// Indicates that this translation unit is part of Google Test's +// implementation. It must come before gtest-internal-inl.h is +// included, or there will be a compiler error. This trick is to +// prevent a user from accidentally including gtest-internal-inl.h in +// his code. +#define GTEST_IMPLEMENTATION_ 1 +#undef GTEST_IMPLEMENTATION_ + +namespace testing { + +// Constants. + +// The default death test style. +static const char kDefaultDeathTestStyle[] = "fast"; + +GTEST_DEFINE_string_( + death_test_style, + internal::StringFromGTestEnv("death_test_style", kDefaultDeathTestStyle), + "Indicates how to run a death test in a forked child process: " + "\"threadsafe\" (child process re-executes the test binary " + "from the beginning, running only the specific death test) or " + "\"fast\" (child process runs the death test immediately " + "after forking)."); + +GTEST_DEFINE_bool_( + death_test_use_fork, + internal::BoolFromGTestEnv("death_test_use_fork", false), + "Instructs to use fork()/_exit() instead of clone() in death tests. " + "Ignored and always uses fork() on POSIX systems where clone() is not " + "implemented. Useful when running under valgrind or similar tools if " + "those do not support clone(). Valgrind 3.3.1 will just fail if " + "it sees an unsupported combination of clone() flags. " + "It is not recommended to use this flag w/o valgrind though it will " + "work in 99% of the cases. Once valgrind is fixed, this flag will " + "most likely be removed."); + +namespace internal { +GTEST_DEFINE_string_( + internal_run_death_test, "", + "Indicates the file, line number, temporal index of " + "the single death test to run, and a file descriptor to " + "which a success code may be sent, all separated by " + "colons. This flag is specified if and only if the current " + "process is a sub-process launched for running a thread-safe " + "death test. FOR INTERNAL USE ONLY."); +} // namespace internal + +#if GTEST_HAS_DEATH_TEST + +// ExitedWithCode constructor. +ExitedWithCode::ExitedWithCode(int exit_code) : exit_code_(exit_code) { +} + +// ExitedWithCode function-call operator. +bool ExitedWithCode::operator()(int exit_status) const { +# if GTEST_OS_WINDOWS + + return exit_status == exit_code_; + +# else + + return WIFEXITED(exit_status) && WEXITSTATUS(exit_status) == exit_code_; + +# endif // GTEST_OS_WINDOWS +} + +# if !GTEST_OS_WINDOWS +// KilledBySignal constructor. +KilledBySignal::KilledBySignal(int signum) : signum_(signum) { +} + +// KilledBySignal function-call operator. +bool KilledBySignal::operator()(int exit_status) const { + return WIFSIGNALED(exit_status) && WTERMSIG(exit_status) == signum_; +} +# endif // !GTEST_OS_WINDOWS + +namespace internal { + +// Utilities needed for death tests. + +// Generates a textual description of a given exit code, in the format +// specified by wait(2). +static String ExitSummary(int exit_code) { + Message m; + +# if GTEST_OS_WINDOWS + + m << "Exited with exit status " << exit_code; + +# else + + if (WIFEXITED(exit_code)) { + m << "Exited with exit status " << WEXITSTATUS(exit_code); + } else if (WIFSIGNALED(exit_code)) { + m << "Terminated by signal " << WTERMSIG(exit_code); + } +# ifdef WCOREDUMP + if (WCOREDUMP(exit_code)) { + m << " (core dumped)"; + } +# endif +# endif // GTEST_OS_WINDOWS + + return m.GetString(); +} + +// Returns true if exit_status describes a process that was terminated +// by a signal, or exited normally with a nonzero exit code. +bool ExitedUnsuccessfully(int exit_status) { + return !ExitedWithCode(0)(exit_status); +} + +# if !GTEST_OS_WINDOWS +// Generates a textual failure message when a death test finds more than +// one thread running, or cannot determine the number of threads, prior +// to executing the given statement. It is the responsibility of the +// caller not to pass a thread_count of 1. +static String DeathTestThreadWarning(size_t thread_count) { + Message msg; + msg << "Death tests use fork(), which is unsafe particularly" + << " in a threaded context. For this test, " << GTEST_NAME_ << " "; + if (thread_count == 0) + msg << "couldn't detect the number of threads."; + else + msg << "detected " << thread_count << " threads."; + return msg.GetString(); +} +# endif // !GTEST_OS_WINDOWS + +// Flag characters for reporting a death test that did not die. +static const char kDeathTestLived = 'L'; +static const char kDeathTestReturned = 'R'; +static const char kDeathTestThrew = 'T'; +static const char kDeathTestInternalError = 'I'; + +// An enumeration describing all of the possible ways that a death test can +// conclude. DIED means that the process died while executing the test +// code; LIVED means that process lived beyond the end of the test code; +// RETURNED means that the test statement attempted to execute a return +// statement, which is not allowed; THREW means that the test statement +// returned control by throwing an exception. IN_PROGRESS means the test +// has not yet concluded. +// TODO(vladl@google.com): Unify names and possibly values for +// AbortReason, DeathTestOutcome, and flag characters above. +enum DeathTestOutcome { IN_PROGRESS, DIED, LIVED, RETURNED, THREW }; + +// Routine for aborting the program which is safe to call from an +// exec-style death test child process, in which case the error +// message is propagated back to the parent process. Otherwise, the +// message is simply printed to stderr. In either case, the program +// then exits with status 1. +void DeathTestAbort(const String& message) { + // On a POSIX system, this function may be called from a threadsafe-style + // death test child process, which operates on a very small stack. Use + // the heap for any additional non-minuscule memory requirements. + const InternalRunDeathTestFlag* const flag = + GetUnitTestImpl()->internal_run_death_test_flag(); + if (flag != NULL) { + FILE* parent = posix::FDOpen(flag->write_fd(), "w"); + fputc(kDeathTestInternalError, parent); + fprintf(parent, "%s", message.c_str()); + fflush(parent); + _exit(1); + } else { + fprintf(stderr, "%s", message.c_str()); + fflush(stderr); + posix::Abort(); + } +} + +// A replacement for CHECK that calls DeathTestAbort if the assertion +// fails. +# define GTEST_DEATH_TEST_CHECK_(expression) \ + do { \ + if (!::testing::internal::IsTrue(expression)) { \ + DeathTestAbort(::testing::internal::String::Format( \ + "CHECK failed: File %s, line %d: %s", \ + __FILE__, __LINE__, #expression)); \ + } \ + } while (::testing::internal::AlwaysFalse()) + +// This macro is similar to GTEST_DEATH_TEST_CHECK_, but it is meant for +// evaluating any system call that fulfills two conditions: it must return +// -1 on failure, and set errno to EINTR when it is interrupted and +// should be tried again. The macro expands to a loop that repeatedly +// evaluates the expression as long as it evaluates to -1 and sets +// errno to EINTR. If the expression evaluates to -1 but errno is +// something other than EINTR, DeathTestAbort is called. +# define GTEST_DEATH_TEST_CHECK_SYSCALL_(expression) \ + do { \ + int gtest_retval; \ + do { \ + gtest_retval = (expression); \ + } while (gtest_retval == -1 && errno == EINTR); \ + if (gtest_retval == -1) { \ + DeathTestAbort(::testing::internal::String::Format( \ + "CHECK failed: File %s, line %d: %s != -1", \ + __FILE__, __LINE__, #expression)); \ + } \ + } while (::testing::internal::AlwaysFalse()) + +// Returns the message describing the last system error in errno. +String GetLastErrnoDescription() { + return String(errno == 0 ? "" : posix::StrError(errno)); +} + +// This is called from a death test parent process to read a failure +// message from the death test child process and log it with the FATAL +// severity. On Windows, the message is read from a pipe handle. On other +// platforms, it is read from a file descriptor. +static void FailFromInternalError(int fd) { + Message error; + char buffer[256]; + int num_read; + + do { + while ((num_read = posix::Read(fd, buffer, 255)) > 0) { + buffer[num_read] = '\0'; + error << buffer; + } + } while (num_read == -1 && errno == EINTR); + + if (num_read == 0) { + GTEST_LOG_(FATAL) << error.GetString(); + } else { + const int last_error = errno; + GTEST_LOG_(FATAL) << "Error while reading death test internal: " + << GetLastErrnoDescription() << " [" << last_error << "]"; + } +} + +// Death test constructor. Increments the running death test count +// for the current test. +DeathTest::DeathTest() { + TestInfo* const info = GetUnitTestImpl()->current_test_info(); + if (info == NULL) { + DeathTestAbort("Cannot run a death test outside of a TEST or " + "TEST_F construct"); + } +} + +// Creates and returns a death test by dispatching to the current +// death test factory. +bool DeathTest::Create(const char* statement, const RE* regex, + const char* file, int line, DeathTest** test) { + return GetUnitTestImpl()->death_test_factory()->Create( + statement, regex, file, line, test); +} + +const char* DeathTest::LastMessage() { + return last_death_test_message_.c_str(); +} + +void DeathTest::set_last_death_test_message(const String& message) { + last_death_test_message_ = message; +} + +String DeathTest::last_death_test_message_; + +// Provides cross platform implementation for some death functionality. +class DeathTestImpl : public DeathTest { + protected: + DeathTestImpl(const char* a_statement, const RE* a_regex) + : statement_(a_statement), + regex_(a_regex), + spawned_(false), + status_(-1), + outcome_(IN_PROGRESS), + read_fd_(-1), + write_fd_(-1) {} + + // read_fd_ is expected to be closed and cleared by a derived class. + ~DeathTestImpl() { GTEST_DEATH_TEST_CHECK_(read_fd_ == -1); } + + void Abort(AbortReason reason); + virtual bool Passed(bool status_ok); + + const char* statement() const { return statement_; } + const RE* regex() const { return regex_; } + bool spawned() const { return spawned_; } + void set_spawned(bool is_spawned) { spawned_ = is_spawned; } + int status() const { return status_; } + void set_status(int a_status) { status_ = a_status; } + DeathTestOutcome outcome() const { return outcome_; } + void set_outcome(DeathTestOutcome an_outcome) { outcome_ = an_outcome; } + int read_fd() const { return read_fd_; } + void set_read_fd(int fd) { read_fd_ = fd; } + int write_fd() const { return write_fd_; } + void set_write_fd(int fd) { write_fd_ = fd; } + + // Called in the parent process only. Reads the result code of the death + // test child process via a pipe, interprets it to set the outcome_ + // member, and closes read_fd_. Outputs diagnostics and terminates in + // case of unexpected codes. + void ReadAndInterpretStatusByte(); + + private: + // The textual content of the code this object is testing. This class + // doesn't own this string and should not attempt to delete it. + const char* const statement_; + // The regular expression which test output must match. DeathTestImpl + // doesn't own this object and should not attempt to delete it. + const RE* const regex_; + // True if the death test child process has been successfully spawned. + bool spawned_; + // The exit status of the child process. + int status_; + // How the death test concluded. + DeathTestOutcome outcome_; + // Descriptor to the read end of the pipe to the child process. It is + // always -1 in the child process. The child keeps its write end of the + // pipe in write_fd_. + int read_fd_; + // Descriptor to the child's write end of the pipe to the parent process. + // It is always -1 in the parent process. The parent keeps its end of the + // pipe in read_fd_. + int write_fd_; +}; + +// Called in the parent process only. Reads the result code of the death +// test child process via a pipe, interprets it to set the outcome_ +// member, and closes read_fd_. Outputs diagnostics and terminates in +// case of unexpected codes. +void DeathTestImpl::ReadAndInterpretStatusByte() { + char flag; + int bytes_read; + + // The read() here blocks until data is available (signifying the + // failure of the death test) or until the pipe is closed (signifying + // its success), so it's okay to call this in the parent before + // the child process has exited. + do { + bytes_read = posix::Read(read_fd(), &flag, 1); + } while (bytes_read == -1 && errno == EINTR); + + if (bytes_read == 0) { + set_outcome(DIED); + } else if (bytes_read == 1) { + switch (flag) { + case kDeathTestReturned: + set_outcome(RETURNED); + break; + case kDeathTestThrew: + set_outcome(THREW); + break; + case kDeathTestLived: + set_outcome(LIVED); + break; + case kDeathTestInternalError: + FailFromInternalError(read_fd()); // Does not return. + break; + default: + GTEST_LOG_(FATAL) << "Death test child process reported " + << "unexpected status byte (" + << static_cast(flag) << ")"; + } + } else { + GTEST_LOG_(FATAL) << "Read from death test child process failed: " + << GetLastErrnoDescription(); + } + GTEST_DEATH_TEST_CHECK_SYSCALL_(posix::Close(read_fd())); + set_read_fd(-1); +} + +// Signals that the death test code which should have exited, didn't. +// Should be called only in a death test child process. +// Writes a status byte to the child's status file descriptor, then +// calls _exit(1). +void DeathTestImpl::Abort(AbortReason reason) { + // The parent process considers the death test to be a failure if + // it finds any data in our pipe. So, here we write a single flag byte + // to the pipe, then exit. + const char status_ch = + reason == TEST_DID_NOT_DIE ? kDeathTestLived : + reason == TEST_THREW_EXCEPTION ? kDeathTestThrew : kDeathTestReturned; + + GTEST_DEATH_TEST_CHECK_SYSCALL_(posix::Write(write_fd(), &status_ch, 1)); + // We are leaking the descriptor here because on some platforms (i.e., + // when built as Windows DLL), destructors of global objects will still + // run after calling _exit(). On such systems, write_fd_ will be + // indirectly closed from the destructor of UnitTestImpl, causing double + // close if it is also closed here. On debug configurations, double close + // may assert. As there are no in-process buffers to flush here, we are + // relying on the OS to close the descriptor after the process terminates + // when the destructors are not run. + _exit(1); // Exits w/o any normal exit hooks (we were supposed to crash) +} + +// Returns an indented copy of stderr output for a death test. +// This makes distinguishing death test output lines from regular log lines +// much easier. +static ::std::string FormatDeathTestOutput(const ::std::string& output) { + ::std::string ret; + for (size_t at = 0; ; ) { + const size_t line_end = output.find('\n', at); + ret += "[ DEATH ] "; + if (line_end == ::std::string::npos) { + ret += output.substr(at); + break; + } + ret += output.substr(at, line_end + 1 - at); + at = line_end + 1; + } + return ret; +} + +// Assesses the success or failure of a death test, using both private +// members which have previously been set, and one argument: +// +// Private data members: +// outcome: An enumeration describing how the death test +// concluded: DIED, LIVED, THREW, or RETURNED. The death test +// fails in the latter three cases. +// status: The exit status of the child process. On *nix, it is in the +// in the format specified by wait(2). On Windows, this is the +// value supplied to the ExitProcess() API or a numeric code +// of the exception that terminated the program. +// regex: A regular expression object to be applied to +// the test's captured standard error output; the death test +// fails if it does not match. +// +// Argument: +// status_ok: true if exit_status is acceptable in the context of +// this particular death test, which fails if it is false +// +// Returns true iff all of the above conditions are met. Otherwise, the +// first failing condition, in the order given above, is the one that is +// reported. Also sets the last death test message string. +bool DeathTestImpl::Passed(bool status_ok) { + if (!spawned()) + return false; + + const String error_message = GetCapturedStderr(); + + bool success = false; + Message buffer; + + buffer << "Death test: " << statement() << "\n"; + switch (outcome()) { + case LIVED: + buffer << " Result: failed to die.\n" + << " Error msg:\n" << FormatDeathTestOutput(error_message); + break; + case THREW: + buffer << " Result: threw an exception.\n" + << " Error msg:\n" << FormatDeathTestOutput(error_message); + break; + case RETURNED: + buffer << " Result: illegal return in test statement.\n" + << " Error msg:\n" << FormatDeathTestOutput(error_message); + break; + case DIED: + if (status_ok) { + const bool matched = RE::PartialMatch(error_message.c_str(), *regex()); + if (matched) { + success = true; + } else { + buffer << " Result: died but not with expected error.\n" + << " Expected: " << regex()->pattern() << "\n" + << "Actual msg:\n" << FormatDeathTestOutput(error_message); + } + } else { + buffer << " Result: died but not with expected exit code:\n" + << " " << ExitSummary(status()) << "\n" + << "Actual msg:\n" << FormatDeathTestOutput(error_message); + } + break; + case IN_PROGRESS: + default: + GTEST_LOG_(FATAL) + << "DeathTest::Passed somehow called before conclusion of test"; + } + + DeathTest::set_last_death_test_message(buffer.GetString()); + return success; +} + +# if GTEST_OS_WINDOWS +// WindowsDeathTest implements death tests on Windows. Due to the +// specifics of starting new processes on Windows, death tests there are +// always threadsafe, and Google Test considers the +// --gtest_death_test_style=fast setting to be equivalent to +// --gtest_death_test_style=threadsafe there. +// +// A few implementation notes: Like the Linux version, the Windows +// implementation uses pipes for child-to-parent communication. But due to +// the specifics of pipes on Windows, some extra steps are required: +// +// 1. The parent creates a communication pipe and stores handles to both +// ends of it. +// 2. The parent starts the child and provides it with the information +// necessary to acquire the handle to the write end of the pipe. +// 3. The child acquires the write end of the pipe and signals the parent +// using a Windows event. +// 4. Now the parent can release the write end of the pipe on its side. If +// this is done before step 3, the object's reference count goes down to +// 0 and it is destroyed, preventing the child from acquiring it. The +// parent now has to release it, or read operations on the read end of +// the pipe will not return when the child terminates. +// 5. The parent reads child's output through the pipe (outcome code and +// any possible error messages) from the pipe, and its stderr and then +// determines whether to fail the test. +// +// Note: to distinguish Win32 API calls from the local method and function +// calls, the former are explicitly resolved in the global namespace. +// +class WindowsDeathTest : public DeathTestImpl { + public: + WindowsDeathTest(const char* a_statement, + const RE* a_regex, + const char* file, + int line) + : DeathTestImpl(a_statement, a_regex), file_(file), line_(line) {} + + // All of these virtual functions are inherited from DeathTest. + virtual int Wait(); + virtual TestRole AssumeRole(); + + private: + // The name of the file in which the death test is located. + const char* const file_; + // The line number on which the death test is located. + const int line_; + // Handle to the write end of the pipe to the child process. + AutoHandle write_handle_; + // Child process handle. + AutoHandle child_handle_; + // Event the child process uses to signal the parent that it has + // acquired the handle to the write end of the pipe. After seeing this + // event the parent can release its own handles to make sure its + // ReadFile() calls return when the child terminates. + AutoHandle event_handle_; +}; + +// Waits for the child in a death test to exit, returning its exit +// status, or 0 if no child process exists. As a side effect, sets the +// outcome data member. +int WindowsDeathTest::Wait() { + if (!spawned()) + return 0; + + // Wait until the child either signals that it has acquired the write end + // of the pipe or it dies. + const HANDLE wait_handles[2] = { child_handle_.Get(), event_handle_.Get() }; + switch (::WaitForMultipleObjects(2, + wait_handles, + FALSE, // Waits for any of the handles. + INFINITE)) { + case WAIT_OBJECT_0: + case WAIT_OBJECT_0 + 1: + break; + default: + GTEST_DEATH_TEST_CHECK_(false); // Should not get here. + } + + // The child has acquired the write end of the pipe or exited. + // We release the handle on our side and continue. + write_handle_.Reset(); + event_handle_.Reset(); + + ReadAndInterpretStatusByte(); + + // Waits for the child process to exit if it haven't already. This + // returns immediately if the child has already exited, regardless of + // whether previous calls to WaitForMultipleObjects synchronized on this + // handle or not. + GTEST_DEATH_TEST_CHECK_( + WAIT_OBJECT_0 == ::WaitForSingleObject(child_handle_.Get(), + INFINITE)); + DWORD status_code; + GTEST_DEATH_TEST_CHECK_( + ::GetExitCodeProcess(child_handle_.Get(), &status_code) != FALSE); + child_handle_.Reset(); + set_status(static_cast(status_code)); + return status(); +} + +// The AssumeRole process for a Windows death test. It creates a child +// process with the same executable as the current process to run the +// death test. The child process is given the --gtest_filter and +// --gtest_internal_run_death_test flags such that it knows to run the +// current death test only. +DeathTest::TestRole WindowsDeathTest::AssumeRole() { + const UnitTestImpl* const impl = GetUnitTestImpl(); + const InternalRunDeathTestFlag* const flag = + impl->internal_run_death_test_flag(); + const TestInfo* const info = impl->current_test_info(); + const int death_test_index = info->result()->death_test_count(); + + if (flag != NULL) { + // ParseInternalRunDeathTestFlag() has performed all the necessary + // processing. + set_write_fd(flag->write_fd()); + return EXECUTE_TEST; + } + + // WindowsDeathTest uses an anonymous pipe to communicate results of + // a death test. + SECURITY_ATTRIBUTES handles_are_inheritable = { + sizeof(SECURITY_ATTRIBUTES), NULL, TRUE }; + HANDLE read_handle, write_handle; + GTEST_DEATH_TEST_CHECK_( + ::CreatePipe(&read_handle, &write_handle, &handles_are_inheritable, + 0) // Default buffer size. + != FALSE); + set_read_fd(::_open_osfhandle(reinterpret_cast(read_handle), + O_RDONLY)); + write_handle_.Reset(write_handle); + event_handle_.Reset(::CreateEvent( + &handles_are_inheritable, + TRUE, // The event will automatically reset to non-signaled state. + FALSE, // The initial state is non-signalled. + NULL)); // The even is unnamed. + GTEST_DEATH_TEST_CHECK_(event_handle_.Get() != NULL); + const String filter_flag = String::Format("--%s%s=%s.%s", + GTEST_FLAG_PREFIX_, kFilterFlag, + info->test_case_name(), + info->name()); + const String internal_flag = String::Format( + "--%s%s=%s|%d|%d|%u|%Iu|%Iu", + GTEST_FLAG_PREFIX_, + kInternalRunDeathTestFlag, + file_, line_, + death_test_index, + static_cast(::GetCurrentProcessId()), + // size_t has the same with as pointers on both 32-bit and 64-bit + // Windows platforms. + // See http://msdn.microsoft.com/en-us/library/tcxf1dw6.aspx. + reinterpret_cast(write_handle), + reinterpret_cast(event_handle_.Get())); + + char executable_path[_MAX_PATH + 1]; // NOLINT + GTEST_DEATH_TEST_CHECK_( + _MAX_PATH + 1 != ::GetModuleFileNameA(NULL, + executable_path, + _MAX_PATH)); + + String command_line = String::Format("%s %s \"%s\"", + ::GetCommandLineA(), + filter_flag.c_str(), + internal_flag.c_str()); + + DeathTest::set_last_death_test_message(""); + + CaptureStderr(); + // Flush the log buffers since the log streams are shared with the child. + FlushInfoLog(); + + // The child process will share the standard handles with the parent. + STARTUPINFOA startup_info; + memset(&startup_info, 0, sizeof(STARTUPINFO)); + startup_info.dwFlags = STARTF_USESTDHANDLES; + startup_info.hStdInput = ::GetStdHandle(STD_INPUT_HANDLE); + startup_info.hStdOutput = ::GetStdHandle(STD_OUTPUT_HANDLE); + startup_info.hStdError = ::GetStdHandle(STD_ERROR_HANDLE); + + PROCESS_INFORMATION process_info; + GTEST_DEATH_TEST_CHECK_(::CreateProcessA( + executable_path, + const_cast(command_line.c_str()), + NULL, // Retuned process handle is not inheritable. + NULL, // Retuned thread handle is not inheritable. + TRUE, // Child inherits all inheritable handles (for write_handle_). + 0x0, // Default creation flags. + NULL, // Inherit the parent's environment. + UnitTest::GetInstance()->original_working_dir(), + &startup_info, + &process_info) != FALSE); + child_handle_.Reset(process_info.hProcess); + ::CloseHandle(process_info.hThread); + set_spawned(true); + return OVERSEE_TEST; +} +# else // We are not on Windows. + +// ForkingDeathTest provides implementations for most of the abstract +// methods of the DeathTest interface. Only the AssumeRole method is +// left undefined. +class ForkingDeathTest : public DeathTestImpl { + public: + ForkingDeathTest(const char* statement, const RE* regex); + + // All of these virtual functions are inherited from DeathTest. + virtual int Wait(); + + protected: + void set_child_pid(pid_t child_pid) { child_pid_ = child_pid; } + + private: + // PID of child process during death test; 0 in the child process itself. + pid_t child_pid_; +}; + +// Constructs a ForkingDeathTest. +ForkingDeathTest::ForkingDeathTest(const char* a_statement, const RE* a_regex) + : DeathTestImpl(a_statement, a_regex), + child_pid_(-1) {} + +// Waits for the child in a death test to exit, returning its exit +// status, or 0 if no child process exists. As a side effect, sets the +// outcome data member. +int ForkingDeathTest::Wait() { + if (!spawned()) + return 0; + + ReadAndInterpretStatusByte(); + + int status_value; + GTEST_DEATH_TEST_CHECK_SYSCALL_(waitpid(child_pid_, &status_value, 0)); + set_status(status_value); + return status_value; +} + +// A concrete death test class that forks, then immediately runs the test +// in the child process. +class NoExecDeathTest : public ForkingDeathTest { + public: + NoExecDeathTest(const char* a_statement, const RE* a_regex) : + ForkingDeathTest(a_statement, a_regex) { } + virtual TestRole AssumeRole(); +}; + +// The AssumeRole process for a fork-and-run death test. It implements a +// straightforward fork, with a simple pipe to transmit the status byte. +DeathTest::TestRole NoExecDeathTest::AssumeRole() { + const size_t thread_count = GetThreadCount(); + if (thread_count != 1) { + GTEST_LOG_(WARNING) << DeathTestThreadWarning(thread_count); + } + + int pipe_fd[2]; + GTEST_DEATH_TEST_CHECK_(pipe(pipe_fd) != -1); + + DeathTest::set_last_death_test_message(""); + CaptureStderr(); + // When we fork the process below, the log file buffers are copied, but the + // file descriptors are shared. We flush all log files here so that closing + // the file descriptors in the child process doesn't throw off the + // synchronization between descriptors and buffers in the parent process. + // This is as close to the fork as possible to avoid a race condition in case + // there are multiple threads running before the death test, and another + // thread writes to the log file. + FlushInfoLog(); + + const pid_t child_pid = fork(); + GTEST_DEATH_TEST_CHECK_(child_pid != -1); + set_child_pid(child_pid); + if (child_pid == 0) { + GTEST_DEATH_TEST_CHECK_SYSCALL_(close(pipe_fd[0])); + set_write_fd(pipe_fd[1]); + // Redirects all logging to stderr in the child process to prevent + // concurrent writes to the log files. We capture stderr in the parent + // process and append the child process' output to a log. + LogToStderr(); + // Event forwarding to the listeners of event listener API mush be shut + // down in death test subprocesses. + GetUnitTestImpl()->listeners()->SuppressEventForwarding(); + return EXECUTE_TEST; + } else { + GTEST_DEATH_TEST_CHECK_SYSCALL_(close(pipe_fd[1])); + set_read_fd(pipe_fd[0]); + set_spawned(true); + return OVERSEE_TEST; + } +} + +// A concrete death test class that forks and re-executes the main +// program from the beginning, with command-line flags set that cause +// only this specific death test to be run. +class ExecDeathTest : public ForkingDeathTest { + public: + ExecDeathTest(const char* a_statement, const RE* a_regex, + const char* file, int line) : + ForkingDeathTest(a_statement, a_regex), file_(file), line_(line) { } + virtual TestRole AssumeRole(); + private: + // The name of the file in which the death test is located. + const char* const file_; + // The line number on which the death test is located. + const int line_; +}; + +// Utility class for accumulating command-line arguments. +class Arguments { + public: + Arguments() { + args_.push_back(NULL); + } + + ~Arguments() { + for (std::vector::iterator i = args_.begin(); i != args_.end(); + ++i) { + free(*i); + } + } + void AddArgument(const char* argument) { + args_.insert(args_.end() - 1, posix::StrDup(argument)); + } + + template + void AddArguments(const ::std::vector& arguments) { + for (typename ::std::vector::const_iterator i = arguments.begin(); + i != arguments.end(); + ++i) { + args_.insert(args_.end() - 1, posix::StrDup(i->c_str())); + } + } + char* const* Argv() { + return &args_[0]; + } + private: + std::vector args_; +}; + +// A struct that encompasses the arguments to the child process of a +// threadsafe-style death test process. +struct ExecDeathTestArgs { + char* const* argv; // Command-line arguments for the child's call to exec + int close_fd; // File descriptor to close; the read end of a pipe +}; + +# if GTEST_OS_MAC +inline char** GetEnviron() { + // When Google Test is built as a framework on MacOS X, the environ variable + // is unavailable. Apple's documentation (man environ) recommends using + // _NSGetEnviron() instead. + return *_NSGetEnviron(); +} +# else +// Some POSIX platforms expect you to declare environ. extern "C" makes +// it reside in the global namespace. +extern "C" char** environ; +inline char** GetEnviron() { return environ; } +# endif // GTEST_OS_MAC + +// The main function for a threadsafe-style death test child process. +// This function is called in a clone()-ed process and thus must avoid +// any potentially unsafe operations like malloc or libc functions. +static int ExecDeathTestChildMain(void* child_arg) { + ExecDeathTestArgs* const args = static_cast(child_arg); + GTEST_DEATH_TEST_CHECK_SYSCALL_(close(args->close_fd)); + + // We need to execute the test program in the same environment where + // it was originally invoked. Therefore we change to the original + // working directory first. + const char* const original_dir = + UnitTest::GetInstance()->original_working_dir(); + // We can safely call chdir() as it's a direct system call. + if (chdir(original_dir) != 0) { + DeathTestAbort(String::Format("chdir(\"%s\") failed: %s", + original_dir, + GetLastErrnoDescription().c_str())); + return EXIT_FAILURE; + } + + // We can safely call execve() as it's a direct system call. We + // cannot use execvp() as it's a libc function and thus potentially + // unsafe. Since execve() doesn't search the PATH, the user must + // invoke the test program via a valid path that contains at least + // one path separator. + execve(args->argv[0], args->argv, GetEnviron()); + DeathTestAbort(String::Format("execve(%s, ...) in %s failed: %s", + args->argv[0], + original_dir, + GetLastErrnoDescription().c_str())); + return EXIT_FAILURE; +} + +// Two utility routines that together determine the direction the stack +// grows. +// This could be accomplished more elegantly by a single recursive +// function, but we want to guard against the unlikely possibility of +// a smart compiler optimizing the recursion away. +// +// GTEST_NO_INLINE_ is required to prevent GCC 4.6 from inlining +// StackLowerThanAddress into StackGrowsDown, which then doesn't give +// correct answer. +bool StackLowerThanAddress(const void* ptr) GTEST_NO_INLINE_; +bool StackLowerThanAddress(const void* ptr) { + int dummy; + return &dummy < ptr; +} + +bool StackGrowsDown() { + int dummy; + return StackLowerThanAddress(&dummy); +} + +// A threadsafe implementation of fork(2) for threadsafe-style death tests +// that uses clone(2). It dies with an error message if anything goes +// wrong. +static pid_t ExecDeathTestFork(char* const* argv, int close_fd) { + ExecDeathTestArgs args = { argv, close_fd }; + pid_t child_pid = -1; + +# if GTEST_HAS_CLONE + const bool use_fork = GTEST_FLAG(death_test_use_fork); + + if (!use_fork) { + static const bool stack_grows_down = StackGrowsDown(); + const size_t stack_size = getpagesize(); + // MMAP_ANONYMOUS is not defined on Mac, so we use MAP_ANON instead. + void* const stack = mmap(NULL, stack_size, PROT_READ | PROT_WRITE, + MAP_ANON | MAP_PRIVATE, -1, 0); + GTEST_DEATH_TEST_CHECK_(stack != MAP_FAILED); + void* const stack_top = + static_cast(stack) + (stack_grows_down ? stack_size : 0); + + child_pid = clone(&ExecDeathTestChildMain, stack_top, SIGCHLD, &args); + + GTEST_DEATH_TEST_CHECK_(munmap(stack, stack_size) != -1); + } +# else + const bool use_fork = true; +# endif // GTEST_HAS_CLONE + + if (use_fork && (child_pid = fork()) == 0) { + ExecDeathTestChildMain(&args); + _exit(0); + } + + GTEST_DEATH_TEST_CHECK_(child_pid != -1); + return child_pid; +} + +// The AssumeRole process for a fork-and-exec death test. It re-executes the +// main program from the beginning, setting the --gtest_filter +// and --gtest_internal_run_death_test flags to cause only the current +// death test to be re-run. +DeathTest::TestRole ExecDeathTest::AssumeRole() { + const UnitTestImpl* const impl = GetUnitTestImpl(); + const InternalRunDeathTestFlag* const flag = + impl->internal_run_death_test_flag(); + const TestInfo* const info = impl->current_test_info(); + const int death_test_index = info->result()->death_test_count(); + + if (flag != NULL) { + set_write_fd(flag->write_fd()); + return EXECUTE_TEST; + } + + int pipe_fd[2]; + GTEST_DEATH_TEST_CHECK_(pipe(pipe_fd) != -1); + // Clear the close-on-exec flag on the write end of the pipe, lest + // it be closed when the child process does an exec: + GTEST_DEATH_TEST_CHECK_(fcntl(pipe_fd[1], F_SETFD, 0) != -1); + + const String filter_flag = + String::Format("--%s%s=%s.%s", + GTEST_FLAG_PREFIX_, kFilterFlag, + info->test_case_name(), info->name()); + const String internal_flag = + String::Format("--%s%s=%s|%d|%d|%d", + GTEST_FLAG_PREFIX_, kInternalRunDeathTestFlag, + file_, line_, death_test_index, pipe_fd[1]); + Arguments args; + args.AddArguments(GetArgvs()); + args.AddArgument(filter_flag.c_str()); + args.AddArgument(internal_flag.c_str()); + + DeathTest::set_last_death_test_message(""); + + CaptureStderr(); + // See the comment in NoExecDeathTest::AssumeRole for why the next line + // is necessary. + FlushInfoLog(); + + const pid_t child_pid = ExecDeathTestFork(args.Argv(), pipe_fd[0]); + GTEST_DEATH_TEST_CHECK_SYSCALL_(close(pipe_fd[1])); + set_child_pid(child_pid); + set_read_fd(pipe_fd[0]); + set_spawned(true); + return OVERSEE_TEST; +} + +# endif // !GTEST_OS_WINDOWS + +// Creates a concrete DeathTest-derived class that depends on the +// --gtest_death_test_style flag, and sets the pointer pointed to +// by the "test" argument to its address. If the test should be +// skipped, sets that pointer to NULL. Returns true, unless the +// flag is set to an invalid value. +bool DefaultDeathTestFactory::Create(const char* statement, const RE* regex, + const char* file, int line, + DeathTest** test) { + UnitTestImpl* const impl = GetUnitTestImpl(); + const InternalRunDeathTestFlag* const flag = + impl->internal_run_death_test_flag(); + const int death_test_index = impl->current_test_info() + ->increment_death_test_count(); + + if (flag != NULL) { + if (death_test_index > flag->index()) { + DeathTest::set_last_death_test_message(String::Format( + "Death test count (%d) somehow exceeded expected maximum (%d)", + death_test_index, flag->index())); + return false; + } + + if (!(flag->file() == file && flag->line() == line && + flag->index() == death_test_index)) { + *test = NULL; + return true; + } + } + +# if GTEST_OS_WINDOWS + + if (GTEST_FLAG(death_test_style) == "threadsafe" || + GTEST_FLAG(death_test_style) == "fast") { + *test = new WindowsDeathTest(statement, regex, file, line); + } + +# else + + if (GTEST_FLAG(death_test_style) == "threadsafe") { + *test = new ExecDeathTest(statement, regex, file, line); + } else if (GTEST_FLAG(death_test_style) == "fast") { + *test = new NoExecDeathTest(statement, regex); + } + +# endif // GTEST_OS_WINDOWS + + else { // NOLINT - this is more readable than unbalanced brackets inside #if. + DeathTest::set_last_death_test_message(String::Format( + "Unknown death test style \"%s\" encountered", + GTEST_FLAG(death_test_style).c_str())); + return false; + } + + return true; +} + +// Splits a given string on a given delimiter, populating a given +// vector with the fields. GTEST_HAS_DEATH_TEST implies that we have +// ::std::string, so we can use it here. +static void SplitString(const ::std::string& str, char delimiter, + ::std::vector< ::std::string>* dest) { + ::std::vector< ::std::string> parsed; + ::std::string::size_type pos = 0; + while (::testing::internal::AlwaysTrue()) { + const ::std::string::size_type colon = str.find(delimiter, pos); + if (colon == ::std::string::npos) { + parsed.push_back(str.substr(pos)); + break; + } else { + parsed.push_back(str.substr(pos, colon - pos)); + pos = colon + 1; + } + } + dest->swap(parsed); +} + +# if GTEST_OS_WINDOWS +// Recreates the pipe and event handles from the provided parameters, +// signals the event, and returns a file descriptor wrapped around the pipe +// handle. This function is called in the child process only. +int GetStatusFileDescriptor(unsigned int parent_process_id, + size_t write_handle_as_size_t, + size_t event_handle_as_size_t) { + AutoHandle parent_process_handle(::OpenProcess(PROCESS_DUP_HANDLE, + FALSE, // Non-inheritable. + parent_process_id)); + if (parent_process_handle.Get() == INVALID_HANDLE_VALUE) { + DeathTestAbort(String::Format("Unable to open parent process %u", + parent_process_id)); + } + + // TODO(vladl@google.com): Replace the following check with a + // compile-time assertion when available. + GTEST_CHECK_(sizeof(HANDLE) <= sizeof(size_t)); + + const HANDLE write_handle = + reinterpret_cast(write_handle_as_size_t); + HANDLE dup_write_handle; + + // The newly initialized handle is accessible only in in the parent + // process. To obtain one accessible within the child, we need to use + // DuplicateHandle. + if (!::DuplicateHandle(parent_process_handle.Get(), write_handle, + ::GetCurrentProcess(), &dup_write_handle, + 0x0, // Requested privileges ignored since + // DUPLICATE_SAME_ACCESS is used. + FALSE, // Request non-inheritable handler. + DUPLICATE_SAME_ACCESS)) { + DeathTestAbort(String::Format( + "Unable to duplicate the pipe handle %Iu from the parent process %u", + write_handle_as_size_t, parent_process_id)); + } + + const HANDLE event_handle = reinterpret_cast(event_handle_as_size_t); + HANDLE dup_event_handle; + + if (!::DuplicateHandle(parent_process_handle.Get(), event_handle, + ::GetCurrentProcess(), &dup_event_handle, + 0x0, + FALSE, + DUPLICATE_SAME_ACCESS)) { + DeathTestAbort(String::Format( + "Unable to duplicate the event handle %Iu from the parent process %u", + event_handle_as_size_t, parent_process_id)); + } + + const int write_fd = + ::_open_osfhandle(reinterpret_cast(dup_write_handle), O_APPEND); + if (write_fd == -1) { + DeathTestAbort(String::Format( + "Unable to convert pipe handle %Iu to a file descriptor", + write_handle_as_size_t)); + } + + // Signals the parent that the write end of the pipe has been acquired + // so the parent can release its own write end. + ::SetEvent(dup_event_handle); + + return write_fd; +} +# endif // GTEST_OS_WINDOWS + +// Returns a newly created InternalRunDeathTestFlag object with fields +// initialized from the GTEST_FLAG(internal_run_death_test) flag if +// the flag is specified; otherwise returns NULL. +InternalRunDeathTestFlag* ParseInternalRunDeathTestFlag() { + if (GTEST_FLAG(internal_run_death_test) == "") return NULL; + + // GTEST_HAS_DEATH_TEST implies that we have ::std::string, so we + // can use it here. + int line = -1; + int index = -1; + ::std::vector< ::std::string> fields; + SplitString(GTEST_FLAG(internal_run_death_test).c_str(), '|', &fields); + int write_fd = -1; + +# if GTEST_OS_WINDOWS + + unsigned int parent_process_id = 0; + size_t write_handle_as_size_t = 0; + size_t event_handle_as_size_t = 0; + + if (fields.size() != 6 + || !ParseNaturalNumber(fields[1], &line) + || !ParseNaturalNumber(fields[2], &index) + || !ParseNaturalNumber(fields[3], &parent_process_id) + || !ParseNaturalNumber(fields[4], &write_handle_as_size_t) + || !ParseNaturalNumber(fields[5], &event_handle_as_size_t)) { + DeathTestAbort(String::Format( + "Bad --gtest_internal_run_death_test flag: %s", + GTEST_FLAG(internal_run_death_test).c_str())); + } + write_fd = GetStatusFileDescriptor(parent_process_id, + write_handle_as_size_t, + event_handle_as_size_t); +# else + + if (fields.size() != 4 + || !ParseNaturalNumber(fields[1], &line) + || !ParseNaturalNumber(fields[2], &index) + || !ParseNaturalNumber(fields[3], &write_fd)) { + DeathTestAbort(String::Format( + "Bad --gtest_internal_run_death_test flag: %s", + GTEST_FLAG(internal_run_death_test).c_str())); + } + +# endif // GTEST_OS_WINDOWS + + return new InternalRunDeathTestFlag(fields[0], line, index, write_fd); +} + +} // namespace internal + +#endif // GTEST_HAS_DEATH_TEST + +} // namespace testing +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Authors: keith.ray@gmail.com (Keith Ray) + + +#include + +#if GTEST_OS_WINDOWS_MOBILE +# include +#elif GTEST_OS_WINDOWS +# include +# include +#elif GTEST_OS_SYMBIAN || GTEST_OS_NACL +// Symbian OpenC and NaCl have PATH_MAX in sys/syslimits.h +# include +#else +# include +# include // Some Linux distributions define PATH_MAX here. +#endif // GTEST_OS_WINDOWS_MOBILE + +#if GTEST_OS_WINDOWS +# define GTEST_PATH_MAX_ _MAX_PATH +#elif defined(PATH_MAX) +# define GTEST_PATH_MAX_ PATH_MAX +#elif defined(_XOPEN_PATH_MAX) +# define GTEST_PATH_MAX_ _XOPEN_PATH_MAX +#else +# define GTEST_PATH_MAX_ _POSIX_PATH_MAX +#endif // GTEST_OS_WINDOWS + + +namespace testing { +namespace internal { + +#if GTEST_OS_WINDOWS +// On Windows, '\\' is the standard path separator, but many tools and the +// Windows API also accept '/' as an alternate path separator. Unless otherwise +// noted, a file path can contain either kind of path separators, or a mixture +// of them. +const char kPathSeparator = '\\'; +const char kAlternatePathSeparator = '/'; +const char kPathSeparatorString[] = "\\"; +const char kAlternatePathSeparatorString[] = "/"; +# if GTEST_OS_WINDOWS_MOBILE +// Windows CE doesn't have a current directory. You should not use +// the current directory in tests on Windows CE, but this at least +// provides a reasonable fallback. +const char kCurrentDirectoryString[] = "\\"; +// Windows CE doesn't define INVALID_FILE_ATTRIBUTES +const DWORD kInvalidFileAttributes = 0xffffffff; +# else +const char kCurrentDirectoryString[] = ".\\"; +# endif // GTEST_OS_WINDOWS_MOBILE +#else +const char kPathSeparator = '/'; +const char kPathSeparatorString[] = "/"; +const char kCurrentDirectoryString[] = "./"; +#endif // GTEST_OS_WINDOWS + +// Returns whether the given character is a valid path separator. +static bool IsPathSeparator(char c) { +#if GTEST_HAS_ALT_PATH_SEP_ + return (c == kPathSeparator) || (c == kAlternatePathSeparator); +#else + return c == kPathSeparator; +#endif +} + +// Returns the current working directory, or "" if unsuccessful. +FilePath FilePath::GetCurrentDir() { +#if GTEST_OS_WINDOWS_MOBILE + // Windows CE doesn't have a current directory, so we just return + // something reasonable. + return FilePath(kCurrentDirectoryString); +#elif GTEST_OS_WINDOWS + char cwd[GTEST_PATH_MAX_ + 1] = { '\0' }; + return FilePath(_getcwd(cwd, sizeof(cwd)) == NULL ? "" : cwd); +#else + char cwd[GTEST_PATH_MAX_ + 1] = { '\0' }; + return FilePath(getcwd(cwd, sizeof(cwd)) == NULL ? "" : cwd); +#endif // GTEST_OS_WINDOWS_MOBILE +} + +// Returns a copy of the FilePath with the case-insensitive extension removed. +// Example: FilePath("dir/file.exe").RemoveExtension("EXE") returns +// FilePath("dir/file"). If a case-insensitive extension is not +// found, returns a copy of the original FilePath. +FilePath FilePath::RemoveExtension(const char* extension) const { + String dot_extension(String::Format(".%s", extension)); + if (pathname_.EndsWithCaseInsensitive(dot_extension.c_str())) { + return FilePath(String(pathname_.c_str(), pathname_.length() - 4)); + } + return *this; +} + +// Returns a pointer to the last occurence of a valid path separator in +// the FilePath. On Windows, for example, both '/' and '\' are valid path +// separators. Returns NULL if no path separator was found. +const char* FilePath::FindLastPathSeparator() const { + const char* const last_sep = strrchr(c_str(), kPathSeparator); +#if GTEST_HAS_ALT_PATH_SEP_ + const char* const last_alt_sep = strrchr(c_str(), kAlternatePathSeparator); + // Comparing two pointers of which only one is NULL is undefined. + if (last_alt_sep != NULL && + (last_sep == NULL || last_alt_sep > last_sep)) { + return last_alt_sep; + } +#endif + return last_sep; +} + +// Returns a copy of the FilePath with the directory part removed. +// Example: FilePath("path/to/file").RemoveDirectoryName() returns +// FilePath("file"). If there is no directory part ("just_a_file"), it returns +// the FilePath unmodified. If there is no file part ("just_a_dir/") it +// returns an empty FilePath (""). +// On Windows platform, '\' is the path separator, otherwise it is '/'. +FilePath FilePath::RemoveDirectoryName() const { + const char* const last_sep = FindLastPathSeparator(); + return last_sep ? FilePath(String(last_sep + 1)) : *this; +} + +// RemoveFileName returns the directory path with the filename removed. +// Example: FilePath("path/to/file").RemoveFileName() returns "path/to/". +// If the FilePath is "a_file" or "/a_file", RemoveFileName returns +// FilePath("./") or, on Windows, FilePath(".\\"). If the filepath does +// not have a file, like "just/a/dir/", it returns the FilePath unmodified. +// On Windows platform, '\' is the path separator, otherwise it is '/'. +FilePath FilePath::RemoveFileName() const { + const char* const last_sep = FindLastPathSeparator(); + String dir; + if (last_sep) { + dir = String(c_str(), last_sep + 1 - c_str()); + } else { + dir = kCurrentDirectoryString; + } + return FilePath(dir); +} + +// Helper functions for naming files in a directory for xml output. + +// Given directory = "dir", base_name = "test", number = 0, +// extension = "xml", returns "dir/test.xml". If number is greater +// than zero (e.g., 12), returns "dir/test_12.xml". +// On Windows platform, uses \ as the separator rather than /. +FilePath FilePath::MakeFileName(const FilePath& directory, + const FilePath& base_name, + int number, + const char* extension) { + String file; + if (number == 0) { + file = String::Format("%s.%s", base_name.c_str(), extension); + } else { + file = String::Format("%s_%d.%s", base_name.c_str(), number, extension); + } + return ConcatPaths(directory, FilePath(file)); +} + +// Given directory = "dir", relative_path = "test.xml", returns "dir/test.xml". +// On Windows, uses \ as the separator rather than /. +FilePath FilePath::ConcatPaths(const FilePath& directory, + const FilePath& relative_path) { + if (directory.IsEmpty()) + return relative_path; + const FilePath dir(directory.RemoveTrailingPathSeparator()); + return FilePath(String::Format("%s%c%s", dir.c_str(), kPathSeparator, + relative_path.c_str())); +} + +// Returns true if pathname describes something findable in the file-system, +// either a file, directory, or whatever. +bool FilePath::FileOrDirectoryExists() const { +#if GTEST_OS_WINDOWS_MOBILE + LPCWSTR unicode = String::AnsiToUtf16(pathname_.c_str()); + const DWORD attributes = GetFileAttributes(unicode); + delete [] unicode; + return attributes != kInvalidFileAttributes; +#else + posix::StatStruct file_stat; + return posix::Stat(pathname_.c_str(), &file_stat) == 0; +#endif // GTEST_OS_WINDOWS_MOBILE +} + +// Returns true if pathname describes a directory in the file-system +// that exists. +bool FilePath::DirectoryExists() const { + bool result = false; +#if GTEST_OS_WINDOWS + // Don't strip off trailing separator if path is a root directory on + // Windows (like "C:\\"). + const FilePath& path(IsRootDirectory() ? *this : + RemoveTrailingPathSeparator()); +#else + const FilePath& path(*this); +#endif + +#if GTEST_OS_WINDOWS_MOBILE + LPCWSTR unicode = String::AnsiToUtf16(path.c_str()); + const DWORD attributes = GetFileAttributes(unicode); + delete [] unicode; + if ((attributes != kInvalidFileAttributes) && + (attributes & FILE_ATTRIBUTE_DIRECTORY)) { + result = true; + } +#else + posix::StatStruct file_stat; + result = posix::Stat(path.c_str(), &file_stat) == 0 && + posix::IsDir(file_stat); +#endif // GTEST_OS_WINDOWS_MOBILE + + return result; +} + +// Returns true if pathname describes a root directory. (Windows has one +// root directory per disk drive.) +bool FilePath::IsRootDirectory() const { +#if GTEST_OS_WINDOWS + // TODO(wan@google.com): on Windows a network share like + // \\server\share can be a root directory, although it cannot be the + // current directory. Handle this properly. + return pathname_.length() == 3 && IsAbsolutePath(); +#else + return pathname_.length() == 1 && IsPathSeparator(pathname_.c_str()[0]); +#endif +} + +// Returns true if pathname describes an absolute path. +bool FilePath::IsAbsolutePath() const { + const char* const name = pathname_.c_str(); +#if GTEST_OS_WINDOWS + return pathname_.length() >= 3 && + ((name[0] >= 'a' && name[0] <= 'z') || + (name[0] >= 'A' && name[0] <= 'Z')) && + name[1] == ':' && + IsPathSeparator(name[2]); +#else + return IsPathSeparator(name[0]); +#endif +} + +// Returns a pathname for a file that does not currently exist. The pathname +// will be directory/base_name.extension or +// directory/base_name_.extension if directory/base_name.extension +// already exists. The number will be incremented until a pathname is found +// that does not already exist. +// Examples: 'dir/foo_test.xml' or 'dir/foo_test_1.xml'. +// There could be a race condition if two or more processes are calling this +// function at the same time -- they could both pick the same filename. +FilePath FilePath::GenerateUniqueFileName(const FilePath& directory, + const FilePath& base_name, + const char* extension) { + FilePath full_pathname; + int number = 0; + do { + full_pathname.Set(MakeFileName(directory, base_name, number++, extension)); + } while (full_pathname.FileOrDirectoryExists()); + return full_pathname; +} + +// Returns true if FilePath ends with a path separator, which indicates that +// it is intended to represent a directory. Returns false otherwise. +// This does NOT check that a directory (or file) actually exists. +bool FilePath::IsDirectory() const { + return !pathname_.empty() && + IsPathSeparator(pathname_.c_str()[pathname_.length() - 1]); +} + +// Create directories so that path exists. Returns true if successful or if +// the directories already exist; returns false if unable to create directories +// for any reason. +bool FilePath::CreateDirectoriesRecursively() const { + if (!this->IsDirectory()) { + return false; + } + + if (pathname_.length() == 0 || this->DirectoryExists()) { + return true; + } + + const FilePath parent(this->RemoveTrailingPathSeparator().RemoveFileName()); + return parent.CreateDirectoriesRecursively() && this->CreateFolder(); +} + +// Create the directory so that path exists. Returns true if successful or +// if the directory already exists; returns false if unable to create the +// directory for any reason, including if the parent directory does not +// exist. Not named "CreateDirectory" because that's a macro on Windows. +bool FilePath::CreateFolder() const { +#if GTEST_OS_WINDOWS_MOBILE + FilePath removed_sep(this->RemoveTrailingPathSeparator()); + LPCWSTR unicode = String::AnsiToUtf16(removed_sep.c_str()); + int result = CreateDirectory(unicode, NULL) ? 0 : -1; + delete [] unicode; +#elif GTEST_OS_WINDOWS + int result = _mkdir(pathname_.c_str()); +#else + int result = mkdir(pathname_.c_str(), 0777); +#endif // GTEST_OS_WINDOWS_MOBILE + + if (result == -1) { + return this->DirectoryExists(); // An error is OK if the directory exists. + } + return true; // No error. +} + +// If input name has a trailing separator character, remove it and return the +// name, otherwise return the name string unmodified. +// On Windows platform, uses \ as the separator, other platforms use /. +FilePath FilePath::RemoveTrailingPathSeparator() const { + return IsDirectory() + ? FilePath(String(pathname_.c_str(), pathname_.length() - 1)) + : *this; +} + +// Removes any redundant separators that might be in the pathname. +// For example, "bar///foo" becomes "bar/foo". Does not eliminate other +// redundancies that might be in a pathname involving "." or "..". +// TODO(wan@google.com): handle Windows network shares (e.g. \\server\share). +void FilePath::Normalize() { + if (pathname_.c_str() == NULL) { + pathname_ = ""; + return; + } + const char* src = pathname_.c_str(); + char* const dest = new char[pathname_.length() + 1]; + char* dest_ptr = dest; + memset(dest_ptr, 0, pathname_.length() + 1); + + while (*src != '\0') { + *dest_ptr = *src; + if (!IsPathSeparator(*src)) { + src++; + } else { +#if GTEST_HAS_ALT_PATH_SEP_ + if (*dest_ptr == kAlternatePathSeparator) { + *dest_ptr = kPathSeparator; + } +#endif + while (IsPathSeparator(*src)) + src++; + } + dest_ptr++; + } + *dest_ptr = '\0'; + pathname_ = dest; + delete[] dest; +} + +} // namespace internal +} // namespace testing +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) + + +#include +#include +#include +#include + +#if GTEST_OS_WINDOWS_MOBILE +# include // For TerminateProcess() +#elif GTEST_OS_WINDOWS +# include +# include +#else +# include +#endif // GTEST_OS_WINDOWS_MOBILE + +#if GTEST_OS_MAC +# include +# include +# include +#endif // GTEST_OS_MAC + + +// Indicates that this translation unit is part of Google Test's +// implementation. It must come before gtest-internal-inl.h is +// included, or there will be a compiler error. This trick is to +// prevent a user from accidentally including gtest-internal-inl.h in +// his code. +#define GTEST_IMPLEMENTATION_ 1 +#undef GTEST_IMPLEMENTATION_ + +namespace testing { +namespace internal { + +#if defined(_MSC_VER) || defined(__BORLANDC__) +// MSVC and C++Builder do not provide a definition of STDERR_FILENO. +const int kStdOutFileno = 1; +const int kStdErrFileno = 2; +#else +const int kStdOutFileno = STDOUT_FILENO; +const int kStdErrFileno = STDERR_FILENO; +#endif // _MSC_VER + +#if GTEST_OS_MAC + +// Returns the number of threads running in the process, or 0 to indicate that +// we cannot detect it. +size_t GetThreadCount() { + const task_t task = mach_task_self(); + mach_msg_type_number_t thread_count; + thread_act_array_t thread_list; + const kern_return_t status = task_threads(task, &thread_list, &thread_count); + if (status == KERN_SUCCESS) { + // task_threads allocates resources in thread_list and we need to free them + // to avoid leaks. + vm_deallocate(task, + reinterpret_cast(thread_list), + sizeof(thread_t) * thread_count); + return static_cast(thread_count); + } else { + return 0; + } +} + +#else + +size_t GetThreadCount() { + // There's no portable way to detect the number of threads, so we just + // return 0 to indicate that we cannot detect it. + return 0; +} + +#endif // GTEST_OS_MAC + +#if GTEST_USES_POSIX_RE + +// Implements RE. Currently only needed for death tests. + +RE::~RE() { + if (is_valid_) { + // regfree'ing an invalid regex might crash because the content + // of the regex is undefined. Since the regex's are essentially + // the same, one cannot be valid (or invalid) without the other + // being so too. + regfree(&partial_regex_); + regfree(&full_regex_); + } + free(const_cast(pattern_)); +} + +// Returns true iff regular expression re matches the entire str. +bool RE::FullMatch(const char* str, const RE& re) { + if (!re.is_valid_) return false; + + regmatch_t match; + return regexec(&re.full_regex_, str, 1, &match, 0) == 0; +} + +// Returns true iff regular expression re matches a substring of str +// (including str itself). +bool RE::PartialMatch(const char* str, const RE& re) { + if (!re.is_valid_) return false; + + regmatch_t match; + return regexec(&re.partial_regex_, str, 1, &match, 0) == 0; +} + +// Initializes an RE from its string representation. +void RE::Init(const char* regex) { + pattern_ = posix::StrDup(regex); + + // Reserves enough bytes to hold the regular expression used for a + // full match. + const size_t full_regex_len = strlen(regex) + 10; + char* const full_pattern = new char[full_regex_len]; + + snprintf(full_pattern, full_regex_len, "^(%s)$", regex); + is_valid_ = regcomp(&full_regex_, full_pattern, REG_EXTENDED) == 0; + // We want to call regcomp(&partial_regex_, ...) even if the + // previous expression returns false. Otherwise partial_regex_ may + // not be properly initialized can may cause trouble when it's + // freed. + // + // Some implementation of POSIX regex (e.g. on at least some + // versions of Cygwin) doesn't accept the empty string as a valid + // regex. We change it to an equivalent form "()" to be safe. + if (is_valid_) { + const char* const partial_regex = (*regex == '\0') ? "()" : regex; + is_valid_ = regcomp(&partial_regex_, partial_regex, REG_EXTENDED) == 0; + } + EXPECT_TRUE(is_valid_) + << "Regular expression \"" << regex + << "\" is not a valid POSIX Extended regular expression."; + + delete[] full_pattern; +} + +#elif GTEST_USES_SIMPLE_RE + +// Returns true iff ch appears anywhere in str (excluding the +// terminating '\0' character). +bool IsInSet(char ch, const char* str) { + return ch != '\0' && strchr(str, ch) != NULL; +} + +// Returns true iff ch belongs to the given classification. Unlike +// similar functions in , these aren't affected by the +// current locale. +bool IsAsciiDigit(char ch) { return '0' <= ch && ch <= '9'; } +bool IsAsciiPunct(char ch) { + return IsInSet(ch, "^-!\"#$%&'()*+,./:;<=>?@[\\]_`{|}~"); +} +bool IsRepeat(char ch) { return IsInSet(ch, "?*+"); } +bool IsAsciiWhiteSpace(char ch) { return IsInSet(ch, " \f\n\r\t\v"); } +bool IsAsciiWordChar(char ch) { + return ('a' <= ch && ch <= 'z') || ('A' <= ch && ch <= 'Z') || + ('0' <= ch && ch <= '9') || ch == '_'; +} + +// Returns true iff "\\c" is a supported escape sequence. +bool IsValidEscape(char c) { + return (IsAsciiPunct(c) || IsInSet(c, "dDfnrsStvwW")); +} + +// Returns true iff the given atom (specified by escaped and pattern) +// matches ch. The result is undefined if the atom is invalid. +bool AtomMatchesChar(bool escaped, char pattern_char, char ch) { + if (escaped) { // "\\p" where p is pattern_char. + switch (pattern_char) { + case 'd': return IsAsciiDigit(ch); + case 'D': return !IsAsciiDigit(ch); + case 'f': return ch == '\f'; + case 'n': return ch == '\n'; + case 'r': return ch == '\r'; + case 's': return IsAsciiWhiteSpace(ch); + case 'S': return !IsAsciiWhiteSpace(ch); + case 't': return ch == '\t'; + case 'v': return ch == '\v'; + case 'w': return IsAsciiWordChar(ch); + case 'W': return !IsAsciiWordChar(ch); + } + return IsAsciiPunct(pattern_char) && pattern_char == ch; + } + + return (pattern_char == '.' && ch != '\n') || pattern_char == ch; +} + +// Helper function used by ValidateRegex() to format error messages. +String FormatRegexSyntaxError(const char* regex, int index) { + return (Message() << "Syntax error at index " << index + << " in simple regular expression \"" << regex << "\": ").GetString(); +} + +// Generates non-fatal failures and returns false if regex is invalid; +// otherwise returns true. +bool ValidateRegex(const char* regex) { + if (regex == NULL) { + // TODO(wan@google.com): fix the source file location in the + // assertion failures to match where the regex is used in user + // code. + ADD_FAILURE() << "NULL is not a valid simple regular expression."; + return false; + } + + bool is_valid = true; + + // True iff ?, *, or + can follow the previous atom. + bool prev_repeatable = false; + for (int i = 0; regex[i]; i++) { + if (regex[i] == '\\') { // An escape sequence + i++; + if (regex[i] == '\0') { + ADD_FAILURE() << FormatRegexSyntaxError(regex, i - 1) + << "'\\' cannot appear at the end."; + return false; + } + + if (!IsValidEscape(regex[i])) { + ADD_FAILURE() << FormatRegexSyntaxError(regex, i - 1) + << "invalid escape sequence \"\\" << regex[i] << "\"."; + is_valid = false; + } + prev_repeatable = true; + } else { // Not an escape sequence. + const char ch = regex[i]; + + if (ch == '^' && i > 0) { + ADD_FAILURE() << FormatRegexSyntaxError(regex, i) + << "'^' can only appear at the beginning."; + is_valid = false; + } else if (ch == '$' && regex[i + 1] != '\0') { + ADD_FAILURE() << FormatRegexSyntaxError(regex, i) + << "'$' can only appear at the end."; + is_valid = false; + } else if (IsInSet(ch, "()[]{}|")) { + ADD_FAILURE() << FormatRegexSyntaxError(regex, i) + << "'" << ch << "' is unsupported."; + is_valid = false; + } else if (IsRepeat(ch) && !prev_repeatable) { + ADD_FAILURE() << FormatRegexSyntaxError(regex, i) + << "'" << ch << "' can only follow a repeatable token."; + is_valid = false; + } + + prev_repeatable = !IsInSet(ch, "^$?*+"); + } + } + + return is_valid; +} + +// Matches a repeated regex atom followed by a valid simple regular +// expression. The regex atom is defined as c if escaped is false, +// or \c otherwise. repeat is the repetition meta character (?, *, +// or +). The behavior is undefined if str contains too many +// characters to be indexable by size_t, in which case the test will +// probably time out anyway. We are fine with this limitation as +// std::string has it too. +bool MatchRepetitionAndRegexAtHead( + bool escaped, char c, char repeat, const char* regex, + const char* str) { + const size_t min_count = (repeat == '+') ? 1 : 0; + const size_t max_count = (repeat == '?') ? 1 : + static_cast(-1) - 1; + // We cannot call numeric_limits::max() as it conflicts with the + // max() macro on Windows. + + for (size_t i = 0; i <= max_count; ++i) { + // We know that the atom matches each of the first i characters in str. + if (i >= min_count && MatchRegexAtHead(regex, str + i)) { + // We have enough matches at the head, and the tail matches too. + // Since we only care about *whether* the pattern matches str + // (as opposed to *how* it matches), there is no need to find a + // greedy match. + return true; + } + if (str[i] == '\0' || !AtomMatchesChar(escaped, c, str[i])) + return false; + } + return false; +} + +// Returns true iff regex matches a prefix of str. regex must be a +// valid simple regular expression and not start with "^", or the +// result is undefined. +bool MatchRegexAtHead(const char* regex, const char* str) { + if (*regex == '\0') // An empty regex matches a prefix of anything. + return true; + + // "$" only matches the end of a string. Note that regex being + // valid guarantees that there's nothing after "$" in it. + if (*regex == '$') + return *str == '\0'; + + // Is the first thing in regex an escape sequence? + const bool escaped = *regex == '\\'; + if (escaped) + ++regex; + if (IsRepeat(regex[1])) { + // MatchRepetitionAndRegexAtHead() calls MatchRegexAtHead(), so + // here's an indirect recursion. It terminates as the regex gets + // shorter in each recursion. + return MatchRepetitionAndRegexAtHead( + escaped, regex[0], regex[1], regex + 2, str); + } else { + // regex isn't empty, isn't "$", and doesn't start with a + // repetition. We match the first atom of regex with the first + // character of str and recurse. + return (*str != '\0') && AtomMatchesChar(escaped, *regex, *str) && + MatchRegexAtHead(regex + 1, str + 1); + } +} + +// Returns true iff regex matches any substring of str. regex must be +// a valid simple regular expression, or the result is undefined. +// +// The algorithm is recursive, but the recursion depth doesn't exceed +// the regex length, so we won't need to worry about running out of +// stack space normally. In rare cases the time complexity can be +// exponential with respect to the regex length + the string length, +// but usually it's must faster (often close to linear). +bool MatchRegexAnywhere(const char* regex, const char* str) { + if (regex == NULL || str == NULL) + return false; + + if (*regex == '^') + return MatchRegexAtHead(regex + 1, str); + + // A successful match can be anywhere in str. + do { + if (MatchRegexAtHead(regex, str)) + return true; + } while (*str++ != '\0'); + return false; +} + +// Implements the RE class. + +RE::~RE() { + free(const_cast(pattern_)); + free(const_cast(full_pattern_)); +} + +// Returns true iff regular expression re matches the entire str. +bool RE::FullMatch(const char* str, const RE& re) { + return re.is_valid_ && MatchRegexAnywhere(re.full_pattern_, str); +} + +// Returns true iff regular expression re matches a substring of str +// (including str itself). +bool RE::PartialMatch(const char* str, const RE& re) { + return re.is_valid_ && MatchRegexAnywhere(re.pattern_, str); +} + +// Initializes an RE from its string representation. +void RE::Init(const char* regex) { + pattern_ = full_pattern_ = NULL; + if (regex != NULL) { + pattern_ = posix::StrDup(regex); + } + + is_valid_ = ValidateRegex(regex); + if (!is_valid_) { + // No need to calculate the full pattern when the regex is invalid. + return; + } + + const size_t len = strlen(regex); + // Reserves enough bytes to hold the regular expression used for a + // full match: we need space to prepend a '^', append a '$', and + // terminate the string with '\0'. + char* buffer = static_cast(malloc(len + 3)); + full_pattern_ = buffer; + + if (*regex != '^') + *buffer++ = '^'; // Makes sure full_pattern_ starts with '^'. + + // We don't use snprintf or strncpy, as they trigger a warning when + // compiled with VC++ 8.0. + memcpy(buffer, regex, len); + buffer += len; + + if (len == 0 || regex[len - 1] != '$') + *buffer++ = '$'; // Makes sure full_pattern_ ends with '$'. + + *buffer = '\0'; +} + +#endif // GTEST_USES_POSIX_RE + +const char kUnknownFile[] = "unknown file"; + +// Formats a source file path and a line number as they would appear +// in an error message from the compiler used to compile this code. +GTEST_API_ ::std::string FormatFileLocation(const char* file, int line) { + const char* const file_name = file == NULL ? kUnknownFile : file; + + if (line < 0) { + return String::Format("%s:", file_name).c_str(); + } +#ifdef _MSC_VER + return String::Format("%s(%d):", file_name, line).c_str(); +#else + return String::Format("%s:%d:", file_name, line).c_str(); +#endif // _MSC_VER +} + +// Formats a file location for compiler-independent XML output. +// Although this function is not platform dependent, we put it next to +// FormatFileLocation in order to contrast the two functions. +// Note that FormatCompilerIndependentFileLocation() does NOT append colon +// to the file location it produces, unlike FormatFileLocation(). +GTEST_API_ ::std::string FormatCompilerIndependentFileLocation( + const char* file, int line) { + const char* const file_name = file == NULL ? kUnknownFile : file; + + if (line < 0) + return file_name; + else + return String::Format("%s:%d", file_name, line).c_str(); +} + + +GTestLog::GTestLog(GTestLogSeverity severity, const char* file, int line) + : severity_(severity) { + const char* const marker = + severity == GTEST_INFO ? "[ INFO ]" : + severity == GTEST_WARNING ? "[WARNING]" : + severity == GTEST_ERROR ? "[ ERROR ]" : "[ FATAL ]"; + GetStream() << ::std::endl << marker << " " + << FormatFileLocation(file, line).c_str() << ": "; +} + +// Flushes the buffers and, if severity is GTEST_FATAL, aborts the program. +GTestLog::~GTestLog() { + GetStream() << ::std::endl; + if (severity_ == GTEST_FATAL) { + fflush(stderr); + posix::Abort(); + } +} +// Disable Microsoft deprecation warnings for POSIX functions called from +// this class (creat, dup, dup2, and close) +#ifdef _MSC_VER +# pragma warning(push) +# pragma warning(disable: 4996) +#endif // _MSC_VER + +#if GTEST_HAS_STREAM_REDIRECTION + +// Object that captures an output stream (stdout/stderr). +class CapturedStream { + public: + // The ctor redirects the stream to a temporary file. + CapturedStream(int fd) : fd_(fd), uncaptured_fd_(dup(fd)) { + +# if GTEST_OS_WINDOWS + char temp_dir_path[MAX_PATH + 1] = { '\0' }; // NOLINT + char temp_file_path[MAX_PATH + 1] = { '\0' }; // NOLINT + + ::GetTempPathA(sizeof(temp_dir_path), temp_dir_path); + const UINT success = ::GetTempFileNameA(temp_dir_path, + "gtest_redir", + 0, // Generate unique file name. + temp_file_path); + GTEST_CHECK_(success != 0) + << "Unable to create a temporary file in " << temp_dir_path; + const int captured_fd = creat(temp_file_path, _S_IREAD | _S_IWRITE); + GTEST_CHECK_(captured_fd != -1) << "Unable to open temporary file " + << temp_file_path; + filename_ = temp_file_path; +# else + // There's no guarantee that a test has write access to the + // current directory, so we create the temporary file in the /tmp + // directory instead. + char name_template[] = "/tmp/captured_stream.XXXXXX"; + const int captured_fd = mkstemp(name_template); + filename_ = name_template; +# endif // GTEST_OS_WINDOWS + fflush(NULL); + dup2(captured_fd, fd_); + close(captured_fd); + } + + ~CapturedStream() { + remove(filename_.c_str()); + } + + String GetCapturedString() { + if (uncaptured_fd_ != -1) { + // Restores the original stream. + fflush(NULL); + dup2(uncaptured_fd_, fd_); + close(uncaptured_fd_); + uncaptured_fd_ = -1; + } + + FILE* const file = posix::FOpen(filename_.c_str(), "r"); + const String content = ReadEntireFile(file); + posix::FClose(file); + return content; + } + + private: + // Reads the entire content of a file as a String. + static String ReadEntireFile(FILE* file); + + // Returns the size (in bytes) of a file. + static size_t GetFileSize(FILE* file); + + const int fd_; // A stream to capture. + int uncaptured_fd_; + // Name of the temporary file holding the stderr output. + ::std::string filename_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(CapturedStream); +}; + +// Returns the size (in bytes) of a file. +size_t CapturedStream::GetFileSize(FILE* file) { + fseek(file, 0, SEEK_END); + return static_cast(ftell(file)); +} + +// Reads the entire content of a file as a string. +String CapturedStream::ReadEntireFile(FILE* file) { + const size_t file_size = GetFileSize(file); + char* const buffer = new char[file_size]; + + size_t bytes_last_read = 0; // # of bytes read in the last fread() + size_t bytes_read = 0; // # of bytes read so far + + fseek(file, 0, SEEK_SET); + + // Keeps reading the file until we cannot read further or the + // pre-determined file size is reached. + do { + bytes_last_read = fread(buffer+bytes_read, 1, file_size-bytes_read, file); + bytes_read += bytes_last_read; + } while (bytes_last_read > 0 && bytes_read < file_size); + + const String content(buffer, bytes_read); + delete[] buffer; + + return content; +} + +# ifdef _MSC_VER +# pragma warning(pop) +# endif // _MSC_VER + +static CapturedStream* g_captured_stderr = NULL; +static CapturedStream* g_captured_stdout = NULL; + +// Starts capturing an output stream (stdout/stderr). +void CaptureStream(int fd, const char* stream_name, CapturedStream** stream) { + if (*stream != NULL) { + GTEST_LOG_(FATAL) << "Only one " << stream_name + << " capturer can exist at a time."; + } + *stream = new CapturedStream(fd); +} + +// Stops capturing the output stream and returns the captured string. +String GetCapturedStream(CapturedStream** captured_stream) { + const String content = (*captured_stream)->GetCapturedString(); + + delete *captured_stream; + *captured_stream = NULL; + + return content; +} + +// Starts capturing stdout. +void CaptureStdout() { + CaptureStream(kStdOutFileno, "stdout", &g_captured_stdout); +} + +// Starts capturing stderr. +void CaptureStderr() { + CaptureStream(kStdErrFileno, "stderr", &g_captured_stderr); +} + +// Stops capturing stdout and returns the captured string. +String GetCapturedStdout() { return GetCapturedStream(&g_captured_stdout); } + +// Stops capturing stderr and returns the captured string. +String GetCapturedStderr() { return GetCapturedStream(&g_captured_stderr); } + +#endif // GTEST_HAS_STREAM_REDIRECTION + +#if GTEST_HAS_DEATH_TEST + +// A copy of all command line arguments. Set by InitGoogleTest(). +::std::vector g_argvs; + +// Returns the command line as a vector of strings. +const ::std::vector& GetArgvs() { return g_argvs; } + +#endif // GTEST_HAS_DEATH_TEST + +#if GTEST_OS_WINDOWS_MOBILE +namespace posix { +void Abort() { + DebugBreak(); + TerminateProcess(GetCurrentProcess(), 1); +} +} // namespace posix +#endif // GTEST_OS_WINDOWS_MOBILE + +// Returns the name of the environment variable corresponding to the +// given flag. For example, FlagToEnvVar("foo") will return +// "GTEST_FOO" in the open-source version. +static String FlagToEnvVar(const char* flag) { + const String full_flag = + (Message() << GTEST_FLAG_PREFIX_ << flag).GetString(); + + Message env_var; + for (size_t i = 0; i != full_flag.length(); i++) { + env_var << ToUpper(full_flag.c_str()[i]); + } + + return env_var.GetString(); +} + +// Parses 'str' for a 32-bit signed integer. If successful, writes +// the result to *value and returns true; otherwise leaves *value +// unchanged and returns false. +bool ParseInt32(const Message& src_text, const char* str, Int32* value) { + // Parses the environment variable as a decimal integer. + char* end = NULL; + const long long_value = strtol(str, &end, 10); // NOLINT + + // Has strtol() consumed all characters in the string? + if (*end != '\0') { + // No - an invalid character was encountered. + Message msg; + msg << "WARNING: " << src_text + << " is expected to be a 32-bit integer, but actually" + << " has value \"" << str << "\".\n"; + printf("%s", msg.GetString().c_str()); + fflush(stdout); + return false; + } + + // Is the parsed value in the range of an Int32? + const Int32 result = static_cast(long_value); + if (long_value == LONG_MAX || long_value == LONG_MIN || + // The parsed value overflows as a long. (strtol() returns + // LONG_MAX or LONG_MIN when the input overflows.) + result != long_value + // The parsed value overflows as an Int32. + ) { + Message msg; + msg << "WARNING: " << src_text + << " is expected to be a 32-bit integer, but actually" + << " has value " << str << ", which overflows.\n"; + printf("%s", msg.GetString().c_str()); + fflush(stdout); + return false; + } + + *value = result; + return true; +} + +// Reads and returns the Boolean environment variable corresponding to +// the given flag; if it's not set, returns default_value. +// +// The value is considered true iff it's not "0". +bool BoolFromGTestEnv(const char* flag, bool default_value) { + const String env_var = FlagToEnvVar(flag); + const char* const string_value = posix::GetEnv(env_var.c_str()); + return string_value == NULL ? + default_value : strcmp(string_value, "0") != 0; +} + +// Reads and returns a 32-bit integer stored in the environment +// variable corresponding to the given flag; if it isn't set or +// doesn't represent a valid 32-bit integer, returns default_value. +Int32 Int32FromGTestEnv(const char* flag, Int32 default_value) { + const String env_var = FlagToEnvVar(flag); + const char* const string_value = posix::GetEnv(env_var.c_str()); + if (string_value == NULL) { + // The environment variable is not set. + return default_value; + } + + Int32 result = default_value; + if (!ParseInt32(Message() << "Environment variable " << env_var, + string_value, &result)) { + printf("The default value %s is used.\n", + (Message() << default_value).GetString().c_str()); + fflush(stdout); + return default_value; + } + + return result; +} + +// Reads and returns the string environment variable corresponding to +// the given flag; if it's not set, returns default_value. +const char* StringFromGTestEnv(const char* flag, const char* default_value) { + const String env_var = FlagToEnvVar(flag); + const char* const value = posix::GetEnv(env_var.c_str()); + return value == NULL ? default_value : value; +} + +} // namespace internal +} // namespace testing +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) + +// Google Test - The Google C++ Testing Framework +// +// This file implements a universal value printer that can print a +// value of any type T: +// +// void ::testing::internal::UniversalPrinter::Print(value, ostream_ptr); +// +// It uses the << operator when possible, and prints the bytes in the +// object otherwise. A user can override its behavior for a class +// type Foo by defining either operator<<(::std::ostream&, const Foo&) +// or void PrintTo(const Foo&, ::std::ostream*) in the namespace that +// defines Foo. + +#include +#include +#include // NOLINT +#include + +namespace testing { + +namespace { + +using ::std::ostream; + +#if GTEST_OS_WINDOWS_MOBILE // Windows CE does not define _snprintf_s. +# define snprintf _snprintf +#elif _MSC_VER >= 1400 // VC 8.0 and later deprecate snprintf and _snprintf. +# define snprintf _snprintf_s +#elif _MSC_VER +# define snprintf _snprintf +#endif // GTEST_OS_WINDOWS_MOBILE + +// Prints a segment of bytes in the given object. +void PrintByteSegmentInObjectTo(const unsigned char* obj_bytes, size_t start, + size_t count, ostream* os) { + char text[5] = ""; + for (size_t i = 0; i != count; i++) { + const size_t j = start + i; + if (i != 0) { + // Organizes the bytes into groups of 2 for easy parsing by + // human. + if ((j % 2) == 0) + *os << ' '; + else + *os << '-'; + } + snprintf(text, sizeof(text), "%02X", obj_bytes[j]); + *os << text; + } +} + +// Prints the bytes in the given value to the given ostream. +void PrintBytesInObjectToImpl(const unsigned char* obj_bytes, size_t count, + ostream* os) { + // Tells the user how big the object is. + *os << count << "-byte object <"; + + const size_t kThreshold = 132; + const size_t kChunkSize = 64; + // If the object size is bigger than kThreshold, we'll have to omit + // some details by printing only the first and the last kChunkSize + // bytes. + // TODO(wan): let the user control the threshold using a flag. + if (count < kThreshold) { + PrintByteSegmentInObjectTo(obj_bytes, 0, count, os); + } else { + PrintByteSegmentInObjectTo(obj_bytes, 0, kChunkSize, os); + *os << " ... "; + // Rounds up to 2-byte boundary. + const size_t resume_pos = (count - kChunkSize + 1)/2*2; + PrintByteSegmentInObjectTo(obj_bytes, resume_pos, count - resume_pos, os); + } + *os << ">"; +} + +} // namespace + +namespace internal2 { + +// Delegates to PrintBytesInObjectToImpl() to print the bytes in the +// given object. The delegation simplifies the implementation, which +// uses the << operator and thus is easier done outside of the +// ::testing::internal namespace, which contains a << operator that +// sometimes conflicts with the one in STL. +void PrintBytesInObjectTo(const unsigned char* obj_bytes, size_t count, + ostream* os) { + PrintBytesInObjectToImpl(obj_bytes, count, os); +} + +} // namespace internal2 + +namespace internal { + +// Depending on the value of a char (or wchar_t), we print it in one +// of three formats: +// - as is if it's a printable ASCII (e.g. 'a', '2', ' '), +// - as a hexidecimal escape sequence (e.g. '\x7F'), or +// - as a special escape sequence (e.g. '\r', '\n'). +enum CharFormat { + kAsIs, + kHexEscape, + kSpecialEscape +}; + +// Returns true if c is a printable ASCII character. We test the +// value of c directly instead of calling isprint(), which is buggy on +// Windows Mobile. +inline bool IsPrintableAscii(wchar_t c) { + return 0x20 <= c && c <= 0x7E; +} + +// Prints a wide or narrow char c as a character literal without the +// quotes, escaping it when necessary; returns how c was formatted. +// The template argument UnsignedChar is the unsigned version of Char, +// which is the type of c. +template +static CharFormat PrintAsCharLiteralTo(Char c, ostream* os) { + switch (static_cast(c)) { + case L'\0': + *os << "\\0"; + break; + case L'\'': + *os << "\\'"; + break; + case L'\\': + *os << "\\\\"; + break; + case L'\a': + *os << "\\a"; + break; + case L'\b': + *os << "\\b"; + break; + case L'\f': + *os << "\\f"; + break; + case L'\n': + *os << "\\n"; + break; + case L'\r': + *os << "\\r"; + break; + case L'\t': + *os << "\\t"; + break; + case L'\v': + *os << "\\v"; + break; + default: + if (IsPrintableAscii(c)) { + *os << static_cast(c); + return kAsIs; + } else { + *os << String::Format("\\x%X", static_cast(c)); + return kHexEscape; + } + } + return kSpecialEscape; +} + +// Prints a char c as if it's part of a string literal, escaping it when +// necessary; returns how c was formatted. +static CharFormat PrintAsWideStringLiteralTo(wchar_t c, ostream* os) { + switch (c) { + case L'\'': + *os << "'"; + return kAsIs; + case L'"': + *os << "\\\""; + return kSpecialEscape; + default: + return PrintAsCharLiteralTo(c, os); + } +} + +// Prints a char c as if it's part of a string literal, escaping it when +// necessary; returns how c was formatted. +static CharFormat PrintAsNarrowStringLiteralTo(char c, ostream* os) { + return PrintAsWideStringLiteralTo(static_cast(c), os); +} + +// Prints a wide or narrow character c and its code. '\0' is printed +// as "'\\0'", other unprintable characters are also properly escaped +// using the standard C++ escape sequence. The template argument +// UnsignedChar is the unsigned version of Char, which is the type of c. +template +void PrintCharAndCodeTo(Char c, ostream* os) { + // First, print c as a literal in the most readable form we can find. + *os << ((sizeof(c) > 1) ? "L'" : "'"); + const CharFormat format = PrintAsCharLiteralTo(c, os); + *os << "'"; + + // To aid user debugging, we also print c's code in decimal, unless + // it's 0 (in which case c was printed as '\\0', making the code + // obvious). + if (c == 0) + return; + *os << " (" << String::Format("%d", c).c_str(); + + // For more convenience, we print c's code again in hexidecimal, + // unless c was already printed in the form '\x##' or the code is in + // [1, 9]. + if (format == kHexEscape || (1 <= c && c <= 9)) { + // Do nothing. + } else { + *os << String::Format(", 0x%X", + static_cast(c)).c_str(); + } + *os << ")"; +} + +void PrintTo(unsigned char c, ::std::ostream* os) { + PrintCharAndCodeTo(c, os); +} +void PrintTo(signed char c, ::std::ostream* os) { + PrintCharAndCodeTo(c, os); +} + +// Prints a wchar_t as a symbol if it is printable or as its internal +// code otherwise and also as its code. L'\0' is printed as "L'\\0'". +void PrintTo(wchar_t wc, ostream* os) { + PrintCharAndCodeTo(wc, os); +} + +// Prints the given array of characters to the ostream. +// The array starts at *begin, the length is len, it may include '\0' characters +// and may not be null-terminated. +static void PrintCharsAsStringTo(const char* begin, size_t len, ostream* os) { + *os << "\""; + bool is_previous_hex = false; + for (size_t index = 0; index < len; ++index) { + const char cur = begin[index]; + if (is_previous_hex && IsXDigit(cur)) { + // Previous character is of '\x..' form and this character can be + // interpreted as another hexadecimal digit in its number. Break string to + // disambiguate. + *os << "\" \""; + } + is_previous_hex = PrintAsNarrowStringLiteralTo(cur, os) == kHexEscape; + } + *os << "\""; +} + +// Prints a (const) char array of 'len' elements, starting at address 'begin'. +void UniversalPrintArray(const char* begin, size_t len, ostream* os) { + PrintCharsAsStringTo(begin, len, os); +} + +// Prints the given array of wide characters to the ostream. +// The array starts at *begin, the length is len, it may include L'\0' +// characters and may not be null-terminated. +static void PrintWideCharsAsStringTo(const wchar_t* begin, size_t len, + ostream* os) { + *os << "L\""; + bool is_previous_hex = false; + for (size_t index = 0; index < len; ++index) { + const wchar_t cur = begin[index]; + if (is_previous_hex && isascii(cur) && IsXDigit(static_cast(cur))) { + // Previous character is of '\x..' form and this character can be + // interpreted as another hexadecimal digit in its number. Break string to + // disambiguate. + *os << "\" L\""; + } + is_previous_hex = PrintAsWideStringLiteralTo(cur, os) == kHexEscape; + } + *os << "\""; +} + +// Prints the given C string to the ostream. +void PrintTo(const char* s, ostream* os) { + if (s == NULL) { + *os << "NULL"; + } else { + *os << ImplicitCast_(s) << " pointing to "; + PrintCharsAsStringTo(s, strlen(s), os); + } +} + +// MSVC compiler can be configured to define whar_t as a typedef +// of unsigned short. Defining an overload for const wchar_t* in that case +// would cause pointers to unsigned shorts be printed as wide strings, +// possibly accessing more memory than intended and causing invalid +// memory accesses. MSVC defines _NATIVE_WCHAR_T_DEFINED symbol when +// wchar_t is implemented as a native type. +#if !defined(_MSC_VER) || defined(_NATIVE_WCHAR_T_DEFINED) +// Prints the given wide C string to the ostream. +void PrintTo(const wchar_t* s, ostream* os) { + if (s == NULL) { + *os << "NULL"; + } else { + *os << ImplicitCast_(s) << " pointing to "; + PrintWideCharsAsStringTo(s, wcslen(s), os); + } +} +#endif // wchar_t is native + +// Prints a ::string object. +#if GTEST_HAS_GLOBAL_STRING +void PrintStringTo(const ::string& s, ostream* os) { + PrintCharsAsStringTo(s.data(), s.size(), os); +} +#endif // GTEST_HAS_GLOBAL_STRING + +void PrintStringTo(const ::std::string& s, ostream* os) { + PrintCharsAsStringTo(s.data(), s.size(), os); +} + +// Prints a ::wstring object. +#if GTEST_HAS_GLOBAL_WSTRING +void PrintWideStringTo(const ::wstring& s, ostream* os) { + PrintWideCharsAsStringTo(s.data(), s.size(), os); +} +#endif // GTEST_HAS_GLOBAL_WSTRING + +#if GTEST_HAS_STD_WSTRING +void PrintWideStringTo(const ::std::wstring& s, ostream* os) { + PrintWideCharsAsStringTo(s.data(), s.size(), os); +} +#endif // GTEST_HAS_STD_WSTRING + +} // namespace internal + +} // namespace testing +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: mheule@google.com (Markus Heule) +// +// The Google C++ Testing Framework (Google Test) + + +// Indicates that this translation unit is part of Google Test's +// implementation. It must come before gtest-internal-inl.h is +// included, or there will be a compiler error. This trick is to +// prevent a user from accidentally including gtest-internal-inl.h in +// his code. +#define GTEST_IMPLEMENTATION_ 1 +#undef GTEST_IMPLEMENTATION_ + +namespace testing { + +using internal::GetUnitTestImpl; + +// Gets the summary of the failure message by omitting the stack trace +// in it. +internal::String TestPartResult::ExtractSummary(const char* message) { + const char* const stack_trace = strstr(message, internal::kStackTraceMarker); + return stack_trace == NULL ? internal::String(message) : + internal::String(message, stack_trace - message); +} + +// Prints a TestPartResult object. +std::ostream& operator<<(std::ostream& os, const TestPartResult& result) { + return os + << result.file_name() << ":" << result.line_number() << ": " + << (result.type() == TestPartResult::kSuccess ? "Success" : + result.type() == TestPartResult::kFatalFailure ? "Fatal failure" : + "Non-fatal failure") << ":\n" + << result.message() << std::endl; +} + +// Appends a TestPartResult to the array. +void TestPartResultArray::Append(const TestPartResult& result) { + array_.push_back(result); +} + +// Returns the TestPartResult at the given index (0-based). +const TestPartResult& TestPartResultArray::GetTestPartResult(int index) const { + if (index < 0 || index >= size()) { + printf("\nInvalid index (%d) into TestPartResultArray.\n", index); + internal::posix::Abort(); + } + + return array_[index]; +} + +// Returns the number of TestPartResult objects in the array. +int TestPartResultArray::size() const { + return static_cast(array_.size()); +} + +namespace internal { + +HasNewFatalFailureHelper::HasNewFatalFailureHelper() + : has_new_fatal_failure_(false), + original_reporter_(GetUnitTestImpl()-> + GetTestPartResultReporterForCurrentThread()) { + GetUnitTestImpl()->SetTestPartResultReporterForCurrentThread(this); +} + +HasNewFatalFailureHelper::~HasNewFatalFailureHelper() { + GetUnitTestImpl()->SetTestPartResultReporterForCurrentThread( + original_reporter_); +} + +void HasNewFatalFailureHelper::ReportTestPartResult( + const TestPartResult& result) { + if (result.fatally_failed()) + has_new_fatal_failure_ = true; + original_reporter_->ReportTestPartResult(result); +} + +} // namespace internal + +} // namespace testing +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) + + +namespace testing { +namespace internal { + +#if GTEST_HAS_TYPED_TEST_P + +// Skips to the first non-space char in str. Returns an empty string if str +// contains only whitespace characters. +static const char* SkipSpaces(const char* str) { + while (IsSpace(*str)) + str++; + return str; +} + +// Verifies that registered_tests match the test names in +// defined_test_names_; returns registered_tests if successful, or +// aborts the program otherwise. +const char* TypedTestCasePState::VerifyRegisteredTestNames( + const char* file, int line, const char* registered_tests) { + typedef ::std::set::const_iterator DefinedTestIter; + registered_ = true; + + // Skip initial whitespace in registered_tests since some + // preprocessors prefix stringizied literals with whitespace. + registered_tests = SkipSpaces(registered_tests); + + Message errors; + ::std::set tests; + for (const char* names = registered_tests; names != NULL; + names = SkipComma(names)) { + const String name = GetPrefixUntilComma(names); + if (tests.count(name) != 0) { + errors << "Test " << name << " is listed more than once.\n"; + continue; + } + + bool found = false; + for (DefinedTestIter it = defined_test_names_.begin(); + it != defined_test_names_.end(); + ++it) { + if (name == *it) { + found = true; + break; + } + } + + if (found) { + tests.insert(name); + } else { + errors << "No test named " << name + << " can be found in this test case.\n"; + } + } + + for (DefinedTestIter it = defined_test_names_.begin(); + it != defined_test_names_.end(); + ++it) { + if (tests.count(*it) == 0) { + errors << "You forgot to list test " << *it << ".\n"; + } + } + + const String& errors_str = errors.GetString(); + if (errors_str != "") { + fprintf(stderr, "%s %s", FormatFileLocation(file, line).c_str(), + errors_str.c_str()); + fflush(stderr); + posix::Abort(); + } + + return registered_tests; +} + +#endif // GTEST_HAS_TYPED_TEST_P + +} // namespace internal +} // namespace testing diff --git a/gtest/gtest.h b/gtest/gtest.h new file mode 100644 index 00000000000..3143bd67996 --- /dev/null +++ b/gtest/gtest.h @@ -0,0 +1,19537 @@ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) +// +// The Google C++ Testing Framework (Google Test) +// +// This header file defines the public API for Google Test. It should be +// included by any test program that uses Google Test. +// +// IMPORTANT NOTE: Due to limitation of the C++ language, we have to +// leave some internal implementation details in this header file. +// They are clearly marked by comments like this: +// +// // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +// +// Such code is NOT meant to be used by a user directly, and is subject +// to CHANGE WITHOUT NOTICE. Therefore DO NOT DEPEND ON IT in a user +// program! +// +// Acknowledgment: Google Test borrowed the idea of automatic test +// registration from Barthelemy Dagenais' (barthelemy@prologique.com) +// easyUnit framework. + +#ifndef GTEST_INCLUDE_GTEST_GTEST_H_ +#define GTEST_INCLUDE_GTEST_GTEST_H_ + +#include +#include + +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Authors: wan@google.com (Zhanyong Wan), eefacm@gmail.com (Sean Mcafee) +// +// The Google C++ Testing Framework (Google Test) +// +// This header file declares functions and macros used internally by +// Google Test. They are subject to change without notice. + +#ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_INTERNAL_H_ +#define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_INTERNAL_H_ + +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Authors: wan@google.com (Zhanyong Wan) +// +// Low-level types and utilities for porting Google Test to various +// platforms. They are subject to change without notice. DO NOT USE +// THEM IN USER CODE. + +#ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PORT_H_ +#define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PORT_H_ + +// The user can define the following macros in the build script to +// control Google Test's behavior. If the user doesn't define a macro +// in this list, Google Test will define it. +// +// GTEST_HAS_CLONE - Define it to 1/0 to indicate that clone(2) +// is/isn't available. +// GTEST_HAS_EXCEPTIONS - Define it to 1/0 to indicate that exceptions +// are enabled. +// GTEST_HAS_GLOBAL_STRING - Define it to 1/0 to indicate that ::string +// is/isn't available (some systems define +// ::string, which is different to std::string). +// GTEST_HAS_GLOBAL_WSTRING - Define it to 1/0 to indicate that ::string +// is/isn't available (some systems define +// ::wstring, which is different to std::wstring). +// GTEST_HAS_POSIX_RE - Define it to 1/0 to indicate that POSIX regular +// expressions are/aren't available. +// GTEST_HAS_PTHREAD - Define it to 1/0 to indicate that +// is/isn't available. +// GTEST_HAS_RTTI - Define it to 1/0 to indicate that RTTI is/isn't +// enabled. +// GTEST_HAS_STD_WSTRING - Define it to 1/0 to indicate that +// std::wstring does/doesn't work (Google Test can +// be used where std::wstring is unavailable). +// GTEST_HAS_TR1_TUPLE - Define it to 1/0 to indicate tr1::tuple +// is/isn't available. +// GTEST_HAS_SEH - Define it to 1/0 to indicate whether the +// compiler supports Microsoft's "Structured +// Exception Handling". +// GTEST_HAS_STREAM_REDIRECTION +// - Define it to 1/0 to indicate whether the +// platform supports I/O stream redirection using +// dup() and dup2(). +// GTEST_USE_OWN_TR1_TUPLE - Define it to 1/0 to indicate whether Google +// Test's own tr1 tuple implementation should be +// used. Unused when the user sets +// GTEST_HAS_TR1_TUPLE to 0. +// GTEST_LINKED_AS_SHARED_LIBRARY +// - Define to 1 when compiling tests that use +// Google Test as a shared library (known as +// DLL on Windows). +// GTEST_CREATE_SHARED_LIBRARY +// - Define to 1 when compiling Google Test itself +// as a shared library. + +// This header defines the following utilities: +// +// Macros indicating the current platform (defined to 1 if compiled on +// the given platform; otherwise undefined): +// GTEST_OS_AIX - IBM AIX +// GTEST_OS_CYGWIN - Cygwin +// GTEST_OS_HPUX - HP-UX +// GTEST_OS_LINUX - Linux +// GTEST_OS_LINUX_ANDROID - Google Android +// GTEST_OS_MAC - Mac OS X +// GTEST_OS_NACL - Google Native Client (NaCl) +// GTEST_OS_SOLARIS - Sun Solaris +// GTEST_OS_SYMBIAN - Symbian +// GTEST_OS_WINDOWS - Windows (Desktop, MinGW, or Mobile) +// GTEST_OS_WINDOWS_DESKTOP - Windows Desktop +// GTEST_OS_WINDOWS_MINGW - MinGW +// GTEST_OS_WINDOWS_MOBILE - Windows Mobile +// GTEST_OS_ZOS - z/OS +// +// Among the platforms, Cygwin, Linux, Max OS X, and Windows have the +// most stable support. Since core members of the Google Test project +// don't have access to other platforms, support for them may be less +// stable. If you notice any problems on your platform, please notify +// googletestframework@googlegroups.com (patches for fixing them are +// even more welcome!). +// +// Note that it is possible that none of the GTEST_OS_* macros are defined. +// +// Macros indicating available Google Test features (defined to 1 if +// the corresponding feature is supported; otherwise undefined): +// GTEST_HAS_COMBINE - the Combine() function (for value-parameterized +// tests) +// GTEST_HAS_DEATH_TEST - death tests +// GTEST_HAS_PARAM_TEST - value-parameterized tests +// GTEST_HAS_TYPED_TEST - typed tests +// GTEST_HAS_TYPED_TEST_P - type-parameterized tests +// GTEST_USES_POSIX_RE - enhanced POSIX regex is used. Do not confuse with +// GTEST_HAS_POSIX_RE (see above) which users can +// define themselves. +// GTEST_USES_SIMPLE_RE - our own simple regex is used; +// the above two are mutually exclusive. +// GTEST_CAN_COMPARE_NULL - accepts untyped NULL in EXPECT_EQ(). +// +// Macros for basic C++ coding: +// GTEST_AMBIGUOUS_ELSE_BLOCKER_ - for disabling a gcc warning. +// GTEST_ATTRIBUTE_UNUSED_ - declares that a class' instances or a +// variable don't have to be used. +// GTEST_DISALLOW_ASSIGN_ - disables operator=. +// GTEST_DISALLOW_COPY_AND_ASSIGN_ - disables copy ctor and operator=. +// GTEST_MUST_USE_RESULT_ - declares that a function's result must be used. +// +// Synchronization: +// Mutex, MutexLock, ThreadLocal, GetThreadCount() +// - synchronization primitives. +// GTEST_IS_THREADSAFE - defined to 1 to indicate that the above +// synchronization primitives have real implementations +// and Google Test is thread-safe; or 0 otherwise. +// +// Template meta programming: +// is_pointer - as in TR1; needed on Symbian and IBM XL C/C++ only. +// IteratorTraits - partial implementation of std::iterator_traits, which +// is not available in libCstd when compiled with Sun C++. +// +// Smart pointers: +// scoped_ptr - as in TR2. +// +// Regular expressions: +// RE - a simple regular expression class using the POSIX +// Extended Regular Expression syntax on UNIX-like +// platforms, or a reduced regular exception syntax on +// other platforms, including Windows. +// +// Logging: +// GTEST_LOG_() - logs messages at the specified severity level. +// LogToStderr() - directs all log messages to stderr. +// FlushInfoLog() - flushes informational log messages. +// +// Stdout and stderr capturing: +// CaptureStdout() - starts capturing stdout. +// GetCapturedStdout() - stops capturing stdout and returns the captured +// string. +// CaptureStderr() - starts capturing stderr. +// GetCapturedStderr() - stops capturing stderr and returns the captured +// string. +// +// Integer types: +// TypeWithSize - maps an integer to a int type. +// Int32, UInt32, Int64, UInt64, TimeInMillis +// - integers of known sizes. +// BiggestInt - the biggest signed integer type. +// +// Command-line utilities: +// GTEST_FLAG() - references a flag. +// GTEST_DECLARE_*() - declares a flag. +// GTEST_DEFINE_*() - defines a flag. +// GetArgvs() - returns the command line as a vector of strings. +// +// Environment variable utilities: +// GetEnv() - gets the value of an environment variable. +// BoolFromGTestEnv() - parses a bool environment variable. +// Int32FromGTestEnv() - parses an Int32 environment variable. +// StringFromGTestEnv() - parses a string environment variable. + +#include // for isspace, etc +#include // for ptrdiff_t +#include +#include +#include +#ifndef _WIN32_WCE +# include +# include +#endif // !_WIN32_WCE + +#include // NOLINT +#include // NOLINT +#include // NOLINT + +#define GTEST_DEV_EMAIL_ "googletestframework@@googlegroups.com" +#define GTEST_FLAG_PREFIX_ "gtest_" +#define GTEST_FLAG_PREFIX_DASH_ "gtest-" +#define GTEST_FLAG_PREFIX_UPPER_ "GTEST_" +#define GTEST_NAME_ "Google Test" +#define GTEST_PROJECT_URL_ "http://code.google.com/p/googletest/" + +// Determines the version of gcc that is used to compile this. +#ifdef __GNUC__ +// 40302 means version 4.3.2. +# define GTEST_GCC_VER_ \ + (__GNUC__*10000 + __GNUC_MINOR__*100 + __GNUC_PATCHLEVEL__) +#endif // __GNUC__ + +// Determines the platform on which Google Test is compiled. +#ifdef __CYGWIN__ +# define GTEST_OS_CYGWIN 1 +#elif defined __SYMBIAN32__ +# define GTEST_OS_SYMBIAN 1 +#elif defined _WIN32 +# define GTEST_OS_WINDOWS 1 +# ifdef _WIN32_WCE +# define GTEST_OS_WINDOWS_MOBILE 1 +# elif defined(__MINGW__) || defined(__MINGW32__) +# define GTEST_OS_WINDOWS_MINGW 1 +# else +# define GTEST_OS_WINDOWS_DESKTOP 1 +# endif // _WIN32_WCE +#elif defined __APPLE__ +# define GTEST_OS_MAC 1 +#elif defined __linux__ +# define GTEST_OS_LINUX 1 +# ifdef ANDROID +# define GTEST_OS_LINUX_ANDROID 1 +# endif // ANDROID +#elif defined __MVS__ +# define GTEST_OS_ZOS 1 +#elif defined(__sun) && defined(__SVR4) +# define GTEST_OS_SOLARIS 1 +#elif defined(_AIX) +# define GTEST_OS_AIX 1 +#elif defined(__hpux) +# define GTEST_OS_HPUX 1 +#elif defined __native_client__ +# define GTEST_OS_NACL 1 +#endif // __CYGWIN__ + +// Brings in definitions for functions used in the testing::internal::posix +// namespace (read, write, close, chdir, isatty, stat). We do not currently +// use them on Windows Mobile. +#if !GTEST_OS_WINDOWS +// This assumes that non-Windows OSes provide unistd.h. For OSes where this +// is not the case, we need to include headers that provide the functions +// mentioned above. +# include +# if !GTEST_OS_NACL +// TODO(vladl@google.com): Remove this condition when Native Client SDK adds +// strings.h (tracked in +// http://code.google.com/p/nativeclient/issues/detail?id=1175). +# include // Native Client doesn't provide strings.h. +# endif +#elif !GTEST_OS_WINDOWS_MOBILE +# include +# include +#endif + +// Defines this to true iff Google Test can use POSIX regular expressions. +#ifndef GTEST_HAS_POSIX_RE +# define GTEST_HAS_POSIX_RE (!GTEST_OS_WINDOWS) +#endif + +#if GTEST_HAS_POSIX_RE + +// On some platforms, needs someone to define size_t, and +// won't compile otherwise. We can #include it here as we already +// included , which is guaranteed to define size_t through +// . +# include // NOLINT + +# define GTEST_USES_POSIX_RE 1 + +#elif GTEST_OS_WINDOWS + +// is not available on Windows. Use our own simple regex +// implementation instead. +# define GTEST_USES_SIMPLE_RE 1 + +#else + +// may not be available on this platform. Use our own +// simple regex implementation instead. +# define GTEST_USES_SIMPLE_RE 1 + +#endif // GTEST_HAS_POSIX_RE + +#ifndef GTEST_HAS_EXCEPTIONS +// The user didn't tell us whether exceptions are enabled, so we need +// to figure it out. +# if defined(_MSC_VER) || defined(__BORLANDC__) +// MSVC's and C++Builder's implementations of the STL use the _HAS_EXCEPTIONS +// macro to enable exceptions, so we'll do the same. +// Assumes that exceptions are enabled by default. +# ifndef _HAS_EXCEPTIONS +# define _HAS_EXCEPTIONS 1 +# endif // _HAS_EXCEPTIONS +# define GTEST_HAS_EXCEPTIONS _HAS_EXCEPTIONS +# elif defined(__GNUC__) && __EXCEPTIONS +// gcc defines __EXCEPTIONS to 1 iff exceptions are enabled. +# define GTEST_HAS_EXCEPTIONS 1 +# elif defined(__SUNPRO_CC) +// Sun Pro CC supports exceptions. However, there is no compile-time way of +// detecting whether they are enabled or not. Therefore, we assume that +// they are enabled unless the user tells us otherwise. +# define GTEST_HAS_EXCEPTIONS 1 +# elif defined(__IBMCPP__) && __EXCEPTIONS +// xlC defines __EXCEPTIONS to 1 iff exceptions are enabled. +# define GTEST_HAS_EXCEPTIONS 1 +# elif defined(__HP_aCC) +// Exception handling is in effect by default in HP aCC compiler. It has to +// be turned of by +noeh compiler option if desired. +# define GTEST_HAS_EXCEPTIONS 1 +# else +// For other compilers, we assume exceptions are disabled to be +// conservative. +# define GTEST_HAS_EXCEPTIONS 0 +# endif // defined(_MSC_VER) || defined(__BORLANDC__) +#endif // GTEST_HAS_EXCEPTIONS + +#if !defined(GTEST_HAS_STD_STRING) +// Even though we don't use this macro any longer, we keep it in case +// some clients still depend on it. +# define GTEST_HAS_STD_STRING 1 +#elif !GTEST_HAS_STD_STRING +// The user told us that ::std::string isn't available. +# error "Google Test cannot be used where ::std::string isn't available." +#endif // !defined(GTEST_HAS_STD_STRING) + +#ifndef GTEST_HAS_GLOBAL_STRING +// The user didn't tell us whether ::string is available, so we need +// to figure it out. + +# define GTEST_HAS_GLOBAL_STRING 0 + +#endif // GTEST_HAS_GLOBAL_STRING + +#ifndef GTEST_HAS_STD_WSTRING +// The user didn't tell us whether ::std::wstring is available, so we need +// to figure it out. +// TODO(wan@google.com): uses autoconf to detect whether ::std::wstring +// is available. + +// Cygwin 1.7 and below doesn't support ::std::wstring. +// Solaris' libc++ doesn't support it either. Android has +// no support for it at least as recent as Froyo (2.2). +# define GTEST_HAS_STD_WSTRING \ + (!(GTEST_OS_LINUX_ANDROID || GTEST_OS_CYGWIN || GTEST_OS_SOLARIS)) + +#endif // GTEST_HAS_STD_WSTRING + +#ifndef GTEST_HAS_GLOBAL_WSTRING +// The user didn't tell us whether ::wstring is available, so we need +// to figure it out. +# define GTEST_HAS_GLOBAL_WSTRING \ + (GTEST_HAS_STD_WSTRING && GTEST_HAS_GLOBAL_STRING) +#endif // GTEST_HAS_GLOBAL_WSTRING + +// Determines whether RTTI is available. +#ifndef GTEST_HAS_RTTI +// The user didn't tell us whether RTTI is enabled, so we need to +// figure it out. + +# ifdef _MSC_VER + +# ifdef _CPPRTTI // MSVC defines this macro iff RTTI is enabled. +# define GTEST_HAS_RTTI 1 +# else +# define GTEST_HAS_RTTI 0 +# endif + +// Starting with version 4.3.2, gcc defines __GXX_RTTI iff RTTI is enabled. +# elif defined(__GNUC__) && (GTEST_GCC_VER_ >= 40302) + +# ifdef __GXX_RTTI +# define GTEST_HAS_RTTI 1 +# else +# define GTEST_HAS_RTTI 0 +# endif // __GXX_RTTI + +// Starting with version 9.0 IBM Visual Age defines __RTTI_ALL__ to 1 if +// both the typeid and dynamic_cast features are present. +# elif defined(__IBMCPP__) && (__IBMCPP__ >= 900) + +# ifdef __RTTI_ALL__ +# define GTEST_HAS_RTTI 1 +# else +# define GTEST_HAS_RTTI 0 +# endif + +# else + +// For all other compilers, we assume RTTI is enabled. +# define GTEST_HAS_RTTI 1 + +# endif // _MSC_VER + +#endif // GTEST_HAS_RTTI + +// It's this header's responsibility to #include when RTTI +// is enabled. +#if GTEST_HAS_RTTI +# include +#endif + +// Determines whether Google Test can use the pthreads library. +#ifndef GTEST_HAS_PTHREAD +// The user didn't tell us explicitly, so we assume pthreads support is +// available on Linux and Mac. +// +// To disable threading support in Google Test, add -DGTEST_HAS_PTHREAD=0 +// to your compiler flags. +# define GTEST_HAS_PTHREAD (GTEST_OS_LINUX || GTEST_OS_MAC || GTEST_OS_HPUX) +#endif // GTEST_HAS_PTHREAD + +#if GTEST_HAS_PTHREAD +// gtest-port.h guarantees to #include when GTEST_HAS_PTHREAD is +// true. +# include // NOLINT + +// For timespec and nanosleep, used below. +# include // NOLINT +#endif + +// Determines whether Google Test can use tr1/tuple. You can define +// this macro to 0 to prevent Google Test from using tuple (any +// feature depending on tuple with be disabled in this mode). +#ifndef GTEST_HAS_TR1_TUPLE +// The user didn't tell us not to do it, so we assume it's OK. +# define GTEST_HAS_TR1_TUPLE 1 +#endif // GTEST_HAS_TR1_TUPLE + +// Determines whether Google Test's own tr1 tuple implementation +// should be used. +#ifndef GTEST_USE_OWN_TR1_TUPLE +// The user didn't tell us, so we need to figure it out. + +// We use our own TR1 tuple if we aren't sure the user has an +// implementation of it already. At this time, GCC 4.0.0+ and MSVC +// 2010 are the only mainstream compilers that come with a TR1 tuple +// implementation. NVIDIA's CUDA NVCC compiler pretends to be GCC by +// defining __GNUC__ and friends, but cannot compile GCC's tuple +// implementation. MSVC 2008 (9.0) provides TR1 tuple in a 323 MB +// Feature Pack download, which we cannot assume the user has. +# if (defined(__GNUC__) && !defined(__CUDACC__) && (GTEST_GCC_VER_ >= 40000)) \ + || _MSC_VER >= 1600 +# define GTEST_USE_OWN_TR1_TUPLE 0 +# else +# define GTEST_USE_OWN_TR1_TUPLE 1 +# endif + +#endif // GTEST_USE_OWN_TR1_TUPLE + +// To avoid conditional compilation everywhere, we make it +// gtest-port.h's responsibility to #include the header implementing +// tr1/tuple. +#if GTEST_HAS_TR1_TUPLE + +# if GTEST_USE_OWN_TR1_TUPLE +// This file was GENERATED by a script. DO NOT EDIT BY HAND!!! + +// Copyright 2009 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) + +// Implements a subset of TR1 tuple needed by Google Test and Google Mock. + +#ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_TUPLE_H_ +#define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_TUPLE_H_ + +#include // For ::std::pair. + +// The compiler used in Symbian has a bug that prevents us from declaring the +// tuple template as a friend (it complains that tuple is redefined). This +// hack bypasses the bug by declaring the members that should otherwise be +// private as public. +// Sun Studio versions < 12 also have the above bug. +#if defined(__SYMBIAN32__) || (defined(__SUNPRO_CC) && __SUNPRO_CC < 0x590) +# define GTEST_DECLARE_TUPLE_AS_FRIEND_ public: +#else +# define GTEST_DECLARE_TUPLE_AS_FRIEND_ \ + template friend class tuple; \ + private: +#endif + +// GTEST_n_TUPLE_(T) is the type of an n-tuple. +#define GTEST_0_TUPLE_(T) tuple<> +#define GTEST_1_TUPLE_(T) tuple +#define GTEST_2_TUPLE_(T) tuple +#define GTEST_3_TUPLE_(T) tuple +#define GTEST_4_TUPLE_(T) tuple +#define GTEST_5_TUPLE_(T) tuple +#define GTEST_6_TUPLE_(T) tuple +#define GTEST_7_TUPLE_(T) tuple +#define GTEST_8_TUPLE_(T) tuple +#define GTEST_9_TUPLE_(T) tuple +#define GTEST_10_TUPLE_(T) tuple + +// GTEST_n_TYPENAMES_(T) declares a list of n typenames. +#define GTEST_0_TYPENAMES_(T) +#define GTEST_1_TYPENAMES_(T) typename T##0 +#define GTEST_2_TYPENAMES_(T) typename T##0, typename T##1 +#define GTEST_3_TYPENAMES_(T) typename T##0, typename T##1, typename T##2 +#define GTEST_4_TYPENAMES_(T) typename T##0, typename T##1, typename T##2, \ + typename T##3 +#define GTEST_5_TYPENAMES_(T) typename T##0, typename T##1, typename T##2, \ + typename T##3, typename T##4 +#define GTEST_6_TYPENAMES_(T) typename T##0, typename T##1, typename T##2, \ + typename T##3, typename T##4, typename T##5 +#define GTEST_7_TYPENAMES_(T) typename T##0, typename T##1, typename T##2, \ + typename T##3, typename T##4, typename T##5, typename T##6 +#define GTEST_8_TYPENAMES_(T) typename T##0, typename T##1, typename T##2, \ + typename T##3, typename T##4, typename T##5, typename T##6, typename T##7 +#define GTEST_9_TYPENAMES_(T) typename T##0, typename T##1, typename T##2, \ + typename T##3, typename T##4, typename T##5, typename T##6, \ + typename T##7, typename T##8 +#define GTEST_10_TYPENAMES_(T) typename T##0, typename T##1, typename T##2, \ + typename T##3, typename T##4, typename T##5, typename T##6, \ + typename T##7, typename T##8, typename T##9 + +// In theory, defining stuff in the ::std namespace is undefined +// behavior. We can do this as we are playing the role of a standard +// library vendor. +namespace std { +namespace tr1 { + +template +class tuple; + +// Anything in namespace gtest_internal is Google Test's INTERNAL +// IMPLEMENTATION DETAIL and MUST NOT BE USED DIRECTLY in user code. +namespace gtest_internal { + +// ByRef::type is T if T is a reference; otherwise it's const T&. +template +struct ByRef { typedef const T& type; }; // NOLINT +template +struct ByRef { typedef T& type; }; // NOLINT + +// A handy wrapper for ByRef. +#define GTEST_BY_REF_(T) typename ::std::tr1::gtest_internal::ByRef::type + +// AddRef::type is T if T is a reference; otherwise it's T&. This +// is the same as tr1::add_reference::type. +template +struct AddRef { typedef T& type; }; // NOLINT +template +struct AddRef { typedef T& type; }; // NOLINT + +// A handy wrapper for AddRef. +#define GTEST_ADD_REF_(T) typename ::std::tr1::gtest_internal::AddRef::type + +// A helper for implementing get(). +template class Get; + +// A helper for implementing tuple_element. kIndexValid is true +// iff k < the number of fields in tuple type T. +template +struct TupleElement; + +template +struct TupleElement { typedef T0 type; }; + +template +struct TupleElement { typedef T1 type; }; + +template +struct TupleElement { typedef T2 type; }; + +template +struct TupleElement { typedef T3 type; }; + +template +struct TupleElement { typedef T4 type; }; + +template +struct TupleElement { typedef T5 type; }; + +template +struct TupleElement { typedef T6 type; }; + +template +struct TupleElement { typedef T7 type; }; + +template +struct TupleElement { typedef T8 type; }; + +template +struct TupleElement { typedef T9 type; }; + +} // namespace gtest_internal + +template <> +class tuple<> { + public: + tuple() {} + tuple(const tuple& /* t */) {} + tuple& operator=(const tuple& /* t */) { return *this; } +}; + +template +class GTEST_1_TUPLE_(T) { + public: + template friend class gtest_internal::Get; + + tuple() : f0_() {} + + explicit tuple(GTEST_BY_REF_(T0) f0) : f0_(f0) {} + + tuple(const tuple& t) : f0_(t.f0_) {} + + template + tuple(const GTEST_1_TUPLE_(U)& t) : f0_(t.f0_) {} + + tuple& operator=(const tuple& t) { return CopyFrom(t); } + + template + tuple& operator=(const GTEST_1_TUPLE_(U)& t) { + return CopyFrom(t); + } + + GTEST_DECLARE_TUPLE_AS_FRIEND_ + + template + tuple& CopyFrom(const GTEST_1_TUPLE_(U)& t) { + f0_ = t.f0_; + return *this; + } + + T0 f0_; +}; + +template +class GTEST_2_TUPLE_(T) { + public: + template friend class gtest_internal::Get; + + tuple() : f0_(), f1_() {} + + explicit tuple(GTEST_BY_REF_(T0) f0, GTEST_BY_REF_(T1) f1) : f0_(f0), + f1_(f1) {} + + tuple(const tuple& t) : f0_(t.f0_), f1_(t.f1_) {} + + template + tuple(const GTEST_2_TUPLE_(U)& t) : f0_(t.f0_), f1_(t.f1_) {} + template + tuple(const ::std::pair& p) : f0_(p.first), f1_(p.second) {} + + tuple& operator=(const tuple& t) { return CopyFrom(t); } + + template + tuple& operator=(const GTEST_2_TUPLE_(U)& t) { + return CopyFrom(t); + } + template + tuple& operator=(const ::std::pair& p) { + f0_ = p.first; + f1_ = p.second; + return *this; + } + + GTEST_DECLARE_TUPLE_AS_FRIEND_ + + template + tuple& CopyFrom(const GTEST_2_TUPLE_(U)& t) { + f0_ = t.f0_; + f1_ = t.f1_; + return *this; + } + + T0 f0_; + T1 f1_; +}; + +template +class GTEST_3_TUPLE_(T) { + public: + template friend class gtest_internal::Get; + + tuple() : f0_(), f1_(), f2_() {} + + explicit tuple(GTEST_BY_REF_(T0) f0, GTEST_BY_REF_(T1) f1, + GTEST_BY_REF_(T2) f2) : f0_(f0), f1_(f1), f2_(f2) {} + + tuple(const tuple& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_) {} + + template + tuple(const GTEST_3_TUPLE_(U)& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_) {} + + tuple& operator=(const tuple& t) { return CopyFrom(t); } + + template + tuple& operator=(const GTEST_3_TUPLE_(U)& t) { + return CopyFrom(t); + } + + GTEST_DECLARE_TUPLE_AS_FRIEND_ + + template + tuple& CopyFrom(const GTEST_3_TUPLE_(U)& t) { + f0_ = t.f0_; + f1_ = t.f1_; + f2_ = t.f2_; + return *this; + } + + T0 f0_; + T1 f1_; + T2 f2_; +}; + +template +class GTEST_4_TUPLE_(T) { + public: + template friend class gtest_internal::Get; + + tuple() : f0_(), f1_(), f2_(), f3_() {} + + explicit tuple(GTEST_BY_REF_(T0) f0, GTEST_BY_REF_(T1) f1, + GTEST_BY_REF_(T2) f2, GTEST_BY_REF_(T3) f3) : f0_(f0), f1_(f1), f2_(f2), + f3_(f3) {} + + tuple(const tuple& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), f3_(t.f3_) {} + + template + tuple(const GTEST_4_TUPLE_(U)& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), + f3_(t.f3_) {} + + tuple& operator=(const tuple& t) { return CopyFrom(t); } + + template + tuple& operator=(const GTEST_4_TUPLE_(U)& t) { + return CopyFrom(t); + } + + GTEST_DECLARE_TUPLE_AS_FRIEND_ + + template + tuple& CopyFrom(const GTEST_4_TUPLE_(U)& t) { + f0_ = t.f0_; + f1_ = t.f1_; + f2_ = t.f2_; + f3_ = t.f3_; + return *this; + } + + T0 f0_; + T1 f1_; + T2 f2_; + T3 f3_; +}; + +template +class GTEST_5_TUPLE_(T) { + public: + template friend class gtest_internal::Get; + + tuple() : f0_(), f1_(), f2_(), f3_(), f4_() {} + + explicit tuple(GTEST_BY_REF_(T0) f0, GTEST_BY_REF_(T1) f1, + GTEST_BY_REF_(T2) f2, GTEST_BY_REF_(T3) f3, + GTEST_BY_REF_(T4) f4) : f0_(f0), f1_(f1), f2_(f2), f3_(f3), f4_(f4) {} + + tuple(const tuple& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), f3_(t.f3_), + f4_(t.f4_) {} + + template + tuple(const GTEST_5_TUPLE_(U)& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), + f3_(t.f3_), f4_(t.f4_) {} + + tuple& operator=(const tuple& t) { return CopyFrom(t); } + + template + tuple& operator=(const GTEST_5_TUPLE_(U)& t) { + return CopyFrom(t); + } + + GTEST_DECLARE_TUPLE_AS_FRIEND_ + + template + tuple& CopyFrom(const GTEST_5_TUPLE_(U)& t) { + f0_ = t.f0_; + f1_ = t.f1_; + f2_ = t.f2_; + f3_ = t.f3_; + f4_ = t.f4_; + return *this; + } + + T0 f0_; + T1 f1_; + T2 f2_; + T3 f3_; + T4 f4_; +}; + +template +class GTEST_6_TUPLE_(T) { + public: + template friend class gtest_internal::Get; + + tuple() : f0_(), f1_(), f2_(), f3_(), f4_(), f5_() {} + + explicit tuple(GTEST_BY_REF_(T0) f0, GTEST_BY_REF_(T1) f1, + GTEST_BY_REF_(T2) f2, GTEST_BY_REF_(T3) f3, GTEST_BY_REF_(T4) f4, + GTEST_BY_REF_(T5) f5) : f0_(f0), f1_(f1), f2_(f2), f3_(f3), f4_(f4), + f5_(f5) {} + + tuple(const tuple& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), f3_(t.f3_), + f4_(t.f4_), f5_(t.f5_) {} + + template + tuple(const GTEST_6_TUPLE_(U)& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), + f3_(t.f3_), f4_(t.f4_), f5_(t.f5_) {} + + tuple& operator=(const tuple& t) { return CopyFrom(t); } + + template + tuple& operator=(const GTEST_6_TUPLE_(U)& t) { + return CopyFrom(t); + } + + GTEST_DECLARE_TUPLE_AS_FRIEND_ + + template + tuple& CopyFrom(const GTEST_6_TUPLE_(U)& t) { + f0_ = t.f0_; + f1_ = t.f1_; + f2_ = t.f2_; + f3_ = t.f3_; + f4_ = t.f4_; + f5_ = t.f5_; + return *this; + } + + T0 f0_; + T1 f1_; + T2 f2_; + T3 f3_; + T4 f4_; + T5 f5_; +}; + +template +class GTEST_7_TUPLE_(T) { + public: + template friend class gtest_internal::Get; + + tuple() : f0_(), f1_(), f2_(), f3_(), f4_(), f5_(), f6_() {} + + explicit tuple(GTEST_BY_REF_(T0) f0, GTEST_BY_REF_(T1) f1, + GTEST_BY_REF_(T2) f2, GTEST_BY_REF_(T3) f3, GTEST_BY_REF_(T4) f4, + GTEST_BY_REF_(T5) f5, GTEST_BY_REF_(T6) f6) : f0_(f0), f1_(f1), f2_(f2), + f3_(f3), f4_(f4), f5_(f5), f6_(f6) {} + + tuple(const tuple& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), f3_(t.f3_), + f4_(t.f4_), f5_(t.f5_), f6_(t.f6_) {} + + template + tuple(const GTEST_7_TUPLE_(U)& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), + f3_(t.f3_), f4_(t.f4_), f5_(t.f5_), f6_(t.f6_) {} + + tuple& operator=(const tuple& t) { return CopyFrom(t); } + + template + tuple& operator=(const GTEST_7_TUPLE_(U)& t) { + return CopyFrom(t); + } + + GTEST_DECLARE_TUPLE_AS_FRIEND_ + + template + tuple& CopyFrom(const GTEST_7_TUPLE_(U)& t) { + f0_ = t.f0_; + f1_ = t.f1_; + f2_ = t.f2_; + f3_ = t.f3_; + f4_ = t.f4_; + f5_ = t.f5_; + f6_ = t.f6_; + return *this; + } + + T0 f0_; + T1 f1_; + T2 f2_; + T3 f3_; + T4 f4_; + T5 f5_; + T6 f6_; +}; + +template +class GTEST_8_TUPLE_(T) { + public: + template friend class gtest_internal::Get; + + tuple() : f0_(), f1_(), f2_(), f3_(), f4_(), f5_(), f6_(), f7_() {} + + explicit tuple(GTEST_BY_REF_(T0) f0, GTEST_BY_REF_(T1) f1, + GTEST_BY_REF_(T2) f2, GTEST_BY_REF_(T3) f3, GTEST_BY_REF_(T4) f4, + GTEST_BY_REF_(T5) f5, GTEST_BY_REF_(T6) f6, + GTEST_BY_REF_(T7) f7) : f0_(f0), f1_(f1), f2_(f2), f3_(f3), f4_(f4), + f5_(f5), f6_(f6), f7_(f7) {} + + tuple(const tuple& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), f3_(t.f3_), + f4_(t.f4_), f5_(t.f5_), f6_(t.f6_), f7_(t.f7_) {} + + template + tuple(const GTEST_8_TUPLE_(U)& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), + f3_(t.f3_), f4_(t.f4_), f5_(t.f5_), f6_(t.f6_), f7_(t.f7_) {} + + tuple& operator=(const tuple& t) { return CopyFrom(t); } + + template + tuple& operator=(const GTEST_8_TUPLE_(U)& t) { + return CopyFrom(t); + } + + GTEST_DECLARE_TUPLE_AS_FRIEND_ + + template + tuple& CopyFrom(const GTEST_8_TUPLE_(U)& t) { + f0_ = t.f0_; + f1_ = t.f1_; + f2_ = t.f2_; + f3_ = t.f3_; + f4_ = t.f4_; + f5_ = t.f5_; + f6_ = t.f6_; + f7_ = t.f7_; + return *this; + } + + T0 f0_; + T1 f1_; + T2 f2_; + T3 f3_; + T4 f4_; + T5 f5_; + T6 f6_; + T7 f7_; +}; + +template +class GTEST_9_TUPLE_(T) { + public: + template friend class gtest_internal::Get; + + tuple() : f0_(), f1_(), f2_(), f3_(), f4_(), f5_(), f6_(), f7_(), f8_() {} + + explicit tuple(GTEST_BY_REF_(T0) f0, GTEST_BY_REF_(T1) f1, + GTEST_BY_REF_(T2) f2, GTEST_BY_REF_(T3) f3, GTEST_BY_REF_(T4) f4, + GTEST_BY_REF_(T5) f5, GTEST_BY_REF_(T6) f6, GTEST_BY_REF_(T7) f7, + GTEST_BY_REF_(T8) f8) : f0_(f0), f1_(f1), f2_(f2), f3_(f3), f4_(f4), + f5_(f5), f6_(f6), f7_(f7), f8_(f8) {} + + tuple(const tuple& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), f3_(t.f3_), + f4_(t.f4_), f5_(t.f5_), f6_(t.f6_), f7_(t.f7_), f8_(t.f8_) {} + + template + tuple(const GTEST_9_TUPLE_(U)& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), + f3_(t.f3_), f4_(t.f4_), f5_(t.f5_), f6_(t.f6_), f7_(t.f7_), f8_(t.f8_) {} + + tuple& operator=(const tuple& t) { return CopyFrom(t); } + + template + tuple& operator=(const GTEST_9_TUPLE_(U)& t) { + return CopyFrom(t); + } + + GTEST_DECLARE_TUPLE_AS_FRIEND_ + + template + tuple& CopyFrom(const GTEST_9_TUPLE_(U)& t) { + f0_ = t.f0_; + f1_ = t.f1_; + f2_ = t.f2_; + f3_ = t.f3_; + f4_ = t.f4_; + f5_ = t.f5_; + f6_ = t.f6_; + f7_ = t.f7_; + f8_ = t.f8_; + return *this; + } + + T0 f0_; + T1 f1_; + T2 f2_; + T3 f3_; + T4 f4_; + T5 f5_; + T6 f6_; + T7 f7_; + T8 f8_; +}; + +template +class tuple { + public: + template friend class gtest_internal::Get; + + tuple() : f0_(), f1_(), f2_(), f3_(), f4_(), f5_(), f6_(), f7_(), f8_(), + f9_() {} + + explicit tuple(GTEST_BY_REF_(T0) f0, GTEST_BY_REF_(T1) f1, + GTEST_BY_REF_(T2) f2, GTEST_BY_REF_(T3) f3, GTEST_BY_REF_(T4) f4, + GTEST_BY_REF_(T5) f5, GTEST_BY_REF_(T6) f6, GTEST_BY_REF_(T7) f7, + GTEST_BY_REF_(T8) f8, GTEST_BY_REF_(T9) f9) : f0_(f0), f1_(f1), f2_(f2), + f3_(f3), f4_(f4), f5_(f5), f6_(f6), f7_(f7), f8_(f8), f9_(f9) {} + + tuple(const tuple& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), f3_(t.f3_), + f4_(t.f4_), f5_(t.f5_), f6_(t.f6_), f7_(t.f7_), f8_(t.f8_), f9_(t.f9_) {} + + template + tuple(const GTEST_10_TUPLE_(U)& t) : f0_(t.f0_), f1_(t.f1_), f2_(t.f2_), + f3_(t.f3_), f4_(t.f4_), f5_(t.f5_), f6_(t.f6_), f7_(t.f7_), f8_(t.f8_), + f9_(t.f9_) {} + + tuple& operator=(const tuple& t) { return CopyFrom(t); } + + template + tuple& operator=(const GTEST_10_TUPLE_(U)& t) { + return CopyFrom(t); + } + + GTEST_DECLARE_TUPLE_AS_FRIEND_ + + template + tuple& CopyFrom(const GTEST_10_TUPLE_(U)& t) { + f0_ = t.f0_; + f1_ = t.f1_; + f2_ = t.f2_; + f3_ = t.f3_; + f4_ = t.f4_; + f5_ = t.f5_; + f6_ = t.f6_; + f7_ = t.f7_; + f8_ = t.f8_; + f9_ = t.f9_; + return *this; + } + + T0 f0_; + T1 f1_; + T2 f2_; + T3 f3_; + T4 f4_; + T5 f5_; + T6 f6_; + T7 f7_; + T8 f8_; + T9 f9_; +}; + +// 6.1.3.2 Tuple creation functions. + +// Known limitations: we don't support passing an +// std::tr1::reference_wrapper to make_tuple(). And we don't +// implement tie(). + +inline tuple<> make_tuple() { return tuple<>(); } + +template +inline GTEST_1_TUPLE_(T) make_tuple(const T0& f0) { + return GTEST_1_TUPLE_(T)(f0); +} + +template +inline GTEST_2_TUPLE_(T) make_tuple(const T0& f0, const T1& f1) { + return GTEST_2_TUPLE_(T)(f0, f1); +} + +template +inline GTEST_3_TUPLE_(T) make_tuple(const T0& f0, const T1& f1, const T2& f2) { + return GTEST_3_TUPLE_(T)(f0, f1, f2); +} + +template +inline GTEST_4_TUPLE_(T) make_tuple(const T0& f0, const T1& f1, const T2& f2, + const T3& f3) { + return GTEST_4_TUPLE_(T)(f0, f1, f2, f3); +} + +template +inline GTEST_5_TUPLE_(T) make_tuple(const T0& f0, const T1& f1, const T2& f2, + const T3& f3, const T4& f4) { + return GTEST_5_TUPLE_(T)(f0, f1, f2, f3, f4); +} + +template +inline GTEST_6_TUPLE_(T) make_tuple(const T0& f0, const T1& f1, const T2& f2, + const T3& f3, const T4& f4, const T5& f5) { + return GTEST_6_TUPLE_(T)(f0, f1, f2, f3, f4, f5); +} + +template +inline GTEST_7_TUPLE_(T) make_tuple(const T0& f0, const T1& f1, const T2& f2, + const T3& f3, const T4& f4, const T5& f5, const T6& f6) { + return GTEST_7_TUPLE_(T)(f0, f1, f2, f3, f4, f5, f6); +} + +template +inline GTEST_8_TUPLE_(T) make_tuple(const T0& f0, const T1& f1, const T2& f2, + const T3& f3, const T4& f4, const T5& f5, const T6& f6, const T7& f7) { + return GTEST_8_TUPLE_(T)(f0, f1, f2, f3, f4, f5, f6, f7); +} + +template +inline GTEST_9_TUPLE_(T) make_tuple(const T0& f0, const T1& f1, const T2& f2, + const T3& f3, const T4& f4, const T5& f5, const T6& f6, const T7& f7, + const T8& f8) { + return GTEST_9_TUPLE_(T)(f0, f1, f2, f3, f4, f5, f6, f7, f8); +} + +template +inline GTEST_10_TUPLE_(T) make_tuple(const T0& f0, const T1& f1, const T2& f2, + const T3& f3, const T4& f4, const T5& f5, const T6& f6, const T7& f7, + const T8& f8, const T9& f9) { + return GTEST_10_TUPLE_(T)(f0, f1, f2, f3, f4, f5, f6, f7, f8, f9); +} + +// 6.1.3.3 Tuple helper classes. + +template struct tuple_size; + +template +struct tuple_size { static const int value = 0; }; + +template +struct tuple_size { static const int value = 1; }; + +template +struct tuple_size { static const int value = 2; }; + +template +struct tuple_size { static const int value = 3; }; + +template +struct tuple_size { static const int value = 4; }; + +template +struct tuple_size { static const int value = 5; }; + +template +struct tuple_size { static const int value = 6; }; + +template +struct tuple_size { static const int value = 7; }; + +template +struct tuple_size { static const int value = 8; }; + +template +struct tuple_size { static const int value = 9; }; + +template +struct tuple_size { static const int value = 10; }; + +template +struct tuple_element { + typedef typename gtest_internal::TupleElement< + k < (tuple_size::value), k, Tuple>::type type; +}; + +#define GTEST_TUPLE_ELEMENT_(k, Tuple) typename tuple_element::type + +// 6.1.3.4 Element access. + +namespace gtest_internal { + +template <> +class Get<0> { + public: + template + static GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(0, Tuple)) + Field(Tuple& t) { return t.f0_; } // NOLINT + + template + static GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(0, Tuple)) + ConstField(const Tuple& t) { return t.f0_; } +}; + +template <> +class Get<1> { + public: + template + static GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(1, Tuple)) + Field(Tuple& t) { return t.f1_; } // NOLINT + + template + static GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(1, Tuple)) + ConstField(const Tuple& t) { return t.f1_; } +}; + +template <> +class Get<2> { + public: + template + static GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(2, Tuple)) + Field(Tuple& t) { return t.f2_; } // NOLINT + + template + static GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(2, Tuple)) + ConstField(const Tuple& t) { return t.f2_; } +}; + +template <> +class Get<3> { + public: + template + static GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(3, Tuple)) + Field(Tuple& t) { return t.f3_; } // NOLINT + + template + static GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(3, Tuple)) + ConstField(const Tuple& t) { return t.f3_; } +}; + +template <> +class Get<4> { + public: + template + static GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(4, Tuple)) + Field(Tuple& t) { return t.f4_; } // NOLINT + + template + static GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(4, Tuple)) + ConstField(const Tuple& t) { return t.f4_; } +}; + +template <> +class Get<5> { + public: + template + static GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(5, Tuple)) + Field(Tuple& t) { return t.f5_; } // NOLINT + + template + static GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(5, Tuple)) + ConstField(const Tuple& t) { return t.f5_; } +}; + +template <> +class Get<6> { + public: + template + static GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(6, Tuple)) + Field(Tuple& t) { return t.f6_; } // NOLINT + + template + static GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(6, Tuple)) + ConstField(const Tuple& t) { return t.f6_; } +}; + +template <> +class Get<7> { + public: + template + static GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(7, Tuple)) + Field(Tuple& t) { return t.f7_; } // NOLINT + + template + static GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(7, Tuple)) + ConstField(const Tuple& t) { return t.f7_; } +}; + +template <> +class Get<8> { + public: + template + static GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(8, Tuple)) + Field(Tuple& t) { return t.f8_; } // NOLINT + + template + static GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(8, Tuple)) + ConstField(const Tuple& t) { return t.f8_; } +}; + +template <> +class Get<9> { + public: + template + static GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(9, Tuple)) + Field(Tuple& t) { return t.f9_; } // NOLINT + + template + static GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(9, Tuple)) + ConstField(const Tuple& t) { return t.f9_; } +}; + +} // namespace gtest_internal + +template +GTEST_ADD_REF_(GTEST_TUPLE_ELEMENT_(k, GTEST_10_TUPLE_(T))) +get(GTEST_10_TUPLE_(T)& t) { + return gtest_internal::Get::Field(t); +} + +template +GTEST_BY_REF_(GTEST_TUPLE_ELEMENT_(k, GTEST_10_TUPLE_(T))) +get(const GTEST_10_TUPLE_(T)& t) { + return gtest_internal::Get::ConstField(t); +} + +// 6.1.3.5 Relational operators + +// We only implement == and !=, as we don't have a need for the rest yet. + +namespace gtest_internal { + +// SameSizeTuplePrefixComparator::Eq(t1, t2) returns true if the +// first k fields of t1 equals the first k fields of t2. +// SameSizeTuplePrefixComparator(k1, k2) would be a compiler error if +// k1 != k2. +template +struct SameSizeTuplePrefixComparator; + +template <> +struct SameSizeTuplePrefixComparator<0, 0> { + template + static bool Eq(const Tuple1& /* t1 */, const Tuple2& /* t2 */) { + return true; + } +}; + +template +struct SameSizeTuplePrefixComparator { + template + static bool Eq(const Tuple1& t1, const Tuple2& t2) { + return SameSizeTuplePrefixComparator::Eq(t1, t2) && + ::std::tr1::get(t1) == ::std::tr1::get(t2); + } +}; + +} // namespace gtest_internal + +template +inline bool operator==(const GTEST_10_TUPLE_(T)& t, + const GTEST_10_TUPLE_(U)& u) { + return gtest_internal::SameSizeTuplePrefixComparator< + tuple_size::value, + tuple_size::value>::Eq(t, u); +} + +template +inline bool operator!=(const GTEST_10_TUPLE_(T)& t, + const GTEST_10_TUPLE_(U)& u) { return !(t == u); } + +// 6.1.4 Pairs. +// Unimplemented. + +} // namespace tr1 +} // namespace std + +#undef GTEST_0_TUPLE_ +#undef GTEST_1_TUPLE_ +#undef GTEST_2_TUPLE_ +#undef GTEST_3_TUPLE_ +#undef GTEST_4_TUPLE_ +#undef GTEST_5_TUPLE_ +#undef GTEST_6_TUPLE_ +#undef GTEST_7_TUPLE_ +#undef GTEST_8_TUPLE_ +#undef GTEST_9_TUPLE_ +#undef GTEST_10_TUPLE_ + +#undef GTEST_0_TYPENAMES_ +#undef GTEST_1_TYPENAMES_ +#undef GTEST_2_TYPENAMES_ +#undef GTEST_3_TYPENAMES_ +#undef GTEST_4_TYPENAMES_ +#undef GTEST_5_TYPENAMES_ +#undef GTEST_6_TYPENAMES_ +#undef GTEST_7_TYPENAMES_ +#undef GTEST_8_TYPENAMES_ +#undef GTEST_9_TYPENAMES_ +#undef GTEST_10_TYPENAMES_ + +#undef GTEST_DECLARE_TUPLE_AS_FRIEND_ +#undef GTEST_BY_REF_ +#undef GTEST_ADD_REF_ +#undef GTEST_TUPLE_ELEMENT_ + +#endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_TUPLE_H_ +# elif GTEST_OS_SYMBIAN + +// On Symbian, BOOST_HAS_TR1_TUPLE causes Boost's TR1 tuple library to +// use STLport's tuple implementation, which unfortunately doesn't +// work as the copy of STLport distributed with Symbian is incomplete. +// By making sure BOOST_HAS_TR1_TUPLE is undefined, we force Boost to +// use its own tuple implementation. +# ifdef BOOST_HAS_TR1_TUPLE +# undef BOOST_HAS_TR1_TUPLE +# endif // BOOST_HAS_TR1_TUPLE + +// This prevents , which defines +// BOOST_HAS_TR1_TUPLE, from being #included by Boost's . +# define BOOST_TR1_DETAIL_CONFIG_HPP_INCLUDED +# include + +# elif defined(__GNUC__) && (GTEST_GCC_VER_ >= 40000) +// GCC 4.0+ implements tr1/tuple in the header. This does +// not conform to the TR1 spec, which requires the header to be . + +# if !GTEST_HAS_RTTI && GTEST_GCC_VER_ < 40302 +// Until version 4.3.2, gcc has a bug that causes , +// which is #included by , to not compile when RTTI is +// disabled. _TR1_FUNCTIONAL is the header guard for +// . Hence the following #define is a hack to prevent +// from being included. +# define _TR1_FUNCTIONAL 1 +# include +# undef _TR1_FUNCTIONAL // Allows the user to #include + // if he chooses to. +# else +# include // NOLINT +# endif // !GTEST_HAS_RTTI && GTEST_GCC_VER_ < 40302 + +# else +// If the compiler is not GCC 4.0+, we assume the user is using a +// spec-conforming TR1 implementation. +# include // NOLINT +# endif // GTEST_USE_OWN_TR1_TUPLE + +#endif // GTEST_HAS_TR1_TUPLE + +// Determines whether clone(2) is supported. +// Usually it will only be available on Linux, excluding +// Linux on the Itanium architecture. +// Also see http://linux.die.net/man/2/clone. +#ifndef GTEST_HAS_CLONE +// The user didn't tell us, so we need to figure it out. + +# if GTEST_OS_LINUX && !defined(__ia64__) +# define GTEST_HAS_CLONE 1 +# else +# define GTEST_HAS_CLONE 0 +# endif // GTEST_OS_LINUX && !defined(__ia64__) + +#endif // GTEST_HAS_CLONE + +// Determines whether to support stream redirection. This is used to test +// output correctness and to implement death tests. +#ifndef GTEST_HAS_STREAM_REDIRECTION +// By default, we assume that stream redirection is supported on all +// platforms except known mobile ones. +# if GTEST_OS_WINDOWS_MOBILE || GTEST_OS_SYMBIAN +# define GTEST_HAS_STREAM_REDIRECTION 0 +# else +# define GTEST_HAS_STREAM_REDIRECTION 1 +# endif // !GTEST_OS_WINDOWS_MOBILE && !GTEST_OS_SYMBIAN +#endif // GTEST_HAS_STREAM_REDIRECTION + +// Determines whether to support death tests. +// Google Test does not support death tests for VC 7.1 and earlier as +// abort() in a VC 7.1 application compiled as GUI in debug config +// pops up a dialog window that cannot be suppressed programmatically. +#if (GTEST_OS_LINUX || GTEST_OS_MAC || GTEST_OS_CYGWIN || GTEST_OS_SOLARIS || \ + (GTEST_OS_WINDOWS_DESKTOP && _MSC_VER >= 1400) || \ + GTEST_OS_WINDOWS_MINGW || GTEST_OS_AIX || GTEST_OS_HPUX) +# define GTEST_HAS_DEATH_TEST 1 +# include // NOLINT +#endif + +// We don't support MSVC 7.1 with exceptions disabled now. Therefore +// all the compilers we care about are adequate for supporting +// value-parameterized tests. +#define GTEST_HAS_PARAM_TEST 1 + +// Determines whether to support type-driven tests. + +// Typed tests need and variadic macros, which GCC, VC++ 8.0, +// Sun Pro CC, IBM Visual Age, and HP aCC support. +#if defined(__GNUC__) || (_MSC_VER >= 1400) || defined(__SUNPRO_CC) || \ + defined(__IBMCPP__) || defined(__HP_aCC) +# define GTEST_HAS_TYPED_TEST 1 +# define GTEST_HAS_TYPED_TEST_P 1 +#endif + +// Determines whether to support Combine(). This only makes sense when +// value-parameterized tests are enabled. The implementation doesn't +// work on Sun Studio since it doesn't understand templated conversion +// operators. +#if GTEST_HAS_PARAM_TEST && GTEST_HAS_TR1_TUPLE && !defined(__SUNPRO_CC) +# define GTEST_HAS_COMBINE 1 +#endif + +// Determines whether the system compiler uses UTF-16 for encoding wide strings. +#define GTEST_WIDE_STRING_USES_UTF16_ \ + (GTEST_OS_WINDOWS || GTEST_OS_CYGWIN || GTEST_OS_SYMBIAN || GTEST_OS_AIX) + +// Determines whether test results can be streamed to a socket. +#if GTEST_OS_LINUX +# define GTEST_CAN_STREAM_RESULTS_ 1 +#endif + +// Defines some utility macros. + +// The GNU compiler emits a warning if nested "if" statements are followed by +// an "else" statement and braces are not used to explicitly disambiguate the +// "else" binding. This leads to problems with code like: +// +// if (gate) +// ASSERT_*(condition) << "Some message"; +// +// The "switch (0) case 0:" idiom is used to suppress this. +#ifdef __INTEL_COMPILER +# define GTEST_AMBIGUOUS_ELSE_BLOCKER_ +#else +# define GTEST_AMBIGUOUS_ELSE_BLOCKER_ switch (0) case 0: default: // NOLINT +#endif + +// Use this annotation at the end of a struct/class definition to +// prevent the compiler from optimizing away instances that are never +// used. This is useful when all interesting logic happens inside the +// c'tor and / or d'tor. Example: +// +// struct Foo { +// Foo() { ... } +// } GTEST_ATTRIBUTE_UNUSED_; +// +// Also use it after a variable or parameter declaration to tell the +// compiler the variable/parameter does not have to be used. +#if defined(__GNUC__) && !defined(COMPILER_ICC) +# define GTEST_ATTRIBUTE_UNUSED_ __attribute__ ((unused)) +#else +# define GTEST_ATTRIBUTE_UNUSED_ +#endif + +// A macro to disallow operator= +// This should be used in the private: declarations for a class. +#define GTEST_DISALLOW_ASSIGN_(type)\ + void operator=(type const &) + +// A macro to disallow copy constructor and operator= +// This should be used in the private: declarations for a class. +#define GTEST_DISALLOW_COPY_AND_ASSIGN_(type)\ + type(type const &);\ + GTEST_DISALLOW_ASSIGN_(type) + +// Tell the compiler to warn about unused return values for functions declared +// with this macro. The macro should be used on function declarations +// following the argument list: +// +// Sprocket* AllocateSprocket() GTEST_MUST_USE_RESULT_; +#if defined(__GNUC__) && (GTEST_GCC_VER_ >= 30400) && !defined(COMPILER_ICC) +# define GTEST_MUST_USE_RESULT_ __attribute__ ((warn_unused_result)) +#else +# define GTEST_MUST_USE_RESULT_ +#endif // __GNUC__ && (GTEST_GCC_VER_ >= 30400) && !COMPILER_ICC + +// Determine whether the compiler supports Microsoft's Structured Exception +// Handling. This is supported by several Windows compilers but generally +// does not exist on any other system. +#ifndef GTEST_HAS_SEH +// The user didn't tell us, so we need to figure it out. + +# if defined(_MSC_VER) || defined(__BORLANDC__) +// These two compilers are known to support SEH. +# define GTEST_HAS_SEH 1 +# else +// Assume no SEH. +# define GTEST_HAS_SEH 0 +# endif + +#endif // GTEST_HAS_SEH + +#ifdef _MSC_VER + +# if GTEST_LINKED_AS_SHARED_LIBRARY +# define GTEST_API_ __declspec(dllimport) +# elif GTEST_CREATE_SHARED_LIBRARY +# define GTEST_API_ __declspec(dllexport) +# endif + +#endif // _MSC_VER + +#ifndef GTEST_API_ +# define GTEST_API_ +#endif + +#ifdef __GNUC__ +// Ask the compiler to never inline a given function. +# define GTEST_NO_INLINE_ __attribute__((noinline)) +#else +# define GTEST_NO_INLINE_ +#endif + +namespace testing { + +class Message; + +namespace internal { + +class String; + +// The GTEST_COMPILE_ASSERT_ macro can be used to verify that a compile time +// expression is true. For example, you could use it to verify the +// size of a static array: +// +// GTEST_COMPILE_ASSERT_(ARRAYSIZE(content_type_names) == CONTENT_NUM_TYPES, +// content_type_names_incorrect_size); +// +// or to make sure a struct is smaller than a certain size: +// +// GTEST_COMPILE_ASSERT_(sizeof(foo) < 128, foo_too_large); +// +// The second argument to the macro is the name of the variable. If +// the expression is false, most compilers will issue a warning/error +// containing the name of the variable. + +template +struct CompileAssert { +}; + +#define GTEST_COMPILE_ASSERT_(expr, msg) \ + typedef ::testing::internal::CompileAssert<(bool(expr))> \ + msg[bool(expr) ? 1 : -1] + +// Implementation details of GTEST_COMPILE_ASSERT_: +// +// - GTEST_COMPILE_ASSERT_ works by defining an array type that has -1 +// elements (and thus is invalid) when the expression is false. +// +// - The simpler definition +// +// #define GTEST_COMPILE_ASSERT_(expr, msg) typedef char msg[(expr) ? 1 : -1] +// +// does not work, as gcc supports variable-length arrays whose sizes +// are determined at run-time (this is gcc's extension and not part +// of the C++ standard). As a result, gcc fails to reject the +// following code with the simple definition: +// +// int foo; +// GTEST_COMPILE_ASSERT_(foo, msg); // not supposed to compile as foo is +// // not a compile-time constant. +// +// - By using the type CompileAssert<(bool(expr))>, we ensures that +// expr is a compile-time constant. (Template arguments must be +// determined at compile-time.) +// +// - The outter parentheses in CompileAssert<(bool(expr))> are necessary +// to work around a bug in gcc 3.4.4 and 4.0.1. If we had written +// +// CompileAssert +// +// instead, these compilers will refuse to compile +// +// GTEST_COMPILE_ASSERT_(5 > 0, some_message); +// +// (They seem to think the ">" in "5 > 0" marks the end of the +// template argument list.) +// +// - The array size is (bool(expr) ? 1 : -1), instead of simply +// +// ((expr) ? 1 : -1). +// +// This is to avoid running into a bug in MS VC 7.1, which +// causes ((0.0) ? 1 : -1) to incorrectly evaluate to 1. + +// StaticAssertTypeEqHelper is used by StaticAssertTypeEq defined in gtest.h. +// +// This template is declared, but intentionally undefined. +template +struct StaticAssertTypeEqHelper; + +template +struct StaticAssertTypeEqHelper {}; + +#if GTEST_HAS_GLOBAL_STRING +typedef ::string string; +#else +typedef ::std::string string; +#endif // GTEST_HAS_GLOBAL_STRING + +#if GTEST_HAS_GLOBAL_WSTRING +typedef ::wstring wstring; +#elif GTEST_HAS_STD_WSTRING +typedef ::std::wstring wstring; +#endif // GTEST_HAS_GLOBAL_WSTRING + +// A helper for suppressing warnings on constant condition. It just +// returns 'condition'. +GTEST_API_ bool IsTrue(bool condition); + +// Defines scoped_ptr. + +// This implementation of scoped_ptr is PARTIAL - it only contains +// enough stuff to satisfy Google Test's need. +template +class scoped_ptr { + public: + typedef T element_type; + + explicit scoped_ptr(T* p = NULL) : ptr_(p) {} + ~scoped_ptr() { reset(); } + + T& operator*() const { return *ptr_; } + T* operator->() const { return ptr_; } + T* get() const { return ptr_; } + + T* release() { + T* const ptr = ptr_; + ptr_ = NULL; + return ptr; + } + + void reset(T* p = NULL) { + if (p != ptr_) { + if (IsTrue(sizeof(T) > 0)) { // Makes sure T is a complete type. + delete ptr_; + } + ptr_ = p; + } + } + private: + T* ptr_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(scoped_ptr); +}; + +// Defines RE. + +// A simple C++ wrapper for . It uses the POSIX Extended +// Regular Expression syntax. +class GTEST_API_ RE { + public: + // A copy constructor is required by the Standard to initialize object + // references from r-values. + RE(const RE& other) { Init(other.pattern()); } + + // Constructs an RE from a string. + RE(const ::std::string& regex) { Init(regex.c_str()); } // NOLINT + +#if GTEST_HAS_GLOBAL_STRING + + RE(const ::string& regex) { Init(regex.c_str()); } // NOLINT + +#endif // GTEST_HAS_GLOBAL_STRING + + RE(const char* regex) { Init(regex); } // NOLINT + ~RE(); + + // Returns the string representation of the regex. + const char* pattern() const { return pattern_; } + + // FullMatch(str, re) returns true iff regular expression re matches + // the entire str. + // PartialMatch(str, re) returns true iff regular expression re + // matches a substring of str (including str itself). + // + // TODO(wan@google.com): make FullMatch() and PartialMatch() work + // when str contains NUL characters. + static bool FullMatch(const ::std::string& str, const RE& re) { + return FullMatch(str.c_str(), re); + } + static bool PartialMatch(const ::std::string& str, const RE& re) { + return PartialMatch(str.c_str(), re); + } + +#if GTEST_HAS_GLOBAL_STRING + + static bool FullMatch(const ::string& str, const RE& re) { + return FullMatch(str.c_str(), re); + } + static bool PartialMatch(const ::string& str, const RE& re) { + return PartialMatch(str.c_str(), re); + } + +#endif // GTEST_HAS_GLOBAL_STRING + + static bool FullMatch(const char* str, const RE& re); + static bool PartialMatch(const char* str, const RE& re); + + private: + void Init(const char* regex); + + // We use a const char* instead of a string, as Google Test may be used + // where string is not available. We also do not use Google Test's own + // String type here, in order to simplify dependencies between the + // files. + const char* pattern_; + bool is_valid_; + +#if GTEST_USES_POSIX_RE + + regex_t full_regex_; // For FullMatch(). + regex_t partial_regex_; // For PartialMatch(). + +#else // GTEST_USES_SIMPLE_RE + + const char* full_pattern_; // For FullMatch(); + +#endif + + GTEST_DISALLOW_ASSIGN_(RE); +}; + +// Formats a source file path and a line number as they would appear +// in an error message from the compiler used to compile this code. +GTEST_API_ ::std::string FormatFileLocation(const char* file, int line); + +// Formats a file location for compiler-independent XML output. +// Although this function is not platform dependent, we put it next to +// FormatFileLocation in order to contrast the two functions. +GTEST_API_ ::std::string FormatCompilerIndependentFileLocation(const char* file, + int line); + +// Defines logging utilities: +// GTEST_LOG_(severity) - logs messages at the specified severity level. The +// message itself is streamed into the macro. +// LogToStderr() - directs all log messages to stderr. +// FlushInfoLog() - flushes informational log messages. + +enum GTestLogSeverity { + GTEST_INFO, + GTEST_WARNING, + GTEST_ERROR, + GTEST_FATAL +}; + +// Formats log entry severity, provides a stream object for streaming the +// log message, and terminates the message with a newline when going out of +// scope. +class GTEST_API_ GTestLog { + public: + GTestLog(GTestLogSeverity severity, const char* file, int line); + + // Flushes the buffers and, if severity is GTEST_FATAL, aborts the program. + ~GTestLog(); + + ::std::ostream& GetStream() { return ::std::cerr; } + + private: + const GTestLogSeverity severity_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(GTestLog); +}; + +#define GTEST_LOG_(severity) \ + ::testing::internal::GTestLog(::testing::internal::GTEST_##severity, \ + __FILE__, __LINE__).GetStream() + +inline void LogToStderr() {} +inline void FlushInfoLog() { fflush(NULL); } + +// INTERNAL IMPLEMENTATION - DO NOT USE. +// +// GTEST_CHECK_ is an all-mode assert. It aborts the program if the condition +// is not satisfied. +// Synopsys: +// GTEST_CHECK_(boolean_condition); +// or +// GTEST_CHECK_(boolean_condition) << "Additional message"; +// +// This checks the condition and if the condition is not satisfied +// it prints message about the condition violation, including the +// condition itself, plus additional message streamed into it, if any, +// and then it aborts the program. It aborts the program irrespective of +// whether it is built in the debug mode or not. +#define GTEST_CHECK_(condition) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::IsTrue(condition)) \ + ; \ + else \ + GTEST_LOG_(FATAL) << "Condition " #condition " failed. " + +// An all-mode assert to verify that the given POSIX-style function +// call returns 0 (indicating success). Known limitation: this +// doesn't expand to a balanced 'if' statement, so enclose the macro +// in {} if you need to use it as the only statement in an 'if' +// branch. +#define GTEST_CHECK_POSIX_SUCCESS_(posix_call) \ + if (const int gtest_error = (posix_call)) \ + GTEST_LOG_(FATAL) << #posix_call << "failed with error " \ + << gtest_error + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Use ImplicitCast_ as a safe version of static_cast for upcasting in +// the type hierarchy (e.g. casting a Foo* to a SuperclassOfFoo* or a +// const Foo*). When you use ImplicitCast_, the compiler checks that +// the cast is safe. Such explicit ImplicitCast_s are necessary in +// surprisingly many situations where C++ demands an exact type match +// instead of an argument type convertable to a target type. +// +// The syntax for using ImplicitCast_ is the same as for static_cast: +// +// ImplicitCast_(expr) +// +// ImplicitCast_ would have been part of the C++ standard library, +// but the proposal was submitted too late. It will probably make +// its way into the language in the future. +// +// This relatively ugly name is intentional. It prevents clashes with +// similar functions users may have (e.g., implicit_cast). The internal +// namespace alone is not enough because the function can be found by ADL. +template +inline To ImplicitCast_(To x) { return x; } + +// When you upcast (that is, cast a pointer from type Foo to type +// SuperclassOfFoo), it's fine to use ImplicitCast_<>, since upcasts +// always succeed. When you downcast (that is, cast a pointer from +// type Foo to type SubclassOfFoo), static_cast<> isn't safe, because +// how do you know the pointer is really of type SubclassOfFoo? It +// could be a bare Foo, or of type DifferentSubclassOfFoo. Thus, +// when you downcast, you should use this macro. In debug mode, we +// use dynamic_cast<> to double-check the downcast is legal (we die +// if it's not). In normal mode, we do the efficient static_cast<> +// instead. Thus, it's important to test in debug mode to make sure +// the cast is legal! +// This is the only place in the code we should use dynamic_cast<>. +// In particular, you SHOULDN'T be using dynamic_cast<> in order to +// do RTTI (eg code like this: +// if (dynamic_cast(foo)) HandleASubclass1Object(foo); +// if (dynamic_cast(foo)) HandleASubclass2Object(foo); +// You should design the code some other way not to need this. +// +// This relatively ugly name is intentional. It prevents clashes with +// similar functions users may have (e.g., down_cast). The internal +// namespace alone is not enough because the function can be found by ADL. +template // use like this: DownCast_(foo); +inline To DownCast_(From* f) { // so we only accept pointers + // Ensures that To is a sub-type of From *. This test is here only + // for compile-time type checking, and has no overhead in an + // optimized build at run-time, as it will be optimized away + // completely. + if (false) { + const To to = NULL; + ::testing::internal::ImplicitCast_(to); + } + +#if GTEST_HAS_RTTI + // RTTI: debug mode only! + GTEST_CHECK_(f == NULL || dynamic_cast(f) != NULL); +#endif + return static_cast(f); +} + +// Downcasts the pointer of type Base to Derived. +// Derived must be a subclass of Base. The parameter MUST +// point to a class of type Derived, not any subclass of it. +// When RTTI is available, the function performs a runtime +// check to enforce this. +template +Derived* CheckedDowncastToActualType(Base* base) { +#if GTEST_HAS_RTTI + GTEST_CHECK_(typeid(*base) == typeid(Derived)); + return dynamic_cast(base); // NOLINT +#else + return static_cast(base); // Poor man's downcast. +#endif +} + +#if GTEST_HAS_STREAM_REDIRECTION + +// Defines the stderr capturer: +// CaptureStdout - starts capturing stdout. +// GetCapturedStdout - stops capturing stdout and returns the captured string. +// CaptureStderr - starts capturing stderr. +// GetCapturedStderr - stops capturing stderr and returns the captured string. +// +GTEST_API_ void CaptureStdout(); +GTEST_API_ String GetCapturedStdout(); +GTEST_API_ void CaptureStderr(); +GTEST_API_ String GetCapturedStderr(); + +#endif // GTEST_HAS_STREAM_REDIRECTION + + +#if GTEST_HAS_DEATH_TEST + +// A copy of all command line arguments. Set by InitGoogleTest(). +extern ::std::vector g_argvs; + +// GTEST_HAS_DEATH_TEST implies we have ::std::string. +const ::std::vector& GetArgvs(); + +#endif // GTEST_HAS_DEATH_TEST + +// Defines synchronization primitives. + +#if GTEST_HAS_PTHREAD + +// Sleeps for (roughly) n milli-seconds. This function is only for +// testing Google Test's own constructs. Don't use it in user tests, +// either directly or indirectly. +inline void SleepMilliseconds(int n) { + const timespec time = { + 0, // 0 seconds. + n * 1000L * 1000L, // And n ms. + }; + nanosleep(&time, NULL); +} + +// Allows a controller thread to pause execution of newly created +// threads until notified. Instances of this class must be created +// and destroyed in the controller thread. +// +// This class is only for testing Google Test's own constructs. Do not +// use it in user tests, either directly or indirectly. +class Notification { + public: + Notification() : notified_(false) {} + + // Notifies all threads created with this notification to start. Must + // be called from the controller thread. + void Notify() { notified_ = true; } + + // Blocks until the controller thread notifies. Must be called from a test + // thread. + void WaitForNotification() { + while(!notified_) { + SleepMilliseconds(10); + } + } + + private: + volatile bool notified_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(Notification); +}; + +// As a C-function, ThreadFuncWithCLinkage cannot be templated itself. +// Consequently, it cannot select a correct instantiation of ThreadWithParam +// in order to call its Run(). Introducing ThreadWithParamBase as a +// non-templated base class for ThreadWithParam allows us to bypass this +// problem. +class ThreadWithParamBase { + public: + virtual ~ThreadWithParamBase() {} + virtual void Run() = 0; +}; + +// pthread_create() accepts a pointer to a function type with the C linkage. +// According to the Standard (7.5/1), function types with different linkages +// are different even if they are otherwise identical. Some compilers (for +// example, SunStudio) treat them as different types. Since class methods +// cannot be defined with C-linkage we need to define a free C-function to +// pass into pthread_create(). +extern "C" inline void* ThreadFuncWithCLinkage(void* thread) { + static_cast(thread)->Run(); + return NULL; +} + +// Helper class for testing Google Test's multi-threading constructs. +// To use it, write: +// +// void ThreadFunc(int param) { /* Do things with param */ } +// Notification thread_can_start; +// ... +// // The thread_can_start parameter is optional; you can supply NULL. +// ThreadWithParam thread(&ThreadFunc, 5, &thread_can_start); +// thread_can_start.Notify(); +// +// These classes are only for testing Google Test's own constructs. Do +// not use them in user tests, either directly or indirectly. +template +class ThreadWithParam : public ThreadWithParamBase { + public: + typedef void (*UserThreadFunc)(T); + + ThreadWithParam( + UserThreadFunc func, T param, Notification* thread_can_start) + : func_(func), + param_(param), + thread_can_start_(thread_can_start), + finished_(false) { + ThreadWithParamBase* const base = this; + // The thread can be created only after all fields except thread_ + // have been initialized. + GTEST_CHECK_POSIX_SUCCESS_( + pthread_create(&thread_, 0, &ThreadFuncWithCLinkage, base)); + } + ~ThreadWithParam() { Join(); } + + void Join() { + if (!finished_) { + GTEST_CHECK_POSIX_SUCCESS_(pthread_join(thread_, 0)); + finished_ = true; + } + } + + virtual void Run() { + if (thread_can_start_ != NULL) + thread_can_start_->WaitForNotification(); + func_(param_); + } + + private: + const UserThreadFunc func_; // User-supplied thread function. + const T param_; // User-supplied parameter to the thread function. + // When non-NULL, used to block execution until the controller thread + // notifies. + Notification* const thread_can_start_; + bool finished_; // true iff we know that the thread function has finished. + pthread_t thread_; // The native thread object. + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ThreadWithParam); +}; + +// MutexBase and Mutex implement mutex on pthreads-based platforms. They +// are used in conjunction with class MutexLock: +// +// Mutex mutex; +// ... +// MutexLock lock(&mutex); // Acquires the mutex and releases it at the end +// // of the current scope. +// +// MutexBase implements behavior for both statically and dynamically +// allocated mutexes. Do not use MutexBase directly. Instead, write +// the following to define a static mutex: +// +// GTEST_DEFINE_STATIC_MUTEX_(g_some_mutex); +// +// You can forward declare a static mutex like this: +// +// GTEST_DECLARE_STATIC_MUTEX_(g_some_mutex); +// +// To create a dynamic mutex, just define an object of type Mutex. +class MutexBase { + public: + // Acquires this mutex. + void Lock() { + GTEST_CHECK_POSIX_SUCCESS_(pthread_mutex_lock(&mutex_)); + owner_ = pthread_self(); + } + + // Releases this mutex. + void Unlock() { + // We don't protect writing to owner_ here, as it's the caller's + // responsibility to ensure that the current thread holds the + // mutex when this is called. + owner_ = 0; + GTEST_CHECK_POSIX_SUCCESS_(pthread_mutex_unlock(&mutex_)); + } + + // Does nothing if the current thread holds the mutex. Otherwise, crashes + // with high probability. + void AssertHeld() const { + GTEST_CHECK_(owner_ == pthread_self()) + << "The current thread is not holding the mutex @" << this; + } + + // A static mutex may be used before main() is entered. It may even + // be used before the dynamic initialization stage. Therefore we + // must be able to initialize a static mutex object at link time. + // This means MutexBase has to be a POD and its member variables + // have to be public. + public: + pthread_mutex_t mutex_; // The underlying pthread mutex. + pthread_t owner_; // The thread holding the mutex; 0 means no one holds it. +}; + +// Forward-declares a static mutex. +# define GTEST_DECLARE_STATIC_MUTEX_(mutex) \ + extern ::testing::internal::MutexBase mutex + +// Defines and statically (i.e. at link time) initializes a static mutex. +# define GTEST_DEFINE_STATIC_MUTEX_(mutex) \ + ::testing::internal::MutexBase mutex = { PTHREAD_MUTEX_INITIALIZER, 0 } + +// The Mutex class can only be used for mutexes created at runtime. It +// shares its API with MutexBase otherwise. +class Mutex : public MutexBase { + public: + Mutex() { + GTEST_CHECK_POSIX_SUCCESS_(pthread_mutex_init(&mutex_, NULL)); + owner_ = 0; + } + ~Mutex() { + GTEST_CHECK_POSIX_SUCCESS_(pthread_mutex_destroy(&mutex_)); + } + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(Mutex); +}; + +// We cannot name this class MutexLock as the ctor declaration would +// conflict with a macro named MutexLock, which is defined on some +// platforms. Hence the typedef trick below. +class GTestMutexLock { + public: + explicit GTestMutexLock(MutexBase* mutex) + : mutex_(mutex) { mutex_->Lock(); } + + ~GTestMutexLock() { mutex_->Unlock(); } + + private: + MutexBase* const mutex_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(GTestMutexLock); +}; + +typedef GTestMutexLock MutexLock; + +// Helpers for ThreadLocal. + +// pthread_key_create() requires DeleteThreadLocalValue() to have +// C-linkage. Therefore it cannot be templatized to access +// ThreadLocal. Hence the need for class +// ThreadLocalValueHolderBase. +class ThreadLocalValueHolderBase { + public: + virtual ~ThreadLocalValueHolderBase() {} +}; + +// Called by pthread to delete thread-local data stored by +// pthread_setspecific(). +extern "C" inline void DeleteThreadLocalValue(void* value_holder) { + delete static_cast(value_holder); +} + +// Implements thread-local storage on pthreads-based systems. +// +// // Thread 1 +// ThreadLocal tl(100); // 100 is the default value for each thread. +// +// // Thread 2 +// tl.set(150); // Changes the value for thread 2 only. +// EXPECT_EQ(150, tl.get()); +// +// // Thread 1 +// EXPECT_EQ(100, tl.get()); // In thread 1, tl has the original value. +// tl.set(200); +// EXPECT_EQ(200, tl.get()); +// +// The template type argument T must have a public copy constructor. +// In addition, the default ThreadLocal constructor requires T to have +// a public default constructor. +// +// An object managed for a thread by a ThreadLocal instance is deleted +// when the thread exits. Or, if the ThreadLocal instance dies in +// that thread, when the ThreadLocal dies. It's the user's +// responsibility to ensure that all other threads using a ThreadLocal +// have exited when it dies, or the per-thread objects for those +// threads will not be deleted. +// +// Google Test only uses global ThreadLocal objects. That means they +// will die after main() has returned. Therefore, no per-thread +// object managed by Google Test will be leaked as long as all threads +// using Google Test have exited when main() returns. +template +class ThreadLocal { + public: + ThreadLocal() : key_(CreateKey()), + default_() {} + explicit ThreadLocal(const T& value) : key_(CreateKey()), + default_(value) {} + + ~ThreadLocal() { + // Destroys the managed object for the current thread, if any. + DeleteThreadLocalValue(pthread_getspecific(key_)); + + // Releases resources associated with the key. This will *not* + // delete managed objects for other threads. + GTEST_CHECK_POSIX_SUCCESS_(pthread_key_delete(key_)); + } + + T* pointer() { return GetOrCreateValue(); } + const T* pointer() const { return GetOrCreateValue(); } + const T& get() const { return *pointer(); } + void set(const T& value) { *pointer() = value; } + + private: + // Holds a value of type T. + class ValueHolder : public ThreadLocalValueHolderBase { + public: + explicit ValueHolder(const T& value) : value_(value) {} + + T* pointer() { return &value_; } + + private: + T value_; + GTEST_DISALLOW_COPY_AND_ASSIGN_(ValueHolder); + }; + + static pthread_key_t CreateKey() { + pthread_key_t key; + // When a thread exits, DeleteThreadLocalValue() will be called on + // the object managed for that thread. + GTEST_CHECK_POSIX_SUCCESS_( + pthread_key_create(&key, &DeleteThreadLocalValue)); + return key; + } + + T* GetOrCreateValue() const { + ThreadLocalValueHolderBase* const holder = + static_cast(pthread_getspecific(key_)); + if (holder != NULL) { + return CheckedDowncastToActualType(holder)->pointer(); + } + + ValueHolder* const new_holder = new ValueHolder(default_); + ThreadLocalValueHolderBase* const holder_base = new_holder; + GTEST_CHECK_POSIX_SUCCESS_(pthread_setspecific(key_, holder_base)); + return new_holder->pointer(); + } + + // A key pthreads uses for looking up per-thread values. + const pthread_key_t key_; + const T default_; // The default value for each thread. + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ThreadLocal); +}; + +# define GTEST_IS_THREADSAFE 1 + +#else // GTEST_HAS_PTHREAD + +// A dummy implementation of synchronization primitives (mutex, lock, +// and thread-local variable). Necessary for compiling Google Test where +// mutex is not supported - using Google Test in multiple threads is not +// supported on such platforms. + +class Mutex { + public: + Mutex() {} + void AssertHeld() const {} +}; + +# define GTEST_DECLARE_STATIC_MUTEX_(mutex) \ + extern ::testing::internal::Mutex mutex + +# define GTEST_DEFINE_STATIC_MUTEX_(mutex) ::testing::internal::Mutex mutex + +class GTestMutexLock { + public: + explicit GTestMutexLock(Mutex*) {} // NOLINT +}; + +typedef GTestMutexLock MutexLock; + +template +class ThreadLocal { + public: + ThreadLocal() : value_() {} + explicit ThreadLocal(const T& value) : value_(value) {} + T* pointer() { return &value_; } + const T* pointer() const { return &value_; } + const T& get() const { return value_; } + void set(const T& value) { value_ = value; } + private: + T value_; +}; + +// The above synchronization primitives have dummy implementations. +// Therefore Google Test is not thread-safe. +# define GTEST_IS_THREADSAFE 0 + +#endif // GTEST_HAS_PTHREAD + +// Returns the number of threads running in the process, or 0 to indicate that +// we cannot detect it. +GTEST_API_ size_t GetThreadCount(); + +// Passing non-POD classes through ellipsis (...) crashes the ARM +// compiler and generates a warning in Sun Studio. The Nokia Symbian +// and the IBM XL C/C++ compiler try to instantiate a copy constructor +// for objects passed through ellipsis (...), failing for uncopyable +// objects. We define this to ensure that only POD is passed through +// ellipsis on these systems. +#if defined(__SYMBIAN32__) || defined(__IBMCPP__) || defined(__SUNPRO_CC) +// We lose support for NULL detection where the compiler doesn't like +// passing non-POD classes through ellipsis (...). +# define GTEST_ELLIPSIS_NEEDS_POD_ 1 +#else +# define GTEST_CAN_COMPARE_NULL 1 +#endif + +// The Nokia Symbian and IBM XL C/C++ compilers cannot decide between +// const T& and const T* in a function template. These compilers +// _can_ decide between class template specializations for T and T*, +// so a tr1::type_traits-like is_pointer works. +#if defined(__SYMBIAN32__) || defined(__IBMCPP__) +# define GTEST_NEEDS_IS_POINTER_ 1 +#endif + +template +struct bool_constant { + typedef bool_constant type; + static const bool value = bool_value; +}; +template const bool bool_constant::value; + +typedef bool_constant false_type; +typedef bool_constant true_type; + +template +struct is_pointer : public false_type {}; + +template +struct is_pointer : public true_type {}; + +template +struct IteratorTraits { + typedef typename Iterator::value_type value_type; +}; + +template +struct IteratorTraits { + typedef T value_type; +}; + +template +struct IteratorTraits { + typedef T value_type; +}; + +#if GTEST_OS_WINDOWS +# define GTEST_PATH_SEP_ "\\" +# define GTEST_HAS_ALT_PATH_SEP_ 1 +// The biggest signed integer type the compiler supports. +typedef __int64 BiggestInt; +#else +# define GTEST_PATH_SEP_ "/" +# define GTEST_HAS_ALT_PATH_SEP_ 0 +typedef long long BiggestInt; // NOLINT +#endif // GTEST_OS_WINDOWS + +// Utilities for char. + +// isspace(int ch) and friends accept an unsigned char or EOF. char +// may be signed, depending on the compiler (or compiler flags). +// Therefore we need to cast a char to unsigned char before calling +// isspace(), etc. + +inline bool IsAlpha(char ch) { + return isalpha(static_cast(ch)) != 0; +} +inline bool IsAlNum(char ch) { + return isalnum(static_cast(ch)) != 0; +} +inline bool IsDigit(char ch) { + return isdigit(static_cast(ch)) != 0; +} +inline bool IsLower(char ch) { + return islower(static_cast(ch)) != 0; +} +inline bool IsSpace(char ch) { + return isspace(static_cast(ch)) != 0; +} +inline bool IsUpper(char ch) { + return isupper(static_cast(ch)) != 0; +} +inline bool IsXDigit(char ch) { + return isxdigit(static_cast(ch)) != 0; +} + +inline char ToLower(char ch) { + return static_cast(tolower(static_cast(ch))); +} +inline char ToUpper(char ch) { + return static_cast(toupper(static_cast(ch))); +} + +// The testing::internal::posix namespace holds wrappers for common +// POSIX functions. These wrappers hide the differences between +// Windows/MSVC and POSIX systems. Since some compilers define these +// standard functions as macros, the wrapper cannot have the same name +// as the wrapped function. + +namespace posix { + +// Functions with a different name on Windows. + +#if GTEST_OS_WINDOWS + +typedef struct _stat StatStruct; + +# ifdef __BORLANDC__ +inline int IsATTY(int fd) { return isatty(fd); } +inline int StrCaseCmp(const char* s1, const char* s2) { + return stricmp(s1, s2); +} +inline char* StrDup(const char* src) { return strdup(src); } +# else // !__BORLANDC__ +# if GTEST_OS_WINDOWS_MOBILE +inline int IsATTY(int /* fd */) { return 0; } +# else +inline int IsATTY(int fd) { return _isatty(fd); } +# endif // GTEST_OS_WINDOWS_MOBILE +inline int StrCaseCmp(const char* s1, const char* s2) { + return _stricmp(s1, s2); +} +inline char* StrDup(const char* src) { return _strdup(src); } +# endif // __BORLANDC__ + +# if GTEST_OS_WINDOWS_MOBILE +inline int FileNo(FILE* file) { return reinterpret_cast(_fileno(file)); } +// Stat(), RmDir(), and IsDir() are not needed on Windows CE at this +// time and thus not defined there. +# else +inline int FileNo(FILE* file) { return _fileno(file); } +inline int Stat(const char* path, StatStruct* buf) { return _stat(path, buf); } +inline int RmDir(const char* dir) { return _rmdir(dir); } +inline bool IsDir(const StatStruct& st) { + return (_S_IFDIR & st.st_mode) != 0; +} +# endif // GTEST_OS_WINDOWS_MOBILE + +#else + +typedef struct stat StatStruct; + +inline int FileNo(FILE* file) { return fileno(file); } +inline int IsATTY(int fd) { return isatty(fd); } +inline int Stat(const char* path, StatStruct* buf) { return stat(path, buf); } +inline int StrCaseCmp(const char* s1, const char* s2) { + return strcasecmp(s1, s2); +} +inline char* StrDup(const char* src) { return strdup(src); } +inline int RmDir(const char* dir) { return rmdir(dir); } +inline bool IsDir(const StatStruct& st) { return S_ISDIR(st.st_mode); } + +#endif // GTEST_OS_WINDOWS + +// Functions deprecated by MSVC 8.0. + +#ifdef _MSC_VER +// Temporarily disable warning 4996 (deprecated function). +# pragma warning(push) +# pragma warning(disable:4996) +#endif + +inline const char* StrNCpy(char* dest, const char* src, size_t n) { + return strncpy(dest, src, n); +} + +// ChDir(), FReopen(), FDOpen(), Read(), Write(), Close(), and +// StrError() aren't needed on Windows CE at this time and thus not +// defined there. + +#if !GTEST_OS_WINDOWS_MOBILE +inline int ChDir(const char* dir) { return chdir(dir); } +#endif +inline FILE* FOpen(const char* path, const char* mode) { + return fopen(path, mode); +} +#if !GTEST_OS_WINDOWS_MOBILE +inline FILE *FReopen(const char* path, const char* mode, FILE* stream) { + return freopen(path, mode, stream); +} +inline FILE* FDOpen(int fd, const char* mode) { return fdopen(fd, mode); } +#endif +inline int FClose(FILE* fp) { return fclose(fp); } +#if !GTEST_OS_WINDOWS_MOBILE +inline int Read(int fd, void* buf, unsigned int count) { + return static_cast(read(fd, buf, count)); +} +inline int Write(int fd, const void* buf, unsigned int count) { + return static_cast(write(fd, buf, count)); +} +inline int Close(int fd) { return close(fd); } +inline const char* StrError(int errnum) { return strerror(errnum); } +#endif +inline const char* GetEnv(const char* name) { +#if GTEST_OS_WINDOWS_MOBILE + // We are on Windows CE, which has no environment variables. + return NULL; +#elif defined(__BORLANDC__) || defined(__SunOS_5_8) || defined(__SunOS_5_9) + // Environment variables which we programmatically clear will be set to the + // empty string rather than unset (NULL). Handle that case. + const char* const env = getenv(name); + return (env != NULL && env[0] != '\0') ? env : NULL; +#else + return getenv(name); +#endif +} + +#ifdef _MSC_VER +# pragma warning(pop) // Restores the warning state. +#endif + +#if GTEST_OS_WINDOWS_MOBILE +// Windows CE has no C library. The abort() function is used in +// several places in Google Test. This implementation provides a reasonable +// imitation of standard behaviour. +void Abort(); +#else +inline void Abort() { abort(); } +#endif // GTEST_OS_WINDOWS_MOBILE + +} // namespace posix + +// The maximum number a BiggestInt can represent. This definition +// works no matter BiggestInt is represented in one's complement or +// two's complement. +// +// We cannot rely on numeric_limits in STL, as __int64 and long long +// are not part of standard C++ and numeric_limits doesn't need to be +// defined for them. +const BiggestInt kMaxBiggestInt = + ~(static_cast(1) << (8*sizeof(BiggestInt) - 1)); + +// This template class serves as a compile-time function from size to +// type. It maps a size in bytes to a primitive type with that +// size. e.g. +// +// TypeWithSize<4>::UInt +// +// is typedef-ed to be unsigned int (unsigned integer made up of 4 +// bytes). +// +// Such functionality should belong to STL, but I cannot find it +// there. +// +// Google Test uses this class in the implementation of floating-point +// comparison. +// +// For now it only handles UInt (unsigned int) as that's all Google Test +// needs. Other types can be easily added in the future if need +// arises. +template +class TypeWithSize { + public: + // This prevents the user from using TypeWithSize with incorrect + // values of N. + typedef void UInt; +}; + +// The specialization for size 4. +template <> +class TypeWithSize<4> { + public: + // unsigned int has size 4 in both gcc and MSVC. + // + // As base/basictypes.h doesn't compile on Windows, we cannot use + // uint32, uint64, and etc here. + typedef int Int; + typedef unsigned int UInt; +}; + +// The specialization for size 8. +template <> +class TypeWithSize<8> { + public: + +#if GTEST_OS_WINDOWS + typedef __int64 Int; + typedef unsigned __int64 UInt; +#else + typedef long long Int; // NOLINT + typedef unsigned long long UInt; // NOLINT +#endif // GTEST_OS_WINDOWS +}; + +// Integer types of known sizes. +typedef TypeWithSize<4>::Int Int32; +typedef TypeWithSize<4>::UInt UInt32; +typedef TypeWithSize<8>::Int Int64; +typedef TypeWithSize<8>::UInt UInt64; +typedef TypeWithSize<8>::Int TimeInMillis; // Represents time in milliseconds. + +// Utilities for command line flags and environment variables. + +// Macro for referencing flags. +#define GTEST_FLAG(name) FLAGS_gtest_##name + +// Macros for declaring flags. +#define GTEST_DECLARE_bool_(name) GTEST_API_ extern bool GTEST_FLAG(name) +#define GTEST_DECLARE_int32_(name) \ + GTEST_API_ extern ::testing::internal::Int32 GTEST_FLAG(name) +#define GTEST_DECLARE_string_(name) \ + GTEST_API_ extern ::testing::internal::String GTEST_FLAG(name) + +// Macros for defining flags. +#define GTEST_DEFINE_bool_(name, default_val, doc) \ + GTEST_API_ bool GTEST_FLAG(name) = (default_val) +#define GTEST_DEFINE_int32_(name, default_val, doc) \ + GTEST_API_ ::testing::internal::Int32 GTEST_FLAG(name) = (default_val) +#define GTEST_DEFINE_string_(name, default_val, doc) \ + GTEST_API_ ::testing::internal::String GTEST_FLAG(name) = (default_val) + +// Parses 'str' for a 32-bit signed integer. If successful, writes the result +// to *value and returns true; otherwise leaves *value unchanged and returns +// false. +// TODO(chandlerc): Find a better way to refactor flag and environment parsing +// out of both gtest-port.cc and gtest.cc to avoid exporting this utility +// function. +bool ParseInt32(const Message& src_text, const char* str, Int32* value); + +// Parses a bool/Int32/string from the environment variable +// corresponding to the given Google Test flag. +bool BoolFromGTestEnv(const char* flag, bool default_val); +GTEST_API_ Int32 Int32FromGTestEnv(const char* flag, Int32 default_val); +const char* StringFromGTestEnv(const char* flag, const char* default_val); + +} // namespace internal +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PORT_H_ + +#if GTEST_OS_LINUX +# include +# include +# include +# include +#endif // GTEST_OS_LINUX + +#include +#include +#include +#include +#include + +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Authors: wan@google.com (Zhanyong Wan), eefacm@gmail.com (Sean Mcafee) +// +// The Google C++ Testing Framework (Google Test) +// +// This header file declares the String class and functions used internally by +// Google Test. They are subject to change without notice. They should not used +// by code external to Google Test. +// +// This header file is #included by . +// It should not be #included by other files. + +#ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_STRING_H_ +#define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_STRING_H_ + +#ifdef __BORLANDC__ +// string.h is not guaranteed to provide strcpy on C++ Builder. +# include +#endif + +#include + +#include + +namespace testing { +namespace internal { + +// String - a UTF-8 string class. +// +// For historic reasons, we don't use std::string. +// +// TODO(wan@google.com): replace this class with std::string or +// implement it in terms of the latter. +// +// Note that String can represent both NULL and the empty string, +// while std::string cannot represent NULL. +// +// NULL and the empty string are considered different. NULL is less +// than anything (including the empty string) except itself. +// +// This class only provides minimum functionality necessary for +// implementing Google Test. We do not intend to implement a full-fledged +// string class here. +// +// Since the purpose of this class is to provide a substitute for +// std::string on platforms where it cannot be used, we define a copy +// constructor and assignment operators such that we don't need +// conditional compilation in a lot of places. +// +// In order to make the representation efficient, the d'tor of String +// is not virtual. Therefore DO NOT INHERIT FROM String. +class GTEST_API_ String { + public: + // Static utility methods + + // Returns the input enclosed in double quotes if it's not NULL; + // otherwise returns "(null)". For example, "\"Hello\"" is returned + // for input "Hello". + // + // This is useful for printing a C string in the syntax of a literal. + // + // Known issue: escape sequences are not handled yet. + static String ShowCStringQuoted(const char* c_str); + + // Clones a 0-terminated C string, allocating memory using new. The + // caller is responsible for deleting the return value using + // delete[]. Returns the cloned string, or NULL if the input is + // NULL. + // + // This is different from strdup() in string.h, which allocates + // memory using malloc(). + static const char* CloneCString(const char* c_str); + +#if GTEST_OS_WINDOWS_MOBILE + // Windows CE does not have the 'ANSI' versions of Win32 APIs. To be + // able to pass strings to Win32 APIs on CE we need to convert them + // to 'Unicode', UTF-16. + + // Creates a UTF-16 wide string from the given ANSI string, allocating + // memory using new. The caller is responsible for deleting the return + // value using delete[]. Returns the wide string, or NULL if the + // input is NULL. + // + // The wide string is created using the ANSI codepage (CP_ACP) to + // match the behaviour of the ANSI versions of Win32 calls and the + // C runtime. + static LPCWSTR AnsiToUtf16(const char* c_str); + + // Creates an ANSI string from the given wide string, allocating + // memory using new. The caller is responsible for deleting the return + // value using delete[]. Returns the ANSI string, or NULL if the + // input is NULL. + // + // The returned string is created using the ANSI codepage (CP_ACP) to + // match the behaviour of the ANSI versions of Win32 calls and the + // C runtime. + static const char* Utf16ToAnsi(LPCWSTR utf16_str); +#endif + + // Compares two C strings. Returns true iff they have the same content. + // + // Unlike strcmp(), this function can handle NULL argument(s). A + // NULL C string is considered different to any non-NULL C string, + // including the empty string. + static bool CStringEquals(const char* lhs, const char* rhs); + + // Converts a wide C string to a String using the UTF-8 encoding. + // NULL will be converted to "(null)". If an error occurred during + // the conversion, "(failed to convert from wide string)" is + // returned. + static String ShowWideCString(const wchar_t* wide_c_str); + + // Similar to ShowWideCString(), except that this function encloses + // the converted string in double quotes. + static String ShowWideCStringQuoted(const wchar_t* wide_c_str); + + // Compares two wide C strings. Returns true iff they have the same + // content. + // + // Unlike wcscmp(), this function can handle NULL argument(s). A + // NULL C string is considered different to any non-NULL C string, + // including the empty string. + static bool WideCStringEquals(const wchar_t* lhs, const wchar_t* rhs); + + // Compares two C strings, ignoring case. Returns true iff they + // have the same content. + // + // Unlike strcasecmp(), this function can handle NULL argument(s). + // A NULL C string is considered different to any non-NULL C string, + // including the empty string. + static bool CaseInsensitiveCStringEquals(const char* lhs, + const char* rhs); + + // Compares two wide C strings, ignoring case. Returns true iff they + // have the same content. + // + // Unlike wcscasecmp(), this function can handle NULL argument(s). + // A NULL C string is considered different to any non-NULL wide C string, + // including the empty string. + // NB: The implementations on different platforms slightly differ. + // On windows, this method uses _wcsicmp which compares according to LC_CTYPE + // environment variable. On GNU platform this method uses wcscasecmp + // which compares according to LC_CTYPE category of the current locale. + // On MacOS X, it uses towlower, which also uses LC_CTYPE category of the + // current locale. + static bool CaseInsensitiveWideCStringEquals(const wchar_t* lhs, + const wchar_t* rhs); + + // Formats a list of arguments to a String, using the same format + // spec string as for printf. + // + // We do not use the StringPrintf class as it is not universally + // available. + // + // The result is limited to 4096 characters (including the tailing + // 0). If 4096 characters are not enough to format the input, + // "" is returned. + static String Format(const char* format, ...); + + // C'tors + + // The default c'tor constructs a NULL string. + String() : c_str_(NULL), length_(0) {} + + // Constructs a String by cloning a 0-terminated C string. + String(const char* a_c_str) { // NOLINT + if (a_c_str == NULL) { + c_str_ = NULL; + length_ = 0; + } else { + ConstructNonNull(a_c_str, strlen(a_c_str)); + } + } + + // Constructs a String by copying a given number of chars from a + // buffer. E.g. String("hello", 3) creates the string "hel", + // String("a\0bcd", 4) creates "a\0bc", String(NULL, 0) creates "", + // and String(NULL, 1) results in access violation. + String(const char* buffer, size_t a_length) { + ConstructNonNull(buffer, a_length); + } + + // The copy c'tor creates a new copy of the string. The two + // String objects do not share content. + String(const String& str) : c_str_(NULL), length_(0) { *this = str; } + + // D'tor. String is intended to be a final class, so the d'tor + // doesn't need to be virtual. + ~String() { delete[] c_str_; } + + // Allows a String to be implicitly converted to an ::std::string or + // ::string, and vice versa. Converting a String containing a NULL + // pointer to ::std::string or ::string is undefined behavior. + // Converting a ::std::string or ::string containing an embedded NUL + // character to a String will result in the prefix up to the first + // NUL character. + String(const ::std::string& str) { + ConstructNonNull(str.c_str(), str.length()); + } + + operator ::std::string() const { return ::std::string(c_str(), length()); } + +#if GTEST_HAS_GLOBAL_STRING + String(const ::string& str) { + ConstructNonNull(str.c_str(), str.length()); + } + + operator ::string() const { return ::string(c_str(), length()); } +#endif // GTEST_HAS_GLOBAL_STRING + + // Returns true iff this is an empty string (i.e. ""). + bool empty() const { return (c_str() != NULL) && (length() == 0); } + + // Compares this with another String. + // Returns < 0 if this is less than rhs, 0 if this is equal to rhs, or > 0 + // if this is greater than rhs. + int Compare(const String& rhs) const; + + // Returns true iff this String equals the given C string. A NULL + // string and a non-NULL string are considered not equal. + bool operator==(const char* a_c_str) const { return Compare(a_c_str) == 0; } + + // Returns true iff this String is less than the given String. A + // NULL string is considered less than "". + bool operator<(const String& rhs) const { return Compare(rhs) < 0; } + + // Returns true iff this String doesn't equal the given C string. A NULL + // string and a non-NULL string are considered not equal. + bool operator!=(const char* a_c_str) const { return !(*this == a_c_str); } + + // Returns true iff this String ends with the given suffix. *Any* + // String is considered to end with a NULL or empty suffix. + bool EndsWith(const char* suffix) const; + + // Returns true iff this String ends with the given suffix, not considering + // case. Any String is considered to end with a NULL or empty suffix. + bool EndsWithCaseInsensitive(const char* suffix) const; + + // Returns the length of the encapsulated string, or 0 if the + // string is NULL. + size_t length() const { return length_; } + + // Gets the 0-terminated C string this String object represents. + // The String object still owns the string. Therefore the caller + // should NOT delete the return value. + const char* c_str() const { return c_str_; } + + // Assigns a C string to this object. Self-assignment works. + const String& operator=(const char* a_c_str) { + return *this = String(a_c_str); + } + + // Assigns a String object to this object. Self-assignment works. + const String& operator=(const String& rhs) { + if (this != &rhs) { + delete[] c_str_; + if (rhs.c_str() == NULL) { + c_str_ = NULL; + length_ = 0; + } else { + ConstructNonNull(rhs.c_str(), rhs.length()); + } + } + + return *this; + } + + private: + // Constructs a non-NULL String from the given content. This + // function can only be called when c_str_ has not been allocated. + // ConstructNonNull(NULL, 0) results in an empty string (""). + // ConstructNonNull(NULL, non_zero) is undefined behavior. + void ConstructNonNull(const char* buffer, size_t a_length) { + char* const str = new char[a_length + 1]; + memcpy(str, buffer, a_length); + str[a_length] = '\0'; + c_str_ = str; + length_ = a_length; + } + + const char* c_str_; + size_t length_; +}; // class String + +// Streams a String to an ostream. Each '\0' character in the String +// is replaced with "\\0". +inline ::std::ostream& operator<<(::std::ostream& os, const String& str) { + if (str.c_str() == NULL) { + os << "(null)"; + } else { + const char* const c_str = str.c_str(); + for (size_t i = 0; i != str.length(); i++) { + if (c_str[i] == '\0') { + os << "\\0"; + } else { + os << c_str[i]; + } + } + } + return os; +} + +// Gets the content of the stringstream's buffer as a String. Each '\0' +// character in the buffer is replaced with "\\0". +GTEST_API_ String StringStreamToString(::std::stringstream* stream); + +// Converts a streamable value to a String. A NULL pointer is +// converted to "(null)". When the input value is a ::string, +// ::std::string, ::wstring, or ::std::wstring object, each NUL +// character in it is replaced with "\\0". + +// Declared here but defined in gtest.h, so that it has access +// to the definition of the Message class, required by the ARM +// compiler. +template +String StreamableToString(const T& streamable); + +} // namespace internal +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_STRING_H_ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: keith.ray@gmail.com (Keith Ray) +// +// Google Test filepath utilities +// +// This header file declares classes and functions used internally by +// Google Test. They are subject to change without notice. +// +// This file is #included in . +// Do not include this header file separately! + +#ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_FILEPATH_H_ +#define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_FILEPATH_H_ + + +namespace testing { +namespace internal { + +// FilePath - a class for file and directory pathname manipulation which +// handles platform-specific conventions (like the pathname separator). +// Used for helper functions for naming files in a directory for xml output. +// Except for Set methods, all methods are const or static, which provides an +// "immutable value object" -- useful for peace of mind. +// A FilePath with a value ending in a path separator ("like/this/") represents +// a directory, otherwise it is assumed to represent a file. In either case, +// it may or may not represent an actual file or directory in the file system. +// Names are NOT checked for syntax correctness -- no checking for illegal +// characters, malformed paths, etc. + +class GTEST_API_ FilePath { + public: + FilePath() : pathname_("") { } + FilePath(const FilePath& rhs) : pathname_(rhs.pathname_) { } + + explicit FilePath(const char* pathname) : pathname_(pathname) { + Normalize(); + } + + explicit FilePath(const String& pathname) : pathname_(pathname) { + Normalize(); + } + + FilePath& operator=(const FilePath& rhs) { + Set(rhs); + return *this; + } + + void Set(const FilePath& rhs) { + pathname_ = rhs.pathname_; + } + + String ToString() const { return pathname_; } + const char* c_str() const { return pathname_.c_str(); } + + // Returns the current working directory, or "" if unsuccessful. + static FilePath GetCurrentDir(); + + // Given directory = "dir", base_name = "test", number = 0, + // extension = "xml", returns "dir/test.xml". If number is greater + // than zero (e.g., 12), returns "dir/test_12.xml". + // On Windows platform, uses \ as the separator rather than /. + static FilePath MakeFileName(const FilePath& directory, + const FilePath& base_name, + int number, + const char* extension); + + // Given directory = "dir", relative_path = "test.xml", + // returns "dir/test.xml". + // On Windows, uses \ as the separator rather than /. + static FilePath ConcatPaths(const FilePath& directory, + const FilePath& relative_path); + + // Returns a pathname for a file that does not currently exist. The pathname + // will be directory/base_name.extension or + // directory/base_name_.extension if directory/base_name.extension + // already exists. The number will be incremented until a pathname is found + // that does not already exist. + // Examples: 'dir/foo_test.xml' or 'dir/foo_test_1.xml'. + // There could be a race condition if two or more processes are calling this + // function at the same time -- they could both pick the same filename. + static FilePath GenerateUniqueFileName(const FilePath& directory, + const FilePath& base_name, + const char* extension); + + // Returns true iff the path is NULL or "". + bool IsEmpty() const { return c_str() == NULL || *c_str() == '\0'; } + + // If input name has a trailing separator character, removes it and returns + // the name, otherwise return the name string unmodified. + // On Windows platform, uses \ as the separator, other platforms use /. + FilePath RemoveTrailingPathSeparator() const; + + // Returns a copy of the FilePath with the directory part removed. + // Example: FilePath("path/to/file").RemoveDirectoryName() returns + // FilePath("file"). If there is no directory part ("just_a_file"), it returns + // the FilePath unmodified. If there is no file part ("just_a_dir/") it + // returns an empty FilePath (""). + // On Windows platform, '\' is the path separator, otherwise it is '/'. + FilePath RemoveDirectoryName() const; + + // RemoveFileName returns the directory path with the filename removed. + // Example: FilePath("path/to/file").RemoveFileName() returns "path/to/". + // If the FilePath is "a_file" or "/a_file", RemoveFileName returns + // FilePath("./") or, on Windows, FilePath(".\\"). If the filepath does + // not have a file, like "just/a/dir/", it returns the FilePath unmodified. + // On Windows platform, '\' is the path separator, otherwise it is '/'. + FilePath RemoveFileName() const; + + // Returns a copy of the FilePath with the case-insensitive extension removed. + // Example: FilePath("dir/file.exe").RemoveExtension("EXE") returns + // FilePath("dir/file"). If a case-insensitive extension is not + // found, returns a copy of the original FilePath. + FilePath RemoveExtension(const char* extension) const; + + // Creates directories so that path exists. Returns true if successful or if + // the directories already exist; returns false if unable to create + // directories for any reason. Will also return false if the FilePath does + // not represent a directory (that is, it doesn't end with a path separator). + bool CreateDirectoriesRecursively() const; + + // Create the directory so that path exists. Returns true if successful or + // if the directory already exists; returns false if unable to create the + // directory for any reason, including if the parent directory does not + // exist. Not named "CreateDirectory" because that's a macro on Windows. + bool CreateFolder() const; + + // Returns true if FilePath describes something in the file-system, + // either a file, directory, or whatever, and that something exists. + bool FileOrDirectoryExists() const; + + // Returns true if pathname describes a directory in the file-system + // that exists. + bool DirectoryExists() const; + + // Returns true if FilePath ends with a path separator, which indicates that + // it is intended to represent a directory. Returns false otherwise. + // This does NOT check that a directory (or file) actually exists. + bool IsDirectory() const; + + // Returns true if pathname describes a root directory. (Windows has one + // root directory per disk drive.) + bool IsRootDirectory() const; + + // Returns true if pathname describes an absolute path. + bool IsAbsolutePath() const; + + private: + // Replaces multiple consecutive separators with a single separator. + // For example, "bar///foo" becomes "bar/foo". Does not eliminate other + // redundancies that might be in a pathname involving "." or "..". + // + // A pathname with multiple consecutive separators may occur either through + // user error or as a result of some scripts or APIs that generate a pathname + // with a trailing separator. On other platforms the same API or script + // may NOT generate a pathname with a trailing "/". Then elsewhere that + // pathname may have another "/" and pathname components added to it, + // without checking for the separator already being there. + // The script language and operating system may allow paths like "foo//bar" + // but some of the functions in FilePath will not handle that correctly. In + // particular, RemoveTrailingPathSeparator() only removes one separator, and + // it is called in CreateDirectoriesRecursively() assuming that it will change + // a pathname from directory syntax (trailing separator) to filename syntax. + // + // On Windows this method also replaces the alternate path separator '/' with + // the primary path separator '\\', so that for example "bar\\/\\foo" becomes + // "bar\\foo". + + void Normalize(); + + // Returns a pointer to the last occurence of a valid path separator in + // the FilePath. On Windows, for example, both '/' and '\' are valid path + // separators. Returns NULL if no path separator was found. + const char* FindLastPathSeparator() const; + + String pathname_; +}; // class FilePath + +} // namespace internal +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_FILEPATH_H_ +// This file was GENERATED by command: +// pump.py gtest-type-util.h.pump +// DO NOT EDIT BY HAND!!! + +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) + +// Type utilities needed for implementing typed and type-parameterized +// tests. This file is generated by a SCRIPT. DO NOT EDIT BY HAND! +// +// Currently we support at most 50 types in a list, and at most 50 +// type-parameterized tests in one type-parameterized test case. +// Please contact googletestframework@googlegroups.com if you need +// more. + +#ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_TYPE_UTIL_H_ +#define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_TYPE_UTIL_H_ + + +// #ifdef __GNUC__ is too general here. It is possible to use gcc without using +// libstdc++ (which is where cxxabi.h comes from). +# ifdef __GLIBCXX__ +# include +# elif defined(__HP_aCC) +# include +# endif // __GLIBCXX__ + +namespace testing { +namespace internal { + +// GetTypeName() returns a human-readable name of type T. +// NB: This function is also used in Google Mock, so don't move it inside of +// the typed-test-only section below. +template +String GetTypeName() { +# if GTEST_HAS_RTTI + + const char* const name = typeid(T).name(); +# if defined(__GLIBCXX__) || defined(__HP_aCC) + int status = 0; + // gcc's implementation of typeid(T).name() mangles the type name, + // so we have to demangle it. +# ifdef __GLIBCXX__ + using abi::__cxa_demangle; +# endif // __GLIBCXX__ + char* const readable_name = __cxa_demangle(name, 0, 0, &status); + const String name_str(status == 0 ? readable_name : name); + free(readable_name); + return name_str; +# else + return name; +# endif // __GLIBCXX__ || __HP_aCC + +# else + + return ""; + +# endif // GTEST_HAS_RTTI +} + +#if GTEST_HAS_TYPED_TEST || GTEST_HAS_TYPED_TEST_P + +// AssertyTypeEq::type is defined iff T1 and T2 are the same +// type. This can be used as a compile-time assertion to ensure that +// two types are equal. + +template +struct AssertTypeEq; + +template +struct AssertTypeEq { + typedef bool type; +}; + +// A unique type used as the default value for the arguments of class +// template Types. This allows us to simulate variadic templates +// (e.g. Types, Type, and etc), which C++ doesn't +// support directly. +struct None {}; + +// The following family of struct and struct templates are used to +// represent type lists. In particular, TypesN +// represents a type list with N types (T1, T2, ..., and TN) in it. +// Except for Types0, every struct in the family has two member types: +// Head for the first type in the list, and Tail for the rest of the +// list. + +// The empty type list. +struct Types0 {}; + +// Type lists of length 1, 2, 3, and so on. + +template +struct Types1 { + typedef T1 Head; + typedef Types0 Tail; +}; +template +struct Types2 { + typedef T1 Head; + typedef Types1 Tail; +}; + +template +struct Types3 { + typedef T1 Head; + typedef Types2 Tail; +}; + +template +struct Types4 { + typedef T1 Head; + typedef Types3 Tail; +}; + +template +struct Types5 { + typedef T1 Head; + typedef Types4 Tail; +}; + +template +struct Types6 { + typedef T1 Head; + typedef Types5 Tail; +}; + +template +struct Types7 { + typedef T1 Head; + typedef Types6 Tail; +}; + +template +struct Types8 { + typedef T1 Head; + typedef Types7 Tail; +}; + +template +struct Types9 { + typedef T1 Head; + typedef Types8 Tail; +}; + +template +struct Types10 { + typedef T1 Head; + typedef Types9 Tail; +}; + +template +struct Types11 { + typedef T1 Head; + typedef Types10 Tail; +}; + +template +struct Types12 { + typedef T1 Head; + typedef Types11 Tail; +}; + +template +struct Types13 { + typedef T1 Head; + typedef Types12 Tail; +}; + +template +struct Types14 { + typedef T1 Head; + typedef Types13 Tail; +}; + +template +struct Types15 { + typedef T1 Head; + typedef Types14 Tail; +}; + +template +struct Types16 { + typedef T1 Head; + typedef Types15 Tail; +}; + +template +struct Types17 { + typedef T1 Head; + typedef Types16 Tail; +}; + +template +struct Types18 { + typedef T1 Head; + typedef Types17 Tail; +}; + +template +struct Types19 { + typedef T1 Head; + typedef Types18 Tail; +}; + +template +struct Types20 { + typedef T1 Head; + typedef Types19 Tail; +}; + +template +struct Types21 { + typedef T1 Head; + typedef Types20 Tail; +}; + +template +struct Types22 { + typedef T1 Head; + typedef Types21 Tail; +}; + +template +struct Types23 { + typedef T1 Head; + typedef Types22 Tail; +}; + +template +struct Types24 { + typedef T1 Head; + typedef Types23 Tail; +}; + +template +struct Types25 { + typedef T1 Head; + typedef Types24 Tail; +}; + +template +struct Types26 { + typedef T1 Head; + typedef Types25 Tail; +}; + +template +struct Types27 { + typedef T1 Head; + typedef Types26 Tail; +}; + +template +struct Types28 { + typedef T1 Head; + typedef Types27 Tail; +}; + +template +struct Types29 { + typedef T1 Head; + typedef Types28 Tail; +}; + +template +struct Types30 { + typedef T1 Head; + typedef Types29 Tail; +}; + +template +struct Types31 { + typedef T1 Head; + typedef Types30 Tail; +}; + +template +struct Types32 { + typedef T1 Head; + typedef Types31 Tail; +}; + +template +struct Types33 { + typedef T1 Head; + typedef Types32 Tail; +}; + +template +struct Types34 { + typedef T1 Head; + typedef Types33 Tail; +}; + +template +struct Types35 { + typedef T1 Head; + typedef Types34 Tail; +}; + +template +struct Types36 { + typedef T1 Head; + typedef Types35 Tail; +}; + +template +struct Types37 { + typedef T1 Head; + typedef Types36 Tail; +}; + +template +struct Types38 { + typedef T1 Head; + typedef Types37 Tail; +}; + +template +struct Types39 { + typedef T1 Head; + typedef Types38 Tail; +}; + +template +struct Types40 { + typedef T1 Head; + typedef Types39 Tail; +}; + +template +struct Types41 { + typedef T1 Head; + typedef Types40 Tail; +}; + +template +struct Types42 { + typedef T1 Head; + typedef Types41 Tail; +}; + +template +struct Types43 { + typedef T1 Head; + typedef Types42 Tail; +}; + +template +struct Types44 { + typedef T1 Head; + typedef Types43 Tail; +}; + +template +struct Types45 { + typedef T1 Head; + typedef Types44 Tail; +}; + +template +struct Types46 { + typedef T1 Head; + typedef Types45 Tail; +}; + +template +struct Types47 { + typedef T1 Head; + typedef Types46 Tail; +}; + +template +struct Types48 { + typedef T1 Head; + typedef Types47 Tail; +}; + +template +struct Types49 { + typedef T1 Head; + typedef Types48 Tail; +}; + +template +struct Types50 { + typedef T1 Head; + typedef Types49 Tail; +}; + + +} // namespace internal + +// We don't want to require the users to write TypesN<...> directly, +// as that would require them to count the length. Types<...> is much +// easier to write, but generates horrible messages when there is a +// compiler error, as gcc insists on printing out each template +// argument, even if it has the default value (this means Types +// will appear as Types in the compiler +// errors). +// +// Our solution is to combine the best part of the two approaches: a +// user would write Types, and Google Test will translate +// that to TypesN internally to make error messages +// readable. The translation is done by the 'type' member of the +// Types template. +template +struct Types { + typedef internal::Types50 type; +}; + +template <> +struct Types { + typedef internal::Types0 type; +}; +template +struct Types { + typedef internal::Types1 type; +}; +template +struct Types { + typedef internal::Types2 type; +}; +template +struct Types { + typedef internal::Types3 type; +}; +template +struct Types { + typedef internal::Types4 type; +}; +template +struct Types { + typedef internal::Types5 type; +}; +template +struct Types { + typedef internal::Types6 type; +}; +template +struct Types { + typedef internal::Types7 type; +}; +template +struct Types { + typedef internal::Types8 type; +}; +template +struct Types { + typedef internal::Types9 type; +}; +template +struct Types { + typedef internal::Types10 type; +}; +template +struct Types { + typedef internal::Types11 type; +}; +template +struct Types { + typedef internal::Types12 type; +}; +template +struct Types { + typedef internal::Types13 type; +}; +template +struct Types { + typedef internal::Types14 type; +}; +template +struct Types { + typedef internal::Types15 type; +}; +template +struct Types { + typedef internal::Types16 type; +}; +template +struct Types { + typedef internal::Types17 type; +}; +template +struct Types { + typedef internal::Types18 type; +}; +template +struct Types { + typedef internal::Types19 type; +}; +template +struct Types { + typedef internal::Types20 type; +}; +template +struct Types { + typedef internal::Types21 type; +}; +template +struct Types { + typedef internal::Types22 type; +}; +template +struct Types { + typedef internal::Types23 type; +}; +template +struct Types { + typedef internal::Types24 type; +}; +template +struct Types { + typedef internal::Types25 type; +}; +template +struct Types { + typedef internal::Types26 type; +}; +template +struct Types { + typedef internal::Types27 type; +}; +template +struct Types { + typedef internal::Types28 type; +}; +template +struct Types { + typedef internal::Types29 type; +}; +template +struct Types { + typedef internal::Types30 type; +}; +template +struct Types { + typedef internal::Types31 type; +}; +template +struct Types { + typedef internal::Types32 type; +}; +template +struct Types { + typedef internal::Types33 type; +}; +template +struct Types { + typedef internal::Types34 type; +}; +template +struct Types { + typedef internal::Types35 type; +}; +template +struct Types { + typedef internal::Types36 type; +}; +template +struct Types { + typedef internal::Types37 type; +}; +template +struct Types { + typedef internal::Types38 type; +}; +template +struct Types { + typedef internal::Types39 type; +}; +template +struct Types { + typedef internal::Types40 type; +}; +template +struct Types { + typedef internal::Types41 type; +}; +template +struct Types { + typedef internal::Types42 type; +}; +template +struct Types { + typedef internal::Types43 type; +}; +template +struct Types { + typedef internal::Types44 type; +}; +template +struct Types { + typedef internal::Types45 type; +}; +template +struct Types { + typedef internal::Types46 type; +}; +template +struct Types { + typedef internal::Types47 type; +}; +template +struct Types { + typedef internal::Types48 type; +}; +template +struct Types { + typedef internal::Types49 type; +}; + +namespace internal { + +# define GTEST_TEMPLATE_ template class + +// The template "selector" struct TemplateSel is used to +// represent Tmpl, which must be a class template with one type +// parameter, as a type. TemplateSel::Bind::type is defined +// as the type Tmpl. This allows us to actually instantiate the +// template "selected" by TemplateSel. +// +// This trick is necessary for simulating typedef for class templates, +// which C++ doesn't support directly. +template +struct TemplateSel { + template + struct Bind { + typedef Tmpl type; + }; +}; + +# define GTEST_BIND_(TmplSel, T) \ + TmplSel::template Bind::type + +// A unique struct template used as the default value for the +// arguments of class template Templates. This allows us to simulate +// variadic templates (e.g. Templates, Templates, +// and etc), which C++ doesn't support directly. +template +struct NoneT {}; + +// The following family of struct and struct templates are used to +// represent template lists. In particular, TemplatesN represents a list of N templates (T1, T2, ..., and TN). Except +// for Templates0, every struct in the family has two member types: +// Head for the selector of the first template in the list, and Tail +// for the rest of the list. + +// The empty template list. +struct Templates0 {}; + +// Template lists of length 1, 2, 3, and so on. + +template +struct Templates1 { + typedef TemplateSel Head; + typedef Templates0 Tail; +}; +template +struct Templates2 { + typedef TemplateSel Head; + typedef Templates1 Tail; +}; + +template +struct Templates3 { + typedef TemplateSel Head; + typedef Templates2 Tail; +}; + +template +struct Templates4 { + typedef TemplateSel Head; + typedef Templates3 Tail; +}; + +template +struct Templates5 { + typedef TemplateSel Head; + typedef Templates4 Tail; +}; + +template +struct Templates6 { + typedef TemplateSel Head; + typedef Templates5 Tail; +}; + +template +struct Templates7 { + typedef TemplateSel Head; + typedef Templates6 Tail; +}; + +template +struct Templates8 { + typedef TemplateSel Head; + typedef Templates7 Tail; +}; + +template +struct Templates9 { + typedef TemplateSel Head; + typedef Templates8 Tail; +}; + +template +struct Templates10 { + typedef TemplateSel Head; + typedef Templates9 Tail; +}; + +template +struct Templates11 { + typedef TemplateSel Head; + typedef Templates10 Tail; +}; + +template +struct Templates12 { + typedef TemplateSel Head; + typedef Templates11 Tail; +}; + +template +struct Templates13 { + typedef TemplateSel Head; + typedef Templates12 Tail; +}; + +template +struct Templates14 { + typedef TemplateSel Head; + typedef Templates13 Tail; +}; + +template +struct Templates15 { + typedef TemplateSel Head; + typedef Templates14 Tail; +}; + +template +struct Templates16 { + typedef TemplateSel Head; + typedef Templates15 Tail; +}; + +template +struct Templates17 { + typedef TemplateSel Head; + typedef Templates16 Tail; +}; + +template +struct Templates18 { + typedef TemplateSel Head; + typedef Templates17 Tail; +}; + +template +struct Templates19 { + typedef TemplateSel Head; + typedef Templates18 Tail; +}; + +template +struct Templates20 { + typedef TemplateSel Head; + typedef Templates19 Tail; +}; + +template +struct Templates21 { + typedef TemplateSel Head; + typedef Templates20 Tail; +}; + +template +struct Templates22 { + typedef TemplateSel Head; + typedef Templates21 Tail; +}; + +template +struct Templates23 { + typedef TemplateSel Head; + typedef Templates22 Tail; +}; + +template +struct Templates24 { + typedef TemplateSel Head; + typedef Templates23 Tail; +}; + +template +struct Templates25 { + typedef TemplateSel Head; + typedef Templates24 Tail; +}; + +template +struct Templates26 { + typedef TemplateSel Head; + typedef Templates25 Tail; +}; + +template +struct Templates27 { + typedef TemplateSel Head; + typedef Templates26 Tail; +}; + +template +struct Templates28 { + typedef TemplateSel Head; + typedef Templates27 Tail; +}; + +template +struct Templates29 { + typedef TemplateSel Head; + typedef Templates28 Tail; +}; + +template +struct Templates30 { + typedef TemplateSel Head; + typedef Templates29 Tail; +}; + +template +struct Templates31 { + typedef TemplateSel Head; + typedef Templates30 Tail; +}; + +template +struct Templates32 { + typedef TemplateSel Head; + typedef Templates31 Tail; +}; + +template +struct Templates33 { + typedef TemplateSel Head; + typedef Templates32 Tail; +}; + +template +struct Templates34 { + typedef TemplateSel Head; + typedef Templates33 Tail; +}; + +template +struct Templates35 { + typedef TemplateSel Head; + typedef Templates34 Tail; +}; + +template +struct Templates36 { + typedef TemplateSel Head; + typedef Templates35 Tail; +}; + +template +struct Templates37 { + typedef TemplateSel Head; + typedef Templates36 Tail; +}; + +template +struct Templates38 { + typedef TemplateSel Head; + typedef Templates37 Tail; +}; + +template +struct Templates39 { + typedef TemplateSel Head; + typedef Templates38 Tail; +}; + +template +struct Templates40 { + typedef TemplateSel Head; + typedef Templates39 Tail; +}; + +template +struct Templates41 { + typedef TemplateSel Head; + typedef Templates40 Tail; +}; + +template +struct Templates42 { + typedef TemplateSel Head; + typedef Templates41 Tail; +}; + +template +struct Templates43 { + typedef TemplateSel Head; + typedef Templates42 Tail; +}; + +template +struct Templates44 { + typedef TemplateSel Head; + typedef Templates43 Tail; +}; + +template +struct Templates45 { + typedef TemplateSel Head; + typedef Templates44 Tail; +}; + +template +struct Templates46 { + typedef TemplateSel Head; + typedef Templates45 Tail; +}; + +template +struct Templates47 { + typedef TemplateSel Head; + typedef Templates46 Tail; +}; + +template +struct Templates48 { + typedef TemplateSel Head; + typedef Templates47 Tail; +}; + +template +struct Templates49 { + typedef TemplateSel Head; + typedef Templates48 Tail; +}; + +template +struct Templates50 { + typedef TemplateSel Head; + typedef Templates49 Tail; +}; + + +// We don't want to require the users to write TemplatesN<...> directly, +// as that would require them to count the length. Templates<...> is much +// easier to write, but generates horrible messages when there is a +// compiler error, as gcc insists on printing out each template +// argument, even if it has the default value (this means Templates +// will appear as Templates in the compiler +// errors). +// +// Our solution is to combine the best part of the two approaches: a +// user would write Templates, and Google Test will translate +// that to TemplatesN internally to make error messages +// readable. The translation is done by the 'type' member of the +// Templates template. +template +struct Templates { + typedef Templates50 type; +}; + +template <> +struct Templates { + typedef Templates0 type; +}; +template +struct Templates { + typedef Templates1 type; +}; +template +struct Templates { + typedef Templates2 type; +}; +template +struct Templates { + typedef Templates3 type; +}; +template +struct Templates { + typedef Templates4 type; +}; +template +struct Templates { + typedef Templates5 type; +}; +template +struct Templates { + typedef Templates6 type; +}; +template +struct Templates { + typedef Templates7 type; +}; +template +struct Templates { + typedef Templates8 type; +}; +template +struct Templates { + typedef Templates9 type; +}; +template +struct Templates { + typedef Templates10 type; +}; +template +struct Templates { + typedef Templates11 type; +}; +template +struct Templates { + typedef Templates12 type; +}; +template +struct Templates { + typedef Templates13 type; +}; +template +struct Templates { + typedef Templates14 type; +}; +template +struct Templates { + typedef Templates15 type; +}; +template +struct Templates { + typedef Templates16 type; +}; +template +struct Templates { + typedef Templates17 type; +}; +template +struct Templates { + typedef Templates18 type; +}; +template +struct Templates { + typedef Templates19 type; +}; +template +struct Templates { + typedef Templates20 type; +}; +template +struct Templates { + typedef Templates21 type; +}; +template +struct Templates { + typedef Templates22 type; +}; +template +struct Templates { + typedef Templates23 type; +}; +template +struct Templates { + typedef Templates24 type; +}; +template +struct Templates { + typedef Templates25 type; +}; +template +struct Templates { + typedef Templates26 type; +}; +template +struct Templates { + typedef Templates27 type; +}; +template +struct Templates { + typedef Templates28 type; +}; +template +struct Templates { + typedef Templates29 type; +}; +template +struct Templates { + typedef Templates30 type; +}; +template +struct Templates { + typedef Templates31 type; +}; +template +struct Templates { + typedef Templates32 type; +}; +template +struct Templates { + typedef Templates33 type; +}; +template +struct Templates { + typedef Templates34 type; +}; +template +struct Templates { + typedef Templates35 type; +}; +template +struct Templates { + typedef Templates36 type; +}; +template +struct Templates { + typedef Templates37 type; +}; +template +struct Templates { + typedef Templates38 type; +}; +template +struct Templates { + typedef Templates39 type; +}; +template +struct Templates { + typedef Templates40 type; +}; +template +struct Templates { + typedef Templates41 type; +}; +template +struct Templates { + typedef Templates42 type; +}; +template +struct Templates { + typedef Templates43 type; +}; +template +struct Templates { + typedef Templates44 type; +}; +template +struct Templates { + typedef Templates45 type; +}; +template +struct Templates { + typedef Templates46 type; +}; +template +struct Templates { + typedef Templates47 type; +}; +template +struct Templates { + typedef Templates48 type; +}; +template +struct Templates { + typedef Templates49 type; +}; + +// The TypeList template makes it possible to use either a single type +// or a Types<...> list in TYPED_TEST_CASE() and +// INSTANTIATE_TYPED_TEST_CASE_P(). + +template +struct TypeList { typedef Types1 type; }; + +template +struct TypeList > { + typedef typename Types::type type; +}; + +#endif // GTEST_HAS_TYPED_TEST || GTEST_HAS_TYPED_TEST_P + +} // namespace internal +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_TYPE_UTIL_H_ + +// Due to C++ preprocessor weirdness, we need double indirection to +// concatenate two tokens when one of them is __LINE__. Writing +// +// foo ## __LINE__ +// +// will result in the token foo__LINE__, instead of foo followed by +// the current line number. For more details, see +// http://www.parashift.com/c++-faq-lite/misc-technical-issues.html#faq-39.6 +#define GTEST_CONCAT_TOKEN_(foo, bar) GTEST_CONCAT_TOKEN_IMPL_(foo, bar) +#define GTEST_CONCAT_TOKEN_IMPL_(foo, bar) foo ## bar + +// Google Test defines the testing::Message class to allow construction of +// test messages via the << operator. The idea is that anything +// streamable to std::ostream can be streamed to a testing::Message. +// This allows a user to use his own types in Google Test assertions by +// overloading the << operator. +// +// util/gtl/stl_logging-inl.h overloads << for STL containers. These +// overloads cannot be defined in the std namespace, as that will be +// undefined behavior. Therefore, they are defined in the global +// namespace instead. +// +// C++'s symbol lookup rule (i.e. Koenig lookup) says that these +// overloads are visible in either the std namespace or the global +// namespace, but not other namespaces, including the testing +// namespace which Google Test's Message class is in. +// +// To allow STL containers (and other types that has a << operator +// defined in the global namespace) to be used in Google Test assertions, +// testing::Message must access the custom << operator from the global +// namespace. Hence this helper function. +// +// Note: Jeffrey Yasskin suggested an alternative fix by "using +// ::operator<<;" in the definition of Message's operator<<. That fix +// doesn't require a helper function, but unfortunately doesn't +// compile with MSVC. +template +inline void GTestStreamToHelper(std::ostream* os, const T& val) { + *os << val; +} + +class ProtocolMessage; +namespace proto2 { class Message; } + +namespace testing { + +// Forward declarations. + +class AssertionResult; // Result of an assertion. +class Message; // Represents a failure message. +class Test; // Represents a test. +class TestInfo; // Information about a test. +class TestPartResult; // Result of a test part. +class UnitTest; // A collection of test cases. + +template +::std::string PrintToString(const T& value); + +namespace internal { + +struct TraceInfo; // Information about a trace point. +class ScopedTrace; // Implements scoped trace. +class TestInfoImpl; // Opaque implementation of TestInfo +class UnitTestImpl; // Opaque implementation of UnitTest + +// How many times InitGoogleTest() has been called. +extern int g_init_gtest_count; + +// The text used in failure messages to indicate the start of the +// stack trace. +GTEST_API_ extern const char kStackTraceMarker[]; + +// A secret type that Google Test users don't know about. It has no +// definition on purpose. Therefore it's impossible to create a +// Secret object, which is what we want. +class Secret; + +// Two overloaded helpers for checking at compile time whether an +// expression is a null pointer literal (i.e. NULL or any 0-valued +// compile-time integral constant). Their return values have +// different sizes, so we can use sizeof() to test which version is +// picked by the compiler. These helpers have no implementations, as +// we only need their signatures. +// +// Given IsNullLiteralHelper(x), the compiler will pick the first +// version if x can be implicitly converted to Secret*, and pick the +// second version otherwise. Since Secret is a secret and incomplete +// type, the only expression a user can write that has type Secret* is +// a null pointer literal. Therefore, we know that x is a null +// pointer literal if and only if the first version is picked by the +// compiler. +char IsNullLiteralHelper(Secret* p); +char (&IsNullLiteralHelper(...))[2]; // NOLINT + +// A compile-time bool constant that is true if and only if x is a +// null pointer literal (i.e. NULL or any 0-valued compile-time +// integral constant). +#ifdef GTEST_ELLIPSIS_NEEDS_POD_ +// We lose support for NULL detection where the compiler doesn't like +// passing non-POD classes through ellipsis (...). +# define GTEST_IS_NULL_LITERAL_(x) false +#else +# define GTEST_IS_NULL_LITERAL_(x) \ + (sizeof(::testing::internal::IsNullLiteralHelper(x)) == 1) +#endif // GTEST_ELLIPSIS_NEEDS_POD_ + +// Appends the user-supplied message to the Google-Test-generated message. +GTEST_API_ String AppendUserMessage(const String& gtest_msg, + const Message& user_msg); + +// A helper class for creating scoped traces in user programs. +class GTEST_API_ ScopedTrace { + public: + // The c'tor pushes the given source file location and message onto + // a trace stack maintained by Google Test. + ScopedTrace(const char* file, int line, const Message& message); + + // The d'tor pops the info pushed by the c'tor. + // + // Note that the d'tor is not virtual in order to be efficient. + // Don't inherit from ScopedTrace! + ~ScopedTrace(); + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(ScopedTrace); +} GTEST_ATTRIBUTE_UNUSED_; // A ScopedTrace object does its job in its + // c'tor and d'tor. Therefore it doesn't + // need to be used otherwise. + +// Converts a streamable value to a String. A NULL pointer is +// converted to "(null)". When the input value is a ::string, +// ::std::string, ::wstring, or ::std::wstring object, each NUL +// character in it is replaced with "\\0". +// Declared here but defined in gtest.h, so that it has access +// to the definition of the Message class, required by the ARM +// compiler. +template +String StreamableToString(const T& streamable); + +// The Symbian compiler has a bug that prevents it from selecting the +// correct overload of FormatForComparisonFailureMessage (see below) +// unless we pass the first argument by reference. If we do that, +// however, Visual Age C++ 10.1 generates a compiler error. Therefore +// we only apply the work-around for Symbian. +#if defined(__SYMBIAN32__) +# define GTEST_CREF_WORKAROUND_ const& +#else +# define GTEST_CREF_WORKAROUND_ +#endif + +// When this operand is a const char* or char*, if the other operand +// is a ::std::string or ::string, we print this operand as a C string +// rather than a pointer (we do the same for wide strings); otherwise +// we print it as a pointer to be safe. + +// This internal macro is used to avoid duplicated code. +#define GTEST_FORMAT_IMPL_(operand2_type, operand1_printer)\ +inline String FormatForComparisonFailureMessage(\ + operand2_type::value_type* GTEST_CREF_WORKAROUND_ str, \ + const operand2_type& /*operand2*/) {\ + return operand1_printer(str);\ +}\ +inline String FormatForComparisonFailureMessage(\ + const operand2_type::value_type* GTEST_CREF_WORKAROUND_ str, \ + const operand2_type& /*operand2*/) {\ + return operand1_printer(str);\ +} + +GTEST_FORMAT_IMPL_(::std::string, String::ShowCStringQuoted) +#if GTEST_HAS_STD_WSTRING +GTEST_FORMAT_IMPL_(::std::wstring, String::ShowWideCStringQuoted) +#endif // GTEST_HAS_STD_WSTRING + +#if GTEST_HAS_GLOBAL_STRING +GTEST_FORMAT_IMPL_(::string, String::ShowCStringQuoted) +#endif // GTEST_HAS_GLOBAL_STRING +#if GTEST_HAS_GLOBAL_WSTRING +GTEST_FORMAT_IMPL_(::wstring, String::ShowWideCStringQuoted) +#endif // GTEST_HAS_GLOBAL_WSTRING + +#undef GTEST_FORMAT_IMPL_ + +// The next four overloads handle the case where the operand being +// printed is a char/wchar_t pointer and the other operand is not a +// string/wstring object. In such cases, we just print the operand as +// a pointer to be safe. +#define GTEST_FORMAT_CHAR_PTR_IMPL_(CharType) \ + template \ + String FormatForComparisonFailureMessage(CharType* GTEST_CREF_WORKAROUND_ p, \ + const T&) { \ + return PrintToString(static_cast(p)); \ + } + +GTEST_FORMAT_CHAR_PTR_IMPL_(char) +GTEST_FORMAT_CHAR_PTR_IMPL_(const char) +GTEST_FORMAT_CHAR_PTR_IMPL_(wchar_t) +GTEST_FORMAT_CHAR_PTR_IMPL_(const wchar_t) + +#undef GTEST_FORMAT_CHAR_PTR_IMPL_ + +// Constructs and returns the message for an equality assertion +// (e.g. ASSERT_EQ, EXPECT_STREQ, etc) failure. +// +// The first four parameters are the expressions used in the assertion +// and their values, as strings. For example, for ASSERT_EQ(foo, bar) +// where foo is 5 and bar is 6, we have: +// +// expected_expression: "foo" +// actual_expression: "bar" +// expected_value: "5" +// actual_value: "6" +// +// The ignoring_case parameter is true iff the assertion is a +// *_STRCASEEQ*. When it's true, the string " (ignoring case)" will +// be inserted into the message. +GTEST_API_ AssertionResult EqFailure(const char* expected_expression, + const char* actual_expression, + const String& expected_value, + const String& actual_value, + bool ignoring_case); + +// Constructs a failure message for Boolean assertions such as EXPECT_TRUE. +GTEST_API_ String GetBoolAssertionFailureMessage( + const AssertionResult& assertion_result, + const char* expression_text, + const char* actual_predicate_value, + const char* expected_predicate_value); + +// This template class represents an IEEE floating-point number +// (either single-precision or double-precision, depending on the +// template parameters). +// +// The purpose of this class is to do more sophisticated number +// comparison. (Due to round-off error, etc, it's very unlikely that +// two floating-points will be equal exactly. Hence a naive +// comparison by the == operation often doesn't work.) +// +// Format of IEEE floating-point: +// +// The most-significant bit being the leftmost, an IEEE +// floating-point looks like +// +// sign_bit exponent_bits fraction_bits +// +// Here, sign_bit is a single bit that designates the sign of the +// number. +// +// For float, there are 8 exponent bits and 23 fraction bits. +// +// For double, there are 11 exponent bits and 52 fraction bits. +// +// More details can be found at +// http://en.wikipedia.org/wiki/IEEE_floating-point_standard. +// +// Template parameter: +// +// RawType: the raw floating-point type (either float or double) +template +class FloatingPoint { + public: + // Defines the unsigned integer type that has the same size as the + // floating point number. + typedef typename TypeWithSize::UInt Bits; + + // Constants. + + // # of bits in a number. + static const size_t kBitCount = 8*sizeof(RawType); + + // # of fraction bits in a number. + static const size_t kFractionBitCount = + std::numeric_limits::digits - 1; + + // # of exponent bits in a number. + static const size_t kExponentBitCount = kBitCount - 1 - kFractionBitCount; + + // The mask for the sign bit. + static const Bits kSignBitMask = static_cast(1) << (kBitCount - 1); + + // The mask for the fraction bits. + static const Bits kFractionBitMask = + ~static_cast(0) >> (kExponentBitCount + 1); + + // The mask for the exponent bits. + static const Bits kExponentBitMask = ~(kSignBitMask | kFractionBitMask); + + // How many ULP's (Units in the Last Place) we want to tolerate when + // comparing two numbers. The larger the value, the more error we + // allow. A 0 value means that two numbers must be exactly the same + // to be considered equal. + // + // The maximum error of a single floating-point operation is 0.5 + // units in the last place. On Intel CPU's, all floating-point + // calculations are done with 80-bit precision, while double has 64 + // bits. Therefore, 4 should be enough for ordinary use. + // + // See the following article for more details on ULP: + // http://www.cygnus-software.com/papers/comparingfloats/comparingfloats.htm. + static const size_t kMaxUlps = 4; + + // Constructs a FloatingPoint from a raw floating-point number. + // + // On an Intel CPU, passing a non-normalized NAN (Not a Number) + // around may change its bits, although the new value is guaranteed + // to be also a NAN. Therefore, don't expect this constructor to + // preserve the bits in x when x is a NAN. + explicit FloatingPoint(const RawType& x) { u_.value_ = x; } + + // Static methods + + // Reinterprets a bit pattern as a floating-point number. + // + // This function is needed to test the AlmostEquals() method. + static RawType ReinterpretBits(const Bits bits) { + FloatingPoint fp(0); + fp.u_.bits_ = bits; + return fp.u_.value_; + } + + // Returns the floating-point number that represent positive infinity. + static RawType Infinity() { + return ReinterpretBits(kExponentBitMask); + } + + // Non-static methods + + // Returns the bits that represents this number. + const Bits &bits() const { return u_.bits_; } + + // Returns the exponent bits of this number. + Bits exponent_bits() const { return kExponentBitMask & u_.bits_; } + + // Returns the fraction bits of this number. + Bits fraction_bits() const { return kFractionBitMask & u_.bits_; } + + // Returns the sign bit of this number. + Bits sign_bit() const { return kSignBitMask & u_.bits_; } + + // Returns true iff this is NAN (not a number). + bool is_nan() const { + // It's a NAN if the exponent bits are all ones and the fraction + // bits are not entirely zeros. + return (exponent_bits() == kExponentBitMask) && (fraction_bits() != 0); + } + + // Returns true iff this number is at most kMaxUlps ULP's away from + // rhs. In particular, this function: + // + // - returns false if either number is (or both are) NAN. + // - treats really large numbers as almost equal to infinity. + // - thinks +0.0 and -0.0 are 0 DLP's apart. + bool AlmostEquals(const FloatingPoint& rhs) const { + // The IEEE standard says that any comparison operation involving + // a NAN must return false. + if (is_nan() || rhs.is_nan()) return false; + + return DistanceBetweenSignAndMagnitudeNumbers(u_.bits_, rhs.u_.bits_) + <= kMaxUlps; + } + + private: + // The data type used to store the actual floating-point number. + union FloatingPointUnion { + RawType value_; // The raw floating-point number. + Bits bits_; // The bits that represent the number. + }; + + // Converts an integer from the sign-and-magnitude representation to + // the biased representation. More precisely, let N be 2 to the + // power of (kBitCount - 1), an integer x is represented by the + // unsigned number x + N. + // + // For instance, + // + // -N + 1 (the most negative number representable using + // sign-and-magnitude) is represented by 1; + // 0 is represented by N; and + // N - 1 (the biggest number representable using + // sign-and-magnitude) is represented by 2N - 1. + // + // Read http://en.wikipedia.org/wiki/Signed_number_representations + // for more details on signed number representations. + static Bits SignAndMagnitudeToBiased(const Bits &sam) { + if (kSignBitMask & sam) { + // sam represents a negative number. + return ~sam + 1; + } else { + // sam represents a positive number. + return kSignBitMask | sam; + } + } + + // Given two numbers in the sign-and-magnitude representation, + // returns the distance between them as an unsigned number. + static Bits DistanceBetweenSignAndMagnitudeNumbers(const Bits &sam1, + const Bits &sam2) { + const Bits biased1 = SignAndMagnitudeToBiased(sam1); + const Bits biased2 = SignAndMagnitudeToBiased(sam2); + return (biased1 >= biased2) ? (biased1 - biased2) : (biased2 - biased1); + } + + FloatingPointUnion u_; +}; + +// Typedefs the instances of the FloatingPoint template class that we +// care to use. +typedef FloatingPoint Float; +typedef FloatingPoint Double; + +// In order to catch the mistake of putting tests that use different +// test fixture classes in the same test case, we need to assign +// unique IDs to fixture classes and compare them. The TypeId type is +// used to hold such IDs. The user should treat TypeId as an opaque +// type: the only operation allowed on TypeId values is to compare +// them for equality using the == operator. +typedef const void* TypeId; + +template +class TypeIdHelper { + public: + // dummy_ must not have a const type. Otherwise an overly eager + // compiler (e.g. MSVC 7.1 & 8.0) may try to merge + // TypeIdHelper::dummy_ for different Ts as an "optimization". + static bool dummy_; +}; + +template +bool TypeIdHelper::dummy_ = false; + +// GetTypeId() returns the ID of type T. Different values will be +// returned for different types. Calling the function twice with the +// same type argument is guaranteed to return the same ID. +template +TypeId GetTypeId() { + // The compiler is required to allocate a different + // TypeIdHelper::dummy_ variable for each T used to instantiate + // the template. Therefore, the address of dummy_ is guaranteed to + // be unique. + return &(TypeIdHelper::dummy_); +} + +// Returns the type ID of ::testing::Test. Always call this instead +// of GetTypeId< ::testing::Test>() to get the type ID of +// ::testing::Test, as the latter may give the wrong result due to a +// suspected linker bug when compiling Google Test as a Mac OS X +// framework. +GTEST_API_ TypeId GetTestTypeId(); + +// Defines the abstract factory interface that creates instances +// of a Test object. +class TestFactoryBase { + public: + virtual ~TestFactoryBase() {} + + // Creates a test instance to run. The instance is both created and destroyed + // within TestInfoImpl::Run() + virtual Test* CreateTest() = 0; + + protected: + TestFactoryBase() {} + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestFactoryBase); +}; + +// This class provides implementation of TeastFactoryBase interface. +// It is used in TEST and TEST_F macros. +template +class TestFactoryImpl : public TestFactoryBase { + public: + virtual Test* CreateTest() { return new TestClass; } +}; + +#if GTEST_OS_WINDOWS + +// Predicate-formatters for implementing the HRESULT checking macros +// {ASSERT|EXPECT}_HRESULT_{SUCCEEDED|FAILED} +// We pass a long instead of HRESULT to avoid causing an +// include dependency for the HRESULT type. +GTEST_API_ AssertionResult IsHRESULTSuccess(const char* expr, + long hr); // NOLINT +GTEST_API_ AssertionResult IsHRESULTFailure(const char* expr, + long hr); // NOLINT + +#endif // GTEST_OS_WINDOWS + +// Types of SetUpTestCase() and TearDownTestCase() functions. +typedef void (*SetUpTestCaseFunc)(); +typedef void (*TearDownTestCaseFunc)(); + +// Creates a new TestInfo object and registers it with Google Test; +// returns the created object. +// +// Arguments: +// +// test_case_name: name of the test case +// name: name of the test +// type_param the name of the test's type parameter, or NULL if +// this is not a typed or a type-parameterized test. +// value_param text representation of the test's value parameter, +// or NULL if this is not a type-parameterized test. +// fixture_class_id: ID of the test fixture class +// set_up_tc: pointer to the function that sets up the test case +// tear_down_tc: pointer to the function that tears down the test case +// factory: pointer to the factory that creates a test object. +// The newly created TestInfo instance will assume +// ownership of the factory object. +GTEST_API_ TestInfo* MakeAndRegisterTestInfo( + const char* test_case_name, const char* name, + const char* type_param, + const char* value_param, + TypeId fixture_class_id, + SetUpTestCaseFunc set_up_tc, + TearDownTestCaseFunc tear_down_tc, + TestFactoryBase* factory); + +// If *pstr starts with the given prefix, modifies *pstr to be right +// past the prefix and returns true; otherwise leaves *pstr unchanged +// and returns false. None of pstr, *pstr, and prefix can be NULL. +GTEST_API_ bool SkipPrefix(const char* prefix, const char** pstr); + +#if GTEST_HAS_TYPED_TEST || GTEST_HAS_TYPED_TEST_P + +// State of the definition of a type-parameterized test case. +class GTEST_API_ TypedTestCasePState { + public: + TypedTestCasePState() : registered_(false) {} + + // Adds the given test name to defined_test_names_ and return true + // if the test case hasn't been registered; otherwise aborts the + // program. + bool AddTestName(const char* file, int line, const char* case_name, + const char* test_name) { + if (registered_) { + fprintf(stderr, "%s Test %s must be defined before " + "REGISTER_TYPED_TEST_CASE_P(%s, ...).\n", + FormatFileLocation(file, line).c_str(), test_name, case_name); + fflush(stderr); + posix::Abort(); + } + defined_test_names_.insert(test_name); + return true; + } + + // Verifies that registered_tests match the test names in + // defined_test_names_; returns registered_tests if successful, or + // aborts the program otherwise. + const char* VerifyRegisteredTestNames( + const char* file, int line, const char* registered_tests); + + private: + bool registered_; + ::std::set defined_test_names_; +}; + +// Skips to the first non-space char after the first comma in 'str'; +// returns NULL if no comma is found in 'str'. +inline const char* SkipComma(const char* str) { + const char* comma = strchr(str, ','); + if (comma == NULL) { + return NULL; + } + while (IsSpace(*(++comma))) {} + return comma; +} + +// Returns the prefix of 'str' before the first comma in it; returns +// the entire string if it contains no comma. +inline String GetPrefixUntilComma(const char* str) { + const char* comma = strchr(str, ','); + return comma == NULL ? String(str) : String(str, comma - str); +} + +// TypeParameterizedTest::Register() +// registers a list of type-parameterized tests with Google Test. The +// return value is insignificant - we just need to return something +// such that we can call this function in a namespace scope. +// +// Implementation note: The GTEST_TEMPLATE_ macro declares a template +// template parameter. It's defined in gtest-type-util.h. +template +class TypeParameterizedTest { + public: + // 'index' is the index of the test in the type list 'Types' + // specified in INSTANTIATE_TYPED_TEST_CASE_P(Prefix, TestCase, + // Types). Valid values for 'index' are [0, N - 1] where N is the + // length of Types. + static bool Register(const char* prefix, const char* case_name, + const char* test_names, int index) { + typedef typename Types::Head Type; + typedef Fixture FixtureClass; + typedef typename GTEST_BIND_(TestSel, Type) TestClass; + + // First, registers the first type-parameterized test in the type + // list. + MakeAndRegisterTestInfo( + String::Format("%s%s%s/%d", prefix, prefix[0] == '\0' ? "" : "/", + case_name, index).c_str(), + GetPrefixUntilComma(test_names).c_str(), + GetTypeName().c_str(), + NULL, // No value parameter. + GetTypeId(), + TestClass::SetUpTestCase, + TestClass::TearDownTestCase, + new TestFactoryImpl); + + // Next, recurses (at compile time) with the tail of the type list. + return TypeParameterizedTest + ::Register(prefix, case_name, test_names, index + 1); + } +}; + +// The base case for the compile time recursion. +template +class TypeParameterizedTest { + public: + static bool Register(const char* /*prefix*/, const char* /*case_name*/, + const char* /*test_names*/, int /*index*/) { + return true; + } +}; + +// TypeParameterizedTestCase::Register() +// registers *all combinations* of 'Tests' and 'Types' with Google +// Test. The return value is insignificant - we just need to return +// something such that we can call this function in a namespace scope. +template +class TypeParameterizedTestCase { + public: + static bool Register(const char* prefix, const char* case_name, + const char* test_names) { + typedef typename Tests::Head Head; + + // First, register the first test in 'Test' for each type in 'Types'. + TypeParameterizedTest::Register( + prefix, case_name, test_names, 0); + + // Next, recurses (at compile time) with the tail of the test list. + return TypeParameterizedTestCase + ::Register(prefix, case_name, SkipComma(test_names)); + } +}; + +// The base case for the compile time recursion. +template +class TypeParameterizedTestCase { + public: + static bool Register(const char* /*prefix*/, const char* /*case_name*/, + const char* /*test_names*/) { + return true; + } +}; + +#endif // GTEST_HAS_TYPED_TEST || GTEST_HAS_TYPED_TEST_P + +// Returns the current OS stack trace as a String. +// +// The maximum number of stack frames to be included is specified by +// the gtest_stack_trace_depth flag. The skip_count parameter +// specifies the number of top frames to be skipped, which doesn't +// count against the number of frames to be included. +// +// For example, if Foo() calls Bar(), which in turn calls +// GetCurrentOsStackTraceExceptTop(..., 1), Foo() will be included in +// the trace but Bar() and GetCurrentOsStackTraceExceptTop() won't. +GTEST_API_ String GetCurrentOsStackTraceExceptTop(UnitTest* unit_test, + int skip_count); + +// Helpers for suppressing warnings on unreachable code or constant +// condition. + +// Always returns true. +GTEST_API_ bool AlwaysTrue(); + +// Always returns false. +inline bool AlwaysFalse() { return !AlwaysTrue(); } + +// Helper for suppressing false warning from Clang on a const char* +// variable declared in a conditional expression always being NULL in +// the else branch. +struct GTEST_API_ ConstCharPtr { + ConstCharPtr(const char* str) : value(str) {} + operator bool() const { return true; } + const char* value; +}; + +// A simple Linear Congruential Generator for generating random +// numbers with a uniform distribution. Unlike rand() and srand(), it +// doesn't use global state (and therefore can't interfere with user +// code). Unlike rand_r(), it's portable. An LCG isn't very random, +// but it's good enough for our purposes. +class GTEST_API_ Random { + public: + static const UInt32 kMaxRange = 1u << 31; + + explicit Random(UInt32 seed) : state_(seed) {} + + void Reseed(UInt32 seed) { state_ = seed; } + + // Generates a random number from [0, range). Crashes if 'range' is + // 0 or greater than kMaxRange. + UInt32 Generate(UInt32 range); + + private: + UInt32 state_; + GTEST_DISALLOW_COPY_AND_ASSIGN_(Random); +}; + +// Defining a variable of type CompileAssertTypesEqual will cause a +// compiler error iff T1 and T2 are different types. +template +struct CompileAssertTypesEqual; + +template +struct CompileAssertTypesEqual { +}; + +// Removes the reference from a type if it is a reference type, +// otherwise leaves it unchanged. This is the same as +// tr1::remove_reference, which is not widely available yet. +template +struct RemoveReference { typedef T type; }; // NOLINT +template +struct RemoveReference { typedef T type; }; // NOLINT + +// A handy wrapper around RemoveReference that works when the argument +// T depends on template parameters. +#define GTEST_REMOVE_REFERENCE_(T) \ + typename ::testing::internal::RemoveReference::type + +// Removes const from a type if it is a const type, otherwise leaves +// it unchanged. This is the same as tr1::remove_const, which is not +// widely available yet. +template +struct RemoveConst { typedef T type; }; // NOLINT +template +struct RemoveConst { typedef T type; }; // NOLINT + +// MSVC 8.0, Sun C++, and IBM XL C++ have a bug which causes the above +// definition to fail to remove the const in 'const int[3]' and 'const +// char[3][4]'. The following specialization works around the bug. +// However, it causes trouble with GCC and thus needs to be +// conditionally compiled. +#if defined(_MSC_VER) || defined(__SUNPRO_CC) || defined(__IBMCPP__) +template +struct RemoveConst { + typedef typename RemoveConst::type type[N]; +}; +#endif + +// A handy wrapper around RemoveConst that works when the argument +// T depends on template parameters. +#define GTEST_REMOVE_CONST_(T) \ + typename ::testing::internal::RemoveConst::type + +// Turns const U&, U&, const U, and U all into U. +#define GTEST_REMOVE_REFERENCE_AND_CONST_(T) \ + GTEST_REMOVE_CONST_(GTEST_REMOVE_REFERENCE_(T)) + +// Adds reference to a type if it is not a reference type, +// otherwise leaves it unchanged. This is the same as +// tr1::add_reference, which is not widely available yet. +template +struct AddReference { typedef T& type; }; // NOLINT +template +struct AddReference { typedef T& type; }; // NOLINT + +// A handy wrapper around AddReference that works when the argument T +// depends on template parameters. +#define GTEST_ADD_REFERENCE_(T) \ + typename ::testing::internal::AddReference::type + +// Adds a reference to const on top of T as necessary. For example, +// it transforms +// +// char ==> const char& +// const char ==> const char& +// char& ==> const char& +// const char& ==> const char& +// +// The argument T must depend on some template parameters. +#define GTEST_REFERENCE_TO_CONST_(T) \ + GTEST_ADD_REFERENCE_(const GTEST_REMOVE_REFERENCE_(T)) + +// ImplicitlyConvertible::value is a compile-time bool +// constant that's true iff type From can be implicitly converted to +// type To. +template +class ImplicitlyConvertible { + private: + // We need the following helper functions only for their types. + // They have no implementations. + + // MakeFrom() is an expression whose type is From. We cannot simply + // use From(), as the type From may not have a public default + // constructor. + static From MakeFrom(); + + // These two functions are overloaded. Given an expression + // Helper(x), the compiler will pick the first version if x can be + // implicitly converted to type To; otherwise it will pick the + // second version. + // + // The first version returns a value of size 1, and the second + // version returns a value of size 2. Therefore, by checking the + // size of Helper(x), which can be done at compile time, we can tell + // which version of Helper() is used, and hence whether x can be + // implicitly converted to type To. + static char Helper(To); + static char (&Helper(...))[2]; // NOLINT + + // We have to put the 'public' section after the 'private' section, + // or MSVC refuses to compile the code. + public: + // MSVC warns about implicitly converting from double to int for + // possible loss of data, so we need to temporarily disable the + // warning. +#ifdef _MSC_VER +# pragma warning(push) // Saves the current warning state. +# pragma warning(disable:4244) // Temporarily disables warning 4244. + + static const bool value = + sizeof(Helper(ImplicitlyConvertible::MakeFrom())) == 1; +# pragma warning(pop) // Restores the warning state. +#elif defined(__BORLANDC__) + // C++Builder cannot use member overload resolution during template + // instantiation. The simplest workaround is to use its C++0x type traits + // functions (C++Builder 2009 and above only). + static const bool value = __is_convertible(From, To); +#else + static const bool value = + sizeof(Helper(ImplicitlyConvertible::MakeFrom())) == 1; +#endif // _MSV_VER +}; +template +const bool ImplicitlyConvertible::value; + +// IsAProtocolMessage::value is a compile-time bool constant that's +// true iff T is type ProtocolMessage, proto2::Message, or a subclass +// of those. +template +struct IsAProtocolMessage + : public bool_constant< + ImplicitlyConvertible::value || + ImplicitlyConvertible::value> { +}; + +// When the compiler sees expression IsContainerTest(0), if C is an +// STL-style container class, the first overload of IsContainerTest +// will be viable (since both C::iterator* and C::const_iterator* are +// valid types and NULL can be implicitly converted to them). It will +// be picked over the second overload as 'int' is a perfect match for +// the type of argument 0. If C::iterator or C::const_iterator is not +// a valid type, the first overload is not viable, and the second +// overload will be picked. Therefore, we can determine whether C is +// a container class by checking the type of IsContainerTest(0). +// The value of the expression is insignificant. +// +// Note that we look for both C::iterator and C::const_iterator. The +// reason is that C++ injects the name of a class as a member of the +// class itself (e.g. you can refer to class iterator as either +// 'iterator' or 'iterator::iterator'). If we look for C::iterator +// only, for example, we would mistakenly think that a class named +// iterator is an STL container. +// +// Also note that the simpler approach of overloading +// IsContainerTest(typename C::const_iterator*) and +// IsContainerTest(...) doesn't work with Visual Age C++ and Sun C++. +typedef int IsContainer; +template +IsContainer IsContainerTest(int /* dummy */, + typename C::iterator* /* it */ = NULL, + typename C::const_iterator* /* const_it */ = NULL) { + return 0; +} + +typedef char IsNotContainer; +template +IsNotContainer IsContainerTest(long /* dummy */) { return '\0'; } + +// EnableIf::type is void when 'Cond' is true, and +// undefined when 'Cond' is false. To use SFINAE to make a function +// overload only apply when a particular expression is true, add +// "typename EnableIf::type* = 0" as the last parameter. +template struct EnableIf; +template<> struct EnableIf { typedef void type; }; // NOLINT + +// Utilities for native arrays. + +// ArrayEq() compares two k-dimensional native arrays using the +// elements' operator==, where k can be any integer >= 0. When k is +// 0, ArrayEq() degenerates into comparing a single pair of values. + +template +bool ArrayEq(const T* lhs, size_t size, const U* rhs); + +// This generic version is used when k is 0. +template +inline bool ArrayEq(const T& lhs, const U& rhs) { return lhs == rhs; } + +// This overload is used when k >= 1. +template +inline bool ArrayEq(const T(&lhs)[N], const U(&rhs)[N]) { + return internal::ArrayEq(lhs, N, rhs); +} + +// This helper reduces code bloat. If we instead put its logic inside +// the previous ArrayEq() function, arrays with different sizes would +// lead to different copies of the template code. +template +bool ArrayEq(const T* lhs, size_t size, const U* rhs) { + for (size_t i = 0; i != size; i++) { + if (!internal::ArrayEq(lhs[i], rhs[i])) + return false; + } + return true; +} + +// Finds the first element in the iterator range [begin, end) that +// equals elem. Element may be a native array type itself. +template +Iter ArrayAwareFind(Iter begin, Iter end, const Element& elem) { + for (Iter it = begin; it != end; ++it) { + if (internal::ArrayEq(*it, elem)) + return it; + } + return end; +} + +// CopyArray() copies a k-dimensional native array using the elements' +// operator=, where k can be any integer >= 0. When k is 0, +// CopyArray() degenerates into copying a single value. + +template +void CopyArray(const T* from, size_t size, U* to); + +// This generic version is used when k is 0. +template +inline void CopyArray(const T& from, U* to) { *to = from; } + +// This overload is used when k >= 1. +template +inline void CopyArray(const T(&from)[N], U(*to)[N]) { + internal::CopyArray(from, N, *to); +} + +// This helper reduces code bloat. If we instead put its logic inside +// the previous CopyArray() function, arrays with different sizes +// would lead to different copies of the template code. +template +void CopyArray(const T* from, size_t size, U* to) { + for (size_t i = 0; i != size; i++) { + internal::CopyArray(from[i], to + i); + } +} + +// The relation between an NativeArray object (see below) and the +// native array it represents. +enum RelationToSource { + kReference, // The NativeArray references the native array. + kCopy // The NativeArray makes a copy of the native array and + // owns the copy. +}; + +// Adapts a native array to a read-only STL-style container. Instead +// of the complete STL container concept, this adaptor only implements +// members useful for Google Mock's container matchers. New members +// should be added as needed. To simplify the implementation, we only +// support Element being a raw type (i.e. having no top-level const or +// reference modifier). It's the client's responsibility to satisfy +// this requirement. Element can be an array type itself (hence +// multi-dimensional arrays are supported). +template +class NativeArray { + public: + // STL-style container typedefs. + typedef Element value_type; + typedef Element* iterator; + typedef const Element* const_iterator; + + // Constructs from a native array. + NativeArray(const Element* array, size_t count, RelationToSource relation) { + Init(array, count, relation); + } + + // Copy constructor. + NativeArray(const NativeArray& rhs) { + Init(rhs.array_, rhs.size_, rhs.relation_to_source_); + } + + ~NativeArray() { + // Ensures that the user doesn't instantiate NativeArray with a + // const or reference type. + static_cast(StaticAssertTypeEqHelper()); + if (relation_to_source_ == kCopy) + delete[] array_; + } + + // STL-style container methods. + size_t size() const { return size_; } + const_iterator begin() const { return array_; } + const_iterator end() const { return array_ + size_; } + bool operator==(const NativeArray& rhs) const { + return size() == rhs.size() && + ArrayEq(begin(), size(), rhs.begin()); + } + + private: + // Initializes this object; makes a copy of the input array if + // 'relation' is kCopy. + void Init(const Element* array, size_t a_size, RelationToSource relation) { + if (relation == kReference) { + array_ = array; + } else { + Element* const copy = new Element[a_size]; + CopyArray(array, a_size, copy); + array_ = copy; + } + size_ = a_size; + relation_to_source_ = relation; + } + + const Element* array_; + size_t size_; + RelationToSource relation_to_source_; + + GTEST_DISALLOW_ASSIGN_(NativeArray); +}; + +} // namespace internal +} // namespace testing + +#define GTEST_MESSAGE_AT_(file, line, message, result_type) \ + ::testing::internal::AssertHelper(result_type, file, line, message) \ + = ::testing::Message() + +#define GTEST_MESSAGE_(message, result_type) \ + GTEST_MESSAGE_AT_(__FILE__, __LINE__, message, result_type) + +#define GTEST_FATAL_FAILURE_(message) \ + return GTEST_MESSAGE_(message, ::testing::TestPartResult::kFatalFailure) + +#define GTEST_NONFATAL_FAILURE_(message) \ + GTEST_MESSAGE_(message, ::testing::TestPartResult::kNonFatalFailure) + +#define GTEST_SUCCESS_(message) \ + GTEST_MESSAGE_(message, ::testing::TestPartResult::kSuccess) + +// Suppresses MSVC warnings 4072 (unreachable code) for the code following +// statement if it returns or throws (or doesn't return or throw in some +// situations). +#define GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement) \ + if (::testing::internal::AlwaysTrue()) { statement; } + +#define GTEST_TEST_THROW_(statement, expected_exception, fail) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::ConstCharPtr gtest_msg = "") { \ + bool gtest_caught_expected = false; \ + try { \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ + } \ + catch (expected_exception const&) { \ + gtest_caught_expected = true; \ + } \ + catch (...) { \ + gtest_msg.value = \ + "Expected: " #statement " throws an exception of type " \ + #expected_exception ".\n Actual: it throws a different type."; \ + goto GTEST_CONCAT_TOKEN_(gtest_label_testthrow_, __LINE__); \ + } \ + if (!gtest_caught_expected) { \ + gtest_msg.value = \ + "Expected: " #statement " throws an exception of type " \ + #expected_exception ".\n Actual: it throws nothing."; \ + goto GTEST_CONCAT_TOKEN_(gtest_label_testthrow_, __LINE__); \ + } \ + } else \ + GTEST_CONCAT_TOKEN_(gtest_label_testthrow_, __LINE__): \ + fail(gtest_msg.value) + +#define GTEST_TEST_NO_THROW_(statement, fail) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::AlwaysTrue()) { \ + try { \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ + } \ + catch (...) { \ + goto GTEST_CONCAT_TOKEN_(gtest_label_testnothrow_, __LINE__); \ + } \ + } else \ + GTEST_CONCAT_TOKEN_(gtest_label_testnothrow_, __LINE__): \ + fail("Expected: " #statement " doesn't throw an exception.\n" \ + " Actual: it throws.") + +#define GTEST_TEST_ANY_THROW_(statement, fail) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::AlwaysTrue()) { \ + bool gtest_caught_any = false; \ + try { \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ + } \ + catch (...) { \ + gtest_caught_any = true; \ + } \ + if (!gtest_caught_any) { \ + goto GTEST_CONCAT_TOKEN_(gtest_label_testanythrow_, __LINE__); \ + } \ + } else \ + GTEST_CONCAT_TOKEN_(gtest_label_testanythrow_, __LINE__): \ + fail("Expected: " #statement " throws an exception.\n" \ + " Actual: it doesn't.") + + +// Implements Boolean test assertions such as EXPECT_TRUE. expression can be +// either a boolean expression or an AssertionResult. text is a textual +// represenation of expression as it was passed into the EXPECT_TRUE. +#define GTEST_TEST_BOOLEAN_(expression, text, actual, expected, fail) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (const ::testing::AssertionResult gtest_ar_ = \ + ::testing::AssertionResult(expression)) \ + ; \ + else \ + fail(::testing::internal::GetBoolAssertionFailureMessage(\ + gtest_ar_, text, #actual, #expected).c_str()) + +#define GTEST_TEST_NO_FATAL_FAILURE_(statement, fail) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::AlwaysTrue()) { \ + ::testing::internal::HasNewFatalFailureHelper gtest_fatal_failure_checker; \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ + if (gtest_fatal_failure_checker.has_new_fatal_failure()) { \ + goto GTEST_CONCAT_TOKEN_(gtest_label_testnofatal_, __LINE__); \ + } \ + } else \ + GTEST_CONCAT_TOKEN_(gtest_label_testnofatal_, __LINE__): \ + fail("Expected: " #statement " doesn't generate new fatal " \ + "failures in the current thread.\n" \ + " Actual: it does.") + +// Expands to the name of the class that implements the given test. +#define GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \ + test_case_name##_##test_name##_Test + +// Helper macro for defining tests. +#define GTEST_TEST_(test_case_name, test_name, parent_class, parent_id)\ +class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) : public parent_class {\ + public:\ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {}\ + private:\ + virtual void TestBody();\ + static ::testing::TestInfo* const test_info_ GTEST_ATTRIBUTE_UNUSED_;\ + GTEST_DISALLOW_COPY_AND_ASSIGN_(\ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name));\ +};\ +\ +::testing::TestInfo* const GTEST_TEST_CLASS_NAME_(test_case_name, test_name)\ + ::test_info_ =\ + ::testing::internal::MakeAndRegisterTestInfo(\ + #test_case_name, #test_name, NULL, NULL, \ + (parent_id), \ + parent_class::SetUpTestCase, \ + parent_class::TearDownTestCase, \ + new ::testing::internal::TestFactoryImpl<\ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name)>);\ +void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody() + +#endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_INTERNAL_H_ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) +// +// The Google C++ Testing Framework (Google Test) +// +// This header file defines the public API for death tests. It is +// #included by gtest.h so a user doesn't need to include this +// directly. + +#ifndef GTEST_INCLUDE_GTEST_GTEST_DEATH_TEST_H_ +#define GTEST_INCLUDE_GTEST_GTEST_DEATH_TEST_H_ + +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Authors: wan@google.com (Zhanyong Wan), eefacm@gmail.com (Sean Mcafee) +// +// The Google C++ Testing Framework (Google Test) +// +// This header file defines internal utilities needed for implementing +// death tests. They are subject to change without notice. + +#ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_DEATH_TEST_INTERNAL_H_ +#define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_DEATH_TEST_INTERNAL_H_ + + +#include + +namespace testing { +namespace internal { + +GTEST_DECLARE_string_(internal_run_death_test); + +// Names of the flags (needed for parsing Google Test flags). +const char kDeathTestStyleFlag[] = "death_test_style"; +const char kDeathTestUseFork[] = "death_test_use_fork"; +const char kInternalRunDeathTestFlag[] = "internal_run_death_test"; + +#if GTEST_HAS_DEATH_TEST + +// DeathTest is a class that hides much of the complexity of the +// GTEST_DEATH_TEST_ macro. It is abstract; its static Create method +// returns a concrete class that depends on the prevailing death test +// style, as defined by the --gtest_death_test_style and/or +// --gtest_internal_run_death_test flags. + +// In describing the results of death tests, these terms are used with +// the corresponding definitions: +// +// exit status: The integer exit information in the format specified +// by wait(2) +// exit code: The integer code passed to exit(3), _exit(2), or +// returned from main() +class GTEST_API_ DeathTest { + public: + // Create returns false if there was an error determining the + // appropriate action to take for the current death test; for example, + // if the gtest_death_test_style flag is set to an invalid value. + // The LastMessage method will return a more detailed message in that + // case. Otherwise, the DeathTest pointer pointed to by the "test" + // argument is set. If the death test should be skipped, the pointer + // is set to NULL; otherwise, it is set to the address of a new concrete + // DeathTest object that controls the execution of the current test. + static bool Create(const char* statement, const RE* regex, + const char* file, int line, DeathTest** test); + DeathTest(); + virtual ~DeathTest() { } + + // A helper class that aborts a death test when it's deleted. + class ReturnSentinel { + public: + explicit ReturnSentinel(DeathTest* test) : test_(test) { } + ~ReturnSentinel() { test_->Abort(TEST_ENCOUNTERED_RETURN_STATEMENT); } + private: + DeathTest* const test_; + GTEST_DISALLOW_COPY_AND_ASSIGN_(ReturnSentinel); + } GTEST_ATTRIBUTE_UNUSED_; + + // An enumeration of possible roles that may be taken when a death + // test is encountered. EXECUTE means that the death test logic should + // be executed immediately. OVERSEE means that the program should prepare + // the appropriate environment for a child process to execute the death + // test, then wait for it to complete. + enum TestRole { OVERSEE_TEST, EXECUTE_TEST }; + + // An enumeration of the three reasons that a test might be aborted. + enum AbortReason { + TEST_ENCOUNTERED_RETURN_STATEMENT, + TEST_THREW_EXCEPTION, + TEST_DID_NOT_DIE + }; + + // Assumes one of the above roles. + virtual TestRole AssumeRole() = 0; + + // Waits for the death test to finish and returns its status. + virtual int Wait() = 0; + + // Returns true if the death test passed; that is, the test process + // exited during the test, its exit status matches a user-supplied + // predicate, and its stderr output matches a user-supplied regular + // expression. + // The user-supplied predicate may be a macro expression rather + // than a function pointer or functor, or else Wait and Passed could + // be combined. + virtual bool Passed(bool exit_status_ok) = 0; + + // Signals that the death test did not die as expected. + virtual void Abort(AbortReason reason) = 0; + + // Returns a human-readable outcome message regarding the outcome of + // the last death test. + static const char* LastMessage(); + + static void set_last_death_test_message(const String& message); + + private: + // A string containing a description of the outcome of the last death test. + static String last_death_test_message_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(DeathTest); +}; + +// Factory interface for death tests. May be mocked out for testing. +class DeathTestFactory { + public: + virtual ~DeathTestFactory() { } + virtual bool Create(const char* statement, const RE* regex, + const char* file, int line, DeathTest** test) = 0; +}; + +// A concrete DeathTestFactory implementation for normal use. +class DefaultDeathTestFactory : public DeathTestFactory { + public: + virtual bool Create(const char* statement, const RE* regex, + const char* file, int line, DeathTest** test); +}; + +// Returns true if exit_status describes a process that was terminated +// by a signal, or exited normally with a nonzero exit code. +GTEST_API_ bool ExitedUnsuccessfully(int exit_status); + +// Traps C++ exceptions escaping statement and reports them as test +// failures. Note that trapping SEH exceptions is not implemented here. +# if GTEST_HAS_EXCEPTIONS +# define GTEST_EXECUTE_DEATH_TEST_STATEMENT_(statement, death_test) \ + try { \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ + } catch (const ::std::exception& gtest_exception) { \ + fprintf(\ + stderr, \ + "\n%s: Caught std::exception-derived exception escaping the " \ + "death test statement. Exception message: %s\n", \ + ::testing::internal::FormatFileLocation(__FILE__, __LINE__).c_str(), \ + gtest_exception.what()); \ + fflush(stderr); \ + death_test->Abort(::testing::internal::DeathTest::TEST_THREW_EXCEPTION); \ + } catch (...) { \ + death_test->Abort(::testing::internal::DeathTest::TEST_THREW_EXCEPTION); \ + } + +# else +# define GTEST_EXECUTE_DEATH_TEST_STATEMENT_(statement, death_test) \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement) + +# endif + +// This macro is for implementing ASSERT_DEATH*, EXPECT_DEATH*, +// ASSERT_EXIT*, and EXPECT_EXIT*. +# define GTEST_DEATH_TEST_(statement, predicate, regex, fail) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::AlwaysTrue()) { \ + const ::testing::internal::RE& gtest_regex = (regex); \ + ::testing::internal::DeathTest* gtest_dt; \ + if (!::testing::internal::DeathTest::Create(#statement, >est_regex, \ + __FILE__, __LINE__, >est_dt)) { \ + goto GTEST_CONCAT_TOKEN_(gtest_label_, __LINE__); \ + } \ + if (gtest_dt != NULL) { \ + ::testing::internal::scoped_ptr< ::testing::internal::DeathTest> \ + gtest_dt_ptr(gtest_dt); \ + switch (gtest_dt->AssumeRole()) { \ + case ::testing::internal::DeathTest::OVERSEE_TEST: \ + if (!gtest_dt->Passed(predicate(gtest_dt->Wait()))) { \ + goto GTEST_CONCAT_TOKEN_(gtest_label_, __LINE__); \ + } \ + break; \ + case ::testing::internal::DeathTest::EXECUTE_TEST: { \ + ::testing::internal::DeathTest::ReturnSentinel \ + gtest_sentinel(gtest_dt); \ + GTEST_EXECUTE_DEATH_TEST_STATEMENT_(statement, gtest_dt); \ + gtest_dt->Abort(::testing::internal::DeathTest::TEST_DID_NOT_DIE); \ + break; \ + } \ + default: \ + break; \ + } \ + } \ + } else \ + GTEST_CONCAT_TOKEN_(gtest_label_, __LINE__): \ + fail(::testing::internal::DeathTest::LastMessage()) +// The symbol "fail" here expands to something into which a message +// can be streamed. + +// A class representing the parsed contents of the +// --gtest_internal_run_death_test flag, as it existed when +// RUN_ALL_TESTS was called. +class InternalRunDeathTestFlag { + public: + InternalRunDeathTestFlag(const String& a_file, + int a_line, + int an_index, + int a_write_fd) + : file_(a_file), line_(a_line), index_(an_index), + write_fd_(a_write_fd) {} + + ~InternalRunDeathTestFlag() { + if (write_fd_ >= 0) + posix::Close(write_fd_); + } + + String file() const { return file_; } + int line() const { return line_; } + int index() const { return index_; } + int write_fd() const { return write_fd_; } + + private: + String file_; + int line_; + int index_; + int write_fd_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(InternalRunDeathTestFlag); +}; + +// Returns a newly created InternalRunDeathTestFlag object with fields +// initialized from the GTEST_FLAG(internal_run_death_test) flag if +// the flag is specified; otherwise returns NULL. +InternalRunDeathTestFlag* ParseInternalRunDeathTestFlag(); + +#else // GTEST_HAS_DEATH_TEST + +// This macro is used for implementing macros such as +// EXPECT_DEATH_IF_SUPPORTED and ASSERT_DEATH_IF_SUPPORTED on systems where +// death tests are not supported. Those macros must compile on such systems +// iff EXPECT_DEATH and ASSERT_DEATH compile with the same parameters on +// systems that support death tests. This allows one to write such a macro +// on a system that does not support death tests and be sure that it will +// compile on a death-test supporting system. +// +// Parameters: +// statement - A statement that a macro such as EXPECT_DEATH would test +// for program termination. This macro has to make sure this +// statement is compiled but not executed, to ensure that +// EXPECT_DEATH_IF_SUPPORTED compiles with a certain +// parameter iff EXPECT_DEATH compiles with it. +// regex - A regex that a macro such as EXPECT_DEATH would use to test +// the output of statement. This parameter has to be +// compiled but not evaluated by this macro, to ensure that +// this macro only accepts expressions that a macro such as +// EXPECT_DEATH would accept. +// terminator - Must be an empty statement for EXPECT_DEATH_IF_SUPPORTED +// and a return statement for ASSERT_DEATH_IF_SUPPORTED. +// This ensures that ASSERT_DEATH_IF_SUPPORTED will not +// compile inside functions where ASSERT_DEATH doesn't +// compile. +// +// The branch that has an always false condition is used to ensure that +// statement and regex are compiled (and thus syntactically correct) but +// never executed. The unreachable code macro protects the terminator +// statement from generating an 'unreachable code' warning in case +// statement unconditionally returns or throws. The Message constructor at +// the end allows the syntax of streaming additional messages into the +// macro, for compilational compatibility with EXPECT_DEATH/ASSERT_DEATH. +# define GTEST_UNSUPPORTED_DEATH_TEST_(statement, regex, terminator) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (::testing::internal::AlwaysTrue()) { \ + GTEST_LOG_(WARNING) \ + << "Death tests are not supported on this platform.\n" \ + << "Statement '" #statement "' cannot be verified."; \ + } else if (::testing::internal::AlwaysFalse()) { \ + ::testing::internal::RE::PartialMatch(".*", (regex)); \ + GTEST_SUPPRESS_UNREACHABLE_CODE_WARNING_BELOW_(statement); \ + terminator; \ + } else \ + ::testing::Message() + +#endif // GTEST_HAS_DEATH_TEST + +} // namespace internal +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_DEATH_TEST_INTERNAL_H_ + +namespace testing { + +// This flag controls the style of death tests. Valid values are "threadsafe", +// meaning that the death test child process will re-execute the test binary +// from the start, running only a single death test, or "fast", +// meaning that the child process will execute the test logic immediately +// after forking. +GTEST_DECLARE_string_(death_test_style); + +#if GTEST_HAS_DEATH_TEST + +// The following macros are useful for writing death tests. + +// Here's what happens when an ASSERT_DEATH* or EXPECT_DEATH* is +// executed: +// +// 1. It generates a warning if there is more than one active +// thread. This is because it's safe to fork() or clone() only +// when there is a single thread. +// +// 2. The parent process clone()s a sub-process and runs the death +// test in it; the sub-process exits with code 0 at the end of the +// death test, if it hasn't exited already. +// +// 3. The parent process waits for the sub-process to terminate. +// +// 4. The parent process checks the exit code and error message of +// the sub-process. +// +// Examples: +// +// ASSERT_DEATH(server.SendMessage(56, "Hello"), "Invalid port number"); +// for (int i = 0; i < 5; i++) { +// EXPECT_DEATH(server.ProcessRequest(i), +// "Invalid request .* in ProcessRequest()") +// << "Failed to die on request " << i); +// } +// +// ASSERT_EXIT(server.ExitNow(), ::testing::ExitedWithCode(0), "Exiting"); +// +// bool KilledBySIGHUP(int exit_code) { +// return WIFSIGNALED(exit_code) && WTERMSIG(exit_code) == SIGHUP; +// } +// +// ASSERT_EXIT(client.HangUpServer(), KilledBySIGHUP, "Hanging up!"); +// +// On the regular expressions used in death tests: +// +// On POSIX-compliant systems (*nix), we use the library, +// which uses the POSIX extended regex syntax. +// +// On other platforms (e.g. Windows), we only support a simple regex +// syntax implemented as part of Google Test. This limited +// implementation should be enough most of the time when writing +// death tests; though it lacks many features you can find in PCRE +// or POSIX extended regex syntax. For example, we don't support +// union ("x|y"), grouping ("(xy)"), brackets ("[xy]"), and +// repetition count ("x{5,7}"), among others. +// +// Below is the syntax that we do support. We chose it to be a +// subset of both PCRE and POSIX extended regex, so it's easy to +// learn wherever you come from. In the following: 'A' denotes a +// literal character, period (.), or a single \\ escape sequence; +// 'x' and 'y' denote regular expressions; 'm' and 'n' are for +// natural numbers. +// +// c matches any literal character c +// \\d matches any decimal digit +// \\D matches any character that's not a decimal digit +// \\f matches \f +// \\n matches \n +// \\r matches \r +// \\s matches any ASCII whitespace, including \n +// \\S matches any character that's not a whitespace +// \\t matches \t +// \\v matches \v +// \\w matches any letter, _, or decimal digit +// \\W matches any character that \\w doesn't match +// \\c matches any literal character c, which must be a punctuation +// . matches any single character except \n +// A? matches 0 or 1 occurrences of A +// A* matches 0 or many occurrences of A +// A+ matches 1 or many occurrences of A +// ^ matches the beginning of a string (not that of each line) +// $ matches the end of a string (not that of each line) +// xy matches x followed by y +// +// If you accidentally use PCRE or POSIX extended regex features +// not implemented by us, you will get a run-time failure. In that +// case, please try to rewrite your regular expression within the +// above syntax. +// +// This implementation is *not* meant to be as highly tuned or robust +// as a compiled regex library, but should perform well enough for a +// death test, which already incurs significant overhead by launching +// a child process. +// +// Known caveats: +// +// A "threadsafe" style death test obtains the path to the test +// program from argv[0] and re-executes it in the sub-process. For +// simplicity, the current implementation doesn't search the PATH +// when launching the sub-process. This means that the user must +// invoke the test program via a path that contains at least one +// path separator (e.g. path/to/foo_test and +// /absolute/path/to/bar_test are fine, but foo_test is not). This +// is rarely a problem as people usually don't put the test binary +// directory in PATH. +// +// TODO(wan@google.com): make thread-safe death tests search the PATH. + +// Asserts that a given statement causes the program to exit, with an +// integer exit status that satisfies predicate, and emitting error output +// that matches regex. +# define ASSERT_EXIT(statement, predicate, regex) \ + GTEST_DEATH_TEST_(statement, predicate, regex, GTEST_FATAL_FAILURE_) + +// Like ASSERT_EXIT, but continues on to successive tests in the +// test case, if any: +# define EXPECT_EXIT(statement, predicate, regex) \ + GTEST_DEATH_TEST_(statement, predicate, regex, GTEST_NONFATAL_FAILURE_) + +// Asserts that a given statement causes the program to exit, either by +// explicitly exiting with a nonzero exit code or being killed by a +// signal, and emitting error output that matches regex. +# define ASSERT_DEATH(statement, regex) \ + ASSERT_EXIT(statement, ::testing::internal::ExitedUnsuccessfully, regex) + +// Like ASSERT_DEATH, but continues on to successive tests in the +// test case, if any: +# define EXPECT_DEATH(statement, regex) \ + EXPECT_EXIT(statement, ::testing::internal::ExitedUnsuccessfully, regex) + +// Two predicate classes that can be used in {ASSERT,EXPECT}_EXIT*: + +// Tests that an exit code describes a normal exit with a given exit code. +class GTEST_API_ ExitedWithCode { + public: + explicit ExitedWithCode(int exit_code); + bool operator()(int exit_status) const; + private: + // No implementation - assignment is unsupported. + void operator=(const ExitedWithCode& other); + + const int exit_code_; +}; + +# if !GTEST_OS_WINDOWS +// Tests that an exit code describes an exit due to termination by a +// given signal. +class GTEST_API_ KilledBySignal { + public: + explicit KilledBySignal(int signum); + bool operator()(int exit_status) const; + private: + const int signum_; +}; +# endif // !GTEST_OS_WINDOWS + +// EXPECT_DEBUG_DEATH asserts that the given statements die in debug mode. +// The death testing framework causes this to have interesting semantics, +// since the sideeffects of the call are only visible in opt mode, and not +// in debug mode. +// +// In practice, this can be used to test functions that utilize the +// LOG(DFATAL) macro using the following style: +// +// int DieInDebugOr12(int* sideeffect) { +// if (sideeffect) { +// *sideeffect = 12; +// } +// LOG(DFATAL) << "death"; +// return 12; +// } +// +// TEST(TestCase, TestDieOr12WorksInDgbAndOpt) { +// int sideeffect = 0; +// // Only asserts in dbg. +// EXPECT_DEBUG_DEATH(DieInDebugOr12(&sideeffect), "death"); +// +// #ifdef NDEBUG +// // opt-mode has sideeffect visible. +// EXPECT_EQ(12, sideeffect); +// #else +// // dbg-mode no visible sideeffect. +// EXPECT_EQ(0, sideeffect); +// #endif +// } +// +// This will assert that DieInDebugReturn12InOpt() crashes in debug +// mode, usually due to a DCHECK or LOG(DFATAL), but returns the +// appropriate fallback value (12 in this case) in opt mode. If you +// need to test that a function has appropriate side-effects in opt +// mode, include assertions against the side-effects. A general +// pattern for this is: +// +// EXPECT_DEBUG_DEATH({ +// // Side-effects here will have an effect after this statement in +// // opt mode, but none in debug mode. +// EXPECT_EQ(12, DieInDebugOr12(&sideeffect)); +// }, "death"); +// +# ifdef NDEBUG + +# define EXPECT_DEBUG_DEATH(statement, regex) \ + do { statement; } while (::testing::internal::AlwaysFalse()) + +# define ASSERT_DEBUG_DEATH(statement, regex) \ + do { statement; } while (::testing::internal::AlwaysFalse()) + +# else + +# define EXPECT_DEBUG_DEATH(statement, regex) \ + EXPECT_DEATH(statement, regex) + +# define ASSERT_DEBUG_DEATH(statement, regex) \ + ASSERT_DEATH(statement, regex) + +# endif // NDEBUG for EXPECT_DEBUG_DEATH +#endif // GTEST_HAS_DEATH_TEST + +// EXPECT_DEATH_IF_SUPPORTED(statement, regex) and +// ASSERT_DEATH_IF_SUPPORTED(statement, regex) expand to real death tests if +// death tests are supported; otherwise they just issue a warning. This is +// useful when you are combining death test assertions with normal test +// assertions in one test. +#if GTEST_HAS_DEATH_TEST +# define EXPECT_DEATH_IF_SUPPORTED(statement, regex) \ + EXPECT_DEATH(statement, regex) +# define ASSERT_DEATH_IF_SUPPORTED(statement, regex) \ + ASSERT_DEATH(statement, regex) +#else +# define EXPECT_DEATH_IF_SUPPORTED(statement, regex) \ + GTEST_UNSUPPORTED_DEATH_TEST_(statement, regex, ) +# define ASSERT_DEATH_IF_SUPPORTED(statement, regex) \ + GTEST_UNSUPPORTED_DEATH_TEST_(statement, regex, return) +#endif + +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_GTEST_DEATH_TEST_H_ +// Copyright 2005, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) +// +// The Google C++ Testing Framework (Google Test) +// +// This header file defines the Message class. +// +// IMPORTANT NOTE: Due to limitation of the C++ language, we have to +// leave some internal implementation details in this header file. +// They are clearly marked by comments like this: +// +// // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +// +// Such code is NOT meant to be used by a user directly, and is subject +// to CHANGE WITHOUT NOTICE. Therefore DO NOT DEPEND ON IT in a user +// program! + +#ifndef GTEST_INCLUDE_GTEST_GTEST_MESSAGE_H_ +#define GTEST_INCLUDE_GTEST_GTEST_MESSAGE_H_ + +#include + + +namespace testing { + +// The Message class works like an ostream repeater. +// +// Typical usage: +// +// 1. You stream a bunch of values to a Message object. +// It will remember the text in a stringstream. +// 2. Then you stream the Message object to an ostream. +// This causes the text in the Message to be streamed +// to the ostream. +// +// For example; +// +// testing::Message foo; +// foo << 1 << " != " << 2; +// std::cout << foo; +// +// will print "1 != 2". +// +// Message is not intended to be inherited from. In particular, its +// destructor is not virtual. +// +// Note that stringstream behaves differently in gcc and in MSVC. You +// can stream a NULL char pointer to it in the former, but not in the +// latter (it causes an access violation if you do). The Message +// class hides this difference by treating a NULL char pointer as +// "(null)". +class GTEST_API_ Message { + private: + // The type of basic IO manipulators (endl, ends, and flush) for + // narrow streams. + typedef std::ostream& (*BasicNarrowIoManip)(std::ostream&); + + public: + // Constructs an empty Message. + // We allocate the stringstream separately because otherwise each use of + // ASSERT/EXPECT in a procedure adds over 200 bytes to the procedure's + // stack frame leading to huge stack frames in some cases; gcc does not reuse + // the stack space. + Message() : ss_(new ::std::stringstream) { + // By default, we want there to be enough precision when printing + // a double to a Message. + *ss_ << std::setprecision(std::numeric_limits::digits10 + 2); + } + + // Copy constructor. + Message(const Message& msg) : ss_(new ::std::stringstream) { // NOLINT + *ss_ << msg.GetString(); + } + + // Constructs a Message from a C-string. + explicit Message(const char* str) : ss_(new ::std::stringstream) { + *ss_ << str; + } + +#if GTEST_OS_SYMBIAN + // Streams a value (either a pointer or not) to this object. + template + inline Message& operator <<(const T& value) { + StreamHelper(typename internal::is_pointer::type(), value); + return *this; + } +#else + // Streams a non-pointer value to this object. + template + inline Message& operator <<(const T& val) { + ::GTestStreamToHelper(ss_.get(), val); + return *this; + } + + // Streams a pointer value to this object. + // + // This function is an overload of the previous one. When you + // stream a pointer to a Message, this definition will be used as it + // is more specialized. (The C++ Standard, section + // [temp.func.order].) If you stream a non-pointer, then the + // previous definition will be used. + // + // The reason for this overload is that streaming a NULL pointer to + // ostream is undefined behavior. Depending on the compiler, you + // may get "0", "(nil)", "(null)", or an access violation. To + // ensure consistent result across compilers, we always treat NULL + // as "(null)". + template + inline Message& operator <<(T* const& pointer) { // NOLINT + if (pointer == NULL) { + *ss_ << "(null)"; + } else { + ::GTestStreamToHelper(ss_.get(), pointer); + } + return *this; + } +#endif // GTEST_OS_SYMBIAN + + // Since the basic IO manipulators are overloaded for both narrow + // and wide streams, we have to provide this specialized definition + // of operator <<, even though its body is the same as the + // templatized version above. Without this definition, streaming + // endl or other basic IO manipulators to Message will confuse the + // compiler. + Message& operator <<(BasicNarrowIoManip val) { + *ss_ << val; + return *this; + } + + // Instead of 1/0, we want to see true/false for bool values. + Message& operator <<(bool b) { + return *this << (b ? "true" : "false"); + } + + // These two overloads allow streaming a wide C string to a Message + // using the UTF-8 encoding. + Message& operator <<(const wchar_t* wide_c_str) { + return *this << internal::String::ShowWideCString(wide_c_str); + } + Message& operator <<(wchar_t* wide_c_str) { + return *this << internal::String::ShowWideCString(wide_c_str); + } + +#if GTEST_HAS_STD_WSTRING + // Converts the given wide string to a narrow string using the UTF-8 + // encoding, and streams the result to this Message object. + Message& operator <<(const ::std::wstring& wstr); +#endif // GTEST_HAS_STD_WSTRING + +#if GTEST_HAS_GLOBAL_WSTRING + // Converts the given wide string to a narrow string using the UTF-8 + // encoding, and streams the result to this Message object. + Message& operator <<(const ::wstring& wstr); +#endif // GTEST_HAS_GLOBAL_WSTRING + + // Gets the text streamed to this object so far as a String. + // Each '\0' character in the buffer is replaced with "\\0". + // + // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. + internal::String GetString() const { + return internal::StringStreamToString(ss_.get()); + } + + private: + +#if GTEST_OS_SYMBIAN + // These are needed as the Nokia Symbian Compiler cannot decide between + // const T& and const T* in a function template. The Nokia compiler _can_ + // decide between class template specializations for T and T*, so a + // tr1::type_traits-like is_pointer works, and we can overload on that. + template + inline void StreamHelper(internal::true_type /*dummy*/, T* pointer) { + if (pointer == NULL) { + *ss_ << "(null)"; + } else { + ::GTestStreamToHelper(ss_.get(), pointer); + } + } + template + inline void StreamHelper(internal::false_type /*dummy*/, const T& value) { + ::GTestStreamToHelper(ss_.get(), value); + } +#endif // GTEST_OS_SYMBIAN + + // We'll hold the text streamed to this object here. + const internal::scoped_ptr< ::std::stringstream> ss_; + + // We declare (but don't implement) this to prevent the compiler + // from implementing the assignment operator. + void operator=(const Message&); +}; + +// Streams a Message to an ostream. +inline std::ostream& operator <<(std::ostream& os, const Message& sb) { + return os << sb.GetString(); +} + +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_GTEST_MESSAGE_H_ +// This file was GENERATED by command: +// pump.py gtest-param-test.h.pump +// DO NOT EDIT BY HAND!!! + +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Authors: vladl@google.com (Vlad Losev) +// +// Macros and functions for implementing parameterized tests +// in Google C++ Testing Framework (Google Test) +// +// This file is generated by a SCRIPT. DO NOT EDIT BY HAND! +// +#ifndef GTEST_INCLUDE_GTEST_GTEST_PARAM_TEST_H_ +#define GTEST_INCLUDE_GTEST_GTEST_PARAM_TEST_H_ + + +// Value-parameterized tests allow you to test your code with different +// parameters without writing multiple copies of the same test. +// +// Here is how you use value-parameterized tests: + +#if 0 + +// To write value-parameterized tests, first you should define a fixture +// class. It is usually derived from testing::TestWithParam (see below for +// another inheritance scheme that's sometimes useful in more complicated +// class hierarchies), where the type of your parameter values. +// TestWithParam is itself derived from testing::Test. T can be any +// copyable type. If it's a raw pointer, you are responsible for managing the +// lifespan of the pointed values. + +class FooTest : public ::testing::TestWithParam { + // You can implement all the usual class fixture members here. +}; + +// Then, use the TEST_P macro to define as many parameterized tests +// for this fixture as you want. The _P suffix is for "parameterized" +// or "pattern", whichever you prefer to think. + +TEST_P(FooTest, DoesBlah) { + // Inside a test, access the test parameter with the GetParam() method + // of the TestWithParam class: + EXPECT_TRUE(foo.Blah(GetParam())); + ... +} + +TEST_P(FooTest, HasBlahBlah) { + ... +} + +// Finally, you can use INSTANTIATE_TEST_CASE_P to instantiate the test +// case with any set of parameters you want. Google Test defines a number +// of functions for generating test parameters. They return what we call +// (surprise!) parameter generators. Here is a summary of them, which +// are all in the testing namespace: +// +// +// Range(begin, end [, step]) - Yields values {begin, begin+step, +// begin+step+step, ...}. The values do not +// include end. step defaults to 1. +// Values(v1, v2, ..., vN) - Yields values {v1, v2, ..., vN}. +// ValuesIn(container) - Yields values from a C-style array, an STL +// ValuesIn(begin,end) container, or an iterator range [begin, end). +// Bool() - Yields sequence {false, true}. +// Combine(g1, g2, ..., gN) - Yields all combinations (the Cartesian product +// for the math savvy) of the values generated +// by the N generators. +// +// For more details, see comments at the definitions of these functions below +// in this file. +// +// The following statement will instantiate tests from the FooTest test case +// each with parameter values "meeny", "miny", and "moe". + +INSTANTIATE_TEST_CASE_P(InstantiationName, + FooTest, + Values("meeny", "miny", "moe")); + +// To distinguish different instances of the pattern, (yes, you +// can instantiate it more then once) the first argument to the +// INSTANTIATE_TEST_CASE_P macro is a prefix that will be added to the +// actual test case name. Remember to pick unique prefixes for different +// instantiations. The tests from the instantiation above will have +// these names: +// +// * InstantiationName/FooTest.DoesBlah/0 for "meeny" +// * InstantiationName/FooTest.DoesBlah/1 for "miny" +// * InstantiationName/FooTest.DoesBlah/2 for "moe" +// * InstantiationName/FooTest.HasBlahBlah/0 for "meeny" +// * InstantiationName/FooTest.HasBlahBlah/1 for "miny" +// * InstantiationName/FooTest.HasBlahBlah/2 for "moe" +// +// You can use these names in --gtest_filter. +// +// This statement will instantiate all tests from FooTest again, each +// with parameter values "cat" and "dog": + +const char* pets[] = {"cat", "dog"}; +INSTANTIATE_TEST_CASE_P(AnotherInstantiationName, FooTest, ValuesIn(pets)); + +// The tests from the instantiation above will have these names: +// +// * AnotherInstantiationName/FooTest.DoesBlah/0 for "cat" +// * AnotherInstantiationName/FooTest.DoesBlah/1 for "dog" +// * AnotherInstantiationName/FooTest.HasBlahBlah/0 for "cat" +// * AnotherInstantiationName/FooTest.HasBlahBlah/1 for "dog" +// +// Please note that INSTANTIATE_TEST_CASE_P will instantiate all tests +// in the given test case, whether their definitions come before or +// AFTER the INSTANTIATE_TEST_CASE_P statement. +// +// Please also note that generator expressions (including parameters to the +// generators) are evaluated in InitGoogleTest(), after main() has started. +// This allows the user on one hand, to adjust generator parameters in order +// to dynamically determine a set of tests to run and on the other hand, +// give the user a chance to inspect the generated tests with Google Test +// reflection API before RUN_ALL_TESTS() is executed. +// +// You can see samples/sample7_unittest.cc and samples/sample8_unittest.cc +// for more examples. +// +// In the future, we plan to publish the API for defining new parameter +// generators. But for now this interface remains part of the internal +// implementation and is subject to change. +// +// +// A parameterized test fixture must be derived from testing::Test and from +// testing::WithParamInterface, where T is the type of the parameter +// values. Inheriting from TestWithParam satisfies that requirement because +// TestWithParam inherits from both Test and WithParamInterface. In more +// complicated hierarchies, however, it is occasionally useful to inherit +// separately from Test and WithParamInterface. For example: + +class BaseTest : public ::testing::Test { + // You can inherit all the usual members for a non-parameterized test + // fixture here. +}; + +class DerivedTest : public BaseTest, public ::testing::WithParamInterface { + // The usual test fixture members go here too. +}; + +TEST_F(BaseTest, HasFoo) { + // This is an ordinary non-parameterized test. +} + +TEST_P(DerivedTest, DoesBlah) { + // GetParam works just the same here as if you inherit from TestWithParam. + EXPECT_TRUE(foo.Blah(GetParam())); +} + +#endif // 0 + + +#if !GTEST_OS_SYMBIAN +# include +#endif + +// scripts/fuse_gtest.py depends on gtest's own header being #included +// *unconditionally*. Therefore these #includes cannot be moved +// inside #if GTEST_HAS_PARAM_TEST. +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: vladl@google.com (Vlad Losev) + +// Type and function utilities for implementing parameterized tests. + +#ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PARAM_UTIL_H_ +#define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PARAM_UTIL_H_ + +#include +#include +#include + +// scripts/fuse_gtest.py depends on gtest's own header being #included +// *unconditionally*. Therefore these #includes cannot be moved +// inside #if GTEST_HAS_PARAM_TEST. +// Copyright 2003 Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Authors: Dan Egnor (egnor@google.com) +// +// A "smart" pointer type with reference tracking. Every pointer to a +// particular object is kept on a circular linked list. When the last pointer +// to an object is destroyed or reassigned, the object is deleted. +// +// Used properly, this deletes the object when the last reference goes away. +// There are several caveats: +// - Like all reference counting schemes, cycles lead to leaks. +// - Each smart pointer is actually two pointers (8 bytes instead of 4). +// - Every time a pointer is assigned, the entire list of pointers to that +// object is traversed. This class is therefore NOT SUITABLE when there +// will often be more than two or three pointers to a particular object. +// - References are only tracked as long as linked_ptr<> objects are copied. +// If a linked_ptr<> is converted to a raw pointer and back, BAD THINGS +// will happen (double deletion). +// +// A good use of this class is storing object references in STL containers. +// You can safely put linked_ptr<> in a vector<>. +// Other uses may not be as good. +// +// Note: If you use an incomplete type with linked_ptr<>, the class +// *containing* linked_ptr<> must have a constructor and destructor (even +// if they do nothing!). +// +// Bill Gibbons suggested we use something like this. +// +// Thread Safety: +// Unlike other linked_ptr implementations, in this implementation +// a linked_ptr object is thread-safe in the sense that: +// - it's safe to copy linked_ptr objects concurrently, +// - it's safe to copy *from* a linked_ptr and read its underlying +// raw pointer (e.g. via get()) concurrently, and +// - it's safe to write to two linked_ptrs that point to the same +// shared object concurrently. +// TODO(wan@google.com): rename this to safe_linked_ptr to avoid +// confusion with normal linked_ptr. + +#ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_LINKED_PTR_H_ +#define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_LINKED_PTR_H_ + +#include +#include + + +namespace testing { +namespace internal { + +// Protects copying of all linked_ptr objects. +GTEST_API_ GTEST_DECLARE_STATIC_MUTEX_(g_linked_ptr_mutex); + +// This is used internally by all instances of linked_ptr<>. It needs to be +// a non-template class because different types of linked_ptr<> can refer to +// the same object (linked_ptr(obj) vs linked_ptr(obj)). +// So, it needs to be possible for different types of linked_ptr to participate +// in the same circular linked list, so we need a single class type here. +// +// DO NOT USE THIS CLASS DIRECTLY YOURSELF. Use linked_ptr. +class linked_ptr_internal { + public: + // Create a new circle that includes only this instance. + void join_new() { + next_ = this; + } + + // Many linked_ptr operations may change p.link_ for some linked_ptr + // variable p in the same circle as this object. Therefore we need + // to prevent two such operations from occurring concurrently. + // + // Note that different types of linked_ptr objects can coexist in a + // circle (e.g. linked_ptr, linked_ptr, and + // linked_ptr). Therefore we must use a single mutex to + // protect all linked_ptr objects. This can create serious + // contention in production code, but is acceptable in a testing + // framework. + + // Join an existing circle. + // L < g_linked_ptr_mutex + void join(linked_ptr_internal const* ptr) { + MutexLock lock(&g_linked_ptr_mutex); + + linked_ptr_internal const* p = ptr; + while (p->next_ != ptr) p = p->next_; + p->next_ = this; + next_ = ptr; + } + + // Leave whatever circle we're part of. Returns true if we were the + // last member of the circle. Once this is done, you can join() another. + // L < g_linked_ptr_mutex + bool depart() { + MutexLock lock(&g_linked_ptr_mutex); + + if (next_ == this) return true; + linked_ptr_internal const* p = next_; + while (p->next_ != this) p = p->next_; + p->next_ = next_; + return false; + } + + private: + mutable linked_ptr_internal const* next_; +}; + +template +class linked_ptr { + public: + typedef T element_type; + + // Take over ownership of a raw pointer. This should happen as soon as + // possible after the object is created. + explicit linked_ptr(T* ptr = NULL) { capture(ptr); } + ~linked_ptr() { depart(); } + + // Copy an existing linked_ptr<>, adding ourselves to the list of references. + template linked_ptr(linked_ptr const& ptr) { copy(&ptr); } + linked_ptr(linked_ptr const& ptr) { // NOLINT + assert(&ptr != this); + copy(&ptr); + } + + // Assignment releases the old value and acquires the new. + template linked_ptr& operator=(linked_ptr const& ptr) { + depart(); + copy(&ptr); + return *this; + } + + linked_ptr& operator=(linked_ptr const& ptr) { + if (&ptr != this) { + depart(); + copy(&ptr); + } + return *this; + } + + // Smart pointer members. + void reset(T* ptr = NULL) { + depart(); + capture(ptr); + } + T* get() const { return value_; } + T* operator->() const { return value_; } + T& operator*() const { return *value_; } + + bool operator==(T* p) const { return value_ == p; } + bool operator!=(T* p) const { return value_ != p; } + template + bool operator==(linked_ptr const& ptr) const { + return value_ == ptr.get(); + } + template + bool operator!=(linked_ptr const& ptr) const { + return value_ != ptr.get(); + } + + private: + template + friend class linked_ptr; + + T* value_; + linked_ptr_internal link_; + + void depart() { + if (link_.depart()) delete value_; + } + + void capture(T* ptr) { + value_ = ptr; + link_.join_new(); + } + + template void copy(linked_ptr const* ptr) { + value_ = ptr->get(); + if (value_) + link_.join(&ptr->link_); + else + link_.join_new(); + } +}; + +template inline +bool operator==(T* ptr, const linked_ptr& x) { + return ptr == x.get(); +} + +template inline +bool operator!=(T* ptr, const linked_ptr& x) { + return ptr != x.get(); +} + +// A function to convert T* into linked_ptr +// Doing e.g. make_linked_ptr(new FooBarBaz(arg)) is a shorter notation +// for linked_ptr >(new FooBarBaz(arg)) +template +linked_ptr make_linked_ptr(T* ptr) { + return linked_ptr(ptr); +} + +} // namespace internal +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_LINKED_PTR_H_ +// Copyright 2007, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) + +// Google Test - The Google C++ Testing Framework +// +// This file implements a universal value printer that can print a +// value of any type T: +// +// void ::testing::internal::UniversalPrinter::Print(value, ostream_ptr); +// +// A user can teach this function how to print a class type T by +// defining either operator<<() or PrintTo() in the namespace that +// defines T. More specifically, the FIRST defined function in the +// following list will be used (assuming T is defined in namespace +// foo): +// +// 1. foo::PrintTo(const T&, ostream*) +// 2. operator<<(ostream&, const T&) defined in either foo or the +// global namespace. +// +// If none of the above is defined, it will print the debug string of +// the value if it is a protocol buffer, or print the raw bytes in the +// value otherwise. +// +// To aid debugging: when T is a reference type, the address of the +// value is also printed; when T is a (const) char pointer, both the +// pointer value and the NUL-terminated string it points to are +// printed. +// +// We also provide some convenient wrappers: +// +// // Prints a value to a string. For a (const or not) char +// // pointer, the NUL-terminated string (but not the pointer) is +// // printed. +// std::string ::testing::PrintToString(const T& value); +// +// // Prints a value tersely: for a reference type, the referenced +// // value (but not the address) is printed; for a (const or not) char +// // pointer, the NUL-terminated string (but not the pointer) is +// // printed. +// void ::testing::internal::UniversalTersePrint(const T& value, ostream*); +// +// // Prints value using the type inferred by the compiler. The difference +// // from UniversalTersePrint() is that this function prints both the +// // pointer and the NUL-terminated string for a (const or not) char pointer. +// void ::testing::internal::UniversalPrint(const T& value, ostream*); +// +// // Prints the fields of a tuple tersely to a string vector, one +// // element for each field. Tuple support must be enabled in +// // gtest-port.h. +// std::vector UniversalTersePrintTupleFieldsToStrings( +// const Tuple& value); +// +// Known limitation: +// +// The print primitives print the elements of an STL-style container +// using the compiler-inferred type of *iter where iter is a +// const_iterator of the container. When const_iterator is an input +// iterator but not a forward iterator, this inferred type may not +// match value_type, and the print output may be incorrect. In +// practice, this is rarely a problem as for most containers +// const_iterator is a forward iterator. We'll fix this if there's an +// actual need for it. Note that this fix cannot rely on value_type +// being defined as many user-defined container types don't have +// value_type. + +#ifndef GTEST_INCLUDE_GTEST_GTEST_PRINTERS_H_ +#define GTEST_INCLUDE_GTEST_GTEST_PRINTERS_H_ + +#include // NOLINT +#include +#include +#include +#include + +namespace testing { + +// Definitions in the 'internal' and 'internal2' name spaces are +// subject to change without notice. DO NOT USE THEM IN USER CODE! +namespace internal2 { + +// Prints the given number of bytes in the given object to the given +// ostream. +GTEST_API_ void PrintBytesInObjectTo(const unsigned char* obj_bytes, + size_t count, + ::std::ostream* os); + +// For selecting which printer to use when a given type has neither << +// nor PrintTo(). +enum TypeKind { + kProtobuf, // a protobuf type + kConvertibleToInteger, // a type implicitly convertible to BiggestInt + // (e.g. a named or unnamed enum type) + kOtherType // anything else +}; + +// TypeWithoutFormatter::PrintValue(value, os) is called +// by the universal printer to print a value of type T when neither +// operator<< nor PrintTo() is defined for T, where kTypeKind is the +// "kind" of T as defined by enum TypeKind. +template +class TypeWithoutFormatter { + public: + // This default version is called when kTypeKind is kOtherType. + static void PrintValue(const T& value, ::std::ostream* os) { + PrintBytesInObjectTo(reinterpret_cast(&value), + sizeof(value), os); + } +}; + +// We print a protobuf using its ShortDebugString() when the string +// doesn't exceed this many characters; otherwise we print it using +// DebugString() for better readability. +const size_t kProtobufOneLinerMaxLength = 50; + +template +class TypeWithoutFormatter { + public: + static void PrintValue(const T& value, ::std::ostream* os) { + const ::testing::internal::string short_str = value.ShortDebugString(); + const ::testing::internal::string pretty_str = + short_str.length() <= kProtobufOneLinerMaxLength ? + short_str : ("\n" + value.DebugString()); + *os << ("<" + pretty_str + ">"); + } +}; + +template +class TypeWithoutFormatter { + public: + // Since T has no << operator or PrintTo() but can be implicitly + // converted to BiggestInt, we print it as a BiggestInt. + // + // Most likely T is an enum type (either named or unnamed), in which + // case printing it as an integer is the desired behavior. In case + // T is not an enum, printing it as an integer is the best we can do + // given that it has no user-defined printer. + static void PrintValue(const T& value, ::std::ostream* os) { + const internal::BiggestInt kBigInt = value; + *os << kBigInt; + } +}; + +// Prints the given value to the given ostream. If the value is a +// protocol message, its debug string is printed; if it's an enum or +// of a type implicitly convertible to BiggestInt, it's printed as an +// integer; otherwise the bytes in the value are printed. This is +// what UniversalPrinter::Print() does when it knows nothing about +// type T and T has neither << operator nor PrintTo(). +// +// A user can override this behavior for a class type Foo by defining +// a << operator in the namespace where Foo is defined. +// +// We put this operator in namespace 'internal2' instead of 'internal' +// to simplify the implementation, as much code in 'internal' needs to +// use << in STL, which would conflict with our own << were it defined +// in 'internal'. +// +// Note that this operator<< takes a generic std::basic_ostream type instead of the more restricted std::ostream. If +// we define it to take an std::ostream instead, we'll get an +// "ambiguous overloads" compiler error when trying to print a type +// Foo that supports streaming to std::basic_ostream, as the compiler cannot tell whether +// operator<<(std::ostream&, const T&) or +// operator<<(std::basic_stream, const Foo&) is more +// specific. +template +::std::basic_ostream& operator<<( + ::std::basic_ostream& os, const T& x) { + TypeWithoutFormatter::value ? kProtobuf : + internal::ImplicitlyConvertible::value ? + kConvertibleToInteger : kOtherType)>::PrintValue(x, &os); + return os; +} + +} // namespace internal2 +} // namespace testing + +// This namespace MUST NOT BE NESTED IN ::testing, or the name look-up +// magic needed for implementing UniversalPrinter won't work. +namespace testing_internal { + +// Used to print a value that is not an STL-style container when the +// user doesn't define PrintTo() for it. +template +void DefaultPrintNonContainerTo(const T& value, ::std::ostream* os) { + // With the following statement, during unqualified name lookup, + // testing::internal2::operator<< appears as if it was declared in + // the nearest enclosing namespace that contains both + // ::testing_internal and ::testing::internal2, i.e. the global + // namespace. For more details, refer to the C++ Standard section + // 7.3.4-1 [namespace.udir]. This allows us to fall back onto + // testing::internal2::operator<< in case T doesn't come with a << + // operator. + // + // We cannot write 'using ::testing::internal2::operator<<;', which + // gcc 3.3 fails to compile due to a compiler bug. + using namespace ::testing::internal2; // NOLINT + + // Assuming T is defined in namespace foo, in the next statement, + // the compiler will consider all of: + // + // 1. foo::operator<< (thanks to Koenig look-up), + // 2. ::operator<< (as the current namespace is enclosed in ::), + // 3. testing::internal2::operator<< (thanks to the using statement above). + // + // The operator<< whose type matches T best will be picked. + // + // We deliberately allow #2 to be a candidate, as sometimes it's + // impossible to define #1 (e.g. when foo is ::std, defining + // anything in it is undefined behavior unless you are a compiler + // vendor.). + *os << value; +} + +} // namespace testing_internal + +namespace testing { +namespace internal { + +// UniversalPrinter::Print(value, ostream_ptr) prints the given +// value to the given ostream. The caller must ensure that +// 'ostream_ptr' is not NULL, or the behavior is undefined. +// +// We define UniversalPrinter as a class template (as opposed to a +// function template), as we need to partially specialize it for +// reference types, which cannot be done with function templates. +template +class UniversalPrinter; + +template +void UniversalPrint(const T& value, ::std::ostream* os); + +// Used to print an STL-style container when the user doesn't define +// a PrintTo() for it. +template +void DefaultPrintTo(IsContainer /* dummy */, + false_type /* is not a pointer */, + const C& container, ::std::ostream* os) { + const size_t kMaxCount = 32; // The maximum number of elements to print. + *os << '{'; + size_t count = 0; + for (typename C::const_iterator it = container.begin(); + it != container.end(); ++it, ++count) { + if (count > 0) { + *os << ','; + if (count == kMaxCount) { // Enough has been printed. + *os << " ..."; + break; + } + } + *os << ' '; + // We cannot call PrintTo(*it, os) here as PrintTo() doesn't + // handle *it being a native array. + internal::UniversalPrint(*it, os); + } + + if (count > 0) { + *os << ' '; + } + *os << '}'; +} + +// Used to print a pointer that is neither a char pointer nor a member +// pointer, when the user doesn't define PrintTo() for it. (A member +// variable pointer or member function pointer doesn't really point to +// a location in the address space. Their representation is +// implementation-defined. Therefore they will be printed as raw +// bytes.) +template +void DefaultPrintTo(IsNotContainer /* dummy */, + true_type /* is a pointer */, + T* p, ::std::ostream* os) { + if (p == NULL) { + *os << "NULL"; + } else { + // C++ doesn't allow casting from a function pointer to any object + // pointer. + // + // IsTrue() silences warnings: "Condition is always true", + // "unreachable code". + if (IsTrue(ImplicitlyConvertible::value)) { + // T is not a function type. We just call << to print p, + // relying on ADL to pick up user-defined << for their pointer + // types, if any. + *os << p; + } else { + // T is a function type, so '*os << p' doesn't do what we want + // (it just prints p as bool). We want to print p as a const + // void*. However, we cannot cast it to const void* directly, + // even using reinterpret_cast, as earlier versions of gcc + // (e.g. 3.4.5) cannot compile the cast when p is a function + // pointer. Casting to UInt64 first solves the problem. + *os << reinterpret_cast( + reinterpret_cast(p)); + } + } +} + +// Used to print a non-container, non-pointer value when the user +// doesn't define PrintTo() for it. +template +void DefaultPrintTo(IsNotContainer /* dummy */, + false_type /* is not a pointer */, + const T& value, ::std::ostream* os) { + ::testing_internal::DefaultPrintNonContainerTo(value, os); +} + +// Prints the given value using the << operator if it has one; +// otherwise prints the bytes in it. This is what +// UniversalPrinter::Print() does when PrintTo() is not specialized +// or overloaded for type T. +// +// A user can override this behavior for a class type Foo by defining +// an overload of PrintTo() in the namespace where Foo is defined. We +// give the user this option as sometimes defining a << operator for +// Foo is not desirable (e.g. the coding style may prevent doing it, +// or there is already a << operator but it doesn't do what the user +// wants). +template +void PrintTo(const T& value, ::std::ostream* os) { + // DefaultPrintTo() is overloaded. The type of its first two + // arguments determine which version will be picked. If T is an + // STL-style container, the version for container will be called; if + // T is a pointer, the pointer version will be called; otherwise the + // generic version will be called. + // + // Note that we check for container types here, prior to we check + // for protocol message types in our operator<<. The rationale is: + // + // For protocol messages, we want to give people a chance to + // override Google Mock's format by defining a PrintTo() or + // operator<<. For STL containers, other formats can be + // incompatible with Google Mock's format for the container + // elements; therefore we check for container types here to ensure + // that our format is used. + // + // The second argument of DefaultPrintTo() is needed to bypass a bug + // in Symbian's C++ compiler that prevents it from picking the right + // overload between: + // + // PrintTo(const T& x, ...); + // PrintTo(T* x, ...); + DefaultPrintTo(IsContainerTest(0), is_pointer(), value, os); +} + +// The following list of PrintTo() overloads tells +// UniversalPrinter::Print() how to print standard types (built-in +// types, strings, plain arrays, and pointers). + +// Overloads for various char types. +GTEST_API_ void PrintTo(unsigned char c, ::std::ostream* os); +GTEST_API_ void PrintTo(signed char c, ::std::ostream* os); +inline void PrintTo(char c, ::std::ostream* os) { + // When printing a plain char, we always treat it as unsigned. This + // way, the output won't be affected by whether the compiler thinks + // char is signed or not. + PrintTo(static_cast(c), os); +} + +// Overloads for other simple built-in types. +inline void PrintTo(bool x, ::std::ostream* os) { + *os << (x ? "true" : "false"); +} + +// Overload for wchar_t type. +// Prints a wchar_t as a symbol if it is printable or as its internal +// code otherwise and also as its decimal code (except for L'\0'). +// The L'\0' char is printed as "L'\\0'". The decimal code is printed +// as signed integer when wchar_t is implemented by the compiler +// as a signed type and is printed as an unsigned integer when wchar_t +// is implemented as an unsigned type. +GTEST_API_ void PrintTo(wchar_t wc, ::std::ostream* os); + +// Overloads for C strings. +GTEST_API_ void PrintTo(const char* s, ::std::ostream* os); +inline void PrintTo(char* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} + +// signed/unsigned char is often used for representing binary data, so +// we print pointers to it as void* to be safe. +inline void PrintTo(const signed char* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} +inline void PrintTo(signed char* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} +inline void PrintTo(const unsigned char* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} +inline void PrintTo(unsigned char* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} + +// MSVC can be configured to define wchar_t as a typedef of unsigned +// short. It defines _NATIVE_WCHAR_T_DEFINED when wchar_t is a native +// type. When wchar_t is a typedef, defining an overload for const +// wchar_t* would cause unsigned short* be printed as a wide string, +// possibly causing invalid memory accesses. +#if !defined(_MSC_VER) || defined(_NATIVE_WCHAR_T_DEFINED) +// Overloads for wide C strings +GTEST_API_ void PrintTo(const wchar_t* s, ::std::ostream* os); +inline void PrintTo(wchar_t* s, ::std::ostream* os) { + PrintTo(ImplicitCast_(s), os); +} +#endif + +// Overload for C arrays. Multi-dimensional arrays are printed +// properly. + +// Prints the given number of elements in an array, without printing +// the curly braces. +template +void PrintRawArrayTo(const T a[], size_t count, ::std::ostream* os) { + UniversalPrint(a[0], os); + for (size_t i = 1; i != count; i++) { + *os << ", "; + UniversalPrint(a[i], os); + } +} + +// Overloads for ::string and ::std::string. +#if GTEST_HAS_GLOBAL_STRING +GTEST_API_ void PrintStringTo(const ::string&s, ::std::ostream* os); +inline void PrintTo(const ::string& s, ::std::ostream* os) { + PrintStringTo(s, os); +} +#endif // GTEST_HAS_GLOBAL_STRING + +GTEST_API_ void PrintStringTo(const ::std::string&s, ::std::ostream* os); +inline void PrintTo(const ::std::string& s, ::std::ostream* os) { + PrintStringTo(s, os); +} + +// Overloads for ::wstring and ::std::wstring. +#if GTEST_HAS_GLOBAL_WSTRING +GTEST_API_ void PrintWideStringTo(const ::wstring&s, ::std::ostream* os); +inline void PrintTo(const ::wstring& s, ::std::ostream* os) { + PrintWideStringTo(s, os); +} +#endif // GTEST_HAS_GLOBAL_WSTRING + +#if GTEST_HAS_STD_WSTRING +GTEST_API_ void PrintWideStringTo(const ::std::wstring&s, ::std::ostream* os); +inline void PrintTo(const ::std::wstring& s, ::std::ostream* os) { + PrintWideStringTo(s, os); +} +#endif // GTEST_HAS_STD_WSTRING + +#if GTEST_HAS_TR1_TUPLE +// Overload for ::std::tr1::tuple. Needed for printing function arguments, +// which are packed as tuples. + +// Helper function for printing a tuple. T must be instantiated with +// a tuple type. +template +void PrintTupleTo(const T& t, ::std::ostream* os); + +// Overloaded PrintTo() for tuples of various arities. We support +// tuples of up-to 10 fields. The following implementation works +// regardless of whether tr1::tuple is implemented using the +// non-standard variadic template feature or not. + +inline void PrintTo(const ::std::tr1::tuple<>& t, ::std::ostream* os) { + PrintTupleTo(t, os); +} + +template +void PrintTo(const ::std::tr1::tuple& t, ::std::ostream* os) { + PrintTupleTo(t, os); +} + +template +void PrintTo(const ::std::tr1::tuple& t, ::std::ostream* os) { + PrintTupleTo(t, os); +} + +template +void PrintTo(const ::std::tr1::tuple& t, ::std::ostream* os) { + PrintTupleTo(t, os); +} + +template +void PrintTo(const ::std::tr1::tuple& t, ::std::ostream* os) { + PrintTupleTo(t, os); +} + +template +void PrintTo(const ::std::tr1::tuple& t, + ::std::ostream* os) { + PrintTupleTo(t, os); +} + +template +void PrintTo(const ::std::tr1::tuple& t, + ::std::ostream* os) { + PrintTupleTo(t, os); +} + +template +void PrintTo(const ::std::tr1::tuple& t, + ::std::ostream* os) { + PrintTupleTo(t, os); +} + +template +void PrintTo(const ::std::tr1::tuple& t, + ::std::ostream* os) { + PrintTupleTo(t, os); +} + +template +void PrintTo(const ::std::tr1::tuple& t, + ::std::ostream* os) { + PrintTupleTo(t, os); +} + +template +void PrintTo( + const ::std::tr1::tuple& t, + ::std::ostream* os) { + PrintTupleTo(t, os); +} +#endif // GTEST_HAS_TR1_TUPLE + +// Overload for std::pair. +template +void PrintTo(const ::std::pair& value, ::std::ostream* os) { + *os << '('; + // We cannot use UniversalPrint(value.first, os) here, as T1 may be + // a reference type. The same for printing value.second. + UniversalPrinter::Print(value.first, os); + *os << ", "; + UniversalPrinter::Print(value.second, os); + *os << ')'; +} + +// Implements printing a non-reference type T by letting the compiler +// pick the right overload of PrintTo() for T. +template +class UniversalPrinter { + public: + // MSVC warns about adding const to a function type, so we want to + // disable the warning. +#ifdef _MSC_VER +# pragma warning(push) // Saves the current warning state. +# pragma warning(disable:4180) // Temporarily disables warning 4180. +#endif // _MSC_VER + + // Note: we deliberately don't call this PrintTo(), as that name + // conflicts with ::testing::internal::PrintTo in the body of the + // function. + static void Print(const T& value, ::std::ostream* os) { + // By default, ::testing::internal::PrintTo() is used for printing + // the value. + // + // Thanks to Koenig look-up, if T is a class and has its own + // PrintTo() function defined in its namespace, that function will + // be visible here. Since it is more specific than the generic ones + // in ::testing::internal, it will be picked by the compiler in the + // following statement - exactly what we want. + PrintTo(value, os); + } + +#ifdef _MSC_VER +# pragma warning(pop) // Restores the warning state. +#endif // _MSC_VER +}; + +// UniversalPrintArray(begin, len, os) prints an array of 'len' +// elements, starting at address 'begin'. +template +void UniversalPrintArray(const T* begin, size_t len, ::std::ostream* os) { + if (len == 0) { + *os << "{}"; + } else { + *os << "{ "; + const size_t kThreshold = 18; + const size_t kChunkSize = 8; + // If the array has more than kThreshold elements, we'll have to + // omit some details by printing only the first and the last + // kChunkSize elements. + // TODO(wan@google.com): let the user control the threshold using a flag. + if (len <= kThreshold) { + PrintRawArrayTo(begin, len, os); + } else { + PrintRawArrayTo(begin, kChunkSize, os); + *os << ", ..., "; + PrintRawArrayTo(begin + len - kChunkSize, kChunkSize, os); + } + *os << " }"; + } +} +// This overload prints a (const) char array compactly. +GTEST_API_ void UniversalPrintArray(const char* begin, + size_t len, + ::std::ostream* os); + +// Implements printing an array type T[N]. +template +class UniversalPrinter { + public: + // Prints the given array, omitting some elements when there are too + // many. + static void Print(const T (&a)[N], ::std::ostream* os) { + UniversalPrintArray(a, N, os); + } +}; + +// Implements printing a reference type T&. +template +class UniversalPrinter { + public: + // MSVC warns about adding const to a function type, so we want to + // disable the warning. +#ifdef _MSC_VER +# pragma warning(push) // Saves the current warning state. +# pragma warning(disable:4180) // Temporarily disables warning 4180. +#endif // _MSC_VER + + static void Print(const T& value, ::std::ostream* os) { + // Prints the address of the value. We use reinterpret_cast here + // as static_cast doesn't compile when T is a function type. + *os << "@" << reinterpret_cast(&value) << " "; + + // Then prints the value itself. + UniversalPrint(value, os); + } + +#ifdef _MSC_VER +# pragma warning(pop) // Restores the warning state. +#endif // _MSC_VER +}; + +// Prints a value tersely: for a reference type, the referenced value +// (but not the address) is printed; for a (const) char pointer, the +// NUL-terminated string (but not the pointer) is printed. +template +void UniversalTersePrint(const T& value, ::std::ostream* os) { + UniversalPrint(value, os); +} +inline void UniversalTersePrint(const char* str, ::std::ostream* os) { + if (str == NULL) { + *os << "NULL"; + } else { + UniversalPrint(string(str), os); + } +} +inline void UniversalTersePrint(char* str, ::std::ostream* os) { + UniversalTersePrint(static_cast(str), os); +} + +// Prints a value using the type inferred by the compiler. The +// difference between this and UniversalTersePrint() is that for a +// (const) char pointer, this prints both the pointer and the +// NUL-terminated string. +template +void UniversalPrint(const T& value, ::std::ostream* os) { + UniversalPrinter::Print(value, os); +} + +#if GTEST_HAS_TR1_TUPLE +typedef ::std::vector Strings; + +// This helper template allows PrintTo() for tuples and +// UniversalTersePrintTupleFieldsToStrings() to be defined by +// induction on the number of tuple fields. The idea is that +// TuplePrefixPrinter::PrintPrefixTo(t, os) prints the first N +// fields in tuple t, and can be defined in terms of +// TuplePrefixPrinter. + +// The inductive case. +template +struct TuplePrefixPrinter { + // Prints the first N fields of a tuple. + template + static void PrintPrefixTo(const Tuple& t, ::std::ostream* os) { + TuplePrefixPrinter::PrintPrefixTo(t, os); + *os << ", "; + UniversalPrinter::type> + ::Print(::std::tr1::get(t), os); + } + + // Tersely prints the first N fields of a tuple to a string vector, + // one element for each field. + template + static void TersePrintPrefixToStrings(const Tuple& t, Strings* strings) { + TuplePrefixPrinter::TersePrintPrefixToStrings(t, strings); + ::std::stringstream ss; + UniversalTersePrint(::std::tr1::get(t), &ss); + strings->push_back(ss.str()); + } +}; + +// Base cases. +template <> +struct TuplePrefixPrinter<0> { + template + static void PrintPrefixTo(const Tuple&, ::std::ostream*) {} + + template + static void TersePrintPrefixToStrings(const Tuple&, Strings*) {} +}; +// We have to specialize the entire TuplePrefixPrinter<> class +// template here, even though the definition of +// TersePrintPrefixToStrings() is the same as the generic version, as +// Embarcadero (formerly CodeGear, formerly Borland) C++ doesn't +// support specializing a method template of a class template. +template <> +struct TuplePrefixPrinter<1> { + template + static void PrintPrefixTo(const Tuple& t, ::std::ostream* os) { + UniversalPrinter::type>:: + Print(::std::tr1::get<0>(t), os); + } + + template + static void TersePrintPrefixToStrings(const Tuple& t, Strings* strings) { + ::std::stringstream ss; + UniversalTersePrint(::std::tr1::get<0>(t), &ss); + strings->push_back(ss.str()); + } +}; + +// Helper function for printing a tuple. T must be instantiated with +// a tuple type. +template +void PrintTupleTo(const T& t, ::std::ostream* os) { + *os << "("; + TuplePrefixPrinter< ::std::tr1::tuple_size::value>:: + PrintPrefixTo(t, os); + *os << ")"; +} + +// Prints the fields of a tuple tersely to a string vector, one +// element for each field. See the comment before +// UniversalTersePrint() for how we define "tersely". +template +Strings UniversalTersePrintTupleFieldsToStrings(const Tuple& value) { + Strings result; + TuplePrefixPrinter< ::std::tr1::tuple_size::value>:: + TersePrintPrefixToStrings(value, &result); + return result; +} +#endif // GTEST_HAS_TR1_TUPLE + +} // namespace internal + +template +::std::string PrintToString(const T& value) { + ::std::stringstream ss; + internal::UniversalTersePrint(value, &ss); + return ss.str(); +} + +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_GTEST_PRINTERS_H_ + +#if GTEST_HAS_PARAM_TEST + +namespace testing { +namespace internal { + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Outputs a message explaining invalid registration of different +// fixture class for the same test case. This may happen when +// TEST_P macro is used to define two tests with the same name +// but in different namespaces. +GTEST_API_ void ReportInvalidTestCaseType(const char* test_case_name, + const char* file, int line); + +template class ParamGeneratorInterface; +template class ParamGenerator; + +// Interface for iterating over elements provided by an implementation +// of ParamGeneratorInterface. +template +class ParamIteratorInterface { + public: + virtual ~ParamIteratorInterface() {} + // A pointer to the base generator instance. + // Used only for the purposes of iterator comparison + // to make sure that two iterators belong to the same generator. + virtual const ParamGeneratorInterface* BaseGenerator() const = 0; + // Advances iterator to point to the next element + // provided by the generator. The caller is responsible + // for not calling Advance() on an iterator equal to + // BaseGenerator()->End(). + virtual void Advance() = 0; + // Clones the iterator object. Used for implementing copy semantics + // of ParamIterator. + virtual ParamIteratorInterface* Clone() const = 0; + // Dereferences the current iterator and provides (read-only) access + // to the pointed value. It is the caller's responsibility not to call + // Current() on an iterator equal to BaseGenerator()->End(). + // Used for implementing ParamGenerator::operator*(). + virtual const T* Current() const = 0; + // Determines whether the given iterator and other point to the same + // element in the sequence generated by the generator. + // Used for implementing ParamGenerator::operator==(). + virtual bool Equals(const ParamIteratorInterface& other) const = 0; +}; + +// Class iterating over elements provided by an implementation of +// ParamGeneratorInterface. It wraps ParamIteratorInterface +// and implements the const forward iterator concept. +template +class ParamIterator { + public: + typedef T value_type; + typedef const T& reference; + typedef ptrdiff_t difference_type; + + // ParamIterator assumes ownership of the impl_ pointer. + ParamIterator(const ParamIterator& other) : impl_(other.impl_->Clone()) {} + ParamIterator& operator=(const ParamIterator& other) { + if (this != &other) + impl_.reset(other.impl_->Clone()); + return *this; + } + + const T& operator*() const { return *impl_->Current(); } + const T* operator->() const { return impl_->Current(); } + // Prefix version of operator++. + ParamIterator& operator++() { + impl_->Advance(); + return *this; + } + // Postfix version of operator++. + ParamIterator operator++(int /*unused*/) { + ParamIteratorInterface* clone = impl_->Clone(); + impl_->Advance(); + return ParamIterator(clone); + } + bool operator==(const ParamIterator& other) const { + return impl_.get() == other.impl_.get() || impl_->Equals(*other.impl_); + } + bool operator!=(const ParamIterator& other) const { + return !(*this == other); + } + + private: + friend class ParamGenerator; + explicit ParamIterator(ParamIteratorInterface* impl) : impl_(impl) {} + scoped_ptr > impl_; +}; + +// ParamGeneratorInterface is the binary interface to access generators +// defined in other translation units. +template +class ParamGeneratorInterface { + public: + typedef T ParamType; + + virtual ~ParamGeneratorInterface() {} + + // Generator interface definition + virtual ParamIteratorInterface* Begin() const = 0; + virtual ParamIteratorInterface* End() const = 0; +}; + +// Wraps ParamGeneratorInterface and provides general generator syntax +// compatible with the STL Container concept. +// This class implements copy initialization semantics and the contained +// ParamGeneratorInterface instance is shared among all copies +// of the original object. This is possible because that instance is immutable. +template +class ParamGenerator { + public: + typedef ParamIterator iterator; + + explicit ParamGenerator(ParamGeneratorInterface* impl) : impl_(impl) {} + ParamGenerator(const ParamGenerator& other) : impl_(other.impl_) {} + + ParamGenerator& operator=(const ParamGenerator& other) { + impl_ = other.impl_; + return *this; + } + + iterator begin() const { return iterator(impl_->Begin()); } + iterator end() const { return iterator(impl_->End()); } + + private: + linked_ptr > impl_; +}; + +// Generates values from a range of two comparable values. Can be used to +// generate sequences of user-defined types that implement operator+() and +// operator<(). +// This class is used in the Range() function. +template +class RangeGenerator : public ParamGeneratorInterface { + public: + RangeGenerator(T begin, T end, IncrementT step) + : begin_(begin), end_(end), + step_(step), end_index_(CalculateEndIndex(begin, end, step)) {} + virtual ~RangeGenerator() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, begin_, 0, step_); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, end_, end_index_, step_); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, T value, int index, + IncrementT step) + : base_(base), value_(value), index_(index), step_(step) {} + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + virtual void Advance() { + value_ = value_ + step_; + index_++; + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + virtual const T* Current() const { return &value_; } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const int other_index = + CheckedDowncastToActualType(&other)->index_; + return index_ == other_index; + } + + private: + Iterator(const Iterator& other) + : ParamIteratorInterface(), + base_(other.base_), value_(other.value_), index_(other.index_), + step_(other.step_) {} + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + T value_; + int index_; + const IncrementT step_; + }; // class RangeGenerator::Iterator + + static int CalculateEndIndex(const T& begin, + const T& end, + const IncrementT& step) { + int end_index = 0; + for (T i = begin; i < end; i = i + step) + end_index++; + return end_index; + } + + // No implementation - assignment is unsupported. + void operator=(const RangeGenerator& other); + + const T begin_; + const T end_; + const IncrementT step_; + // The index for the end() iterator. All the elements in the generated + // sequence are indexed (0-based) to aid iterator comparison. + const int end_index_; +}; // class RangeGenerator + + +// Generates values from a pair of STL-style iterators. Used in the +// ValuesIn() function. The elements are copied from the source range +// since the source can be located on the stack, and the generator +// is likely to persist beyond that stack frame. +template +class ValuesInIteratorRangeGenerator : public ParamGeneratorInterface { + public: + template + ValuesInIteratorRangeGenerator(ForwardIterator begin, ForwardIterator end) + : container_(begin, end) {} + virtual ~ValuesInIteratorRangeGenerator() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, container_.begin()); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, container_.end()); + } + + private: + typedef typename ::std::vector ContainerType; + + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + typename ContainerType::const_iterator iterator) + : base_(base), iterator_(iterator) {} + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + virtual void Advance() { + ++iterator_; + value_.reset(); + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + // We need to use cached value referenced by iterator_ because *iterator_ + // can return a temporary object (and of type other then T), so just + // having "return &*iterator_;" doesn't work. + // value_ is updated here and not in Advance() because Advance() + // can advance iterator_ beyond the end of the range, and we cannot + // detect that fact. The client code, on the other hand, is + // responsible for not calling Current() on an out-of-range iterator. + virtual const T* Current() const { + if (value_.get() == NULL) + value_.reset(new T(*iterator_)); + return value_.get(); + } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + return iterator_ == + CheckedDowncastToActualType(&other)->iterator_; + } + + private: + Iterator(const Iterator& other) + // The explicit constructor call suppresses a false warning + // emitted by gcc when supplied with the -Wextra option. + : ParamIteratorInterface(), + base_(other.base_), + iterator_(other.iterator_) {} + + const ParamGeneratorInterface* const base_; + typename ContainerType::const_iterator iterator_; + // A cached value of *iterator_. We keep it here to allow access by + // pointer in the wrapping iterator's operator->(). + // value_ needs to be mutable to be accessed in Current(). + // Use of scoped_ptr helps manage cached value's lifetime, + // which is bound by the lifespan of the iterator itself. + mutable scoped_ptr value_; + }; // class ValuesInIteratorRangeGenerator::Iterator + + // No implementation - assignment is unsupported. + void operator=(const ValuesInIteratorRangeGenerator& other); + + const ContainerType container_; +}; // class ValuesInIteratorRangeGenerator + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Stores a parameter value and later creates tests parameterized with that +// value. +template +class ParameterizedTestFactory : public TestFactoryBase { + public: + typedef typename TestClass::ParamType ParamType; + explicit ParameterizedTestFactory(ParamType parameter) : + parameter_(parameter) {} + virtual Test* CreateTest() { + TestClass::SetParam(¶meter_); + return new TestClass(); + } + + private: + const ParamType parameter_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ParameterizedTestFactory); +}; + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// TestMetaFactoryBase is a base class for meta-factories that create +// test factories for passing into MakeAndRegisterTestInfo function. +template +class TestMetaFactoryBase { + public: + virtual ~TestMetaFactoryBase() {} + + virtual TestFactoryBase* CreateTestFactory(ParamType parameter) = 0; +}; + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// TestMetaFactory creates test factories for passing into +// MakeAndRegisterTestInfo function. Since MakeAndRegisterTestInfo receives +// ownership of test factory pointer, same factory object cannot be passed +// into that method twice. But ParameterizedTestCaseInfo is going to call +// it for each Test/Parameter value combination. Thus it needs meta factory +// creator class. +template +class TestMetaFactory + : public TestMetaFactoryBase { + public: + typedef typename TestCase::ParamType ParamType; + + TestMetaFactory() {} + + virtual TestFactoryBase* CreateTestFactory(ParamType parameter) { + return new ParameterizedTestFactory(parameter); + } + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestMetaFactory); +}; + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// ParameterizedTestCaseInfoBase is a generic interface +// to ParameterizedTestCaseInfo classes. ParameterizedTestCaseInfoBase +// accumulates test information provided by TEST_P macro invocations +// and generators provided by INSTANTIATE_TEST_CASE_P macro invocations +// and uses that information to register all resulting test instances +// in RegisterTests method. The ParameterizeTestCaseRegistry class holds +// a collection of pointers to the ParameterizedTestCaseInfo objects +// and calls RegisterTests() on each of them when asked. +class ParameterizedTestCaseInfoBase { + public: + virtual ~ParameterizedTestCaseInfoBase() {} + + // Base part of test case name for display purposes. + virtual const string& GetTestCaseName() const = 0; + // Test case id to verify identity. + virtual TypeId GetTestCaseTypeId() const = 0; + // UnitTest class invokes this method to register tests in this + // test case right before running them in RUN_ALL_TESTS macro. + // This method should not be called more then once on any single + // instance of a ParameterizedTestCaseInfoBase derived class. + virtual void RegisterTests() = 0; + + protected: + ParameterizedTestCaseInfoBase() {} + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(ParameterizedTestCaseInfoBase); +}; + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// ParameterizedTestCaseInfo accumulates tests obtained from TEST_P +// macro invocations for a particular test case and generators +// obtained from INSTANTIATE_TEST_CASE_P macro invocations for that +// test case. It registers tests with all values generated by all +// generators when asked. +template +class ParameterizedTestCaseInfo : public ParameterizedTestCaseInfoBase { + public: + // ParamType and GeneratorCreationFunc are private types but are required + // for declarations of public methods AddTestPattern() and + // AddTestCaseInstantiation(). + typedef typename TestCase::ParamType ParamType; + // A function that returns an instance of appropriate generator type. + typedef ParamGenerator(GeneratorCreationFunc)(); + + explicit ParameterizedTestCaseInfo(const char* name) + : test_case_name_(name) {} + + // Test case base name for display purposes. + virtual const string& GetTestCaseName() const { return test_case_name_; } + // Test case id to verify identity. + virtual TypeId GetTestCaseTypeId() const { return GetTypeId(); } + // TEST_P macro uses AddTestPattern() to record information + // about a single test in a LocalTestInfo structure. + // test_case_name is the base name of the test case (without invocation + // prefix). test_base_name is the name of an individual test without + // parameter index. For the test SequenceA/FooTest.DoBar/1 FooTest is + // test case base name and DoBar is test base name. + void AddTestPattern(const char* test_case_name, + const char* test_base_name, + TestMetaFactoryBase* meta_factory) { + tests_.push_back(linked_ptr(new TestInfo(test_case_name, + test_base_name, + meta_factory))); + } + // INSTANTIATE_TEST_CASE_P macro uses AddGenerator() to record information + // about a generator. + int AddTestCaseInstantiation(const string& instantiation_name, + GeneratorCreationFunc* func, + const char* /* file */, + int /* line */) { + instantiations_.push_back(::std::make_pair(instantiation_name, func)); + return 0; // Return value used only to run this method in namespace scope. + } + // UnitTest class invokes this method to register tests in this test case + // test cases right before running tests in RUN_ALL_TESTS macro. + // This method should not be called more then once on any single + // instance of a ParameterizedTestCaseInfoBase derived class. + // UnitTest has a guard to prevent from calling this method more then once. + virtual void RegisterTests() { + for (typename TestInfoContainer::iterator test_it = tests_.begin(); + test_it != tests_.end(); ++test_it) { + linked_ptr test_info = *test_it; + for (typename InstantiationContainer::iterator gen_it = + instantiations_.begin(); gen_it != instantiations_.end(); + ++gen_it) { + const string& instantiation_name = gen_it->first; + ParamGenerator generator((*gen_it->second)()); + + Message test_case_name_stream; + if ( !instantiation_name.empty() ) + test_case_name_stream << instantiation_name << "/"; + test_case_name_stream << test_info->test_case_base_name; + + int i = 0; + for (typename ParamGenerator::iterator param_it = + generator.begin(); + param_it != generator.end(); ++param_it, ++i) { + Message test_name_stream; + test_name_stream << test_info->test_base_name << "/" << i; + MakeAndRegisterTestInfo( + test_case_name_stream.GetString().c_str(), + test_name_stream.GetString().c_str(), + NULL, // No type parameter. + PrintToString(*param_it).c_str(), + GetTestCaseTypeId(), + TestCase::SetUpTestCase, + TestCase::TearDownTestCase, + test_info->test_meta_factory->CreateTestFactory(*param_it)); + } // for param_it + } // for gen_it + } // for test_it + } // RegisterTests + + private: + // LocalTestInfo structure keeps information about a single test registered + // with TEST_P macro. + struct TestInfo { + TestInfo(const char* a_test_case_base_name, + const char* a_test_base_name, + TestMetaFactoryBase* a_test_meta_factory) : + test_case_base_name(a_test_case_base_name), + test_base_name(a_test_base_name), + test_meta_factory(a_test_meta_factory) {} + + const string test_case_base_name; + const string test_base_name; + const scoped_ptr > test_meta_factory; + }; + typedef ::std::vector > TestInfoContainer; + // Keeps pairs of + // received from INSTANTIATE_TEST_CASE_P macros. + typedef ::std::vector > + InstantiationContainer; + + const string test_case_name_; + TestInfoContainer tests_; + InstantiationContainer instantiations_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ParameterizedTestCaseInfo); +}; // class ParameterizedTestCaseInfo + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// ParameterizedTestCaseRegistry contains a map of ParameterizedTestCaseInfoBase +// classes accessed by test case names. TEST_P and INSTANTIATE_TEST_CASE_P +// macros use it to locate their corresponding ParameterizedTestCaseInfo +// descriptors. +class ParameterizedTestCaseRegistry { + public: + ParameterizedTestCaseRegistry() {} + ~ParameterizedTestCaseRegistry() { + for (TestCaseInfoContainer::iterator it = test_case_infos_.begin(); + it != test_case_infos_.end(); ++it) { + delete *it; + } + } + + // Looks up or creates and returns a structure containing information about + // tests and instantiations of a particular test case. + template + ParameterizedTestCaseInfo* GetTestCasePatternHolder( + const char* test_case_name, + const char* file, + int line) { + ParameterizedTestCaseInfo* typed_test_info = NULL; + for (TestCaseInfoContainer::iterator it = test_case_infos_.begin(); + it != test_case_infos_.end(); ++it) { + if ((*it)->GetTestCaseName() == test_case_name) { + if ((*it)->GetTestCaseTypeId() != GetTypeId()) { + // Complain about incorrect usage of Google Test facilities + // and terminate the program since we cannot guaranty correct + // test case setup and tear-down in this case. + ReportInvalidTestCaseType(test_case_name, file, line); + posix::Abort(); + } else { + // At this point we are sure that the object we found is of the same + // type we are looking for, so we downcast it to that type + // without further checks. + typed_test_info = CheckedDowncastToActualType< + ParameterizedTestCaseInfo >(*it); + } + break; + } + } + if (typed_test_info == NULL) { + typed_test_info = new ParameterizedTestCaseInfo(test_case_name); + test_case_infos_.push_back(typed_test_info); + } + return typed_test_info; + } + void RegisterTests() { + for (TestCaseInfoContainer::iterator it = test_case_infos_.begin(); + it != test_case_infos_.end(); ++it) { + (*it)->RegisterTests(); + } + } + + private: + typedef ::std::vector TestCaseInfoContainer; + + TestCaseInfoContainer test_case_infos_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(ParameterizedTestCaseRegistry); +}; + +} // namespace internal +} // namespace testing + +#endif // GTEST_HAS_PARAM_TEST + +#endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PARAM_UTIL_H_ +// This file was GENERATED by command: +// pump.py gtest-param-util-generated.h.pump +// DO NOT EDIT BY HAND!!! + +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: vladl@google.com (Vlad Losev) + +// Type and function utilities for implementing parameterized tests. +// This file is generated by a SCRIPT. DO NOT EDIT BY HAND! +// +// Currently Google Test supports at most 50 arguments in Values, +// and at most 10 arguments in Combine. Please contact +// googletestframework@googlegroups.com if you need more. +// Please note that the number of arguments to Combine is limited +// by the maximum arity of the implementation of tr1::tuple which is +// currently set at 10. + +#ifndef GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PARAM_UTIL_GENERATED_H_ +#define GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PARAM_UTIL_GENERATED_H_ + +// scripts/fuse_gtest.py depends on gtest's own header being #included +// *unconditionally*. Therefore these #includes cannot be moved +// inside #if GTEST_HAS_PARAM_TEST. + +#if GTEST_HAS_PARAM_TEST + +namespace testing { + +// Forward declarations of ValuesIn(), which is implemented in +// include/gtest/gtest-param-test.h. +template +internal::ParamGenerator< + typename ::testing::internal::IteratorTraits::value_type> +ValuesIn(ForwardIterator begin, ForwardIterator end); + +template +internal::ParamGenerator ValuesIn(const T (&array)[N]); + +template +internal::ParamGenerator ValuesIn( + const Container& container); + +namespace internal { + +// Used in the Values() function to provide polymorphic capabilities. +template +class ValueArray1 { + public: + explicit ValueArray1(T1 v1) : v1_(v1) {} + + template + operator ParamGenerator() const { return ValuesIn(&v1_, &v1_ + 1); } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray1& other); + + const T1 v1_; +}; + +template +class ValueArray2 { + public: + ValueArray2(T1 v1, T2 v2) : v1_(v1), v2_(v2) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray2& other); + + const T1 v1_; + const T2 v2_; +}; + +template +class ValueArray3 { + public: + ValueArray3(T1 v1, T2 v2, T3 v3) : v1_(v1), v2_(v2), v3_(v3) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray3& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; +}; + +template +class ValueArray4 { + public: + ValueArray4(T1 v1, T2 v2, T3 v3, T4 v4) : v1_(v1), v2_(v2), v3_(v3), + v4_(v4) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray4& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; +}; + +template +class ValueArray5 { + public: + ValueArray5(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5) : v1_(v1), v2_(v2), v3_(v3), + v4_(v4), v5_(v5) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray5& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; +}; + +template +class ValueArray6 { + public: + ValueArray6(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6) : v1_(v1), v2_(v2), + v3_(v3), v4_(v4), v5_(v5), v6_(v6) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray6& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; +}; + +template +class ValueArray7 { + public: + ValueArray7(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7) : v1_(v1), + v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray7& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; +}; + +template +class ValueArray8 { + public: + ValueArray8(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, + T8 v8) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray8& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; +}; + +template +class ValueArray9 { + public: + ValueArray9(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, + T9 v9) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray9& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; +}; + +template +class ValueArray10 { + public: + ValueArray10(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray10& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; +}; + +template +class ValueArray11 { + public: + ValueArray11(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), + v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray11& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; +}; + +template +class ValueArray12 { + public: + ValueArray12(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), + v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray12& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; +}; + +template +class ValueArray13 { + public: + ValueArray13(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), + v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), + v12_(v12), v13_(v13) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray13& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; +}; + +template +class ValueArray14 { + public: + ValueArray14(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14) : v1_(v1), v2_(v2), v3_(v3), + v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), + v11_(v11), v12_(v12), v13_(v13), v14_(v14) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray14& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; +}; + +template +class ValueArray15 { + public: + ValueArray15(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15) : v1_(v1), v2_(v2), + v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), + v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray15& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; +}; + +template +class ValueArray16 { + public: + ValueArray16(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16) : v1_(v1), + v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), + v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), + v16_(v16) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray16& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; +}; + +template +class ValueArray17 { + public: + ValueArray17(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, + T17 v17) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), + v15_(v15), v16_(v16), v17_(v17) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray17& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; +}; + +template +class ValueArray18 { + public: + ValueArray18(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), + v15_(v15), v16_(v16), v17_(v17), v18_(v18) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray18& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; +}; + +template +class ValueArray19 { + public: + ValueArray19(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), + v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), + v14_(v14), v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray19& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; +}; + +template +class ValueArray20 { + public: + ValueArray20(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), + v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), + v13_(v13), v14_(v14), v15_(v15), v16_(v16), v17_(v17), v18_(v18), + v19_(v19), v20_(v20) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray20& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; +}; + +template +class ValueArray21 { + public: + ValueArray21(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), + v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), + v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), v17_(v17), + v18_(v18), v19_(v19), v20_(v20), v21_(v21) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray21& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; +}; + +template +class ValueArray22 { + public: + ValueArray22(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22) : v1_(v1), v2_(v2), v3_(v3), + v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), + v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), + v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray22& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; +}; + +template +class ValueArray23 { + public: + ValueArray23(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23) : v1_(v1), v2_(v2), + v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), + v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), + v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22), + v23_(v23) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, + v23_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray23& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; +}; + +template +class ValueArray24 { + public: + ValueArray24(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24) : v1_(v1), + v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), + v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), + v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), + v22_(v22), v23_(v23), v24_(v24) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray24& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; +}; + +template +class ValueArray25 { + public: + ValueArray25(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, + T25 v25) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), + v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), + v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray25& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; +}; + +template +class ValueArray26 { + public: + ValueArray26(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), + v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), + v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25), v26_(v26) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray26& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; +}; + +template +class ValueArray27 { + public: + ValueArray27(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), + v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), + v14_(v14), v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), + v20_(v20), v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25), + v26_(v26), v27_(v27) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray27& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; +}; + +template +class ValueArray28 { + public: + ValueArray28(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), + v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), + v13_(v13), v14_(v14), v15_(v15), v16_(v16), v17_(v17), v18_(v18), + v19_(v19), v20_(v20), v21_(v21), v22_(v22), v23_(v23), v24_(v24), + v25_(v25), v26_(v26), v27_(v27), v28_(v28) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray28& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; +}; + +template +class ValueArray29 { + public: + ValueArray29(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), + v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), + v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), v17_(v17), + v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22), v23_(v23), + v24_(v24), v25_(v25), v26_(v26), v27_(v27), v28_(v28), v29_(v29) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray29& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; +}; + +template +class ValueArray30 { + public: + ValueArray30(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30) : v1_(v1), v2_(v2), v3_(v3), + v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), + v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), + v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22), + v23_(v23), v24_(v24), v25_(v25), v26_(v26), v27_(v27), v28_(v28), + v29_(v29), v30_(v30) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray30& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; +}; + +template +class ValueArray31 { + public: + ValueArray31(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31) : v1_(v1), v2_(v2), + v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), + v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), + v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22), + v23_(v23), v24_(v24), v25_(v25), v26_(v26), v27_(v27), v28_(v28), + v29_(v29), v30_(v30), v31_(v31) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray31& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; +}; + +template +class ValueArray32 { + public: + ValueArray32(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32) : v1_(v1), + v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), + v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), + v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), + v22_(v22), v23_(v23), v24_(v24), v25_(v25), v26_(v26), v27_(v27), + v28_(v28), v29_(v29), v30_(v30), v31_(v31), v32_(v32) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray32& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; +}; + +template +class ValueArray33 { + public: + ValueArray33(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, + T33 v33) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), + v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), + v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25), v26_(v26), + v27_(v27), v28_(v28), v29_(v29), v30_(v30), v31_(v31), v32_(v32), + v33_(v33) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray33& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; +}; + +template +class ValueArray34 { + public: + ValueArray34(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), + v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), + v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25), v26_(v26), + v27_(v27), v28_(v28), v29_(v29), v30_(v30), v31_(v31), v32_(v32), + v33_(v33), v34_(v34) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray34& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; +}; + +template +class ValueArray35 { + public: + ValueArray35(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), + v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), + v14_(v14), v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), + v20_(v20), v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25), + v26_(v26), v27_(v27), v28_(v28), v29_(v29), v30_(v30), v31_(v31), + v32_(v32), v33_(v33), v34_(v34), v35_(v35) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, + v35_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray35& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; +}; + +template +class ValueArray36 { + public: + ValueArray36(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), + v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), + v13_(v13), v14_(v14), v15_(v15), v16_(v16), v17_(v17), v18_(v18), + v19_(v19), v20_(v20), v21_(v21), v22_(v22), v23_(v23), v24_(v24), + v25_(v25), v26_(v26), v27_(v27), v28_(v28), v29_(v29), v30_(v30), + v31_(v31), v32_(v32), v33_(v33), v34_(v34), v35_(v35), v36_(v36) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray36& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; +}; + +template +class ValueArray37 { + public: + ValueArray37(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), + v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), + v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), v17_(v17), + v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22), v23_(v23), + v24_(v24), v25_(v25), v26_(v26), v27_(v27), v28_(v28), v29_(v29), + v30_(v30), v31_(v31), v32_(v32), v33_(v33), v34_(v34), v35_(v35), + v36_(v36), v37_(v37) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray37& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; +}; + +template +class ValueArray38 { + public: + ValueArray38(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38) : v1_(v1), v2_(v2), v3_(v3), + v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), + v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), + v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22), + v23_(v23), v24_(v24), v25_(v25), v26_(v26), v27_(v27), v28_(v28), + v29_(v29), v30_(v30), v31_(v31), v32_(v32), v33_(v33), v34_(v34), + v35_(v35), v36_(v36), v37_(v37), v38_(v38) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray38& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; +}; + +template +class ValueArray39 { + public: + ValueArray39(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39) : v1_(v1), v2_(v2), + v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), + v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), + v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22), + v23_(v23), v24_(v24), v25_(v25), v26_(v26), v27_(v27), v28_(v28), + v29_(v29), v30_(v30), v31_(v31), v32_(v32), v33_(v33), v34_(v34), + v35_(v35), v36_(v36), v37_(v37), v38_(v38), v39_(v39) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray39& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; +}; + +template +class ValueArray40 { + public: + ValueArray40(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40) : v1_(v1), + v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), + v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), + v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), + v22_(v22), v23_(v23), v24_(v24), v25_(v25), v26_(v26), v27_(v27), + v28_(v28), v29_(v29), v30_(v30), v31_(v31), v32_(v32), v33_(v33), + v34_(v34), v35_(v35), v36_(v36), v37_(v37), v38_(v38), v39_(v39), + v40_(v40) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray40& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; +}; + +template +class ValueArray41 { + public: + ValueArray41(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, + T41 v41) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), + v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), + v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25), v26_(v26), + v27_(v27), v28_(v28), v29_(v29), v30_(v30), v31_(v31), v32_(v32), + v33_(v33), v34_(v34), v35_(v35), v36_(v36), v37_(v37), v38_(v38), + v39_(v39), v40_(v40), v41_(v41) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_, v41_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray41& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; + const T41 v41_; +}; + +template +class ValueArray42 { + public: + ValueArray42(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), + v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), + v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25), v26_(v26), + v27_(v27), v28_(v28), v29_(v29), v30_(v30), v31_(v31), v32_(v32), + v33_(v33), v34_(v34), v35_(v35), v36_(v36), v37_(v37), v38_(v38), + v39_(v39), v40_(v40), v41_(v41), v42_(v42) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_, v41_, v42_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray42& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; + const T41 v41_; + const T42 v42_; +}; + +template +class ValueArray43 { + public: + ValueArray43(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42, T43 v43) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), + v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), + v14_(v14), v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), + v20_(v20), v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25), + v26_(v26), v27_(v27), v28_(v28), v29_(v29), v30_(v30), v31_(v31), + v32_(v32), v33_(v33), v34_(v34), v35_(v35), v36_(v36), v37_(v37), + v38_(v38), v39_(v39), v40_(v40), v41_(v41), v42_(v42), v43_(v43) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_, v41_, v42_, v43_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray43& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; + const T41 v41_; + const T42 v42_; + const T43 v43_; +}; + +template +class ValueArray44 { + public: + ValueArray44(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42, T43 v43, T44 v44) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), + v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), + v13_(v13), v14_(v14), v15_(v15), v16_(v16), v17_(v17), v18_(v18), + v19_(v19), v20_(v20), v21_(v21), v22_(v22), v23_(v23), v24_(v24), + v25_(v25), v26_(v26), v27_(v27), v28_(v28), v29_(v29), v30_(v30), + v31_(v31), v32_(v32), v33_(v33), v34_(v34), v35_(v35), v36_(v36), + v37_(v37), v38_(v38), v39_(v39), v40_(v40), v41_(v41), v42_(v42), + v43_(v43), v44_(v44) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_, v41_, v42_, v43_, v44_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray44& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; + const T41 v41_; + const T42 v42_; + const T43 v43_; + const T44 v44_; +}; + +template +class ValueArray45 { + public: + ValueArray45(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42, T43 v43, T44 v44, T45 v45) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), + v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), v11_(v11), + v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), v17_(v17), + v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22), v23_(v23), + v24_(v24), v25_(v25), v26_(v26), v27_(v27), v28_(v28), v29_(v29), + v30_(v30), v31_(v31), v32_(v32), v33_(v33), v34_(v34), v35_(v35), + v36_(v36), v37_(v37), v38_(v38), v39_(v39), v40_(v40), v41_(v41), + v42_(v42), v43_(v43), v44_(v44), v45_(v45) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_, v41_, v42_, v43_, v44_, v45_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray45& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; + const T41 v41_; + const T42 v42_; + const T43 v43_; + const T44 v44_; + const T45 v45_; +}; + +template +class ValueArray46 { + public: + ValueArray46(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42, T43 v43, T44 v44, T45 v45, T46 v46) : v1_(v1), v2_(v2), v3_(v3), + v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), + v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), + v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22), + v23_(v23), v24_(v24), v25_(v25), v26_(v26), v27_(v27), v28_(v28), + v29_(v29), v30_(v30), v31_(v31), v32_(v32), v33_(v33), v34_(v34), + v35_(v35), v36_(v36), v37_(v37), v38_(v38), v39_(v39), v40_(v40), + v41_(v41), v42_(v42), v43_(v43), v44_(v44), v45_(v45), v46_(v46) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_, v41_, v42_, v43_, v44_, v45_, v46_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray46& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; + const T41 v41_; + const T42 v42_; + const T43 v43_; + const T44 v44_; + const T45 v45_; + const T46 v46_; +}; + +template +class ValueArray47 { + public: + ValueArray47(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42, T43 v43, T44 v44, T45 v45, T46 v46, T47 v47) : v1_(v1), v2_(v2), + v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), v10_(v10), + v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), v16_(v16), + v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), v22_(v22), + v23_(v23), v24_(v24), v25_(v25), v26_(v26), v27_(v27), v28_(v28), + v29_(v29), v30_(v30), v31_(v31), v32_(v32), v33_(v33), v34_(v34), + v35_(v35), v36_(v36), v37_(v37), v38_(v38), v39_(v39), v40_(v40), + v41_(v41), v42_(v42), v43_(v43), v44_(v44), v45_(v45), v46_(v46), + v47_(v47) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_, v41_, v42_, v43_, v44_, v45_, v46_, + v47_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray47& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; + const T41 v41_; + const T42 v42_; + const T43 v43_; + const T44 v44_; + const T45 v45_; + const T46 v46_; + const T47 v47_; +}; + +template +class ValueArray48 { + public: + ValueArray48(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42, T43 v43, T44 v44, T45 v45, T46 v46, T47 v47, T48 v48) : v1_(v1), + v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), v8_(v8), v9_(v9), + v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), v15_(v15), + v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), v21_(v21), + v22_(v22), v23_(v23), v24_(v24), v25_(v25), v26_(v26), v27_(v27), + v28_(v28), v29_(v29), v30_(v30), v31_(v31), v32_(v32), v33_(v33), + v34_(v34), v35_(v35), v36_(v36), v37_(v37), v38_(v38), v39_(v39), + v40_(v40), v41_(v41), v42_(v42), v43_(v43), v44_(v44), v45_(v45), + v46_(v46), v47_(v47), v48_(v48) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_, v41_, v42_, v43_, v44_, v45_, v46_, v47_, + v48_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray48& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; + const T41 v41_; + const T42 v42_; + const T43 v43_; + const T44 v44_; + const T45 v45_; + const T46 v46_; + const T47 v47_; + const T48 v48_; +}; + +template +class ValueArray49 { + public: + ValueArray49(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42, T43 v43, T44 v44, T45 v45, T46 v46, T47 v47, T48 v48, + T49 v49) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), + v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), + v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25), v26_(v26), + v27_(v27), v28_(v28), v29_(v29), v30_(v30), v31_(v31), v32_(v32), + v33_(v33), v34_(v34), v35_(v35), v36_(v36), v37_(v37), v38_(v38), + v39_(v39), v40_(v40), v41_(v41), v42_(v42), v43_(v43), v44_(v44), + v45_(v45), v46_(v46), v47_(v47), v48_(v48), v49_(v49) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_, v41_, v42_, v43_, v44_, v45_, v46_, v47_, + v48_, v49_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray49& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; + const T41 v41_; + const T42 v42_; + const T43 v43_; + const T44 v44_; + const T45 v45_; + const T46 v46_; + const T47 v47_; + const T48 v48_; + const T49 v49_; +}; + +template +class ValueArray50 { + public: + ValueArray50(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42, T43 v43, T44 v44, T45 v45, T46 v46, T47 v47, T48 v48, T49 v49, + T50 v50) : v1_(v1), v2_(v2), v3_(v3), v4_(v4), v5_(v5), v6_(v6), v7_(v7), + v8_(v8), v9_(v9), v10_(v10), v11_(v11), v12_(v12), v13_(v13), v14_(v14), + v15_(v15), v16_(v16), v17_(v17), v18_(v18), v19_(v19), v20_(v20), + v21_(v21), v22_(v22), v23_(v23), v24_(v24), v25_(v25), v26_(v26), + v27_(v27), v28_(v28), v29_(v29), v30_(v30), v31_(v31), v32_(v32), + v33_(v33), v34_(v34), v35_(v35), v36_(v36), v37_(v37), v38_(v38), + v39_(v39), v40_(v40), v41_(v41), v42_(v42), v43_(v43), v44_(v44), + v45_(v45), v46_(v46), v47_(v47), v48_(v48), v49_(v49), v50_(v50) {} + + template + operator ParamGenerator() const { + const T array[] = {v1_, v2_, v3_, v4_, v5_, v6_, v7_, v8_, v9_, v10_, v11_, + v12_, v13_, v14_, v15_, v16_, v17_, v18_, v19_, v20_, v21_, v22_, v23_, + v24_, v25_, v26_, v27_, v28_, v29_, v30_, v31_, v32_, v33_, v34_, v35_, + v36_, v37_, v38_, v39_, v40_, v41_, v42_, v43_, v44_, v45_, v46_, v47_, + v48_, v49_, v50_}; + return ValuesIn(array); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const ValueArray50& other); + + const T1 v1_; + const T2 v2_; + const T3 v3_; + const T4 v4_; + const T5 v5_; + const T6 v6_; + const T7 v7_; + const T8 v8_; + const T9 v9_; + const T10 v10_; + const T11 v11_; + const T12 v12_; + const T13 v13_; + const T14 v14_; + const T15 v15_; + const T16 v16_; + const T17 v17_; + const T18 v18_; + const T19 v19_; + const T20 v20_; + const T21 v21_; + const T22 v22_; + const T23 v23_; + const T24 v24_; + const T25 v25_; + const T26 v26_; + const T27 v27_; + const T28 v28_; + const T29 v29_; + const T30 v30_; + const T31 v31_; + const T32 v32_; + const T33 v33_; + const T34 v34_; + const T35 v35_; + const T36 v36_; + const T37 v37_; + const T38 v38_; + const T39 v39_; + const T40 v40_; + const T41 v41_; + const T42 v42_; + const T43 v43_; + const T44 v44_; + const T45 v45_; + const T46 v46_; + const T47 v47_; + const T48 v48_; + const T49 v49_; + const T50 v50_; +}; + +# if GTEST_HAS_COMBINE +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Generates values from the Cartesian product of values produced +// by the argument generators. +// +template +class CartesianProductGenerator2 + : public ParamGeneratorInterface< ::std::tr1::tuple > { + public: + typedef ::std::tr1::tuple ParamType; + + CartesianProductGenerator2(const ParamGenerator& g1, + const ParamGenerator& g2) + : g1_(g1), g2_(g2) {} + virtual ~CartesianProductGenerator2() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, g1_, g1_.begin(), g2_, g2_.begin()); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, g1_, g1_.end(), g2_, g2_.end()); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + const ParamGenerator& g1, + const typename ParamGenerator::iterator& current1, + const ParamGenerator& g2, + const typename ParamGenerator::iterator& current2) + : base_(base), + begin1_(g1.begin()), end1_(g1.end()), current1_(current1), + begin2_(g2.begin()), end2_(g2.end()), current2_(current2) { + ComputeCurrentValue(); + } + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + // Advance should not be called on beyond-of-range iterators + // so no component iterators must be beyond end of range, either. + virtual void Advance() { + assert(!AtEnd()); + ++current2_; + if (current2_ == end2_) { + current2_ = begin2_; + ++current1_; + } + ComputeCurrentValue(); + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + virtual const ParamType* Current() const { return ¤t_value_; } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const Iterator* typed_other = + CheckedDowncastToActualType(&other); + // We must report iterators equal if they both point beyond their + // respective ranges. That can happen in a variety of fashions, + // so we have to consult AtEnd(). + return (AtEnd() && typed_other->AtEnd()) || + ( + current1_ == typed_other->current1_ && + current2_ == typed_other->current2_); + } + + private: + Iterator(const Iterator& other) + : base_(other.base_), + begin1_(other.begin1_), + end1_(other.end1_), + current1_(other.current1_), + begin2_(other.begin2_), + end2_(other.end2_), + current2_(other.current2_) { + ComputeCurrentValue(); + } + + void ComputeCurrentValue() { + if (!AtEnd()) + current_value_ = ParamType(*current1_, *current2_); + } + bool AtEnd() const { + // We must report iterator past the end of the range when either of the + // component iterators has reached the end of its range. + return + current1_ == end1_ || + current2_ == end2_; + } + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + // begin[i]_ and end[i]_ define the i-th range that Iterator traverses. + // current[i]_ is the actual traversing iterator. + const typename ParamGenerator::iterator begin1_; + const typename ParamGenerator::iterator end1_; + typename ParamGenerator::iterator current1_; + const typename ParamGenerator::iterator begin2_; + const typename ParamGenerator::iterator end2_; + typename ParamGenerator::iterator current2_; + ParamType current_value_; + }; // class CartesianProductGenerator2::Iterator + + // No implementation - assignment is unsupported. + void operator=(const CartesianProductGenerator2& other); + + const ParamGenerator g1_; + const ParamGenerator g2_; +}; // class CartesianProductGenerator2 + + +template +class CartesianProductGenerator3 + : public ParamGeneratorInterface< ::std::tr1::tuple > { + public: + typedef ::std::tr1::tuple ParamType; + + CartesianProductGenerator3(const ParamGenerator& g1, + const ParamGenerator& g2, const ParamGenerator& g3) + : g1_(g1), g2_(g2), g3_(g3) {} + virtual ~CartesianProductGenerator3() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, g1_, g1_.begin(), g2_, g2_.begin(), g3_, + g3_.begin()); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, g1_, g1_.end(), g2_, g2_.end(), g3_, g3_.end()); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + const ParamGenerator& g1, + const typename ParamGenerator::iterator& current1, + const ParamGenerator& g2, + const typename ParamGenerator::iterator& current2, + const ParamGenerator& g3, + const typename ParamGenerator::iterator& current3) + : base_(base), + begin1_(g1.begin()), end1_(g1.end()), current1_(current1), + begin2_(g2.begin()), end2_(g2.end()), current2_(current2), + begin3_(g3.begin()), end3_(g3.end()), current3_(current3) { + ComputeCurrentValue(); + } + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + // Advance should not be called on beyond-of-range iterators + // so no component iterators must be beyond end of range, either. + virtual void Advance() { + assert(!AtEnd()); + ++current3_; + if (current3_ == end3_) { + current3_ = begin3_; + ++current2_; + } + if (current2_ == end2_) { + current2_ = begin2_; + ++current1_; + } + ComputeCurrentValue(); + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + virtual const ParamType* Current() const { return ¤t_value_; } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const Iterator* typed_other = + CheckedDowncastToActualType(&other); + // We must report iterators equal if they both point beyond their + // respective ranges. That can happen in a variety of fashions, + // so we have to consult AtEnd(). + return (AtEnd() && typed_other->AtEnd()) || + ( + current1_ == typed_other->current1_ && + current2_ == typed_other->current2_ && + current3_ == typed_other->current3_); + } + + private: + Iterator(const Iterator& other) + : base_(other.base_), + begin1_(other.begin1_), + end1_(other.end1_), + current1_(other.current1_), + begin2_(other.begin2_), + end2_(other.end2_), + current2_(other.current2_), + begin3_(other.begin3_), + end3_(other.end3_), + current3_(other.current3_) { + ComputeCurrentValue(); + } + + void ComputeCurrentValue() { + if (!AtEnd()) + current_value_ = ParamType(*current1_, *current2_, *current3_); + } + bool AtEnd() const { + // We must report iterator past the end of the range when either of the + // component iterators has reached the end of its range. + return + current1_ == end1_ || + current2_ == end2_ || + current3_ == end3_; + } + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + // begin[i]_ and end[i]_ define the i-th range that Iterator traverses. + // current[i]_ is the actual traversing iterator. + const typename ParamGenerator::iterator begin1_; + const typename ParamGenerator::iterator end1_; + typename ParamGenerator::iterator current1_; + const typename ParamGenerator::iterator begin2_; + const typename ParamGenerator::iterator end2_; + typename ParamGenerator::iterator current2_; + const typename ParamGenerator::iterator begin3_; + const typename ParamGenerator::iterator end3_; + typename ParamGenerator::iterator current3_; + ParamType current_value_; + }; // class CartesianProductGenerator3::Iterator + + // No implementation - assignment is unsupported. + void operator=(const CartesianProductGenerator3& other); + + const ParamGenerator g1_; + const ParamGenerator g2_; + const ParamGenerator g3_; +}; // class CartesianProductGenerator3 + + +template +class CartesianProductGenerator4 + : public ParamGeneratorInterface< ::std::tr1::tuple > { + public: + typedef ::std::tr1::tuple ParamType; + + CartesianProductGenerator4(const ParamGenerator& g1, + const ParamGenerator& g2, const ParamGenerator& g3, + const ParamGenerator& g4) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4) {} + virtual ~CartesianProductGenerator4() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, g1_, g1_.begin(), g2_, g2_.begin(), g3_, + g3_.begin(), g4_, g4_.begin()); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, g1_, g1_.end(), g2_, g2_.end(), g3_, g3_.end(), + g4_, g4_.end()); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + const ParamGenerator& g1, + const typename ParamGenerator::iterator& current1, + const ParamGenerator& g2, + const typename ParamGenerator::iterator& current2, + const ParamGenerator& g3, + const typename ParamGenerator::iterator& current3, + const ParamGenerator& g4, + const typename ParamGenerator::iterator& current4) + : base_(base), + begin1_(g1.begin()), end1_(g1.end()), current1_(current1), + begin2_(g2.begin()), end2_(g2.end()), current2_(current2), + begin3_(g3.begin()), end3_(g3.end()), current3_(current3), + begin4_(g4.begin()), end4_(g4.end()), current4_(current4) { + ComputeCurrentValue(); + } + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + // Advance should not be called on beyond-of-range iterators + // so no component iterators must be beyond end of range, either. + virtual void Advance() { + assert(!AtEnd()); + ++current4_; + if (current4_ == end4_) { + current4_ = begin4_; + ++current3_; + } + if (current3_ == end3_) { + current3_ = begin3_; + ++current2_; + } + if (current2_ == end2_) { + current2_ = begin2_; + ++current1_; + } + ComputeCurrentValue(); + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + virtual const ParamType* Current() const { return ¤t_value_; } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const Iterator* typed_other = + CheckedDowncastToActualType(&other); + // We must report iterators equal if they both point beyond their + // respective ranges. That can happen in a variety of fashions, + // so we have to consult AtEnd(). + return (AtEnd() && typed_other->AtEnd()) || + ( + current1_ == typed_other->current1_ && + current2_ == typed_other->current2_ && + current3_ == typed_other->current3_ && + current4_ == typed_other->current4_); + } + + private: + Iterator(const Iterator& other) + : base_(other.base_), + begin1_(other.begin1_), + end1_(other.end1_), + current1_(other.current1_), + begin2_(other.begin2_), + end2_(other.end2_), + current2_(other.current2_), + begin3_(other.begin3_), + end3_(other.end3_), + current3_(other.current3_), + begin4_(other.begin4_), + end4_(other.end4_), + current4_(other.current4_) { + ComputeCurrentValue(); + } + + void ComputeCurrentValue() { + if (!AtEnd()) + current_value_ = ParamType(*current1_, *current2_, *current3_, + *current4_); + } + bool AtEnd() const { + // We must report iterator past the end of the range when either of the + // component iterators has reached the end of its range. + return + current1_ == end1_ || + current2_ == end2_ || + current3_ == end3_ || + current4_ == end4_; + } + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + // begin[i]_ and end[i]_ define the i-th range that Iterator traverses. + // current[i]_ is the actual traversing iterator. + const typename ParamGenerator::iterator begin1_; + const typename ParamGenerator::iterator end1_; + typename ParamGenerator::iterator current1_; + const typename ParamGenerator::iterator begin2_; + const typename ParamGenerator::iterator end2_; + typename ParamGenerator::iterator current2_; + const typename ParamGenerator::iterator begin3_; + const typename ParamGenerator::iterator end3_; + typename ParamGenerator::iterator current3_; + const typename ParamGenerator::iterator begin4_; + const typename ParamGenerator::iterator end4_; + typename ParamGenerator::iterator current4_; + ParamType current_value_; + }; // class CartesianProductGenerator4::Iterator + + // No implementation - assignment is unsupported. + void operator=(const CartesianProductGenerator4& other); + + const ParamGenerator g1_; + const ParamGenerator g2_; + const ParamGenerator g3_; + const ParamGenerator g4_; +}; // class CartesianProductGenerator4 + + +template +class CartesianProductGenerator5 + : public ParamGeneratorInterface< ::std::tr1::tuple > { + public: + typedef ::std::tr1::tuple ParamType; + + CartesianProductGenerator5(const ParamGenerator& g1, + const ParamGenerator& g2, const ParamGenerator& g3, + const ParamGenerator& g4, const ParamGenerator& g5) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5) {} + virtual ~CartesianProductGenerator5() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, g1_, g1_.begin(), g2_, g2_.begin(), g3_, + g3_.begin(), g4_, g4_.begin(), g5_, g5_.begin()); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, g1_, g1_.end(), g2_, g2_.end(), g3_, g3_.end(), + g4_, g4_.end(), g5_, g5_.end()); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + const ParamGenerator& g1, + const typename ParamGenerator::iterator& current1, + const ParamGenerator& g2, + const typename ParamGenerator::iterator& current2, + const ParamGenerator& g3, + const typename ParamGenerator::iterator& current3, + const ParamGenerator& g4, + const typename ParamGenerator::iterator& current4, + const ParamGenerator& g5, + const typename ParamGenerator::iterator& current5) + : base_(base), + begin1_(g1.begin()), end1_(g1.end()), current1_(current1), + begin2_(g2.begin()), end2_(g2.end()), current2_(current2), + begin3_(g3.begin()), end3_(g3.end()), current3_(current3), + begin4_(g4.begin()), end4_(g4.end()), current4_(current4), + begin5_(g5.begin()), end5_(g5.end()), current5_(current5) { + ComputeCurrentValue(); + } + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + // Advance should not be called on beyond-of-range iterators + // so no component iterators must be beyond end of range, either. + virtual void Advance() { + assert(!AtEnd()); + ++current5_; + if (current5_ == end5_) { + current5_ = begin5_; + ++current4_; + } + if (current4_ == end4_) { + current4_ = begin4_; + ++current3_; + } + if (current3_ == end3_) { + current3_ = begin3_; + ++current2_; + } + if (current2_ == end2_) { + current2_ = begin2_; + ++current1_; + } + ComputeCurrentValue(); + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + virtual const ParamType* Current() const { return ¤t_value_; } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const Iterator* typed_other = + CheckedDowncastToActualType(&other); + // We must report iterators equal if they both point beyond their + // respective ranges. That can happen in a variety of fashions, + // so we have to consult AtEnd(). + return (AtEnd() && typed_other->AtEnd()) || + ( + current1_ == typed_other->current1_ && + current2_ == typed_other->current2_ && + current3_ == typed_other->current3_ && + current4_ == typed_other->current4_ && + current5_ == typed_other->current5_); + } + + private: + Iterator(const Iterator& other) + : base_(other.base_), + begin1_(other.begin1_), + end1_(other.end1_), + current1_(other.current1_), + begin2_(other.begin2_), + end2_(other.end2_), + current2_(other.current2_), + begin3_(other.begin3_), + end3_(other.end3_), + current3_(other.current3_), + begin4_(other.begin4_), + end4_(other.end4_), + current4_(other.current4_), + begin5_(other.begin5_), + end5_(other.end5_), + current5_(other.current5_) { + ComputeCurrentValue(); + } + + void ComputeCurrentValue() { + if (!AtEnd()) + current_value_ = ParamType(*current1_, *current2_, *current3_, + *current4_, *current5_); + } + bool AtEnd() const { + // We must report iterator past the end of the range when either of the + // component iterators has reached the end of its range. + return + current1_ == end1_ || + current2_ == end2_ || + current3_ == end3_ || + current4_ == end4_ || + current5_ == end5_; + } + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + // begin[i]_ and end[i]_ define the i-th range that Iterator traverses. + // current[i]_ is the actual traversing iterator. + const typename ParamGenerator::iterator begin1_; + const typename ParamGenerator::iterator end1_; + typename ParamGenerator::iterator current1_; + const typename ParamGenerator::iterator begin2_; + const typename ParamGenerator::iterator end2_; + typename ParamGenerator::iterator current2_; + const typename ParamGenerator::iterator begin3_; + const typename ParamGenerator::iterator end3_; + typename ParamGenerator::iterator current3_; + const typename ParamGenerator::iterator begin4_; + const typename ParamGenerator::iterator end4_; + typename ParamGenerator::iterator current4_; + const typename ParamGenerator::iterator begin5_; + const typename ParamGenerator::iterator end5_; + typename ParamGenerator::iterator current5_; + ParamType current_value_; + }; // class CartesianProductGenerator5::Iterator + + // No implementation - assignment is unsupported. + void operator=(const CartesianProductGenerator5& other); + + const ParamGenerator g1_; + const ParamGenerator g2_; + const ParamGenerator g3_; + const ParamGenerator g4_; + const ParamGenerator g5_; +}; // class CartesianProductGenerator5 + + +template +class CartesianProductGenerator6 + : public ParamGeneratorInterface< ::std::tr1::tuple > { + public: + typedef ::std::tr1::tuple ParamType; + + CartesianProductGenerator6(const ParamGenerator& g1, + const ParamGenerator& g2, const ParamGenerator& g3, + const ParamGenerator& g4, const ParamGenerator& g5, + const ParamGenerator& g6) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5), g6_(g6) {} + virtual ~CartesianProductGenerator6() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, g1_, g1_.begin(), g2_, g2_.begin(), g3_, + g3_.begin(), g4_, g4_.begin(), g5_, g5_.begin(), g6_, g6_.begin()); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, g1_, g1_.end(), g2_, g2_.end(), g3_, g3_.end(), + g4_, g4_.end(), g5_, g5_.end(), g6_, g6_.end()); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + const ParamGenerator& g1, + const typename ParamGenerator::iterator& current1, + const ParamGenerator& g2, + const typename ParamGenerator::iterator& current2, + const ParamGenerator& g3, + const typename ParamGenerator::iterator& current3, + const ParamGenerator& g4, + const typename ParamGenerator::iterator& current4, + const ParamGenerator& g5, + const typename ParamGenerator::iterator& current5, + const ParamGenerator& g6, + const typename ParamGenerator::iterator& current6) + : base_(base), + begin1_(g1.begin()), end1_(g1.end()), current1_(current1), + begin2_(g2.begin()), end2_(g2.end()), current2_(current2), + begin3_(g3.begin()), end3_(g3.end()), current3_(current3), + begin4_(g4.begin()), end4_(g4.end()), current4_(current4), + begin5_(g5.begin()), end5_(g5.end()), current5_(current5), + begin6_(g6.begin()), end6_(g6.end()), current6_(current6) { + ComputeCurrentValue(); + } + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + // Advance should not be called on beyond-of-range iterators + // so no component iterators must be beyond end of range, either. + virtual void Advance() { + assert(!AtEnd()); + ++current6_; + if (current6_ == end6_) { + current6_ = begin6_; + ++current5_; + } + if (current5_ == end5_) { + current5_ = begin5_; + ++current4_; + } + if (current4_ == end4_) { + current4_ = begin4_; + ++current3_; + } + if (current3_ == end3_) { + current3_ = begin3_; + ++current2_; + } + if (current2_ == end2_) { + current2_ = begin2_; + ++current1_; + } + ComputeCurrentValue(); + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + virtual const ParamType* Current() const { return ¤t_value_; } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const Iterator* typed_other = + CheckedDowncastToActualType(&other); + // We must report iterators equal if they both point beyond their + // respective ranges. That can happen in a variety of fashions, + // so we have to consult AtEnd(). + return (AtEnd() && typed_other->AtEnd()) || + ( + current1_ == typed_other->current1_ && + current2_ == typed_other->current2_ && + current3_ == typed_other->current3_ && + current4_ == typed_other->current4_ && + current5_ == typed_other->current5_ && + current6_ == typed_other->current6_); + } + + private: + Iterator(const Iterator& other) + : base_(other.base_), + begin1_(other.begin1_), + end1_(other.end1_), + current1_(other.current1_), + begin2_(other.begin2_), + end2_(other.end2_), + current2_(other.current2_), + begin3_(other.begin3_), + end3_(other.end3_), + current3_(other.current3_), + begin4_(other.begin4_), + end4_(other.end4_), + current4_(other.current4_), + begin5_(other.begin5_), + end5_(other.end5_), + current5_(other.current5_), + begin6_(other.begin6_), + end6_(other.end6_), + current6_(other.current6_) { + ComputeCurrentValue(); + } + + void ComputeCurrentValue() { + if (!AtEnd()) + current_value_ = ParamType(*current1_, *current2_, *current3_, + *current4_, *current5_, *current6_); + } + bool AtEnd() const { + // We must report iterator past the end of the range when either of the + // component iterators has reached the end of its range. + return + current1_ == end1_ || + current2_ == end2_ || + current3_ == end3_ || + current4_ == end4_ || + current5_ == end5_ || + current6_ == end6_; + } + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + // begin[i]_ and end[i]_ define the i-th range that Iterator traverses. + // current[i]_ is the actual traversing iterator. + const typename ParamGenerator::iterator begin1_; + const typename ParamGenerator::iterator end1_; + typename ParamGenerator::iterator current1_; + const typename ParamGenerator::iterator begin2_; + const typename ParamGenerator::iterator end2_; + typename ParamGenerator::iterator current2_; + const typename ParamGenerator::iterator begin3_; + const typename ParamGenerator::iterator end3_; + typename ParamGenerator::iterator current3_; + const typename ParamGenerator::iterator begin4_; + const typename ParamGenerator::iterator end4_; + typename ParamGenerator::iterator current4_; + const typename ParamGenerator::iterator begin5_; + const typename ParamGenerator::iterator end5_; + typename ParamGenerator::iterator current5_; + const typename ParamGenerator::iterator begin6_; + const typename ParamGenerator::iterator end6_; + typename ParamGenerator::iterator current6_; + ParamType current_value_; + }; // class CartesianProductGenerator6::Iterator + + // No implementation - assignment is unsupported. + void operator=(const CartesianProductGenerator6& other); + + const ParamGenerator g1_; + const ParamGenerator g2_; + const ParamGenerator g3_; + const ParamGenerator g4_; + const ParamGenerator g5_; + const ParamGenerator g6_; +}; // class CartesianProductGenerator6 + + +template +class CartesianProductGenerator7 + : public ParamGeneratorInterface< ::std::tr1::tuple > { + public: + typedef ::std::tr1::tuple ParamType; + + CartesianProductGenerator7(const ParamGenerator& g1, + const ParamGenerator& g2, const ParamGenerator& g3, + const ParamGenerator& g4, const ParamGenerator& g5, + const ParamGenerator& g6, const ParamGenerator& g7) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5), g6_(g6), g7_(g7) {} + virtual ~CartesianProductGenerator7() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, g1_, g1_.begin(), g2_, g2_.begin(), g3_, + g3_.begin(), g4_, g4_.begin(), g5_, g5_.begin(), g6_, g6_.begin(), g7_, + g7_.begin()); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, g1_, g1_.end(), g2_, g2_.end(), g3_, g3_.end(), + g4_, g4_.end(), g5_, g5_.end(), g6_, g6_.end(), g7_, g7_.end()); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + const ParamGenerator& g1, + const typename ParamGenerator::iterator& current1, + const ParamGenerator& g2, + const typename ParamGenerator::iterator& current2, + const ParamGenerator& g3, + const typename ParamGenerator::iterator& current3, + const ParamGenerator& g4, + const typename ParamGenerator::iterator& current4, + const ParamGenerator& g5, + const typename ParamGenerator::iterator& current5, + const ParamGenerator& g6, + const typename ParamGenerator::iterator& current6, + const ParamGenerator& g7, + const typename ParamGenerator::iterator& current7) + : base_(base), + begin1_(g1.begin()), end1_(g1.end()), current1_(current1), + begin2_(g2.begin()), end2_(g2.end()), current2_(current2), + begin3_(g3.begin()), end3_(g3.end()), current3_(current3), + begin4_(g4.begin()), end4_(g4.end()), current4_(current4), + begin5_(g5.begin()), end5_(g5.end()), current5_(current5), + begin6_(g6.begin()), end6_(g6.end()), current6_(current6), + begin7_(g7.begin()), end7_(g7.end()), current7_(current7) { + ComputeCurrentValue(); + } + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + // Advance should not be called on beyond-of-range iterators + // so no component iterators must be beyond end of range, either. + virtual void Advance() { + assert(!AtEnd()); + ++current7_; + if (current7_ == end7_) { + current7_ = begin7_; + ++current6_; + } + if (current6_ == end6_) { + current6_ = begin6_; + ++current5_; + } + if (current5_ == end5_) { + current5_ = begin5_; + ++current4_; + } + if (current4_ == end4_) { + current4_ = begin4_; + ++current3_; + } + if (current3_ == end3_) { + current3_ = begin3_; + ++current2_; + } + if (current2_ == end2_) { + current2_ = begin2_; + ++current1_; + } + ComputeCurrentValue(); + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + virtual const ParamType* Current() const { return ¤t_value_; } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const Iterator* typed_other = + CheckedDowncastToActualType(&other); + // We must report iterators equal if they both point beyond their + // respective ranges. That can happen in a variety of fashions, + // so we have to consult AtEnd(). + return (AtEnd() && typed_other->AtEnd()) || + ( + current1_ == typed_other->current1_ && + current2_ == typed_other->current2_ && + current3_ == typed_other->current3_ && + current4_ == typed_other->current4_ && + current5_ == typed_other->current5_ && + current6_ == typed_other->current6_ && + current7_ == typed_other->current7_); + } + + private: + Iterator(const Iterator& other) + : base_(other.base_), + begin1_(other.begin1_), + end1_(other.end1_), + current1_(other.current1_), + begin2_(other.begin2_), + end2_(other.end2_), + current2_(other.current2_), + begin3_(other.begin3_), + end3_(other.end3_), + current3_(other.current3_), + begin4_(other.begin4_), + end4_(other.end4_), + current4_(other.current4_), + begin5_(other.begin5_), + end5_(other.end5_), + current5_(other.current5_), + begin6_(other.begin6_), + end6_(other.end6_), + current6_(other.current6_), + begin7_(other.begin7_), + end7_(other.end7_), + current7_(other.current7_) { + ComputeCurrentValue(); + } + + void ComputeCurrentValue() { + if (!AtEnd()) + current_value_ = ParamType(*current1_, *current2_, *current3_, + *current4_, *current5_, *current6_, *current7_); + } + bool AtEnd() const { + // We must report iterator past the end of the range when either of the + // component iterators has reached the end of its range. + return + current1_ == end1_ || + current2_ == end2_ || + current3_ == end3_ || + current4_ == end4_ || + current5_ == end5_ || + current6_ == end6_ || + current7_ == end7_; + } + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + // begin[i]_ and end[i]_ define the i-th range that Iterator traverses. + // current[i]_ is the actual traversing iterator. + const typename ParamGenerator::iterator begin1_; + const typename ParamGenerator::iterator end1_; + typename ParamGenerator::iterator current1_; + const typename ParamGenerator::iterator begin2_; + const typename ParamGenerator::iterator end2_; + typename ParamGenerator::iterator current2_; + const typename ParamGenerator::iterator begin3_; + const typename ParamGenerator::iterator end3_; + typename ParamGenerator::iterator current3_; + const typename ParamGenerator::iterator begin4_; + const typename ParamGenerator::iterator end4_; + typename ParamGenerator::iterator current4_; + const typename ParamGenerator::iterator begin5_; + const typename ParamGenerator::iterator end5_; + typename ParamGenerator::iterator current5_; + const typename ParamGenerator::iterator begin6_; + const typename ParamGenerator::iterator end6_; + typename ParamGenerator::iterator current6_; + const typename ParamGenerator::iterator begin7_; + const typename ParamGenerator::iterator end7_; + typename ParamGenerator::iterator current7_; + ParamType current_value_; + }; // class CartesianProductGenerator7::Iterator + + // No implementation - assignment is unsupported. + void operator=(const CartesianProductGenerator7& other); + + const ParamGenerator g1_; + const ParamGenerator g2_; + const ParamGenerator g3_; + const ParamGenerator g4_; + const ParamGenerator g5_; + const ParamGenerator g6_; + const ParamGenerator g7_; +}; // class CartesianProductGenerator7 + + +template +class CartesianProductGenerator8 + : public ParamGeneratorInterface< ::std::tr1::tuple > { + public: + typedef ::std::tr1::tuple ParamType; + + CartesianProductGenerator8(const ParamGenerator& g1, + const ParamGenerator& g2, const ParamGenerator& g3, + const ParamGenerator& g4, const ParamGenerator& g5, + const ParamGenerator& g6, const ParamGenerator& g7, + const ParamGenerator& g8) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5), g6_(g6), g7_(g7), + g8_(g8) {} + virtual ~CartesianProductGenerator8() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, g1_, g1_.begin(), g2_, g2_.begin(), g3_, + g3_.begin(), g4_, g4_.begin(), g5_, g5_.begin(), g6_, g6_.begin(), g7_, + g7_.begin(), g8_, g8_.begin()); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, g1_, g1_.end(), g2_, g2_.end(), g3_, g3_.end(), + g4_, g4_.end(), g5_, g5_.end(), g6_, g6_.end(), g7_, g7_.end(), g8_, + g8_.end()); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + const ParamGenerator& g1, + const typename ParamGenerator::iterator& current1, + const ParamGenerator& g2, + const typename ParamGenerator::iterator& current2, + const ParamGenerator& g3, + const typename ParamGenerator::iterator& current3, + const ParamGenerator& g4, + const typename ParamGenerator::iterator& current4, + const ParamGenerator& g5, + const typename ParamGenerator::iterator& current5, + const ParamGenerator& g6, + const typename ParamGenerator::iterator& current6, + const ParamGenerator& g7, + const typename ParamGenerator::iterator& current7, + const ParamGenerator& g8, + const typename ParamGenerator::iterator& current8) + : base_(base), + begin1_(g1.begin()), end1_(g1.end()), current1_(current1), + begin2_(g2.begin()), end2_(g2.end()), current2_(current2), + begin3_(g3.begin()), end3_(g3.end()), current3_(current3), + begin4_(g4.begin()), end4_(g4.end()), current4_(current4), + begin5_(g5.begin()), end5_(g5.end()), current5_(current5), + begin6_(g6.begin()), end6_(g6.end()), current6_(current6), + begin7_(g7.begin()), end7_(g7.end()), current7_(current7), + begin8_(g8.begin()), end8_(g8.end()), current8_(current8) { + ComputeCurrentValue(); + } + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + // Advance should not be called on beyond-of-range iterators + // so no component iterators must be beyond end of range, either. + virtual void Advance() { + assert(!AtEnd()); + ++current8_; + if (current8_ == end8_) { + current8_ = begin8_; + ++current7_; + } + if (current7_ == end7_) { + current7_ = begin7_; + ++current6_; + } + if (current6_ == end6_) { + current6_ = begin6_; + ++current5_; + } + if (current5_ == end5_) { + current5_ = begin5_; + ++current4_; + } + if (current4_ == end4_) { + current4_ = begin4_; + ++current3_; + } + if (current3_ == end3_) { + current3_ = begin3_; + ++current2_; + } + if (current2_ == end2_) { + current2_ = begin2_; + ++current1_; + } + ComputeCurrentValue(); + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + virtual const ParamType* Current() const { return ¤t_value_; } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const Iterator* typed_other = + CheckedDowncastToActualType(&other); + // We must report iterators equal if they both point beyond their + // respective ranges. That can happen in a variety of fashions, + // so we have to consult AtEnd(). + return (AtEnd() && typed_other->AtEnd()) || + ( + current1_ == typed_other->current1_ && + current2_ == typed_other->current2_ && + current3_ == typed_other->current3_ && + current4_ == typed_other->current4_ && + current5_ == typed_other->current5_ && + current6_ == typed_other->current6_ && + current7_ == typed_other->current7_ && + current8_ == typed_other->current8_); + } + + private: + Iterator(const Iterator& other) + : base_(other.base_), + begin1_(other.begin1_), + end1_(other.end1_), + current1_(other.current1_), + begin2_(other.begin2_), + end2_(other.end2_), + current2_(other.current2_), + begin3_(other.begin3_), + end3_(other.end3_), + current3_(other.current3_), + begin4_(other.begin4_), + end4_(other.end4_), + current4_(other.current4_), + begin5_(other.begin5_), + end5_(other.end5_), + current5_(other.current5_), + begin6_(other.begin6_), + end6_(other.end6_), + current6_(other.current6_), + begin7_(other.begin7_), + end7_(other.end7_), + current7_(other.current7_), + begin8_(other.begin8_), + end8_(other.end8_), + current8_(other.current8_) { + ComputeCurrentValue(); + } + + void ComputeCurrentValue() { + if (!AtEnd()) + current_value_ = ParamType(*current1_, *current2_, *current3_, + *current4_, *current5_, *current6_, *current7_, *current8_); + } + bool AtEnd() const { + // We must report iterator past the end of the range when either of the + // component iterators has reached the end of its range. + return + current1_ == end1_ || + current2_ == end2_ || + current3_ == end3_ || + current4_ == end4_ || + current5_ == end5_ || + current6_ == end6_ || + current7_ == end7_ || + current8_ == end8_; + } + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + // begin[i]_ and end[i]_ define the i-th range that Iterator traverses. + // current[i]_ is the actual traversing iterator. + const typename ParamGenerator::iterator begin1_; + const typename ParamGenerator::iterator end1_; + typename ParamGenerator::iterator current1_; + const typename ParamGenerator::iterator begin2_; + const typename ParamGenerator::iterator end2_; + typename ParamGenerator::iterator current2_; + const typename ParamGenerator::iterator begin3_; + const typename ParamGenerator::iterator end3_; + typename ParamGenerator::iterator current3_; + const typename ParamGenerator::iterator begin4_; + const typename ParamGenerator::iterator end4_; + typename ParamGenerator::iterator current4_; + const typename ParamGenerator::iterator begin5_; + const typename ParamGenerator::iterator end5_; + typename ParamGenerator::iterator current5_; + const typename ParamGenerator::iterator begin6_; + const typename ParamGenerator::iterator end6_; + typename ParamGenerator::iterator current6_; + const typename ParamGenerator::iterator begin7_; + const typename ParamGenerator::iterator end7_; + typename ParamGenerator::iterator current7_; + const typename ParamGenerator::iterator begin8_; + const typename ParamGenerator::iterator end8_; + typename ParamGenerator::iterator current8_; + ParamType current_value_; + }; // class CartesianProductGenerator8::Iterator + + // No implementation - assignment is unsupported. + void operator=(const CartesianProductGenerator8& other); + + const ParamGenerator g1_; + const ParamGenerator g2_; + const ParamGenerator g3_; + const ParamGenerator g4_; + const ParamGenerator g5_; + const ParamGenerator g6_; + const ParamGenerator g7_; + const ParamGenerator g8_; +}; // class CartesianProductGenerator8 + + +template +class CartesianProductGenerator9 + : public ParamGeneratorInterface< ::std::tr1::tuple > { + public: + typedef ::std::tr1::tuple ParamType; + + CartesianProductGenerator9(const ParamGenerator& g1, + const ParamGenerator& g2, const ParamGenerator& g3, + const ParamGenerator& g4, const ParamGenerator& g5, + const ParamGenerator& g6, const ParamGenerator& g7, + const ParamGenerator& g8, const ParamGenerator& g9) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5), g6_(g6), g7_(g7), g8_(g8), + g9_(g9) {} + virtual ~CartesianProductGenerator9() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, g1_, g1_.begin(), g2_, g2_.begin(), g3_, + g3_.begin(), g4_, g4_.begin(), g5_, g5_.begin(), g6_, g6_.begin(), g7_, + g7_.begin(), g8_, g8_.begin(), g9_, g9_.begin()); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, g1_, g1_.end(), g2_, g2_.end(), g3_, g3_.end(), + g4_, g4_.end(), g5_, g5_.end(), g6_, g6_.end(), g7_, g7_.end(), g8_, + g8_.end(), g9_, g9_.end()); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + const ParamGenerator& g1, + const typename ParamGenerator::iterator& current1, + const ParamGenerator& g2, + const typename ParamGenerator::iterator& current2, + const ParamGenerator& g3, + const typename ParamGenerator::iterator& current3, + const ParamGenerator& g4, + const typename ParamGenerator::iterator& current4, + const ParamGenerator& g5, + const typename ParamGenerator::iterator& current5, + const ParamGenerator& g6, + const typename ParamGenerator::iterator& current6, + const ParamGenerator& g7, + const typename ParamGenerator::iterator& current7, + const ParamGenerator& g8, + const typename ParamGenerator::iterator& current8, + const ParamGenerator& g9, + const typename ParamGenerator::iterator& current9) + : base_(base), + begin1_(g1.begin()), end1_(g1.end()), current1_(current1), + begin2_(g2.begin()), end2_(g2.end()), current2_(current2), + begin3_(g3.begin()), end3_(g3.end()), current3_(current3), + begin4_(g4.begin()), end4_(g4.end()), current4_(current4), + begin5_(g5.begin()), end5_(g5.end()), current5_(current5), + begin6_(g6.begin()), end6_(g6.end()), current6_(current6), + begin7_(g7.begin()), end7_(g7.end()), current7_(current7), + begin8_(g8.begin()), end8_(g8.end()), current8_(current8), + begin9_(g9.begin()), end9_(g9.end()), current9_(current9) { + ComputeCurrentValue(); + } + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + // Advance should not be called on beyond-of-range iterators + // so no component iterators must be beyond end of range, either. + virtual void Advance() { + assert(!AtEnd()); + ++current9_; + if (current9_ == end9_) { + current9_ = begin9_; + ++current8_; + } + if (current8_ == end8_) { + current8_ = begin8_; + ++current7_; + } + if (current7_ == end7_) { + current7_ = begin7_; + ++current6_; + } + if (current6_ == end6_) { + current6_ = begin6_; + ++current5_; + } + if (current5_ == end5_) { + current5_ = begin5_; + ++current4_; + } + if (current4_ == end4_) { + current4_ = begin4_; + ++current3_; + } + if (current3_ == end3_) { + current3_ = begin3_; + ++current2_; + } + if (current2_ == end2_) { + current2_ = begin2_; + ++current1_; + } + ComputeCurrentValue(); + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + virtual const ParamType* Current() const { return ¤t_value_; } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const Iterator* typed_other = + CheckedDowncastToActualType(&other); + // We must report iterators equal if they both point beyond their + // respective ranges. That can happen in a variety of fashions, + // so we have to consult AtEnd(). + return (AtEnd() && typed_other->AtEnd()) || + ( + current1_ == typed_other->current1_ && + current2_ == typed_other->current2_ && + current3_ == typed_other->current3_ && + current4_ == typed_other->current4_ && + current5_ == typed_other->current5_ && + current6_ == typed_other->current6_ && + current7_ == typed_other->current7_ && + current8_ == typed_other->current8_ && + current9_ == typed_other->current9_); + } + + private: + Iterator(const Iterator& other) + : base_(other.base_), + begin1_(other.begin1_), + end1_(other.end1_), + current1_(other.current1_), + begin2_(other.begin2_), + end2_(other.end2_), + current2_(other.current2_), + begin3_(other.begin3_), + end3_(other.end3_), + current3_(other.current3_), + begin4_(other.begin4_), + end4_(other.end4_), + current4_(other.current4_), + begin5_(other.begin5_), + end5_(other.end5_), + current5_(other.current5_), + begin6_(other.begin6_), + end6_(other.end6_), + current6_(other.current6_), + begin7_(other.begin7_), + end7_(other.end7_), + current7_(other.current7_), + begin8_(other.begin8_), + end8_(other.end8_), + current8_(other.current8_), + begin9_(other.begin9_), + end9_(other.end9_), + current9_(other.current9_) { + ComputeCurrentValue(); + } + + void ComputeCurrentValue() { + if (!AtEnd()) + current_value_ = ParamType(*current1_, *current2_, *current3_, + *current4_, *current5_, *current6_, *current7_, *current8_, + *current9_); + } + bool AtEnd() const { + // We must report iterator past the end of the range when either of the + // component iterators has reached the end of its range. + return + current1_ == end1_ || + current2_ == end2_ || + current3_ == end3_ || + current4_ == end4_ || + current5_ == end5_ || + current6_ == end6_ || + current7_ == end7_ || + current8_ == end8_ || + current9_ == end9_; + } + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + // begin[i]_ and end[i]_ define the i-th range that Iterator traverses. + // current[i]_ is the actual traversing iterator. + const typename ParamGenerator::iterator begin1_; + const typename ParamGenerator::iterator end1_; + typename ParamGenerator::iterator current1_; + const typename ParamGenerator::iterator begin2_; + const typename ParamGenerator::iterator end2_; + typename ParamGenerator::iterator current2_; + const typename ParamGenerator::iterator begin3_; + const typename ParamGenerator::iterator end3_; + typename ParamGenerator::iterator current3_; + const typename ParamGenerator::iterator begin4_; + const typename ParamGenerator::iterator end4_; + typename ParamGenerator::iterator current4_; + const typename ParamGenerator::iterator begin5_; + const typename ParamGenerator::iterator end5_; + typename ParamGenerator::iterator current5_; + const typename ParamGenerator::iterator begin6_; + const typename ParamGenerator::iterator end6_; + typename ParamGenerator::iterator current6_; + const typename ParamGenerator::iterator begin7_; + const typename ParamGenerator::iterator end7_; + typename ParamGenerator::iterator current7_; + const typename ParamGenerator::iterator begin8_; + const typename ParamGenerator::iterator end8_; + typename ParamGenerator::iterator current8_; + const typename ParamGenerator::iterator begin9_; + const typename ParamGenerator::iterator end9_; + typename ParamGenerator::iterator current9_; + ParamType current_value_; + }; // class CartesianProductGenerator9::Iterator + + // No implementation - assignment is unsupported. + void operator=(const CartesianProductGenerator9& other); + + const ParamGenerator g1_; + const ParamGenerator g2_; + const ParamGenerator g3_; + const ParamGenerator g4_; + const ParamGenerator g5_; + const ParamGenerator g6_; + const ParamGenerator g7_; + const ParamGenerator g8_; + const ParamGenerator g9_; +}; // class CartesianProductGenerator9 + + +template +class CartesianProductGenerator10 + : public ParamGeneratorInterface< ::std::tr1::tuple > { + public: + typedef ::std::tr1::tuple ParamType; + + CartesianProductGenerator10(const ParamGenerator& g1, + const ParamGenerator& g2, const ParamGenerator& g3, + const ParamGenerator& g4, const ParamGenerator& g5, + const ParamGenerator& g6, const ParamGenerator& g7, + const ParamGenerator& g8, const ParamGenerator& g9, + const ParamGenerator& g10) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5), g6_(g6), g7_(g7), g8_(g8), + g9_(g9), g10_(g10) {} + virtual ~CartesianProductGenerator10() {} + + virtual ParamIteratorInterface* Begin() const { + return new Iterator(this, g1_, g1_.begin(), g2_, g2_.begin(), g3_, + g3_.begin(), g4_, g4_.begin(), g5_, g5_.begin(), g6_, g6_.begin(), g7_, + g7_.begin(), g8_, g8_.begin(), g9_, g9_.begin(), g10_, g10_.begin()); + } + virtual ParamIteratorInterface* End() const { + return new Iterator(this, g1_, g1_.end(), g2_, g2_.end(), g3_, g3_.end(), + g4_, g4_.end(), g5_, g5_.end(), g6_, g6_.end(), g7_, g7_.end(), g8_, + g8_.end(), g9_, g9_.end(), g10_, g10_.end()); + } + + private: + class Iterator : public ParamIteratorInterface { + public: + Iterator(const ParamGeneratorInterface* base, + const ParamGenerator& g1, + const typename ParamGenerator::iterator& current1, + const ParamGenerator& g2, + const typename ParamGenerator::iterator& current2, + const ParamGenerator& g3, + const typename ParamGenerator::iterator& current3, + const ParamGenerator& g4, + const typename ParamGenerator::iterator& current4, + const ParamGenerator& g5, + const typename ParamGenerator::iterator& current5, + const ParamGenerator& g6, + const typename ParamGenerator::iterator& current6, + const ParamGenerator& g7, + const typename ParamGenerator::iterator& current7, + const ParamGenerator& g8, + const typename ParamGenerator::iterator& current8, + const ParamGenerator& g9, + const typename ParamGenerator::iterator& current9, + const ParamGenerator& g10, + const typename ParamGenerator::iterator& current10) + : base_(base), + begin1_(g1.begin()), end1_(g1.end()), current1_(current1), + begin2_(g2.begin()), end2_(g2.end()), current2_(current2), + begin3_(g3.begin()), end3_(g3.end()), current3_(current3), + begin4_(g4.begin()), end4_(g4.end()), current4_(current4), + begin5_(g5.begin()), end5_(g5.end()), current5_(current5), + begin6_(g6.begin()), end6_(g6.end()), current6_(current6), + begin7_(g7.begin()), end7_(g7.end()), current7_(current7), + begin8_(g8.begin()), end8_(g8.end()), current8_(current8), + begin9_(g9.begin()), end9_(g9.end()), current9_(current9), + begin10_(g10.begin()), end10_(g10.end()), current10_(current10) { + ComputeCurrentValue(); + } + virtual ~Iterator() {} + + virtual const ParamGeneratorInterface* BaseGenerator() const { + return base_; + } + // Advance should not be called on beyond-of-range iterators + // so no component iterators must be beyond end of range, either. + virtual void Advance() { + assert(!AtEnd()); + ++current10_; + if (current10_ == end10_) { + current10_ = begin10_; + ++current9_; + } + if (current9_ == end9_) { + current9_ = begin9_; + ++current8_; + } + if (current8_ == end8_) { + current8_ = begin8_; + ++current7_; + } + if (current7_ == end7_) { + current7_ = begin7_; + ++current6_; + } + if (current6_ == end6_) { + current6_ = begin6_; + ++current5_; + } + if (current5_ == end5_) { + current5_ = begin5_; + ++current4_; + } + if (current4_ == end4_) { + current4_ = begin4_; + ++current3_; + } + if (current3_ == end3_) { + current3_ = begin3_; + ++current2_; + } + if (current2_ == end2_) { + current2_ = begin2_; + ++current1_; + } + ComputeCurrentValue(); + } + virtual ParamIteratorInterface* Clone() const { + return new Iterator(*this); + } + virtual const ParamType* Current() const { return ¤t_value_; } + virtual bool Equals(const ParamIteratorInterface& other) const { + // Having the same base generator guarantees that the other + // iterator is of the same type and we can downcast. + GTEST_CHECK_(BaseGenerator() == other.BaseGenerator()) + << "The program attempted to compare iterators " + << "from different generators." << std::endl; + const Iterator* typed_other = + CheckedDowncastToActualType(&other); + // We must report iterators equal if they both point beyond their + // respective ranges. That can happen in a variety of fashions, + // so we have to consult AtEnd(). + return (AtEnd() && typed_other->AtEnd()) || + ( + current1_ == typed_other->current1_ && + current2_ == typed_other->current2_ && + current3_ == typed_other->current3_ && + current4_ == typed_other->current4_ && + current5_ == typed_other->current5_ && + current6_ == typed_other->current6_ && + current7_ == typed_other->current7_ && + current8_ == typed_other->current8_ && + current9_ == typed_other->current9_ && + current10_ == typed_other->current10_); + } + + private: + Iterator(const Iterator& other) + : base_(other.base_), + begin1_(other.begin1_), + end1_(other.end1_), + current1_(other.current1_), + begin2_(other.begin2_), + end2_(other.end2_), + current2_(other.current2_), + begin3_(other.begin3_), + end3_(other.end3_), + current3_(other.current3_), + begin4_(other.begin4_), + end4_(other.end4_), + current4_(other.current4_), + begin5_(other.begin5_), + end5_(other.end5_), + current5_(other.current5_), + begin6_(other.begin6_), + end6_(other.end6_), + current6_(other.current6_), + begin7_(other.begin7_), + end7_(other.end7_), + current7_(other.current7_), + begin8_(other.begin8_), + end8_(other.end8_), + current8_(other.current8_), + begin9_(other.begin9_), + end9_(other.end9_), + current9_(other.current9_), + begin10_(other.begin10_), + end10_(other.end10_), + current10_(other.current10_) { + ComputeCurrentValue(); + } + + void ComputeCurrentValue() { + if (!AtEnd()) + current_value_ = ParamType(*current1_, *current2_, *current3_, + *current4_, *current5_, *current6_, *current7_, *current8_, + *current9_, *current10_); + } + bool AtEnd() const { + // We must report iterator past the end of the range when either of the + // component iterators has reached the end of its range. + return + current1_ == end1_ || + current2_ == end2_ || + current3_ == end3_ || + current4_ == end4_ || + current5_ == end5_ || + current6_ == end6_ || + current7_ == end7_ || + current8_ == end8_ || + current9_ == end9_ || + current10_ == end10_; + } + + // No implementation - assignment is unsupported. + void operator=(const Iterator& other); + + const ParamGeneratorInterface* const base_; + // begin[i]_ and end[i]_ define the i-th range that Iterator traverses. + // current[i]_ is the actual traversing iterator. + const typename ParamGenerator::iterator begin1_; + const typename ParamGenerator::iterator end1_; + typename ParamGenerator::iterator current1_; + const typename ParamGenerator::iterator begin2_; + const typename ParamGenerator::iterator end2_; + typename ParamGenerator::iterator current2_; + const typename ParamGenerator::iterator begin3_; + const typename ParamGenerator::iterator end3_; + typename ParamGenerator::iterator current3_; + const typename ParamGenerator::iterator begin4_; + const typename ParamGenerator::iterator end4_; + typename ParamGenerator::iterator current4_; + const typename ParamGenerator::iterator begin5_; + const typename ParamGenerator::iterator end5_; + typename ParamGenerator::iterator current5_; + const typename ParamGenerator::iterator begin6_; + const typename ParamGenerator::iterator end6_; + typename ParamGenerator::iterator current6_; + const typename ParamGenerator::iterator begin7_; + const typename ParamGenerator::iterator end7_; + typename ParamGenerator::iterator current7_; + const typename ParamGenerator::iterator begin8_; + const typename ParamGenerator::iterator end8_; + typename ParamGenerator::iterator current8_; + const typename ParamGenerator::iterator begin9_; + const typename ParamGenerator::iterator end9_; + typename ParamGenerator::iterator current9_; + const typename ParamGenerator::iterator begin10_; + const typename ParamGenerator::iterator end10_; + typename ParamGenerator::iterator current10_; + ParamType current_value_; + }; // class CartesianProductGenerator10::Iterator + + // No implementation - assignment is unsupported. + void operator=(const CartesianProductGenerator10& other); + + const ParamGenerator g1_; + const ParamGenerator g2_; + const ParamGenerator g3_; + const ParamGenerator g4_; + const ParamGenerator g5_; + const ParamGenerator g6_; + const ParamGenerator g7_; + const ParamGenerator g8_; + const ParamGenerator g9_; + const ParamGenerator g10_; +}; // class CartesianProductGenerator10 + + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Helper classes providing Combine() with polymorphic features. They allow +// casting CartesianProductGeneratorN to ParamGenerator if T is +// convertible to U. +// +template +class CartesianProductHolder2 { + public: +CartesianProductHolder2(const Generator1& g1, const Generator2& g2) + : g1_(g1), g2_(g2) {} + template + operator ParamGenerator< ::std::tr1::tuple >() const { + return ParamGenerator< ::std::tr1::tuple >( + new CartesianProductGenerator2( + static_cast >(g1_), + static_cast >(g2_))); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const CartesianProductHolder2& other); + + const Generator1 g1_; + const Generator2 g2_; +}; // class CartesianProductHolder2 + +template +class CartesianProductHolder3 { + public: +CartesianProductHolder3(const Generator1& g1, const Generator2& g2, + const Generator3& g3) + : g1_(g1), g2_(g2), g3_(g3) {} + template + operator ParamGenerator< ::std::tr1::tuple >() const { + return ParamGenerator< ::std::tr1::tuple >( + new CartesianProductGenerator3( + static_cast >(g1_), + static_cast >(g2_), + static_cast >(g3_))); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const CartesianProductHolder3& other); + + const Generator1 g1_; + const Generator2 g2_; + const Generator3 g3_; +}; // class CartesianProductHolder3 + +template +class CartesianProductHolder4 { + public: +CartesianProductHolder4(const Generator1& g1, const Generator2& g2, + const Generator3& g3, const Generator4& g4) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4) {} + template + operator ParamGenerator< ::std::tr1::tuple >() const { + return ParamGenerator< ::std::tr1::tuple >( + new CartesianProductGenerator4( + static_cast >(g1_), + static_cast >(g2_), + static_cast >(g3_), + static_cast >(g4_))); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const CartesianProductHolder4& other); + + const Generator1 g1_; + const Generator2 g2_; + const Generator3 g3_; + const Generator4 g4_; +}; // class CartesianProductHolder4 + +template +class CartesianProductHolder5 { + public: +CartesianProductHolder5(const Generator1& g1, const Generator2& g2, + const Generator3& g3, const Generator4& g4, const Generator5& g5) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5) {} + template + operator ParamGenerator< ::std::tr1::tuple >() const { + return ParamGenerator< ::std::tr1::tuple >( + new CartesianProductGenerator5( + static_cast >(g1_), + static_cast >(g2_), + static_cast >(g3_), + static_cast >(g4_), + static_cast >(g5_))); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const CartesianProductHolder5& other); + + const Generator1 g1_; + const Generator2 g2_; + const Generator3 g3_; + const Generator4 g4_; + const Generator5 g5_; +}; // class CartesianProductHolder5 + +template +class CartesianProductHolder6 { + public: +CartesianProductHolder6(const Generator1& g1, const Generator2& g2, + const Generator3& g3, const Generator4& g4, const Generator5& g5, + const Generator6& g6) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5), g6_(g6) {} + template + operator ParamGenerator< ::std::tr1::tuple >() const { + return ParamGenerator< ::std::tr1::tuple >( + new CartesianProductGenerator6( + static_cast >(g1_), + static_cast >(g2_), + static_cast >(g3_), + static_cast >(g4_), + static_cast >(g5_), + static_cast >(g6_))); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const CartesianProductHolder6& other); + + const Generator1 g1_; + const Generator2 g2_; + const Generator3 g3_; + const Generator4 g4_; + const Generator5 g5_; + const Generator6 g6_; +}; // class CartesianProductHolder6 + +template +class CartesianProductHolder7 { + public: +CartesianProductHolder7(const Generator1& g1, const Generator2& g2, + const Generator3& g3, const Generator4& g4, const Generator5& g5, + const Generator6& g6, const Generator7& g7) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5), g6_(g6), g7_(g7) {} + template + operator ParamGenerator< ::std::tr1::tuple >() const { + return ParamGenerator< ::std::tr1::tuple >( + new CartesianProductGenerator7( + static_cast >(g1_), + static_cast >(g2_), + static_cast >(g3_), + static_cast >(g4_), + static_cast >(g5_), + static_cast >(g6_), + static_cast >(g7_))); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const CartesianProductHolder7& other); + + const Generator1 g1_; + const Generator2 g2_; + const Generator3 g3_; + const Generator4 g4_; + const Generator5 g5_; + const Generator6 g6_; + const Generator7 g7_; +}; // class CartesianProductHolder7 + +template +class CartesianProductHolder8 { + public: +CartesianProductHolder8(const Generator1& g1, const Generator2& g2, + const Generator3& g3, const Generator4& g4, const Generator5& g5, + const Generator6& g6, const Generator7& g7, const Generator8& g8) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5), g6_(g6), g7_(g7), + g8_(g8) {} + template + operator ParamGenerator< ::std::tr1::tuple >() const { + return ParamGenerator< ::std::tr1::tuple >( + new CartesianProductGenerator8( + static_cast >(g1_), + static_cast >(g2_), + static_cast >(g3_), + static_cast >(g4_), + static_cast >(g5_), + static_cast >(g6_), + static_cast >(g7_), + static_cast >(g8_))); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const CartesianProductHolder8& other); + + const Generator1 g1_; + const Generator2 g2_; + const Generator3 g3_; + const Generator4 g4_; + const Generator5 g5_; + const Generator6 g6_; + const Generator7 g7_; + const Generator8 g8_; +}; // class CartesianProductHolder8 + +template +class CartesianProductHolder9 { + public: +CartesianProductHolder9(const Generator1& g1, const Generator2& g2, + const Generator3& g3, const Generator4& g4, const Generator5& g5, + const Generator6& g6, const Generator7& g7, const Generator8& g8, + const Generator9& g9) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5), g6_(g6), g7_(g7), g8_(g8), + g9_(g9) {} + template + operator ParamGenerator< ::std::tr1::tuple >() const { + return ParamGenerator< ::std::tr1::tuple >( + new CartesianProductGenerator9( + static_cast >(g1_), + static_cast >(g2_), + static_cast >(g3_), + static_cast >(g4_), + static_cast >(g5_), + static_cast >(g6_), + static_cast >(g7_), + static_cast >(g8_), + static_cast >(g9_))); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const CartesianProductHolder9& other); + + const Generator1 g1_; + const Generator2 g2_; + const Generator3 g3_; + const Generator4 g4_; + const Generator5 g5_; + const Generator6 g6_; + const Generator7 g7_; + const Generator8 g8_; + const Generator9 g9_; +}; // class CartesianProductHolder9 + +template +class CartesianProductHolder10 { + public: +CartesianProductHolder10(const Generator1& g1, const Generator2& g2, + const Generator3& g3, const Generator4& g4, const Generator5& g5, + const Generator6& g6, const Generator7& g7, const Generator8& g8, + const Generator9& g9, const Generator10& g10) + : g1_(g1), g2_(g2), g3_(g3), g4_(g4), g5_(g5), g6_(g6), g7_(g7), g8_(g8), + g9_(g9), g10_(g10) {} + template + operator ParamGenerator< ::std::tr1::tuple >() const { + return ParamGenerator< ::std::tr1::tuple >( + new CartesianProductGenerator10( + static_cast >(g1_), + static_cast >(g2_), + static_cast >(g3_), + static_cast >(g4_), + static_cast >(g5_), + static_cast >(g6_), + static_cast >(g7_), + static_cast >(g8_), + static_cast >(g9_), + static_cast >(g10_))); + } + + private: + // No implementation - assignment is unsupported. + void operator=(const CartesianProductHolder10& other); + + const Generator1 g1_; + const Generator2 g2_; + const Generator3 g3_; + const Generator4 g4_; + const Generator5 g5_; + const Generator6 g6_; + const Generator7 g7_; + const Generator8 g8_; + const Generator9 g9_; + const Generator10 g10_; +}; // class CartesianProductHolder10 + +# endif // GTEST_HAS_COMBINE + +} // namespace internal +} // namespace testing + +#endif // GTEST_HAS_PARAM_TEST + +#endif // GTEST_INCLUDE_GTEST_INTERNAL_GTEST_PARAM_UTIL_GENERATED_H_ + +#if GTEST_HAS_PARAM_TEST + +namespace testing { + +// Functions producing parameter generators. +// +// Google Test uses these generators to produce parameters for value- +// parameterized tests. When a parameterized test case is instantiated +// with a particular generator, Google Test creates and runs tests +// for each element in the sequence produced by the generator. +// +// In the following sample, tests from test case FooTest are instantiated +// each three times with parameter values 3, 5, and 8: +// +// class FooTest : public TestWithParam { ... }; +// +// TEST_P(FooTest, TestThis) { +// } +// TEST_P(FooTest, TestThat) { +// } +// INSTANTIATE_TEST_CASE_P(TestSequence, FooTest, Values(3, 5, 8)); +// + +// Range() returns generators providing sequences of values in a range. +// +// Synopsis: +// Range(start, end) +// - returns a generator producing a sequence of values {start, start+1, +// start+2, ..., }. +// Range(start, end, step) +// - returns a generator producing a sequence of values {start, start+step, +// start+step+step, ..., }. +// Notes: +// * The generated sequences never include end. For example, Range(1, 5) +// returns a generator producing a sequence {1, 2, 3, 4}. Range(1, 9, 2) +// returns a generator producing {1, 3, 5, 7}. +// * start and end must have the same type. That type may be any integral or +// floating-point type or a user defined type satisfying these conditions: +// * It must be assignable (have operator=() defined). +// * It must have operator+() (operator+(int-compatible type) for +// two-operand version). +// * It must have operator<() defined. +// Elements in the resulting sequences will also have that type. +// * Condition start < end must be satisfied in order for resulting sequences +// to contain any elements. +// +template +internal::ParamGenerator Range(T start, T end, IncrementT step) { + return internal::ParamGenerator( + new internal::RangeGenerator(start, end, step)); +} + +template +internal::ParamGenerator Range(T start, T end) { + return Range(start, end, 1); +} + +// ValuesIn() function allows generation of tests with parameters coming from +// a container. +// +// Synopsis: +// ValuesIn(const T (&array)[N]) +// - returns a generator producing sequences with elements from +// a C-style array. +// ValuesIn(const Container& container) +// - returns a generator producing sequences with elements from +// an STL-style container. +// ValuesIn(Iterator begin, Iterator end) +// - returns a generator producing sequences with elements from +// a range [begin, end) defined by a pair of STL-style iterators. These +// iterators can also be plain C pointers. +// +// Please note that ValuesIn copies the values from the containers +// passed in and keeps them to generate tests in RUN_ALL_TESTS(). +// +// Examples: +// +// This instantiates tests from test case StringTest +// each with C-string values of "foo", "bar", and "baz": +// +// const char* strings[] = {"foo", "bar", "baz"}; +// INSTANTIATE_TEST_CASE_P(StringSequence, SrtingTest, ValuesIn(strings)); +// +// This instantiates tests from test case StlStringTest +// each with STL strings with values "a" and "b": +// +// ::std::vector< ::std::string> GetParameterStrings() { +// ::std::vector< ::std::string> v; +// v.push_back("a"); +// v.push_back("b"); +// return v; +// } +// +// INSTANTIATE_TEST_CASE_P(CharSequence, +// StlStringTest, +// ValuesIn(GetParameterStrings())); +// +// +// This will also instantiate tests from CharTest +// each with parameter values 'a' and 'b': +// +// ::std::list GetParameterChars() { +// ::std::list list; +// list.push_back('a'); +// list.push_back('b'); +// return list; +// } +// ::std::list l = GetParameterChars(); +// INSTANTIATE_TEST_CASE_P(CharSequence2, +// CharTest, +// ValuesIn(l.begin(), l.end())); +// +template +internal::ParamGenerator< + typename ::testing::internal::IteratorTraits::value_type> +ValuesIn(ForwardIterator begin, ForwardIterator end) { + typedef typename ::testing::internal::IteratorTraits + ::value_type ParamType; + return internal::ParamGenerator( + new internal::ValuesInIteratorRangeGenerator(begin, end)); +} + +template +internal::ParamGenerator ValuesIn(const T (&array)[N]) { + return ValuesIn(array, array + N); +} + +template +internal::ParamGenerator ValuesIn( + const Container& container) { + return ValuesIn(container.begin(), container.end()); +} + +// Values() allows generating tests from explicitly specified list of +// parameters. +// +// Synopsis: +// Values(T v1, T v2, ..., T vN) +// - returns a generator producing sequences with elements v1, v2, ..., vN. +// +// For example, this instantiates tests from test case BarTest each +// with values "one", "two", and "three": +// +// INSTANTIATE_TEST_CASE_P(NumSequence, BarTest, Values("one", "two", "three")); +// +// This instantiates tests from test case BazTest each with values 1, 2, 3.5. +// The exact type of values will depend on the type of parameter in BazTest. +// +// INSTANTIATE_TEST_CASE_P(FloatingNumbers, BazTest, Values(1, 2, 3.5)); +// +// Currently, Values() supports from 1 to 50 parameters. +// +template +internal::ValueArray1 Values(T1 v1) { + return internal::ValueArray1(v1); +} + +template +internal::ValueArray2 Values(T1 v1, T2 v2) { + return internal::ValueArray2(v1, v2); +} + +template +internal::ValueArray3 Values(T1 v1, T2 v2, T3 v3) { + return internal::ValueArray3(v1, v2, v3); +} + +template +internal::ValueArray4 Values(T1 v1, T2 v2, T3 v3, T4 v4) { + return internal::ValueArray4(v1, v2, v3, v4); +} + +template +internal::ValueArray5 Values(T1 v1, T2 v2, T3 v3, T4 v4, + T5 v5) { + return internal::ValueArray5(v1, v2, v3, v4, v5); +} + +template +internal::ValueArray6 Values(T1 v1, T2 v2, T3 v3, + T4 v4, T5 v5, T6 v6) { + return internal::ValueArray6(v1, v2, v3, v4, v5, v6); +} + +template +internal::ValueArray7 Values(T1 v1, T2 v2, T3 v3, + T4 v4, T5 v5, T6 v6, T7 v7) { + return internal::ValueArray7(v1, v2, v3, v4, v5, + v6, v7); +} + +template +internal::ValueArray8 Values(T1 v1, T2 v2, + T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8) { + return internal::ValueArray8(v1, v2, v3, v4, + v5, v6, v7, v8); +} + +template +internal::ValueArray9 Values(T1 v1, T2 v2, + T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9) { + return internal::ValueArray9(v1, v2, v3, + v4, v5, v6, v7, v8, v9); +} + +template +internal::ValueArray10 Values(T1 v1, + T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10) { + return internal::ValueArray10(v1, + v2, v3, v4, v5, v6, v7, v8, v9, v10); +} + +template +internal::ValueArray11 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11) { + return internal::ValueArray11(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11); +} + +template +internal::ValueArray12 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12) { + return internal::ValueArray12(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12); +} + +template +internal::ValueArray13 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13) { + return internal::ValueArray13(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13); +} + +template +internal::ValueArray14 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14) { + return internal::ValueArray14(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, + v14); +} + +template +internal::ValueArray15 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, + T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15) { + return internal::ValueArray15(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, + v13, v14, v15); +} + +template +internal::ValueArray16 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, + T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, + T16 v16) { + return internal::ValueArray16(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, + v12, v13, v14, v15, v16); +} + +template +internal::ValueArray17 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, + T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, + T16 v16, T17 v17) { + return internal::ValueArray17(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, + v11, v12, v13, v14, v15, v16, v17); +} + +template +internal::ValueArray18 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, + T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, + T16 v16, T17 v17, T18 v18) { + return internal::ValueArray18(v1, v2, v3, v4, v5, v6, v7, v8, v9, + v10, v11, v12, v13, v14, v15, v16, v17, v18); +} + +template +internal::ValueArray19 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, + T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, + T15 v15, T16 v16, T17 v17, T18 v18, T19 v19) { + return internal::ValueArray19(v1, v2, v3, v4, v5, v6, v7, v8, + v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19); +} + +template +internal::ValueArray20 Values(T1 v1, T2 v2, T3 v3, T4 v4, + T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, + T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20) { + return internal::ValueArray20(v1, v2, v3, v4, v5, v6, v7, + v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20); +} + +template +internal::ValueArray21 Values(T1 v1, T2 v2, T3 v3, T4 v4, + T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, + T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21) { + return internal::ValueArray21(v1, v2, v3, v4, v5, v6, + v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21); +} + +template +internal::ValueArray22 Values(T1 v1, T2 v2, T3 v3, + T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, + T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, + T21 v21, T22 v22) { + return internal::ValueArray22(v1, v2, v3, v4, + v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22); +} + +template +internal::ValueArray23 Values(T1 v1, T2 v2, + T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, + T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, + T21 v21, T22 v22, T23 v23) { + return internal::ValueArray23(v1, v2, v3, + v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23); +} + +template +internal::ValueArray24 Values(T1 v1, T2 v2, + T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, + T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, + T21 v21, T22 v22, T23 v23, T24 v24) { + return internal::ValueArray24(v1, v2, + v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, + v19, v20, v21, v22, v23, v24); +} + +template +internal::ValueArray25 Values(T1 v1, + T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, + T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, + T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25) { + return internal::ValueArray25(v1, + v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, + v18, v19, v20, v21, v22, v23, v24, v25); +} + +template +internal::ValueArray26 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26) { + return internal::ValueArray26(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, + v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26); +} + +template +internal::ValueArray27 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27) { + return internal::ValueArray27(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, + v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27); +} + +template +internal::ValueArray28 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28) { + return internal::ValueArray28(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, + v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, + v28); +} + +template +internal::ValueArray29 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29) { + return internal::ValueArray29(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, + v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, + v27, v28, v29); +} + +template +internal::ValueArray30 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, + T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, + T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, + T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, T30 v30) { + return internal::ValueArray30(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, + v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, + v26, v27, v28, v29, v30); +} + +template +internal::ValueArray31 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, + T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, + T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, + T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31) { + return internal::ValueArray31(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, + v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, + v25, v26, v27, v28, v29, v30, v31); +} + +template +internal::ValueArray32 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, + T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, + T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, + T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, + T32 v32) { + return internal::ValueArray32(v1, v2, v3, v4, v5, v6, v7, v8, v9, + v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, + v24, v25, v26, v27, v28, v29, v30, v31, v32); +} + +template +internal::ValueArray33 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, + T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, + T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, + T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, + T32 v32, T33 v33) { + return internal::ValueArray33(v1, v2, v3, v4, v5, v6, v7, v8, + v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, + v24, v25, v26, v27, v28, v29, v30, v31, v32, v33); +} + +template +internal::ValueArray34 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, + T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, + T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, + T23 v23, T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, + T31 v31, T32 v32, T33 v33, T34 v34) { + return internal::ValueArray34(v1, v2, v3, v4, v5, v6, v7, + v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, + v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34); +} + +template +internal::ValueArray35 Values(T1 v1, T2 v2, T3 v3, T4 v4, + T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, + T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, + T22 v22, T23 v23, T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, + T30 v30, T31 v31, T32 v32, T33 v33, T34 v34, T35 v35) { + return internal::ValueArray35(v1, v2, v3, v4, v5, v6, + v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, + v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35); +} + +template +internal::ValueArray36 Values(T1 v1, T2 v2, T3 v3, T4 v4, + T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, + T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, + T22 v22, T23 v23, T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, + T30 v30, T31 v31, T32 v32, T33 v33, T34 v34, T35 v35, T36 v36) { + return internal::ValueArray36(v1, v2, v3, v4, + v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, + v34, v35, v36); +} + +template +internal::ValueArray37 Values(T1 v1, T2 v2, T3 v3, + T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, + T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, + T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, + T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, T34 v34, T35 v35, T36 v36, + T37 v37) { + return internal::ValueArray37(v1, v2, v3, + v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, + v34, v35, v36, v37); +} + +template +internal::ValueArray38 Values(T1 v1, T2 v2, + T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, + T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, + T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, + T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, T34 v34, T35 v35, T36 v36, + T37 v37, T38 v38) { + return internal::ValueArray38(v1, v2, + v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, + v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, + v33, v34, v35, v36, v37, v38); +} + +template +internal::ValueArray39 Values(T1 v1, T2 v2, + T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, + T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, + T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, + T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, T34 v34, T35 v35, T36 v36, + T37 v37, T38 v38, T39 v39) { + return internal::ValueArray39(v1, + v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, + v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, + v32, v33, v34, v35, v36, v37, v38, v39); +} + +template +internal::ValueArray40 Values(T1 v1, + T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, + T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, + T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, T26 v26, T27 v27, + T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, T34 v34, T35 v35, + T36 v36, T37 v37, T38 v38, T39 v39, T40 v40) { + return internal::ValueArray40(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, + v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, + v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40); +} + +template +internal::ValueArray41 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41) { + return internal::ValueArray41(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, + v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, v28, + v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41); +} + +template +internal::ValueArray42 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42) { + return internal::ValueArray42(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, v13, + v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, v27, + v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, v41, + v42); +} + +template +internal::ValueArray43 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42, T43 v43) { + return internal::ValueArray43(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, v12, + v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, v26, + v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, v40, + v41, v42, v43); +} + +template +internal::ValueArray44 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, + T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, T17 v17, + T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, T25 v25, + T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, T33 v33, + T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, T41 v41, + T42 v42, T43 v43, T44 v44) { + return internal::ValueArray44(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, v11, + v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, v25, + v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, v39, + v40, v41, v42, v43, v44); +} + +template +internal::ValueArray45 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, T8 v8, + T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, T16 v16, + T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, T24 v24, + T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, T32 v32, + T33 v33, T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, T40 v40, + T41 v41, T42 v42, T43 v43, T44 v44, T45 v45) { + return internal::ValueArray45(v1, v2, v3, v4, v5, v6, v7, v8, v9, v10, + v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, v24, + v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, v38, + v39, v40, v41, v42, v43, v44, v45); +} + +template +internal::ValueArray46 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, + T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, + T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, + T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, + T32 v32, T33 v33, T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, + T40 v40, T41 v41, T42 v42, T43 v43, T44 v44, T45 v45, T46 v46) { + return internal::ValueArray46(v1, v2, v3, v4, v5, v6, v7, v8, v9, + v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, + v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, + v38, v39, v40, v41, v42, v43, v44, v45, v46); +} + +template +internal::ValueArray47 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, T7 v7, + T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, + T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, + T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, + T32 v32, T33 v33, T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, + T40 v40, T41 v41, T42 v42, T43 v43, T44 v44, T45 v45, T46 v46, T47 v47) { + return internal::ValueArray47(v1, v2, v3, v4, v5, v6, v7, v8, + v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, v23, + v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, v37, + v38, v39, v40, v41, v42, v43, v44, v45, v46, v47); +} + +template +internal::ValueArray48 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, T6 v6, + T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, T15 v15, + T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, T23 v23, + T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, T31 v31, + T32 v32, T33 v33, T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, T39 v39, + T40 v40, T41 v41, T42 v42, T43 v43, T44 v44, T45 v45, T46 v46, T47 v47, + T48 v48) { + return internal::ValueArray48(v1, v2, v3, v4, v5, v6, v7, + v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, v22, + v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, v36, + v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48); +} + +template +internal::ValueArray49 Values(T1 v1, T2 v2, T3 v3, T4 v4, T5 v5, + T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, T14 v14, + T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, T22 v22, + T23 v23, T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, T30 v30, + T31 v31, T32 v32, T33 v33, T34 v34, T35 v35, T36 v36, T37 v37, T38 v38, + T39 v39, T40 v40, T41 v41, T42 v42, T43 v43, T44 v44, T45 v45, T46 v46, + T47 v47, T48 v48, T49 v49) { + return internal::ValueArray49(v1, v2, v3, v4, v5, v6, + v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, v20, v21, + v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, v34, v35, + v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, v48, v49); +} + +template +internal::ValueArray50 Values(T1 v1, T2 v2, T3 v3, T4 v4, + T5 v5, T6 v6, T7 v7, T8 v8, T9 v9, T10 v10, T11 v11, T12 v12, T13 v13, + T14 v14, T15 v15, T16 v16, T17 v17, T18 v18, T19 v19, T20 v20, T21 v21, + T22 v22, T23 v23, T24 v24, T25 v25, T26 v26, T27 v27, T28 v28, T29 v29, + T30 v30, T31 v31, T32 v32, T33 v33, T34 v34, T35 v35, T36 v36, T37 v37, + T38 v38, T39 v39, T40 v40, T41 v41, T42 v42, T43 v43, T44 v44, T45 v45, + T46 v46, T47 v47, T48 v48, T49 v49, T50 v50) { + return internal::ValueArray50(v1, v2, v3, v4, + v5, v6, v7, v8, v9, v10, v11, v12, v13, v14, v15, v16, v17, v18, v19, + v20, v21, v22, v23, v24, v25, v26, v27, v28, v29, v30, v31, v32, v33, + v34, v35, v36, v37, v38, v39, v40, v41, v42, v43, v44, v45, v46, v47, + v48, v49, v50); +} + +// Bool() allows generating tests with parameters in a set of (false, true). +// +// Synopsis: +// Bool() +// - returns a generator producing sequences with elements {false, true}. +// +// It is useful when testing code that depends on Boolean flags. Combinations +// of multiple flags can be tested when several Bool()'s are combined using +// Combine() function. +// +// In the following example all tests in the test case FlagDependentTest +// will be instantiated twice with parameters false and true. +// +// class FlagDependentTest : public testing::TestWithParam { +// virtual void SetUp() { +// external_flag = GetParam(); +// } +// } +// INSTANTIATE_TEST_CASE_P(BoolSequence, FlagDependentTest, Bool()); +// +inline internal::ParamGenerator Bool() { + return Values(false, true); +} + +# if GTEST_HAS_COMBINE +// Combine() allows the user to combine two or more sequences to produce +// values of a Cartesian product of those sequences' elements. +// +// Synopsis: +// Combine(gen1, gen2, ..., genN) +// - returns a generator producing sequences with elements coming from +// the Cartesian product of elements from the sequences generated by +// gen1, gen2, ..., genN. The sequence elements will have a type of +// tuple where T1, T2, ..., TN are the types +// of elements from sequences produces by gen1, gen2, ..., genN. +// +// Combine can have up to 10 arguments. This number is currently limited +// by the maximum number of elements in the tuple implementation used by Google +// Test. +// +// Example: +// +// This will instantiate tests in test case AnimalTest each one with +// the parameter values tuple("cat", BLACK), tuple("cat", WHITE), +// tuple("dog", BLACK), and tuple("dog", WHITE): +// +// enum Color { BLACK, GRAY, WHITE }; +// class AnimalTest +// : public testing::TestWithParam > {...}; +// +// TEST_P(AnimalTest, AnimalLooksNice) {...} +// +// INSTANTIATE_TEST_CASE_P(AnimalVariations, AnimalTest, +// Combine(Values("cat", "dog"), +// Values(BLACK, WHITE))); +// +// This will instantiate tests in FlagDependentTest with all variations of two +// Boolean flags: +// +// class FlagDependentTest +// : public testing::TestWithParam > { +// virtual void SetUp() { +// // Assigns external_flag_1 and external_flag_2 values from the tuple. +// tie(external_flag_1, external_flag_2) = GetParam(); +// } +// }; +// +// TEST_P(FlagDependentTest, TestFeature1) { +// // Test your code using external_flag_1 and external_flag_2 here. +// } +// INSTANTIATE_TEST_CASE_P(TwoBoolSequence, FlagDependentTest, +// Combine(Bool(), Bool())); +// +template +internal::CartesianProductHolder2 Combine( + const Generator1& g1, const Generator2& g2) { + return internal::CartesianProductHolder2( + g1, g2); +} + +template +internal::CartesianProductHolder3 Combine( + const Generator1& g1, const Generator2& g2, const Generator3& g3) { + return internal::CartesianProductHolder3( + g1, g2, g3); +} + +template +internal::CartesianProductHolder4 Combine( + const Generator1& g1, const Generator2& g2, const Generator3& g3, + const Generator4& g4) { + return internal::CartesianProductHolder4( + g1, g2, g3, g4); +} + +template +internal::CartesianProductHolder5 Combine( + const Generator1& g1, const Generator2& g2, const Generator3& g3, + const Generator4& g4, const Generator5& g5) { + return internal::CartesianProductHolder5( + g1, g2, g3, g4, g5); +} + +template +internal::CartesianProductHolder6 Combine( + const Generator1& g1, const Generator2& g2, const Generator3& g3, + const Generator4& g4, const Generator5& g5, const Generator6& g6) { + return internal::CartesianProductHolder6( + g1, g2, g3, g4, g5, g6); +} + +template +internal::CartesianProductHolder7 Combine( + const Generator1& g1, const Generator2& g2, const Generator3& g3, + const Generator4& g4, const Generator5& g5, const Generator6& g6, + const Generator7& g7) { + return internal::CartesianProductHolder7( + g1, g2, g3, g4, g5, g6, g7); +} + +template +internal::CartesianProductHolder8 Combine( + const Generator1& g1, const Generator2& g2, const Generator3& g3, + const Generator4& g4, const Generator5& g5, const Generator6& g6, + const Generator7& g7, const Generator8& g8) { + return internal::CartesianProductHolder8( + g1, g2, g3, g4, g5, g6, g7, g8); +} + +template +internal::CartesianProductHolder9 Combine( + const Generator1& g1, const Generator2& g2, const Generator3& g3, + const Generator4& g4, const Generator5& g5, const Generator6& g6, + const Generator7& g7, const Generator8& g8, const Generator9& g9) { + return internal::CartesianProductHolder9( + g1, g2, g3, g4, g5, g6, g7, g8, g9); +} + +template +internal::CartesianProductHolder10 Combine( + const Generator1& g1, const Generator2& g2, const Generator3& g3, + const Generator4& g4, const Generator5& g5, const Generator6& g6, + const Generator7& g7, const Generator8& g8, const Generator9& g9, + const Generator10& g10) { + return internal::CartesianProductHolder10( + g1, g2, g3, g4, g5, g6, g7, g8, g9, g10); +} +# endif // GTEST_HAS_COMBINE + + + +# define TEST_P(test_case_name, test_name) \ + class GTEST_TEST_CLASS_NAME_(test_case_name, test_name) \ + : public test_case_name { \ + public: \ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name)() {} \ + virtual void TestBody(); \ + private: \ + static int AddToRegistry() { \ + ::testing::UnitTest::GetInstance()->parameterized_test_registry(). \ + GetTestCasePatternHolder(\ + #test_case_name, __FILE__, __LINE__)->AddTestPattern(\ + #test_case_name, \ + #test_name, \ + new ::testing::internal::TestMetaFactory< \ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name)>()); \ + return 0; \ + } \ + static int gtest_registering_dummy_; \ + GTEST_DISALLOW_COPY_AND_ASSIGN_(\ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name)); \ + }; \ + int GTEST_TEST_CLASS_NAME_(test_case_name, \ + test_name)::gtest_registering_dummy_ = \ + GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::AddToRegistry(); \ + void GTEST_TEST_CLASS_NAME_(test_case_name, test_name)::TestBody() + +# define INSTANTIATE_TEST_CASE_P(prefix, test_case_name, generator) \ + ::testing::internal::ParamGenerator \ + gtest_##prefix##test_case_name##_EvalGenerator_() { return generator; } \ + int gtest_##prefix##test_case_name##_dummy_ = \ + ::testing::UnitTest::GetInstance()->parameterized_test_registry(). \ + GetTestCasePatternHolder(\ + #test_case_name, __FILE__, __LINE__)->AddTestCaseInstantiation(\ + #prefix, \ + >est_##prefix##test_case_name##_EvalGenerator_, \ + __FILE__, __LINE__) + +} // namespace testing + +#endif // GTEST_HAS_PARAM_TEST + +#endif // GTEST_INCLUDE_GTEST_GTEST_PARAM_TEST_H_ +// Copyright 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) +// +// Google C++ Testing Framework definitions useful in production code. + +#ifndef GTEST_INCLUDE_GTEST_GTEST_PROD_H_ +#define GTEST_INCLUDE_GTEST_GTEST_PROD_H_ + +// When you need to test the private or protected members of a class, +// use the FRIEND_TEST macro to declare your tests as friends of the +// class. For example: +// +// class MyClass { +// private: +// void MyMethod(); +// FRIEND_TEST(MyClassTest, MyMethod); +// }; +// +// class MyClassTest : public testing::Test { +// // ... +// }; +// +// TEST_F(MyClassTest, MyMethod) { +// // Can call MyClass::MyMethod() here. +// } + +#define FRIEND_TEST(test_case_name, test_name)\ +friend class test_case_name##_##test_name##_Test + +#endif // GTEST_INCLUDE_GTEST_GTEST_PROD_H_ +// Copyright 2008, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: mheule@google.com (Markus Heule) +// + +#ifndef GTEST_INCLUDE_GTEST_GTEST_TEST_PART_H_ +#define GTEST_INCLUDE_GTEST_GTEST_TEST_PART_H_ + +#include +#include + +namespace testing { + +// A copyable object representing the result of a test part (i.e. an +// assertion or an explicit FAIL(), ADD_FAILURE(), or SUCCESS()). +// +// Don't inherit from TestPartResult as its destructor is not virtual. +class GTEST_API_ TestPartResult { + public: + // The possible outcomes of a test part (i.e. an assertion or an + // explicit SUCCEED(), FAIL(), or ADD_FAILURE()). + enum Type { + kSuccess, // Succeeded. + kNonFatalFailure, // Failed but the test can continue. + kFatalFailure // Failed and the test should be terminated. + }; + + // C'tor. TestPartResult does NOT have a default constructor. + // Always use this constructor (with parameters) to create a + // TestPartResult object. + TestPartResult(Type a_type, + const char* a_file_name, + int a_line_number, + const char* a_message) + : type_(a_type), + file_name_(a_file_name), + line_number_(a_line_number), + summary_(ExtractSummary(a_message)), + message_(a_message) { + } + + // Gets the outcome of the test part. + Type type() const { return type_; } + + // Gets the name of the source file where the test part took place, or + // NULL if it's unknown. + const char* file_name() const { return file_name_.c_str(); } + + // Gets the line in the source file where the test part took place, + // or -1 if it's unknown. + int line_number() const { return line_number_; } + + // Gets the summary of the failure message. + const char* summary() const { return summary_.c_str(); } + + // Gets the message associated with the test part. + const char* message() const { return message_.c_str(); } + + // Returns true iff the test part passed. + bool passed() const { return type_ == kSuccess; } + + // Returns true iff the test part failed. + bool failed() const { return type_ != kSuccess; } + + // Returns true iff the test part non-fatally failed. + bool nonfatally_failed() const { return type_ == kNonFatalFailure; } + + // Returns true iff the test part fatally failed. + bool fatally_failed() const { return type_ == kFatalFailure; } + private: + Type type_; + + // Gets the summary of the failure message by omitting the stack + // trace in it. + static internal::String ExtractSummary(const char* message); + + // The name of the source file where the test part took place, or + // NULL if the source file is unknown. + internal::String file_name_; + // The line in the source file where the test part took place, or -1 + // if the line number is unknown. + int line_number_; + internal::String summary_; // The test failure summary. + internal::String message_; // The test failure message. +}; + +// Prints a TestPartResult object. +std::ostream& operator<<(std::ostream& os, const TestPartResult& result); + +// An array of TestPartResult objects. +// +// Don't inherit from TestPartResultArray as its destructor is not +// virtual. +class GTEST_API_ TestPartResultArray { + public: + TestPartResultArray() {} + + // Appends the given TestPartResult to the array. + void Append(const TestPartResult& result); + + // Returns the TestPartResult at the given index (0-based). + const TestPartResult& GetTestPartResult(int index) const; + + // Returns the number of TestPartResult objects in the array. + int size() const; + + private: + std::vector array_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestPartResultArray); +}; + +// This interface knows how to report a test part result. +class TestPartResultReporterInterface { + public: + virtual ~TestPartResultReporterInterface() {} + + virtual void ReportTestPartResult(const TestPartResult& result) = 0; +}; + +namespace internal { + +// This helper class is used by {ASSERT|EXPECT}_NO_FATAL_FAILURE to check if a +// statement generates new fatal failures. To do so it registers itself as the +// current test part result reporter. Besides checking if fatal failures were +// reported, it only delegates the reporting to the former result reporter. +// The original result reporter is restored in the destructor. +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +class GTEST_API_ HasNewFatalFailureHelper + : public TestPartResultReporterInterface { + public: + HasNewFatalFailureHelper(); + virtual ~HasNewFatalFailureHelper(); + virtual void ReportTestPartResult(const TestPartResult& result); + bool has_new_fatal_failure() const { return has_new_fatal_failure_; } + private: + bool has_new_fatal_failure_; + TestPartResultReporterInterface* original_reporter_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(HasNewFatalFailureHelper); +}; + +} // namespace internal + +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_GTEST_TEST_PART_H_ +// Copyright 2008 Google Inc. +// All Rights Reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// +// Author: wan@google.com (Zhanyong Wan) + +#ifndef GTEST_INCLUDE_GTEST_GTEST_TYPED_TEST_H_ +#define GTEST_INCLUDE_GTEST_GTEST_TYPED_TEST_H_ + +// This header implements typed tests and type-parameterized tests. + +// Typed (aka type-driven) tests repeat the same test for types in a +// list. You must know which types you want to test with when writing +// typed tests. Here's how you do it: + +#if 0 + +// First, define a fixture class template. It should be parameterized +// by a type. Remember to derive it from testing::Test. +template +class FooTest : public testing::Test { + public: + ... + typedef std::list List; + static T shared_; + T value_; +}; + +// Next, associate a list of types with the test case, which will be +// repeated for each type in the list. The typedef is necessary for +// the macro to parse correctly. +typedef testing::Types MyTypes; +TYPED_TEST_CASE(FooTest, MyTypes); + +// If the type list contains only one type, you can write that type +// directly without Types<...>: +// TYPED_TEST_CASE(FooTest, int); + +// Then, use TYPED_TEST() instead of TEST_F() to define as many typed +// tests for this test case as you want. +TYPED_TEST(FooTest, DoesBlah) { + // Inside a test, refer to TypeParam to get the type parameter. + // Since we are inside a derived class template, C++ requires use to + // visit the members of FooTest via 'this'. + TypeParam n = this->value_; + + // To visit static members of the fixture, add the TestFixture:: + // prefix. + n += TestFixture::shared_; + + // To refer to typedefs in the fixture, add the "typename + // TestFixture::" prefix. + typename TestFixture::List values; + values.push_back(n); + ... +} + +TYPED_TEST(FooTest, HasPropertyA) { ... } + +#endif // 0 + +// Type-parameterized tests are abstract test patterns parameterized +// by a type. Compared with typed tests, type-parameterized tests +// allow you to define the test pattern without knowing what the type +// parameters are. The defined pattern can be instantiated with +// different types any number of times, in any number of translation +// units. +// +// If you are designing an interface or concept, you can define a +// suite of type-parameterized tests to verify properties that any +// valid implementation of the interface/concept should have. Then, +// each implementation can easily instantiate the test suite to verify +// that it conforms to the requirements, without having to write +// similar tests repeatedly. Here's an example: + +#if 0 + +// First, define a fixture class template. It should be parameterized +// by a type. Remember to derive it from testing::Test. +template +class FooTest : public testing::Test { + ... +}; + +// Next, declare that you will define a type-parameterized test case +// (the _P suffix is for "parameterized" or "pattern", whichever you +// prefer): +TYPED_TEST_CASE_P(FooTest); + +// Then, use TYPED_TEST_P() to define as many type-parameterized tests +// for this type-parameterized test case as you want. +TYPED_TEST_P(FooTest, DoesBlah) { + // Inside a test, refer to TypeParam to get the type parameter. + TypeParam n = 0; + ... +} + +TYPED_TEST_P(FooTest, HasPropertyA) { ... } + +// Now the tricky part: you need to register all test patterns before +// you can instantiate them. The first argument of the macro is the +// test case name; the rest are the names of the tests in this test +// case. +REGISTER_TYPED_TEST_CASE_P(FooTest, + DoesBlah, HasPropertyA); + +// Finally, you are free to instantiate the pattern with the types you +// want. If you put the above code in a header file, you can #include +// it in multiple C++ source files and instantiate it multiple times. +// +// To distinguish different instances of the pattern, the first +// argument to the INSTANTIATE_* macro is a prefix that will be added +// to the actual test case name. Remember to pick unique prefixes for +// different instances. +typedef testing::Types MyTypes; +INSTANTIATE_TYPED_TEST_CASE_P(My, FooTest, MyTypes); + +// If the type list contains only one type, you can write that type +// directly without Types<...>: +// INSTANTIATE_TYPED_TEST_CASE_P(My, FooTest, int); + +#endif // 0 + + +// Implements typed tests. + +#if GTEST_HAS_TYPED_TEST + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Expands to the name of the typedef for the type parameters of the +// given test case. +# define GTEST_TYPE_PARAMS_(TestCaseName) gtest_type_params_##TestCaseName##_ + +// The 'Types' template argument below must have spaces around it +// since some compilers may choke on '>>' when passing a template +// instance (e.g. Types) +# define TYPED_TEST_CASE(CaseName, Types) \ + typedef ::testing::internal::TypeList< Types >::type \ + GTEST_TYPE_PARAMS_(CaseName) + +# define TYPED_TEST(CaseName, TestName) \ + template \ + class GTEST_TEST_CLASS_NAME_(CaseName, TestName) \ + : public CaseName { \ + private: \ + typedef CaseName TestFixture; \ + typedef gtest_TypeParam_ TypeParam; \ + virtual void TestBody(); \ + }; \ + bool gtest_##CaseName##_##TestName##_registered_ GTEST_ATTRIBUTE_UNUSED_ = \ + ::testing::internal::TypeParameterizedTest< \ + CaseName, \ + ::testing::internal::TemplateSel< \ + GTEST_TEST_CLASS_NAME_(CaseName, TestName)>, \ + GTEST_TYPE_PARAMS_(CaseName)>::Register(\ + "", #CaseName, #TestName, 0); \ + template \ + void GTEST_TEST_CLASS_NAME_(CaseName, TestName)::TestBody() + +#endif // GTEST_HAS_TYPED_TEST + +// Implements type-parameterized tests. + +#if GTEST_HAS_TYPED_TEST_P + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Expands to the namespace name that the type-parameterized tests for +// the given type-parameterized test case are defined in. The exact +// name of the namespace is subject to change without notice. +# define GTEST_CASE_NAMESPACE_(TestCaseName) \ + gtest_case_##TestCaseName##_ + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// +// Expands to the name of the variable used to remember the names of +// the defined tests in the given test case. +# define GTEST_TYPED_TEST_CASE_P_STATE_(TestCaseName) \ + gtest_typed_test_case_p_state_##TestCaseName##_ + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE DIRECTLY. +// +// Expands to the name of the variable used to remember the names of +// the registered tests in the given test case. +# define GTEST_REGISTERED_TEST_NAMES_(TestCaseName) \ + gtest_registered_test_names_##TestCaseName##_ + +// The variables defined in the type-parameterized test macros are +// static as typically these macros are used in a .h file that can be +// #included in multiple translation units linked together. +# define TYPED_TEST_CASE_P(CaseName) \ + static ::testing::internal::TypedTestCasePState \ + GTEST_TYPED_TEST_CASE_P_STATE_(CaseName) + +# define TYPED_TEST_P(CaseName, TestName) \ + namespace GTEST_CASE_NAMESPACE_(CaseName) { \ + template \ + class TestName : public CaseName { \ + private: \ + typedef CaseName TestFixture; \ + typedef gtest_TypeParam_ TypeParam; \ + virtual void TestBody(); \ + }; \ + static bool gtest_##TestName##_defined_ GTEST_ATTRIBUTE_UNUSED_ = \ + GTEST_TYPED_TEST_CASE_P_STATE_(CaseName).AddTestName(\ + __FILE__, __LINE__, #CaseName, #TestName); \ + } \ + template \ + void GTEST_CASE_NAMESPACE_(CaseName)::TestName::TestBody() + +# define REGISTER_TYPED_TEST_CASE_P(CaseName, ...) \ + namespace GTEST_CASE_NAMESPACE_(CaseName) { \ + typedef ::testing::internal::Templates<__VA_ARGS__>::type gtest_AllTests_; \ + } \ + static const char* const GTEST_REGISTERED_TEST_NAMES_(CaseName) = \ + GTEST_TYPED_TEST_CASE_P_STATE_(CaseName).VerifyRegisteredTestNames(\ + __FILE__, __LINE__, #__VA_ARGS__) + +// The 'Types' template argument below must have spaces around it +// since some compilers may choke on '>>' when passing a template +// instance (e.g. Types) +# define INSTANTIATE_TYPED_TEST_CASE_P(Prefix, CaseName, Types) \ + bool gtest_##Prefix##_##CaseName GTEST_ATTRIBUTE_UNUSED_ = \ + ::testing::internal::TypeParameterizedTestCase::type>::Register(\ + #Prefix, #CaseName, GTEST_REGISTERED_TEST_NAMES_(CaseName)) + +#endif // GTEST_HAS_TYPED_TEST_P + +#endif // GTEST_INCLUDE_GTEST_GTEST_TYPED_TEST_H_ + +// Depending on the platform, different string classes are available. +// On Linux, in addition to ::std::string, Google also makes use of +// class ::string, which has the same interface as ::std::string, but +// has a different implementation. +// +// The user can define GTEST_HAS_GLOBAL_STRING to 1 to indicate that +// ::string is available AND is a distinct type to ::std::string, or +// define it to 0 to indicate otherwise. +// +// If the user's ::std::string and ::string are the same class due to +// aliasing, he should define GTEST_HAS_GLOBAL_STRING to 0. +// +// If the user doesn't define GTEST_HAS_GLOBAL_STRING, it is defined +// heuristically. + +namespace testing { + +// Declares the flags. + +// This flag temporary enables the disabled tests. +GTEST_DECLARE_bool_(also_run_disabled_tests); + +// This flag brings the debugger on an assertion failure. +GTEST_DECLARE_bool_(break_on_failure); + +// This flag controls whether Google Test catches all test-thrown exceptions +// and logs them as failures. +GTEST_DECLARE_bool_(catch_exceptions); + +// This flag enables using colors in terminal output. Available values are +// "yes" to enable colors, "no" (disable colors), or "auto" (the default) +// to let Google Test decide. +GTEST_DECLARE_string_(color); + +// This flag sets up the filter to select by name using a glob pattern +// the tests to run. If the filter is not given all tests are executed. +GTEST_DECLARE_string_(filter); + +// This flag causes the Google Test to list tests. None of the tests listed +// are actually run if the flag is provided. +GTEST_DECLARE_bool_(list_tests); + +// This flag controls whether Google Test emits a detailed XML report to a file +// in addition to its normal textual output. +GTEST_DECLARE_string_(output); + +// This flags control whether Google Test prints the elapsed time for each +// test. +GTEST_DECLARE_bool_(print_time); + +// This flag specifies the random number seed. +GTEST_DECLARE_int32_(random_seed); + +// This flag sets how many times the tests are repeated. The default value +// is 1. If the value is -1 the tests are repeating forever. +GTEST_DECLARE_int32_(repeat); + +// This flag controls whether Google Test includes Google Test internal +// stack frames in failure stack traces. +GTEST_DECLARE_bool_(show_internal_stack_frames); + +// When this flag is specified, tests' order is randomized on every iteration. +GTEST_DECLARE_bool_(shuffle); + +// This flag specifies the maximum number of stack frames to be +// printed in a failure message. +GTEST_DECLARE_int32_(stack_trace_depth); + +// When this flag is specified, a failed assertion will throw an +// exception if exceptions are enabled, or exit the program with a +// non-zero code otherwise. +GTEST_DECLARE_bool_(throw_on_failure); + +// When this flag is set with a "host:port" string, on supported +// platforms test results are streamed to the specified port on +// the specified host machine. +GTEST_DECLARE_string_(stream_result_to); + +// The upper limit for valid stack trace depths. +const int kMaxStackTraceDepth = 100; + +namespace internal { + +class AssertHelper; +class DefaultGlobalTestPartResultReporter; +class ExecDeathTest; +class NoExecDeathTest; +class FinalSuccessChecker; +class GTestFlagSaver; +class TestResultAccessor; +class TestEventListenersAccessor; +class TestEventRepeater; +class WindowsDeathTest; +class UnitTestImpl* GetUnitTestImpl(); +void ReportFailureInUnknownLocation(TestPartResult::Type result_type, + const String& message); + +// Converts a streamable value to a String. A NULL pointer is +// converted to "(null)". When the input value is a ::string, +// ::std::string, ::wstring, or ::std::wstring object, each NUL +// character in it is replaced with "\\0". +// Declared in gtest-internal.h but defined here, so that it has access +// to the definition of the Message class, required by the ARM +// compiler. +template +String StreamableToString(const T& streamable) { + return (Message() << streamable).GetString(); +} + +} // namespace internal + +// The friend relationship of some of these classes is cyclic. +// If we don't forward declare them the compiler might confuse the classes +// in friendship clauses with same named classes on the scope. +class Test; +class TestCase; +class TestInfo; +class UnitTest; + +// A class for indicating whether an assertion was successful. When +// the assertion wasn't successful, the AssertionResult object +// remembers a non-empty message that describes how it failed. +// +// To create an instance of this class, use one of the factory functions +// (AssertionSuccess() and AssertionFailure()). +// +// This class is useful for two purposes: +// 1. Defining predicate functions to be used with Boolean test assertions +// EXPECT_TRUE/EXPECT_FALSE and their ASSERT_ counterparts +// 2. Defining predicate-format functions to be +// used with predicate assertions (ASSERT_PRED_FORMAT*, etc). +// +// For example, if you define IsEven predicate: +// +// testing::AssertionResult IsEven(int n) { +// if ((n % 2) == 0) +// return testing::AssertionSuccess(); +// else +// return testing::AssertionFailure() << n << " is odd"; +// } +// +// Then the failed expectation EXPECT_TRUE(IsEven(Fib(5))) +// will print the message +// +// Value of: IsEven(Fib(5)) +// Actual: false (5 is odd) +// Expected: true +// +// instead of a more opaque +// +// Value of: IsEven(Fib(5)) +// Actual: false +// Expected: true +// +// in case IsEven is a simple Boolean predicate. +// +// If you expect your predicate to be reused and want to support informative +// messages in EXPECT_FALSE and ASSERT_FALSE (negative assertions show up +// about half as often as positive ones in our tests), supply messages for +// both success and failure cases: +// +// testing::AssertionResult IsEven(int n) { +// if ((n % 2) == 0) +// return testing::AssertionSuccess() << n << " is even"; +// else +// return testing::AssertionFailure() << n << " is odd"; +// } +// +// Then a statement EXPECT_FALSE(IsEven(Fib(6))) will print +// +// Value of: IsEven(Fib(6)) +// Actual: true (8 is even) +// Expected: false +// +// NB: Predicates that support negative Boolean assertions have reduced +// performance in positive ones so be careful not to use them in tests +// that have lots (tens of thousands) of positive Boolean assertions. +// +// To use this class with EXPECT_PRED_FORMAT assertions such as: +// +// // Verifies that Foo() returns an even number. +// EXPECT_PRED_FORMAT1(IsEven, Foo()); +// +// you need to define: +// +// testing::AssertionResult IsEven(const char* expr, int n) { +// if ((n % 2) == 0) +// return testing::AssertionSuccess(); +// else +// return testing::AssertionFailure() +// << "Expected: " << expr << " is even\n Actual: it's " << n; +// } +// +// If Foo() returns 5, you will see the following message: +// +// Expected: Foo() is even +// Actual: it's 5 +// +class GTEST_API_ AssertionResult { + public: + // Copy constructor. + // Used in EXPECT_TRUE/FALSE(assertion_result). + AssertionResult(const AssertionResult& other); + // Used in the EXPECT_TRUE/FALSE(bool_expression). + explicit AssertionResult(bool success) : success_(success) {} + + // Returns true iff the assertion succeeded. + operator bool() const { return success_; } // NOLINT + + // Returns the assertion's negation. Used with EXPECT/ASSERT_FALSE. + AssertionResult operator!() const; + + // Returns the text streamed into this AssertionResult. Test assertions + // use it when they fail (i.e., the predicate's outcome doesn't match the + // assertion's expectation). When nothing has been streamed into the + // object, returns an empty string. + const char* message() const { + return message_.get() != NULL ? message_->c_str() : ""; + } + // TODO(vladl@google.com): Remove this after making sure no clients use it. + // Deprecated; please use message() instead. + const char* failure_message() const { return message(); } + + // Streams a custom failure message into this object. + template AssertionResult& operator<<(const T& value) { + AppendMessage(Message() << value); + return *this; + } + + // Allows streaming basic output manipulators such as endl or flush into + // this object. + AssertionResult& operator<<( + ::std::ostream& (*basic_manipulator)(::std::ostream& stream)) { + AppendMessage(Message() << basic_manipulator); + return *this; + } + + private: + // Appends the contents of message to message_. + void AppendMessage(const Message& a_message) { + if (message_.get() == NULL) + message_.reset(new ::std::string); + message_->append(a_message.GetString().c_str()); + } + + // Stores result of the assertion predicate. + bool success_; + // Stores the message describing the condition in case the expectation + // construct is not satisfied with the predicate's outcome. + // Referenced via a pointer to avoid taking too much stack frame space + // with test assertions. + internal::scoped_ptr< ::std::string> message_; + + GTEST_DISALLOW_ASSIGN_(AssertionResult); +}; + +// Makes a successful assertion result. +GTEST_API_ AssertionResult AssertionSuccess(); + +// Makes a failed assertion result. +GTEST_API_ AssertionResult AssertionFailure(); + +// Makes a failed assertion result with the given failure message. +// Deprecated; use AssertionFailure() << msg. +GTEST_API_ AssertionResult AssertionFailure(const Message& msg); + +// The abstract class that all tests inherit from. +// +// In Google Test, a unit test program contains one or many TestCases, and +// each TestCase contains one or many Tests. +// +// When you define a test using the TEST macro, you don't need to +// explicitly derive from Test - the TEST macro automatically does +// this for you. +// +// The only time you derive from Test is when defining a test fixture +// to be used a TEST_F. For example: +// +// class FooTest : public testing::Test { +// protected: +// virtual void SetUp() { ... } +// virtual void TearDown() { ... } +// ... +// }; +// +// TEST_F(FooTest, Bar) { ... } +// TEST_F(FooTest, Baz) { ... } +// +// Test is not copyable. +class GTEST_API_ Test { + public: + friend class TestInfo; + + // Defines types for pointers to functions that set up and tear down + // a test case. + typedef internal::SetUpTestCaseFunc SetUpTestCaseFunc; + typedef internal::TearDownTestCaseFunc TearDownTestCaseFunc; + + // The d'tor is virtual as we intend to inherit from Test. + virtual ~Test(); + + // Sets up the stuff shared by all tests in this test case. + // + // Google Test will call Foo::SetUpTestCase() before running the first + // test in test case Foo. Hence a sub-class can define its own + // SetUpTestCase() method to shadow the one defined in the super + // class. + static void SetUpTestCase() {} + + // Tears down the stuff shared by all tests in this test case. + // + // Google Test will call Foo::TearDownTestCase() after running the last + // test in test case Foo. Hence a sub-class can define its own + // TearDownTestCase() method to shadow the one defined in the super + // class. + static void TearDownTestCase() {} + + // Returns true iff the current test has a fatal failure. + static bool HasFatalFailure(); + + // Returns true iff the current test has a non-fatal failure. + static bool HasNonfatalFailure(); + + // Returns true iff the current test has a (either fatal or + // non-fatal) failure. + static bool HasFailure() { return HasFatalFailure() || HasNonfatalFailure(); } + + // Logs a property for the current test. Only the last value for a given + // key is remembered. + // These are public static so they can be called from utility functions + // that are not members of the test fixture. + // The arguments are const char* instead strings, as Google Test is used + // on platforms where string doesn't compile. + // + // Note that a driving consideration for these RecordProperty methods + // was to produce xml output suited to the Greenspan charting utility, + // which at present will only chart values that fit in a 32-bit int. It + // is the user's responsibility to restrict their values to 32-bit ints + // if they intend them to be used with Greenspan. + static void RecordProperty(const char* key, const char* value); + static void RecordProperty(const char* key, int value); + + protected: + // Creates a Test object. + Test(); + + // Sets up the test fixture. + virtual void SetUp(); + + // Tears down the test fixture. + virtual void TearDown(); + + private: + // Returns true iff the current test has the same fixture class as + // the first test in the current test case. + static bool HasSameFixtureClass(); + + // Runs the test after the test fixture has been set up. + // + // A sub-class must implement this to define the test logic. + // + // DO NOT OVERRIDE THIS FUNCTION DIRECTLY IN A USER PROGRAM. + // Instead, use the TEST or TEST_F macro. + virtual void TestBody() = 0; + + // Sets up, executes, and tears down the test. + void Run(); + + // Deletes self. We deliberately pick an unusual name for this + // internal method to avoid clashing with names used in user TESTs. + void DeleteSelf_() { delete this; } + + // Uses a GTestFlagSaver to save and restore all Google Test flags. + const internal::GTestFlagSaver* const gtest_flag_saver_; + + // Often a user mis-spells SetUp() as Setup() and spends a long time + // wondering why it is never called by Google Test. The declaration of + // the following method is solely for catching such an error at + // compile time: + // + // - The return type is deliberately chosen to be not void, so it + // will be a conflict if a user declares void Setup() in his test + // fixture. + // + // - This method is private, so it will be another compiler error + // if a user calls it from his test fixture. + // + // DO NOT OVERRIDE THIS FUNCTION. + // + // If you see an error about overriding the following function or + // about it being private, you have mis-spelled SetUp() as Setup(). + struct Setup_should_be_spelled_SetUp {}; + virtual Setup_should_be_spelled_SetUp* Setup() { return NULL; } + + // We disallow copying Tests. + GTEST_DISALLOW_COPY_AND_ASSIGN_(Test); +}; + +typedef internal::TimeInMillis TimeInMillis; + +// A copyable object representing a user specified test property which can be +// output as a key/value string pair. +// +// Don't inherit from TestProperty as its destructor is not virtual. +class TestProperty { + public: + // C'tor. TestProperty does NOT have a default constructor. + // Always use this constructor (with parameters) to create a + // TestProperty object. + TestProperty(const char* a_key, const char* a_value) : + key_(a_key), value_(a_value) { + } + + // Gets the user supplied key. + const char* key() const { + return key_.c_str(); + } + + // Gets the user supplied value. + const char* value() const { + return value_.c_str(); + } + + // Sets a new value, overriding the one supplied in the constructor. + void SetValue(const char* new_value) { + value_ = new_value; + } + + private: + // The key supplied by the user. + internal::String key_; + // The value supplied by the user. + internal::String value_; +}; + +// The result of a single Test. This includes a list of +// TestPartResults, a list of TestProperties, a count of how many +// death tests there are in the Test, and how much time it took to run +// the Test. +// +// TestResult is not copyable. +class GTEST_API_ TestResult { + public: + // Creates an empty TestResult. + TestResult(); + + // D'tor. Do not inherit from TestResult. + ~TestResult(); + + // Gets the number of all test parts. This is the sum of the number + // of successful test parts and the number of failed test parts. + int total_part_count() const; + + // Returns the number of the test properties. + int test_property_count() const; + + // Returns true iff the test passed (i.e. no test part failed). + bool Passed() const { return !Failed(); } + + // Returns true iff the test failed. + bool Failed() const; + + // Returns true iff the test fatally failed. + bool HasFatalFailure() const; + + // Returns true iff the test has a non-fatal failure. + bool HasNonfatalFailure() const; + + // Returns the elapsed time, in milliseconds. + TimeInMillis elapsed_time() const { return elapsed_time_; } + + // Returns the i-th test part result among all the results. i can range + // from 0 to test_property_count() - 1. If i is not in that range, aborts + // the program. + const TestPartResult& GetTestPartResult(int i) const; + + // Returns the i-th test property. i can range from 0 to + // test_property_count() - 1. If i is not in that range, aborts the + // program. + const TestProperty& GetTestProperty(int i) const; + + private: + friend class TestInfo; + friend class UnitTest; + friend class internal::DefaultGlobalTestPartResultReporter; + friend class internal::ExecDeathTest; + friend class internal::TestResultAccessor; + friend class internal::UnitTestImpl; + friend class internal::WindowsDeathTest; + + // Gets the vector of TestPartResults. + const std::vector& test_part_results() const { + return test_part_results_; + } + + // Gets the vector of TestProperties. + const std::vector& test_properties() const { + return test_properties_; + } + + // Sets the elapsed time. + void set_elapsed_time(TimeInMillis elapsed) { elapsed_time_ = elapsed; } + + // Adds a test property to the list. The property is validated and may add + // a non-fatal failure if invalid (e.g., if it conflicts with reserved + // key names). If a property is already recorded for the same key, the + // value will be updated, rather than storing multiple values for the same + // key. + void RecordProperty(const TestProperty& test_property); + + // Adds a failure if the key is a reserved attribute of Google Test + // testcase tags. Returns true if the property is valid. + // TODO(russr): Validate attribute names are legal and human readable. + static bool ValidateTestProperty(const TestProperty& test_property); + + // Adds a test part result to the list. + void AddTestPartResult(const TestPartResult& test_part_result); + + // Returns the death test count. + int death_test_count() const { return death_test_count_; } + + // Increments the death test count, returning the new count. + int increment_death_test_count() { return ++death_test_count_; } + + // Clears the test part results. + void ClearTestPartResults(); + + // Clears the object. + void Clear(); + + // Protects mutable state of the property vector and of owned + // properties, whose values may be updated. + internal::Mutex test_properites_mutex_; + + // The vector of TestPartResults + std::vector test_part_results_; + // The vector of TestProperties + std::vector test_properties_; + // Running count of death tests. + int death_test_count_; + // The elapsed time, in milliseconds. + TimeInMillis elapsed_time_; + + // We disallow copying TestResult. + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestResult); +}; // class TestResult + +// A TestInfo object stores the following information about a test: +// +// Test case name +// Test name +// Whether the test should be run +// A function pointer that creates the test object when invoked +// Test result +// +// The constructor of TestInfo registers itself with the UnitTest +// singleton such that the RUN_ALL_TESTS() macro knows which tests to +// run. +class GTEST_API_ TestInfo { + public: + // Destructs a TestInfo object. This function is not virtual, so + // don't inherit from TestInfo. + ~TestInfo(); + + // Returns the test case name. + const char* test_case_name() const { return test_case_name_.c_str(); } + + // Returns the test name. + const char* name() const { return name_.c_str(); } + + // Returns the name of the parameter type, or NULL if this is not a typed + // or a type-parameterized test. + const char* type_param() const { + if (type_param_.get() != NULL) + return type_param_->c_str(); + return NULL; + } + + // Returns the text representation of the value parameter, or NULL if this + // is not a value-parameterized test. + const char* value_param() const { + if (value_param_.get() != NULL) + return value_param_->c_str(); + return NULL; + } + + // Returns true if this test should run, that is if the test is not disabled + // (or it is disabled but the also_run_disabled_tests flag has been specified) + // and its full name matches the user-specified filter. + // + // Google Test allows the user to filter the tests by their full names. + // The full name of a test Bar in test case Foo is defined as + // "Foo.Bar". Only the tests that match the filter will run. + // + // A filter is a colon-separated list of glob (not regex) patterns, + // optionally followed by a '-' and a colon-separated list of + // negative patterns (tests to exclude). A test is run if it + // matches one of the positive patterns and does not match any of + // the negative patterns. + // + // For example, *A*:Foo.* is a filter that matches any string that + // contains the character 'A' or starts with "Foo.". + bool should_run() const { return should_run_; } + + // Returns the result of the test. + const TestResult* result() const { return &result_; } + + private: + +#if GTEST_HAS_DEATH_TEST + friend class internal::DefaultDeathTestFactory; +#endif // GTEST_HAS_DEATH_TEST + friend class Test; + friend class TestCase; + friend class internal::UnitTestImpl; + friend TestInfo* internal::MakeAndRegisterTestInfo( + const char* test_case_name, const char* name, + const char* type_param, + const char* value_param, + internal::TypeId fixture_class_id, + Test::SetUpTestCaseFunc set_up_tc, + Test::TearDownTestCaseFunc tear_down_tc, + internal::TestFactoryBase* factory); + + // Constructs a TestInfo object. The newly constructed instance assumes + // ownership of the factory object. + TestInfo(const char* test_case_name, const char* name, + const char* a_type_param, + const char* a_value_param, + internal::TypeId fixture_class_id, + internal::TestFactoryBase* factory); + + // Increments the number of death tests encountered in this test so + // far. + int increment_death_test_count() { + return result_.increment_death_test_count(); + } + + // Creates the test object, runs it, records its result, and then + // deletes it. + void Run(); + + static void ClearTestResult(TestInfo* test_info) { + test_info->result_.Clear(); + } + + // These fields are immutable properties of the test. + const std::string test_case_name_; // Test case name + const std::string name_; // Test name + // Name of the parameter type, or NULL if this is not a typed or a + // type-parameterized test. + const internal::scoped_ptr type_param_; + // Text representation of the value parameter, or NULL if this is not a + // value-parameterized test. + const internal::scoped_ptr value_param_; + const internal::TypeId fixture_class_id_; // ID of the test fixture class + bool should_run_; // True iff this test should run + bool is_disabled_; // True iff this test is disabled + bool matches_filter_; // True if this test matches the + // user-specified filter. + internal::TestFactoryBase* const factory_; // The factory that creates + // the test object + + // This field is mutable and needs to be reset before running the + // test for the second time. + TestResult result_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestInfo); +}; + +// A test case, which consists of a vector of TestInfos. +// +// TestCase is not copyable. +class GTEST_API_ TestCase { + public: + // Creates a TestCase with the given name. + // + // TestCase does NOT have a default constructor. Always use this + // constructor to create a TestCase object. + // + // Arguments: + // + // name: name of the test case + // a_type_param: the name of the test's type parameter, or NULL if + // this is not a type-parameterized test. + // set_up_tc: pointer to the function that sets up the test case + // tear_down_tc: pointer to the function that tears down the test case + TestCase(const char* name, const char* a_type_param, + Test::SetUpTestCaseFunc set_up_tc, + Test::TearDownTestCaseFunc tear_down_tc); + + // Destructor of TestCase. + virtual ~TestCase(); + + // Gets the name of the TestCase. + const char* name() const { return name_.c_str(); } + + // Returns the name of the parameter type, or NULL if this is not a + // type-parameterized test case. + const char* type_param() const { + if (type_param_.get() != NULL) + return type_param_->c_str(); + return NULL; + } + + // Returns true if any test in this test case should run. + bool should_run() const { return should_run_; } + + // Gets the number of successful tests in this test case. + int successful_test_count() const; + + // Gets the number of failed tests in this test case. + int failed_test_count() const; + + // Gets the number of disabled tests in this test case. + int disabled_test_count() const; + + // Get the number of tests in this test case that should run. + int test_to_run_count() const; + + // Gets the number of all tests in this test case. + int total_test_count() const; + + // Returns true iff the test case passed. + bool Passed() const { return !Failed(); } + + // Returns true iff the test case failed. + bool Failed() const { return failed_test_count() > 0; } + + // Returns the elapsed time, in milliseconds. + TimeInMillis elapsed_time() const { return elapsed_time_; } + + // Returns the i-th test among all the tests. i can range from 0 to + // total_test_count() - 1. If i is not in that range, returns NULL. + const TestInfo* GetTestInfo(int i) const; + + private: + friend class Test; + friend class internal::UnitTestImpl; + + // Gets the (mutable) vector of TestInfos in this TestCase. + std::vector& test_info_list() { return test_info_list_; } + + // Gets the (immutable) vector of TestInfos in this TestCase. + const std::vector& test_info_list() const { + return test_info_list_; + } + + // Returns the i-th test among all the tests. i can range from 0 to + // total_test_count() - 1. If i is not in that range, returns NULL. + TestInfo* GetMutableTestInfo(int i); + + // Sets the should_run member. + void set_should_run(bool should) { should_run_ = should; } + + // Adds a TestInfo to this test case. Will delete the TestInfo upon + // destruction of the TestCase object. + void AddTestInfo(TestInfo * test_info); + + // Clears the results of all tests in this test case. + void ClearResult(); + + // Clears the results of all tests in the given test case. + static void ClearTestCaseResult(TestCase* test_case) { + test_case->ClearResult(); + } + + // Runs every test in this TestCase. + void Run(); + + // Runs SetUpTestCase() for this TestCase. This wrapper is needed + // for catching exceptions thrown from SetUpTestCase(). + void RunSetUpTestCase() { (*set_up_tc_)(); } + + // Runs TearDownTestCase() for this TestCase. This wrapper is + // needed for catching exceptions thrown from TearDownTestCase(). + void RunTearDownTestCase() { (*tear_down_tc_)(); } + + // Returns true iff test passed. + static bool TestPassed(const TestInfo* test_info) { + return test_info->should_run() && test_info->result()->Passed(); + } + + // Returns true iff test failed. + static bool TestFailed(const TestInfo* test_info) { + return test_info->should_run() && test_info->result()->Failed(); + } + + // Returns true iff test is disabled. + static bool TestDisabled(const TestInfo* test_info) { + return test_info->is_disabled_; + } + + // Returns true if the given test should run. + static bool ShouldRunTest(const TestInfo* test_info) { + return test_info->should_run(); + } + + // Shuffles the tests in this test case. + void ShuffleTests(internal::Random* random); + + // Restores the test order to before the first shuffle. + void UnshuffleTests(); + + // Name of the test case. + internal::String name_; + // Name of the parameter type, or NULL if this is not a typed or a + // type-parameterized test. + const internal::scoped_ptr type_param_; + // The vector of TestInfos in their original order. It owns the + // elements in the vector. + std::vector test_info_list_; + // Provides a level of indirection for the test list to allow easy + // shuffling and restoring the test order. The i-th element in this + // vector is the index of the i-th test in the shuffled test list. + std::vector test_indices_; + // Pointer to the function that sets up the test case. + Test::SetUpTestCaseFunc set_up_tc_; + // Pointer to the function that tears down the test case. + Test::TearDownTestCaseFunc tear_down_tc_; + // True iff any test in this test case should run. + bool should_run_; + // Elapsed time, in milliseconds. + TimeInMillis elapsed_time_; + + // We disallow copying TestCases. + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestCase); +}; + +// An Environment object is capable of setting up and tearing down an +// environment. The user should subclass this to define his own +// environment(s). +// +// An Environment object does the set-up and tear-down in virtual +// methods SetUp() and TearDown() instead of the constructor and the +// destructor, as: +// +// 1. You cannot safely throw from a destructor. This is a problem +// as in some cases Google Test is used where exceptions are enabled, and +// we may want to implement ASSERT_* using exceptions where they are +// available. +// 2. You cannot use ASSERT_* directly in a constructor or +// destructor. +class Environment { + public: + // The d'tor is virtual as we need to subclass Environment. + virtual ~Environment() {} + + // Override this to define how to set up the environment. + virtual void SetUp() {} + + // Override this to define how to tear down the environment. + virtual void TearDown() {} + private: + // If you see an error about overriding the following function or + // about it being private, you have mis-spelled SetUp() as Setup(). + struct Setup_should_be_spelled_SetUp {}; + virtual Setup_should_be_spelled_SetUp* Setup() { return NULL; } +}; + +// The interface for tracing execution of tests. The methods are organized in +// the order the corresponding events are fired. +class TestEventListener { + public: + virtual ~TestEventListener() {} + + // Fired before any test activity starts. + virtual void OnTestProgramStart(const UnitTest& unit_test) = 0; + + // Fired before each iteration of tests starts. There may be more than + // one iteration if GTEST_FLAG(repeat) is set. iteration is the iteration + // index, starting from 0. + virtual void OnTestIterationStart(const UnitTest& unit_test, + int iteration) = 0; + + // Fired before environment set-up for each iteration of tests starts. + virtual void OnEnvironmentsSetUpStart(const UnitTest& unit_test) = 0; + + // Fired after environment set-up for each iteration of tests ends. + virtual void OnEnvironmentsSetUpEnd(const UnitTest& unit_test) = 0; + + // Fired before the test case starts. + virtual void OnTestCaseStart(const TestCase& test_case) = 0; + + // Fired before the test starts. + virtual void OnTestStart(const TestInfo& test_info) = 0; + + // Fired after a failed assertion or a SUCCEED() invocation. + virtual void OnTestPartResult(const TestPartResult& test_part_result) = 0; + + // Fired after the test ends. + virtual void OnTestEnd(const TestInfo& test_info) = 0; + + // Fired after the test case ends. + virtual void OnTestCaseEnd(const TestCase& test_case) = 0; + + // Fired before environment tear-down for each iteration of tests starts. + virtual void OnEnvironmentsTearDownStart(const UnitTest& unit_test) = 0; + + // Fired after environment tear-down for each iteration of tests ends. + virtual void OnEnvironmentsTearDownEnd(const UnitTest& unit_test) = 0; + + // Fired after each iteration of tests finishes. + virtual void OnTestIterationEnd(const UnitTest& unit_test, + int iteration) = 0; + + // Fired after all test activities have ended. + virtual void OnTestProgramEnd(const UnitTest& unit_test) = 0; +}; + +// The convenience class for users who need to override just one or two +// methods and are not concerned that a possible change to a signature of +// the methods they override will not be caught during the build. For +// comments about each method please see the definition of TestEventListener +// above. +class EmptyTestEventListener : public TestEventListener { + public: + virtual void OnTestProgramStart(const UnitTest& /*unit_test*/) {} + virtual void OnTestIterationStart(const UnitTest& /*unit_test*/, + int /*iteration*/) {} + virtual void OnEnvironmentsSetUpStart(const UnitTest& /*unit_test*/) {} + virtual void OnEnvironmentsSetUpEnd(const UnitTest& /*unit_test*/) {} + virtual void OnTestCaseStart(const TestCase& /*test_case*/) {} + virtual void OnTestStart(const TestInfo& /*test_info*/) {} + virtual void OnTestPartResult(const TestPartResult& /*test_part_result*/) {} + virtual void OnTestEnd(const TestInfo& /*test_info*/) {} + virtual void OnTestCaseEnd(const TestCase& /*test_case*/) {} + virtual void OnEnvironmentsTearDownStart(const UnitTest& /*unit_test*/) {} + virtual void OnEnvironmentsTearDownEnd(const UnitTest& /*unit_test*/) {} + virtual void OnTestIterationEnd(const UnitTest& /*unit_test*/, + int /*iteration*/) {} + virtual void OnTestProgramEnd(const UnitTest& /*unit_test*/) {} +}; + +// TestEventListeners lets users add listeners to track events in Google Test. +class GTEST_API_ TestEventListeners { + public: + TestEventListeners(); + ~TestEventListeners(); + + // Appends an event listener to the end of the list. Google Test assumes + // the ownership of the listener (i.e. it will delete the listener when + // the test program finishes). + void Append(TestEventListener* listener); + + // Removes the given event listener from the list and returns it. It then + // becomes the caller's responsibility to delete the listener. Returns + // NULL if the listener is not found in the list. + TestEventListener* Release(TestEventListener* listener); + + // Returns the standard listener responsible for the default console + // output. Can be removed from the listeners list to shut down default + // console output. Note that removing this object from the listener list + // with Release transfers its ownership to the caller and makes this + // function return NULL the next time. + TestEventListener* default_result_printer() const { + return default_result_printer_; + } + + // Returns the standard listener responsible for the default XML output + // controlled by the --gtest_output=xml flag. Can be removed from the + // listeners list by users who want to shut down the default XML output + // controlled by this flag and substitute it with custom one. Note that + // removing this object from the listener list with Release transfers its + // ownership to the caller and makes this function return NULL the next + // time. + TestEventListener* default_xml_generator() const { + return default_xml_generator_; + } + + private: + friend class TestCase; + friend class TestInfo; + friend class internal::DefaultGlobalTestPartResultReporter; + friend class internal::NoExecDeathTest; + friend class internal::TestEventListenersAccessor; + friend class internal::UnitTestImpl; + + // Returns repeater that broadcasts the TestEventListener events to all + // subscribers. + TestEventListener* repeater(); + + // Sets the default_result_printer attribute to the provided listener. + // The listener is also added to the listener list and previous + // default_result_printer is removed from it and deleted. The listener can + // also be NULL in which case it will not be added to the list. Does + // nothing if the previous and the current listener objects are the same. + void SetDefaultResultPrinter(TestEventListener* listener); + + // Sets the default_xml_generator attribute to the provided listener. The + // listener is also added to the listener list and previous + // default_xml_generator is removed from it and deleted. The listener can + // also be NULL in which case it will not be added to the list. Does + // nothing if the previous and the current listener objects are the same. + void SetDefaultXmlGenerator(TestEventListener* listener); + + // Controls whether events will be forwarded by the repeater to the + // listeners in the list. + bool EventForwardingEnabled() const; + void SuppressEventForwarding(); + + // The actual list of listeners. + internal::TestEventRepeater* repeater_; + // Listener responsible for the standard result output. + TestEventListener* default_result_printer_; + // Listener responsible for the creation of the XML output file. + TestEventListener* default_xml_generator_; + + // We disallow copying TestEventListeners. + GTEST_DISALLOW_COPY_AND_ASSIGN_(TestEventListeners); +}; + +// A UnitTest consists of a vector of TestCases. +// +// This is a singleton class. The only instance of UnitTest is +// created when UnitTest::GetInstance() is first called. This +// instance is never deleted. +// +// UnitTest is not copyable. +// +// This class is thread-safe as long as the methods are called +// according to their specification. +class GTEST_API_ UnitTest { + public: + // Gets the singleton UnitTest object. The first time this method + // is called, a UnitTest object is constructed and returned. + // Consecutive calls will return the same object. + static UnitTest* GetInstance(); + + // Runs all tests in this UnitTest object and prints the result. + // Returns 0 if successful, or 1 otherwise. + // + // This method can only be called from the main thread. + // + // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. + int Run() GTEST_MUST_USE_RESULT_; + + // Returns the working directory when the first TEST() or TEST_F() + // was executed. The UnitTest object owns the string. + const char* original_working_dir() const; + + // Returns the TestCase object for the test that's currently running, + // or NULL if no test is running. + const TestCase* current_test_case() const; + + // Returns the TestInfo object for the test that's currently running, + // or NULL if no test is running. + const TestInfo* current_test_info() const; + + // Returns the random seed used at the start of the current test run. + int random_seed() const; + +#if GTEST_HAS_PARAM_TEST + // Returns the ParameterizedTestCaseRegistry object used to keep track of + // value-parameterized tests and instantiate and register them. + // + // INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. + internal::ParameterizedTestCaseRegistry& parameterized_test_registry(); +#endif // GTEST_HAS_PARAM_TEST + + // Gets the number of successful test cases. + int successful_test_case_count() const; + + // Gets the number of failed test cases. + int failed_test_case_count() const; + + // Gets the number of all test cases. + int total_test_case_count() const; + + // Gets the number of all test cases that contain at least one test + // that should run. + int test_case_to_run_count() const; + + // Gets the number of successful tests. + int successful_test_count() const; + + // Gets the number of failed tests. + int failed_test_count() const; + + // Gets the number of disabled tests. + int disabled_test_count() const; + + // Gets the number of all tests. + int total_test_count() const; + + // Gets the number of tests that should run. + int test_to_run_count() const; + + // Gets the elapsed time, in milliseconds. + TimeInMillis elapsed_time() const; + + // Returns true iff the unit test passed (i.e. all test cases passed). + bool Passed() const; + + // Returns true iff the unit test failed (i.e. some test case failed + // or something outside of all tests failed). + bool Failed() const; + + // Gets the i-th test case among all the test cases. i can range from 0 to + // total_test_case_count() - 1. If i is not in that range, returns NULL. + const TestCase* GetTestCase(int i) const; + + // Returns the list of event listeners that can be used to track events + // inside Google Test. + TestEventListeners& listeners(); + + private: + // Registers and returns a global test environment. When a test + // program is run, all global test environments will be set-up in + // the order they were registered. After all tests in the program + // have finished, all global test environments will be torn-down in + // the *reverse* order they were registered. + // + // The UnitTest object takes ownership of the given environment. + // + // This method can only be called from the main thread. + Environment* AddEnvironment(Environment* env); + + // Adds a TestPartResult to the current TestResult object. All + // Google Test assertion macros (e.g. ASSERT_TRUE, EXPECT_EQ, etc) + // eventually call this to report their results. The user code + // should use the assertion macros instead of calling this directly. + void AddTestPartResult(TestPartResult::Type result_type, + const char* file_name, + int line_number, + const internal::String& message, + const internal::String& os_stack_trace); + + // Adds a TestProperty to the current TestResult object. If the result already + // contains a property with the same key, the value will be updated. + void RecordPropertyForCurrentTest(const char* key, const char* value); + + // Gets the i-th test case among all the test cases. i can range from 0 to + // total_test_case_count() - 1. If i is not in that range, returns NULL. + TestCase* GetMutableTestCase(int i); + + // Accessors for the implementation object. + internal::UnitTestImpl* impl() { return impl_; } + const internal::UnitTestImpl* impl() const { return impl_; } + + // These classes and funcions are friends as they need to access private + // members of UnitTest. + friend class Test; + friend class internal::AssertHelper; + friend class internal::ScopedTrace; + friend Environment* AddGlobalTestEnvironment(Environment* env); + friend internal::UnitTestImpl* internal::GetUnitTestImpl(); + friend void internal::ReportFailureInUnknownLocation( + TestPartResult::Type result_type, + const internal::String& message); + + // Creates an empty UnitTest. + UnitTest(); + + // D'tor + virtual ~UnitTest(); + + // Pushes a trace defined by SCOPED_TRACE() on to the per-thread + // Google Test trace stack. + void PushGTestTrace(const internal::TraceInfo& trace); + + // Pops a trace from the per-thread Google Test trace stack. + void PopGTestTrace(); + + // Protects mutable state in *impl_. This is mutable as some const + // methods need to lock it too. + mutable internal::Mutex mutex_; + + // Opaque implementation object. This field is never changed once + // the object is constructed. We don't mark it as const here, as + // doing so will cause a warning in the constructor of UnitTest. + // Mutable state in *impl_ is protected by mutex_. + internal::UnitTestImpl* impl_; + + // We disallow copying UnitTest. + GTEST_DISALLOW_COPY_AND_ASSIGN_(UnitTest); +}; + +// A convenient wrapper for adding an environment for the test +// program. +// +// You should call this before RUN_ALL_TESTS() is called, probably in +// main(). If you use gtest_main, you need to call this before main() +// starts for it to take effect. For example, you can define a global +// variable like this: +// +// testing::Environment* const foo_env = +// testing::AddGlobalTestEnvironment(new FooEnvironment); +// +// However, we strongly recommend you to write your own main() and +// call AddGlobalTestEnvironment() there, as relying on initialization +// of global variables makes the code harder to read and may cause +// problems when you register multiple environments from different +// translation units and the environments have dependencies among them +// (remember that the compiler doesn't guarantee the order in which +// global variables from different translation units are initialized). +inline Environment* AddGlobalTestEnvironment(Environment* env) { + return UnitTest::GetInstance()->AddEnvironment(env); +} + +// Initializes Google Test. This must be called before calling +// RUN_ALL_TESTS(). In particular, it parses a command line for the +// flags that Google Test recognizes. Whenever a Google Test flag is +// seen, it is removed from argv, and *argc is decremented. +// +// No value is returned. Instead, the Google Test flag variables are +// updated. +// +// Calling the function for the second time has no user-visible effect. +GTEST_API_ void InitGoogleTest(int* argc, char** argv); + +// This overloaded version can be used in Windows programs compiled in +// UNICODE mode. +GTEST_API_ void InitGoogleTest(int* argc, wchar_t** argv); + +namespace internal { + +// Formats a comparison assertion (e.g. ASSERT_EQ, EXPECT_LT, and etc) +// operand to be used in a failure message. The type (but not value) +// of the other operand may affect the format. This allows us to +// print a char* as a raw pointer when it is compared against another +// char*, and print it as a C string when it is compared against an +// std::string object, for example. +// +// The default implementation ignores the type of the other operand. +// Some specialized versions are used to handle formatting wide or +// narrow C strings. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +template +String FormatForComparisonFailureMessage(const T1& value, + const T2& /* other_operand */) { + // C++Builder compiles this incorrectly if the namespace isn't explicitly + // given. + return ::testing::PrintToString(value); +} + +// The helper function for {ASSERT|EXPECT}_EQ. +template +AssertionResult CmpHelperEQ(const char* expected_expression, + const char* actual_expression, + const T1& expected, + const T2& actual) { +#ifdef _MSC_VER +# pragma warning(push) // Saves the current warning state. +# pragma warning(disable:4389) // Temporarily disables warning on + // signed/unsigned mismatch. +#endif + + if (expected == actual) { + return AssertionSuccess(); + } + +#ifdef _MSC_VER +# pragma warning(pop) // Restores the warning state. +#endif + + return EqFailure(expected_expression, + actual_expression, + FormatForComparisonFailureMessage(expected, actual), + FormatForComparisonFailureMessage(actual, expected), + false); +} + +// With this overloaded version, we allow anonymous enums to be used +// in {ASSERT|EXPECT}_EQ when compiled with gcc 4, as anonymous enums +// can be implicitly cast to BiggestInt. +GTEST_API_ AssertionResult CmpHelperEQ(const char* expected_expression, + const char* actual_expression, + BiggestInt expected, + BiggestInt actual); + +// The helper class for {ASSERT|EXPECT}_EQ. The template argument +// lhs_is_null_literal is true iff the first argument to ASSERT_EQ() +// is a null pointer literal. The following default implementation is +// for lhs_is_null_literal being false. +template +class EqHelper { + public: + // This templatized version is for the general case. + template + static AssertionResult Compare(const char* expected_expression, + const char* actual_expression, + const T1& expected, + const T2& actual) { + return CmpHelperEQ(expected_expression, actual_expression, expected, + actual); + } + + // With this overloaded version, we allow anonymous enums to be used + // in {ASSERT|EXPECT}_EQ when compiled with gcc 4, as anonymous + // enums can be implicitly cast to BiggestInt. + // + // Even though its body looks the same as the above version, we + // cannot merge the two, as it will make anonymous enums unhappy. + static AssertionResult Compare(const char* expected_expression, + const char* actual_expression, + BiggestInt expected, + BiggestInt actual) { + return CmpHelperEQ(expected_expression, actual_expression, expected, + actual); + } +}; + +// This specialization is used when the first argument to ASSERT_EQ() +// is a null pointer literal, like NULL, false, or 0. +template <> +class EqHelper { + public: + // We define two overloaded versions of Compare(). The first + // version will be picked when the second argument to ASSERT_EQ() is + // NOT a pointer, e.g. ASSERT_EQ(0, AnIntFunction()) or + // EXPECT_EQ(false, a_bool). + template + static AssertionResult Compare( + const char* expected_expression, + const char* actual_expression, + const T1& expected, + const T2& actual, + // The following line prevents this overload from being considered if T2 + // is not a pointer type. We need this because ASSERT_EQ(NULL, my_ptr) + // expands to Compare("", "", NULL, my_ptr), which requires a conversion + // to match the Secret* in the other overload, which would otherwise make + // this template match better. + typename EnableIf::value>::type* = 0) { + return CmpHelperEQ(expected_expression, actual_expression, expected, + actual); + } + + // This version will be picked when the second argument to ASSERT_EQ() is a + // pointer, e.g. ASSERT_EQ(NULL, a_pointer). + template + static AssertionResult Compare( + const char* expected_expression, + const char* actual_expression, + // We used to have a second template parameter instead of Secret*. That + // template parameter would deduce to 'long', making this a better match + // than the first overload even without the first overload's EnableIf. + // Unfortunately, gcc with -Wconversion-null warns when "passing NULL to + // non-pointer argument" (even a deduced integral argument), so the old + // implementation caused warnings in user code. + Secret* /* expected (NULL) */, + T* actual) { + // We already know that 'expected' is a null pointer. + return CmpHelperEQ(expected_expression, actual_expression, + static_cast(NULL), actual); + } +}; + +// A macro for implementing the helper functions needed to implement +// ASSERT_?? and EXPECT_??. It is here just to avoid copy-and-paste +// of similar code. +// +// For each templatized helper function, we also define an overloaded +// version for BiggestInt in order to reduce code bloat and allow +// anonymous enums to be used with {ASSERT|EXPECT}_?? when compiled +// with gcc 4. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +#define GTEST_IMPL_CMP_HELPER_(op_name, op)\ +template \ +AssertionResult CmpHelper##op_name(const char* expr1, const char* expr2, \ + const T1& val1, const T2& val2) {\ + if (val1 op val2) {\ + return AssertionSuccess();\ + } else {\ + return AssertionFailure() \ + << "Expected: (" << expr1 << ") " #op " (" << expr2\ + << "), actual: " << FormatForComparisonFailureMessage(val1, val2)\ + << " vs " << FormatForComparisonFailureMessage(val2, val1);\ + }\ +}\ +GTEST_API_ AssertionResult CmpHelper##op_name(\ + const char* expr1, const char* expr2, BiggestInt val1, BiggestInt val2) + +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. + +// Implements the helper function for {ASSERT|EXPECT}_NE +GTEST_IMPL_CMP_HELPER_(NE, !=); +// Implements the helper function for {ASSERT|EXPECT}_LE +GTEST_IMPL_CMP_HELPER_(LE, <=); +// Implements the helper function for {ASSERT|EXPECT}_LT +GTEST_IMPL_CMP_HELPER_(LT, < ); +// Implements the helper function for {ASSERT|EXPECT}_GE +GTEST_IMPL_CMP_HELPER_(GE, >=); +// Implements the helper function for {ASSERT|EXPECT}_GT +GTEST_IMPL_CMP_HELPER_(GT, > ); + +#undef GTEST_IMPL_CMP_HELPER_ + +// The helper function for {ASSERT|EXPECT}_STREQ. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult CmpHelperSTREQ(const char* expected_expression, + const char* actual_expression, + const char* expected, + const char* actual); + +// The helper function for {ASSERT|EXPECT}_STRCASEEQ. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult CmpHelperSTRCASEEQ(const char* expected_expression, + const char* actual_expression, + const char* expected, + const char* actual); + +// The helper function for {ASSERT|EXPECT}_STRNE. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult CmpHelperSTRNE(const char* s1_expression, + const char* s2_expression, + const char* s1, + const char* s2); + +// The helper function for {ASSERT|EXPECT}_STRCASENE. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult CmpHelperSTRCASENE(const char* s1_expression, + const char* s2_expression, + const char* s1, + const char* s2); + + +// Helper function for *_STREQ on wide strings. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult CmpHelperSTREQ(const char* expected_expression, + const char* actual_expression, + const wchar_t* expected, + const wchar_t* actual); + +// Helper function for *_STRNE on wide strings. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult CmpHelperSTRNE(const char* s1_expression, + const char* s2_expression, + const wchar_t* s1, + const wchar_t* s2); + +} // namespace internal + +// IsSubstring() and IsNotSubstring() are intended to be used as the +// first argument to {EXPECT,ASSERT}_PRED_FORMAT2(), not by +// themselves. They check whether needle is a substring of haystack +// (NULL is considered a substring of itself only), and return an +// appropriate error message when they fail. +// +// The {needle,haystack}_expr arguments are the stringified +// expressions that generated the two real arguments. +GTEST_API_ AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const char* needle, const char* haystack); +GTEST_API_ AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const wchar_t* needle, const wchar_t* haystack); +GTEST_API_ AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const char* needle, const char* haystack); +GTEST_API_ AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const wchar_t* needle, const wchar_t* haystack); +GTEST_API_ AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::string& needle, const ::std::string& haystack); +GTEST_API_ AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::string& needle, const ::std::string& haystack); + +#if GTEST_HAS_STD_WSTRING +GTEST_API_ AssertionResult IsSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::wstring& needle, const ::std::wstring& haystack); +GTEST_API_ AssertionResult IsNotSubstring( + const char* needle_expr, const char* haystack_expr, + const ::std::wstring& needle, const ::std::wstring& haystack); +#endif // GTEST_HAS_STD_WSTRING + +namespace internal { + +// Helper template function for comparing floating-points. +// +// Template parameter: +// +// RawType: the raw floating-point type (either float or double) +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +template +AssertionResult CmpHelperFloatingPointEQ(const char* expected_expression, + const char* actual_expression, + RawType expected, + RawType actual) { + const FloatingPoint lhs(expected), rhs(actual); + + if (lhs.AlmostEquals(rhs)) { + return AssertionSuccess(); + } + + ::std::stringstream expected_ss; + expected_ss << std::setprecision(std::numeric_limits::digits10 + 2) + << expected; + + ::std::stringstream actual_ss; + actual_ss << std::setprecision(std::numeric_limits::digits10 + 2) + << actual; + + return EqFailure(expected_expression, + actual_expression, + StringStreamToString(&expected_ss), + StringStreamToString(&actual_ss), + false); +} + +// Helper function for implementing ASSERT_NEAR. +// +// INTERNAL IMPLEMENTATION - DO NOT USE IN A USER PROGRAM. +GTEST_API_ AssertionResult DoubleNearPredFormat(const char* expr1, + const char* expr2, + const char* abs_error_expr, + double val1, + double val2, + double abs_error); + +// INTERNAL IMPLEMENTATION - DO NOT USE IN USER CODE. +// A class that enables one to stream messages to assertion macros +class GTEST_API_ AssertHelper { + public: + // Constructor. + AssertHelper(TestPartResult::Type type, + const char* file, + int line, + const char* message); + ~AssertHelper(); + + // Message assignment is a semantic trick to enable assertion + // streaming; see the GTEST_MESSAGE_ macro below. + void operator=(const Message& message) const; + + private: + // We put our data in a struct so that the size of the AssertHelper class can + // be as small as possible. This is important because gcc is incapable of + // re-using stack space even for temporary variables, so every EXPECT_EQ + // reserves stack space for another AssertHelper. + struct AssertHelperData { + AssertHelperData(TestPartResult::Type t, + const char* srcfile, + int line_num, + const char* msg) + : type(t), file(srcfile), line(line_num), message(msg) { } + + TestPartResult::Type const type; + const char* const file; + int const line; + String const message; + + private: + GTEST_DISALLOW_COPY_AND_ASSIGN_(AssertHelperData); + }; + + AssertHelperData* const data_; + + GTEST_DISALLOW_COPY_AND_ASSIGN_(AssertHelper); +}; + +} // namespace internal + +#if GTEST_HAS_PARAM_TEST +// The pure interface class that all value-parameterized tests inherit from. +// A value-parameterized class must inherit from both ::testing::Test and +// ::testing::WithParamInterface. In most cases that just means inheriting +// from ::testing::TestWithParam, but more complicated test hierarchies +// may need to inherit from Test and WithParamInterface at different levels. +// +// This interface has support for accessing the test parameter value via +// the GetParam() method. +// +// Use it with one of the parameter generator defining functions, like Range(), +// Values(), ValuesIn(), Bool(), and Combine(). +// +// class FooTest : public ::testing::TestWithParam { +// protected: +// FooTest() { +// // Can use GetParam() here. +// } +// virtual ~FooTest() { +// // Can use GetParam() here. +// } +// virtual void SetUp() { +// // Can use GetParam() here. +// } +// virtual void TearDown { +// // Can use GetParam() here. +// } +// }; +// TEST_P(FooTest, DoesBar) { +// // Can use GetParam() method here. +// Foo foo; +// ASSERT_TRUE(foo.DoesBar(GetParam())); +// } +// INSTANTIATE_TEST_CASE_P(OneToTenRange, FooTest, ::testing::Range(1, 10)); + +template +class WithParamInterface { + public: + typedef T ParamType; + virtual ~WithParamInterface() {} + + // The current parameter value. Is also available in the test fixture's + // constructor. This member function is non-static, even though it only + // references static data, to reduce the opportunity for incorrect uses + // like writing 'WithParamInterface::GetParam()' for a test that + // uses a fixture whose parameter type is int. + const ParamType& GetParam() const { return *parameter_; } + + private: + // Sets parameter value. The caller is responsible for making sure the value + // remains alive and unchanged throughout the current test. + static void SetParam(const ParamType* parameter) { + parameter_ = parameter; + } + + // Static value used for accessing parameter during a test lifetime. + static const ParamType* parameter_; + + // TestClass must be a subclass of WithParamInterface and Test. + template friend class internal::ParameterizedTestFactory; +}; + +template +const T* WithParamInterface::parameter_ = NULL; + +// Most value-parameterized classes can ignore the existence of +// WithParamInterface, and can just inherit from ::testing::TestWithParam. + +template +class TestWithParam : public Test, public WithParamInterface { +}; + +#endif // GTEST_HAS_PARAM_TEST + +// Macros for indicating success/failure in test code. + +// ADD_FAILURE unconditionally adds a failure to the current test. +// SUCCEED generates a success - it doesn't automatically make the +// current test successful, as a test is only successful when it has +// no failure. +// +// EXPECT_* verifies that a certain condition is satisfied. If not, +// it behaves like ADD_FAILURE. In particular: +// +// EXPECT_TRUE verifies that a Boolean condition is true. +// EXPECT_FALSE verifies that a Boolean condition is false. +// +// FAIL and ASSERT_* are similar to ADD_FAILURE and EXPECT_*, except +// that they will also abort the current function on failure. People +// usually want the fail-fast behavior of FAIL and ASSERT_*, but those +// writing data-driven tests often find themselves using ADD_FAILURE +// and EXPECT_* more. +// +// Examples: +// +// EXPECT_TRUE(server.StatusIsOK()); +// ASSERT_FALSE(server.HasPendingRequest(port)) +// << "There are still pending requests " << "on port " << port; + +// Generates a nonfatal failure with a generic message. +#define ADD_FAILURE() GTEST_NONFATAL_FAILURE_("Failed") + +// Generates a nonfatal failure at the given source file location with +// a generic message. +#define ADD_FAILURE_AT(file, line) \ + GTEST_MESSAGE_AT_(file, line, "Failed", \ + ::testing::TestPartResult::kNonFatalFailure) + +// Generates a fatal failure with a generic message. +#define GTEST_FAIL() GTEST_FATAL_FAILURE_("Failed") + +// Define this macro to 1 to omit the definition of FAIL(), which is a +// generic name and clashes with some other libraries. +#if !GTEST_DONT_DEFINE_FAIL +# define FAIL() GTEST_FAIL() +#endif + +// Generates a success with a generic message. +#define GTEST_SUCCEED() GTEST_SUCCESS_("Succeeded") + +// Define this macro to 1 to omit the definition of SUCCEED(), which +// is a generic name and clashes with some other libraries. +#if !GTEST_DONT_DEFINE_SUCCEED +# define SUCCEED() GTEST_SUCCEED() +#endif + +// Macros for testing exceptions. +// +// * {ASSERT|EXPECT}_THROW(statement, expected_exception): +// Tests that the statement throws the expected exception. +// * {ASSERT|EXPECT}_NO_THROW(statement): +// Tests that the statement doesn't throw any exception. +// * {ASSERT|EXPECT}_ANY_THROW(statement): +// Tests that the statement throws an exception. + +#define EXPECT_THROW(statement, expected_exception) \ + GTEST_TEST_THROW_(statement, expected_exception, GTEST_NONFATAL_FAILURE_) +#define EXPECT_NO_THROW(statement) \ + GTEST_TEST_NO_THROW_(statement, GTEST_NONFATAL_FAILURE_) +#define EXPECT_ANY_THROW(statement) \ + GTEST_TEST_ANY_THROW_(statement, GTEST_NONFATAL_FAILURE_) +#define ASSERT_THROW(statement, expected_exception) \ + GTEST_TEST_THROW_(statement, expected_exception, GTEST_FATAL_FAILURE_) +#define ASSERT_NO_THROW(statement) \ + GTEST_TEST_NO_THROW_(statement, GTEST_FATAL_FAILURE_) +#define ASSERT_ANY_THROW(statement) \ + GTEST_TEST_ANY_THROW_(statement, GTEST_FATAL_FAILURE_) + +// Boolean assertions. Condition can be either a Boolean expression or an +// AssertionResult. For more information on how to use AssertionResult with +// these macros see comments on that class. +#define EXPECT_TRUE(condition) \ + GTEST_TEST_BOOLEAN_(condition, #condition, false, true, \ + GTEST_NONFATAL_FAILURE_) +#define EXPECT_FALSE(condition) \ + GTEST_TEST_BOOLEAN_(!(condition), #condition, true, false, \ + GTEST_NONFATAL_FAILURE_) +#define ASSERT_TRUE(condition) \ + GTEST_TEST_BOOLEAN_(condition, #condition, false, true, \ + GTEST_FATAL_FAILURE_) +#define ASSERT_FALSE(condition) \ + GTEST_TEST_BOOLEAN_(!(condition), #condition, true, false, \ + GTEST_FATAL_FAILURE_) + +// Includes the auto-generated header that implements a family of +// generic predicate assertion macros. +// Copyright 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// This file is AUTOMATICALLY GENERATED on 09/24/2010 by command +// 'gen_gtest_pred_impl.py 5'. DO NOT EDIT BY HAND! +// +// Implements a family of generic predicate assertion macros. + +#ifndef GTEST_INCLUDE_GTEST_GTEST_PRED_IMPL_H_ +#define GTEST_INCLUDE_GTEST_GTEST_PRED_IMPL_H_ + +// Makes sure this header is not included before gtest.h. +#ifndef GTEST_INCLUDE_GTEST_GTEST_H_ +# error Do not include gtest_pred_impl.h directly. Include gtest.h instead. +#endif // GTEST_INCLUDE_GTEST_GTEST_H_ + +// This header implements a family of generic predicate assertion +// macros: +// +// ASSERT_PRED_FORMAT1(pred_format, v1) +// ASSERT_PRED_FORMAT2(pred_format, v1, v2) +// ... +// +// where pred_format is a function or functor that takes n (in the +// case of ASSERT_PRED_FORMATn) values and their source expression +// text, and returns a testing::AssertionResult. See the definition +// of ASSERT_EQ in gtest.h for an example. +// +// If you don't care about formatting, you can use the more +// restrictive version: +// +// ASSERT_PRED1(pred, v1) +// ASSERT_PRED2(pred, v1, v2) +// ... +// +// where pred is an n-ary function or functor that returns bool, +// and the values v1, v2, ..., must support the << operator for +// streaming to std::ostream. +// +// We also define the EXPECT_* variations. +// +// For now we only support predicates whose arity is at most 5. +// Please email googletestframework@googlegroups.com if you need +// support for higher arities. + +// GTEST_ASSERT_ is the basic statement to which all of the assertions +// in this file reduce. Don't use this in your code. + +#define GTEST_ASSERT_(expression, on_failure) \ + GTEST_AMBIGUOUS_ELSE_BLOCKER_ \ + if (const ::testing::AssertionResult gtest_ar = (expression)) \ + ; \ + else \ + on_failure(gtest_ar.failure_message()) + + +// Helper function for implementing {EXPECT|ASSERT}_PRED1. Don't use +// this in your code. +template +AssertionResult AssertPred1Helper(const char* pred_text, + const char* e1, + Pred pred, + const T1& v1) { + if (pred(v1)) return AssertionSuccess(); + + return AssertionFailure() << pred_text << "(" + << e1 << ") evaluates to false, where" + << "\n" << e1 << " evaluates to " << v1; +} + +// Internal macro for implementing {EXPECT|ASSERT}_PRED_FORMAT1. +// Don't use this in your code. +#define GTEST_PRED_FORMAT1_(pred_format, v1, on_failure)\ + GTEST_ASSERT_(pred_format(#v1, v1),\ + on_failure) + +// Internal macro for implementing {EXPECT|ASSERT}_PRED1. Don't use +// this in your code. +#define GTEST_PRED1_(pred, v1, on_failure)\ + GTEST_ASSERT_(::testing::AssertPred1Helper(#pred, \ + #v1, \ + pred, \ + v1), on_failure) + +// Unary predicate assertion macros. +#define EXPECT_PRED_FORMAT1(pred_format, v1) \ + GTEST_PRED_FORMAT1_(pred_format, v1, GTEST_NONFATAL_FAILURE_) +#define EXPECT_PRED1(pred, v1) \ + GTEST_PRED1_(pred, v1, GTEST_NONFATAL_FAILURE_) +#define ASSERT_PRED_FORMAT1(pred_format, v1) \ + GTEST_PRED_FORMAT1_(pred_format, v1, GTEST_FATAL_FAILURE_) +#define ASSERT_PRED1(pred, v1) \ + GTEST_PRED1_(pred, v1, GTEST_FATAL_FAILURE_) + + + +// Helper function for implementing {EXPECT|ASSERT}_PRED2. Don't use +// this in your code. +template +AssertionResult AssertPred2Helper(const char* pred_text, + const char* e1, + const char* e2, + Pred pred, + const T1& v1, + const T2& v2) { + if (pred(v1, v2)) return AssertionSuccess(); + + return AssertionFailure() << pred_text << "(" + << e1 << ", " + << e2 << ") evaluates to false, where" + << "\n" << e1 << " evaluates to " << v1 + << "\n" << e2 << " evaluates to " << v2; +} + +// Internal macro for implementing {EXPECT|ASSERT}_PRED_FORMAT2. +// Don't use this in your code. +#define GTEST_PRED_FORMAT2_(pred_format, v1, v2, on_failure)\ + GTEST_ASSERT_(pred_format(#v1, #v2, v1, v2),\ + on_failure) + +// Internal macro for implementing {EXPECT|ASSERT}_PRED2. Don't use +// this in your code. +#define GTEST_PRED2_(pred, v1, v2, on_failure)\ + GTEST_ASSERT_(::testing::AssertPred2Helper(#pred, \ + #v1, \ + #v2, \ + pred, \ + v1, \ + v2), on_failure) + +// Binary predicate assertion macros. +#define EXPECT_PRED_FORMAT2(pred_format, v1, v2) \ + GTEST_PRED_FORMAT2_(pred_format, v1, v2, GTEST_NONFATAL_FAILURE_) +#define EXPECT_PRED2(pred, v1, v2) \ + GTEST_PRED2_(pred, v1, v2, GTEST_NONFATAL_FAILURE_) +#define ASSERT_PRED_FORMAT2(pred_format, v1, v2) \ + GTEST_PRED_FORMAT2_(pred_format, v1, v2, GTEST_FATAL_FAILURE_) +#define ASSERT_PRED2(pred, v1, v2) \ + GTEST_PRED2_(pred, v1, v2, GTEST_FATAL_FAILURE_) + + + +// Helper function for implementing {EXPECT|ASSERT}_PRED3. Don't use +// this in your code. +template +AssertionResult AssertPred3Helper(const char* pred_text, + const char* e1, + const char* e2, + const char* e3, + Pred pred, + const T1& v1, + const T2& v2, + const T3& v3) { + if (pred(v1, v2, v3)) return AssertionSuccess(); + + return AssertionFailure() << pred_text << "(" + << e1 << ", " + << e2 << ", " + << e3 << ") evaluates to false, where" + << "\n" << e1 << " evaluates to " << v1 + << "\n" << e2 << " evaluates to " << v2 + << "\n" << e3 << " evaluates to " << v3; +} + +// Internal macro for implementing {EXPECT|ASSERT}_PRED_FORMAT3. +// Don't use this in your code. +#define GTEST_PRED_FORMAT3_(pred_format, v1, v2, v3, on_failure)\ + GTEST_ASSERT_(pred_format(#v1, #v2, #v3, v1, v2, v3),\ + on_failure) + +// Internal macro for implementing {EXPECT|ASSERT}_PRED3. Don't use +// this in your code. +#define GTEST_PRED3_(pred, v1, v2, v3, on_failure)\ + GTEST_ASSERT_(::testing::AssertPred3Helper(#pred, \ + #v1, \ + #v2, \ + #v3, \ + pred, \ + v1, \ + v2, \ + v3), on_failure) + +// Ternary predicate assertion macros. +#define EXPECT_PRED_FORMAT3(pred_format, v1, v2, v3) \ + GTEST_PRED_FORMAT3_(pred_format, v1, v2, v3, GTEST_NONFATAL_FAILURE_) +#define EXPECT_PRED3(pred, v1, v2, v3) \ + GTEST_PRED3_(pred, v1, v2, v3, GTEST_NONFATAL_FAILURE_) +#define ASSERT_PRED_FORMAT3(pred_format, v1, v2, v3) \ + GTEST_PRED_FORMAT3_(pred_format, v1, v2, v3, GTEST_FATAL_FAILURE_) +#define ASSERT_PRED3(pred, v1, v2, v3) \ + GTEST_PRED3_(pred, v1, v2, v3, GTEST_FATAL_FAILURE_) + + + +// Helper function for implementing {EXPECT|ASSERT}_PRED4. Don't use +// this in your code. +template +AssertionResult AssertPred4Helper(const char* pred_text, + const char* e1, + const char* e2, + const char* e3, + const char* e4, + Pred pred, + const T1& v1, + const T2& v2, + const T3& v3, + const T4& v4) { + if (pred(v1, v2, v3, v4)) return AssertionSuccess(); + + return AssertionFailure() << pred_text << "(" + << e1 << ", " + << e2 << ", " + << e3 << ", " + << e4 << ") evaluates to false, where" + << "\n" << e1 << " evaluates to " << v1 + << "\n" << e2 << " evaluates to " << v2 + << "\n" << e3 << " evaluates to " << v3 + << "\n" << e4 << " evaluates to " << v4; +} + +// Internal macro for implementing {EXPECT|ASSERT}_PRED_FORMAT4. +// Don't use this in your code. +#define GTEST_PRED_FORMAT4_(pred_format, v1, v2, v3, v4, on_failure)\ + GTEST_ASSERT_(pred_format(#v1, #v2, #v3, #v4, v1, v2, v3, v4),\ + on_failure) + +// Internal macro for implementing {EXPECT|ASSERT}_PRED4. Don't use +// this in your code. +#define GTEST_PRED4_(pred, v1, v2, v3, v4, on_failure)\ + GTEST_ASSERT_(::testing::AssertPred4Helper(#pred, \ + #v1, \ + #v2, \ + #v3, \ + #v4, \ + pred, \ + v1, \ + v2, \ + v3, \ + v4), on_failure) + +// 4-ary predicate assertion macros. +#define EXPECT_PRED_FORMAT4(pred_format, v1, v2, v3, v4) \ + GTEST_PRED_FORMAT4_(pred_format, v1, v2, v3, v4, GTEST_NONFATAL_FAILURE_) +#define EXPECT_PRED4(pred, v1, v2, v3, v4) \ + GTEST_PRED4_(pred, v1, v2, v3, v4, GTEST_NONFATAL_FAILURE_) +#define ASSERT_PRED_FORMAT4(pred_format, v1, v2, v3, v4) \ + GTEST_PRED_FORMAT4_(pred_format, v1, v2, v3, v4, GTEST_FATAL_FAILURE_) +#define ASSERT_PRED4(pred, v1, v2, v3, v4) \ + GTEST_PRED4_(pred, v1, v2, v3, v4, GTEST_FATAL_FAILURE_) + + + +// Helper function for implementing {EXPECT|ASSERT}_PRED5. Don't use +// this in your code. +template +AssertionResult AssertPred5Helper(const char* pred_text, + const char* e1, + const char* e2, + const char* e3, + const char* e4, + const char* e5, + Pred pred, + const T1& v1, + const T2& v2, + const T3& v3, + const T4& v4, + const T5& v5) { + if (pred(v1, v2, v3, v4, v5)) return AssertionSuccess(); + + return AssertionFailure() << pred_text << "(" + << e1 << ", " + << e2 << ", " + << e3 << ", " + << e4 << ", " + << e5 << ") evaluates to false, where" + << "\n" << e1 << " evaluates to " << v1 + << "\n" << e2 << " evaluates to " << v2 + << "\n" << e3 << " evaluates to " << v3 + << "\n" << e4 << " evaluates to " << v4 + << "\n" << e5 << " evaluates to " << v5; +} + +// Internal macro for implementing {EXPECT|ASSERT}_PRED_FORMAT5. +// Don't use this in your code. +#define GTEST_PRED_FORMAT5_(pred_format, v1, v2, v3, v4, v5, on_failure)\ + GTEST_ASSERT_(pred_format(#v1, #v2, #v3, #v4, #v5, v1, v2, v3, v4, v5),\ + on_failure) + +// Internal macro for implementing {EXPECT|ASSERT}_PRED5. Don't use +// this in your code. +#define GTEST_PRED5_(pred, v1, v2, v3, v4, v5, on_failure)\ + GTEST_ASSERT_(::testing::AssertPred5Helper(#pred, \ + #v1, \ + #v2, \ + #v3, \ + #v4, \ + #v5, \ + pred, \ + v1, \ + v2, \ + v3, \ + v4, \ + v5), on_failure) + +// 5-ary predicate assertion macros. +#define EXPECT_PRED_FORMAT5(pred_format, v1, v2, v3, v4, v5) \ + GTEST_PRED_FORMAT5_(pred_format, v1, v2, v3, v4, v5, GTEST_NONFATAL_FAILURE_) +#define EXPECT_PRED5(pred, v1, v2, v3, v4, v5) \ + GTEST_PRED5_(pred, v1, v2, v3, v4, v5, GTEST_NONFATAL_FAILURE_) +#define ASSERT_PRED_FORMAT5(pred_format, v1, v2, v3, v4, v5) \ + GTEST_PRED_FORMAT5_(pred_format, v1, v2, v3, v4, v5, GTEST_FATAL_FAILURE_) +#define ASSERT_PRED5(pred, v1, v2, v3, v4, v5) \ + GTEST_PRED5_(pred, v1, v2, v3, v4, v5, GTEST_FATAL_FAILURE_) + + + +#endif // GTEST_INCLUDE_GTEST_GTEST_PRED_IMPL_H_ + +// Macros for testing equalities and inequalities. +// +// * {ASSERT|EXPECT}_EQ(expected, actual): Tests that expected == actual +// * {ASSERT|EXPECT}_NE(v1, v2): Tests that v1 != v2 +// * {ASSERT|EXPECT}_LT(v1, v2): Tests that v1 < v2 +// * {ASSERT|EXPECT}_LE(v1, v2): Tests that v1 <= v2 +// * {ASSERT|EXPECT}_GT(v1, v2): Tests that v1 > v2 +// * {ASSERT|EXPECT}_GE(v1, v2): Tests that v1 >= v2 +// +// When they are not, Google Test prints both the tested expressions and +// their actual values. The values must be compatible built-in types, +// or you will get a compiler error. By "compatible" we mean that the +// values can be compared by the respective operator. +// +// Note: +// +// 1. It is possible to make a user-defined type work with +// {ASSERT|EXPECT}_??(), but that requires overloading the +// comparison operators and is thus discouraged by the Google C++ +// Usage Guide. Therefore, you are advised to use the +// {ASSERT|EXPECT}_TRUE() macro to assert that two objects are +// equal. +// +// 2. The {ASSERT|EXPECT}_??() macros do pointer comparisons on +// pointers (in particular, C strings). Therefore, if you use it +// with two C strings, you are testing how their locations in memory +// are related, not how their content is related. To compare two C +// strings by content, use {ASSERT|EXPECT}_STR*(). +// +// 3. {ASSERT|EXPECT}_EQ(expected, actual) is preferred to +// {ASSERT|EXPECT}_TRUE(expected == actual), as the former tells you +// what the actual value is when it fails, and similarly for the +// other comparisons. +// +// 4. Do not depend on the order in which {ASSERT|EXPECT}_??() +// evaluate their arguments, which is undefined. +// +// 5. These macros evaluate their arguments exactly once. +// +// Examples: +// +// EXPECT_NE(5, Foo()); +// EXPECT_EQ(NULL, a_pointer); +// ASSERT_LT(i, array_size); +// ASSERT_GT(records.size(), 0) << "There is no record left."; + +#define EXPECT_EQ(expected, actual) \ + EXPECT_PRED_FORMAT2(::testing::internal:: \ + EqHelper::Compare, \ + expected, actual) +#define EXPECT_NE(expected, actual) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperNE, expected, actual) +#define EXPECT_LE(val1, val2) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperLE, val1, val2) +#define EXPECT_LT(val1, val2) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperLT, val1, val2) +#define EXPECT_GE(val1, val2) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperGE, val1, val2) +#define EXPECT_GT(val1, val2) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperGT, val1, val2) + +#define GTEST_ASSERT_EQ(expected, actual) \ + ASSERT_PRED_FORMAT2(::testing::internal:: \ + EqHelper::Compare, \ + expected, actual) +#define GTEST_ASSERT_NE(val1, val2) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperNE, val1, val2) +#define GTEST_ASSERT_LE(val1, val2) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperLE, val1, val2) +#define GTEST_ASSERT_LT(val1, val2) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperLT, val1, val2) +#define GTEST_ASSERT_GE(val1, val2) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperGE, val1, val2) +#define GTEST_ASSERT_GT(val1, val2) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperGT, val1, val2) + +// Define macro GTEST_DONT_DEFINE_ASSERT_XY to 1 to omit the definition of +// ASSERT_XY(), which clashes with some users' own code. + +#if !GTEST_DONT_DEFINE_ASSERT_EQ +# define ASSERT_EQ(val1, val2) GTEST_ASSERT_EQ(val1, val2) +#endif + +#if !GTEST_DONT_DEFINE_ASSERT_NE +# define ASSERT_NE(val1, val2) GTEST_ASSERT_NE(val1, val2) +#endif + +#if !GTEST_DONT_DEFINE_ASSERT_LE +# define ASSERT_LE(val1, val2) GTEST_ASSERT_LE(val1, val2) +#endif + +#if !GTEST_DONT_DEFINE_ASSERT_LT +# define ASSERT_LT(val1, val2) GTEST_ASSERT_LT(val1, val2) +#endif + +#if !GTEST_DONT_DEFINE_ASSERT_GE +# define ASSERT_GE(val1, val2) GTEST_ASSERT_GE(val1, val2) +#endif + +#if !GTEST_DONT_DEFINE_ASSERT_GT +# define ASSERT_GT(val1, val2) GTEST_ASSERT_GT(val1, val2) +#endif + +// C String Comparisons. All tests treat NULL and any non-NULL string +// as different. Two NULLs are equal. +// +// * {ASSERT|EXPECT}_STREQ(s1, s2): Tests that s1 == s2 +// * {ASSERT|EXPECT}_STRNE(s1, s2): Tests that s1 != s2 +// * {ASSERT|EXPECT}_STRCASEEQ(s1, s2): Tests that s1 == s2, ignoring case +// * {ASSERT|EXPECT}_STRCASENE(s1, s2): Tests that s1 != s2, ignoring case +// +// For wide or narrow string objects, you can use the +// {ASSERT|EXPECT}_??() macros. +// +// Don't depend on the order in which the arguments are evaluated, +// which is undefined. +// +// These macros evaluate their arguments exactly once. + +#define EXPECT_STREQ(expected, actual) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperSTREQ, expected, actual) +#define EXPECT_STRNE(s1, s2) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperSTRNE, s1, s2) +#define EXPECT_STRCASEEQ(expected, actual) \ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperSTRCASEEQ, expected, actual) +#define EXPECT_STRCASENE(s1, s2)\ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperSTRCASENE, s1, s2) + +#define ASSERT_STREQ(expected, actual) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperSTREQ, expected, actual) +#define ASSERT_STRNE(s1, s2) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperSTRNE, s1, s2) +#define ASSERT_STRCASEEQ(expected, actual) \ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperSTRCASEEQ, expected, actual) +#define ASSERT_STRCASENE(s1, s2)\ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperSTRCASENE, s1, s2) + +// Macros for comparing floating-point numbers. +// +// * {ASSERT|EXPECT}_FLOAT_EQ(expected, actual): +// Tests that two float values are almost equal. +// * {ASSERT|EXPECT}_DOUBLE_EQ(expected, actual): +// Tests that two double values are almost equal. +// * {ASSERT|EXPECT}_NEAR(v1, v2, abs_error): +// Tests that v1 and v2 are within the given distance to each other. +// +// Google Test uses ULP-based comparison to automatically pick a default +// error bound that is appropriate for the operands. See the +// FloatingPoint template class in gtest-internal.h if you are +// interested in the implementation details. + +#define EXPECT_FLOAT_EQ(expected, actual)\ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ + expected, actual) + +#define EXPECT_DOUBLE_EQ(expected, actual)\ + EXPECT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ + expected, actual) + +#define ASSERT_FLOAT_EQ(expected, actual)\ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ + expected, actual) + +#define ASSERT_DOUBLE_EQ(expected, actual)\ + ASSERT_PRED_FORMAT2(::testing::internal::CmpHelperFloatingPointEQ, \ + expected, actual) + +#define EXPECT_NEAR(val1, val2, abs_error)\ + EXPECT_PRED_FORMAT3(::testing::internal::DoubleNearPredFormat, \ + val1, val2, abs_error) + +#define ASSERT_NEAR(val1, val2, abs_error)\ + ASSERT_PRED_FORMAT3(::testing::internal::DoubleNearPredFormat, \ + val1, val2, abs_error) + +// These predicate format functions work on floating-point values, and +// can be used in {ASSERT|EXPECT}_PRED_FORMAT2*(), e.g. +// +// EXPECT_PRED_FORMAT2(testing::DoubleLE, Foo(), 5.0); + +// Asserts that val1 is less than, or almost equal to, val2. Fails +// otherwise. In particular, it fails if either val1 or val2 is NaN. +GTEST_API_ AssertionResult FloatLE(const char* expr1, const char* expr2, + float val1, float val2); +GTEST_API_ AssertionResult DoubleLE(const char* expr1, const char* expr2, + double val1, double val2); + + +#if GTEST_OS_WINDOWS + +// Macros that test for HRESULT failure and success, these are only useful +// on Windows, and rely on Windows SDK macros and APIs to compile. +// +// * {ASSERT|EXPECT}_HRESULT_{SUCCEEDED|FAILED}(expr) +// +// When expr unexpectedly fails or succeeds, Google Test prints the +// expected result and the actual result with both a human-readable +// string representation of the error, if available, as well as the +// hex result code. +# define EXPECT_HRESULT_SUCCEEDED(expr) \ + EXPECT_PRED_FORMAT1(::testing::internal::IsHRESULTSuccess, (expr)) + +# define ASSERT_HRESULT_SUCCEEDED(expr) \ + ASSERT_PRED_FORMAT1(::testing::internal::IsHRESULTSuccess, (expr)) + +# define EXPECT_HRESULT_FAILED(expr) \ + EXPECT_PRED_FORMAT1(::testing::internal::IsHRESULTFailure, (expr)) + +# define ASSERT_HRESULT_FAILED(expr) \ + ASSERT_PRED_FORMAT1(::testing::internal::IsHRESULTFailure, (expr)) + +#endif // GTEST_OS_WINDOWS + +// Macros that execute statement and check that it doesn't generate new fatal +// failures in the current thread. +// +// * {ASSERT|EXPECT}_NO_FATAL_FAILURE(statement); +// +// Examples: +// +// EXPECT_NO_FATAL_FAILURE(Process()); +// ASSERT_NO_FATAL_FAILURE(Process()) << "Process() failed"; +// +#define ASSERT_NO_FATAL_FAILURE(statement) \ + GTEST_TEST_NO_FATAL_FAILURE_(statement, GTEST_FATAL_FAILURE_) +#define EXPECT_NO_FATAL_FAILURE(statement) \ + GTEST_TEST_NO_FATAL_FAILURE_(statement, GTEST_NONFATAL_FAILURE_) + +// Causes a trace (including the source file path, the current line +// number, and the given message) to be included in every test failure +// message generated by code in the current scope. The effect is +// undone when the control leaves the current scope. +// +// The message argument can be anything streamable to std::ostream. +// +// In the implementation, we include the current line number as part +// of the dummy variable name, thus allowing multiple SCOPED_TRACE()s +// to appear in the same block - as long as they are on different +// lines. +#define SCOPED_TRACE(message) \ + ::testing::internal::ScopedTrace GTEST_CONCAT_TOKEN_(gtest_trace_, __LINE__)(\ + __FILE__, __LINE__, ::testing::Message() << (message)) + +// Compile-time assertion for type equality. +// StaticAssertTypeEq() compiles iff type1 and type2 are +// the same type. The value it returns is not interesting. +// +// Instead of making StaticAssertTypeEq a class template, we make it a +// function template that invokes a helper class template. This +// prevents a user from misusing StaticAssertTypeEq by +// defining objects of that type. +// +// CAVEAT: +// +// When used inside a method of a class template, +// StaticAssertTypeEq() is effective ONLY IF the method is +// instantiated. For example, given: +// +// template class Foo { +// public: +// void Bar() { testing::StaticAssertTypeEq(); } +// }; +// +// the code: +// +// void Test1() { Foo foo; } +// +// will NOT generate a compiler error, as Foo::Bar() is never +// actually instantiated. Instead, you need: +// +// void Test2() { Foo foo; foo.Bar(); } +// +// to cause a compiler error. +template +bool StaticAssertTypeEq() { + (void)internal::StaticAssertTypeEqHelper(); + return true; +} + +// Defines a test. +// +// The first parameter is the name of the test case, and the second +// parameter is the name of the test within the test case. +// +// The convention is to end the test case name with "Test". For +// example, a test case for the Foo class can be named FooTest. +// +// The user should put his test code between braces after using this +// macro. Example: +// +// TEST(FooTest, InitializesCorrectly) { +// Foo foo; +// EXPECT_TRUE(foo.StatusIsOK()); +// } + +// Note that we call GetTestTypeId() instead of GetTypeId< +// ::testing::Test>() here to get the type ID of testing::Test. This +// is to work around a suspected linker bug when using Google Test as +// a framework on Mac OS X. The bug causes GetTypeId< +// ::testing::Test>() to return different values depending on whether +// the call is from the Google Test framework itself or from user test +// code. GetTestTypeId() is guaranteed to always return the same +// value, as it always calls GetTypeId<>() from the Google Test +// framework. +#define GTEST_TEST(test_case_name, test_name)\ + GTEST_TEST_(test_case_name, test_name, \ + ::testing::Test, ::testing::internal::GetTestTypeId()) + +// Define this macro to 1 to omit the definition of TEST(), which +// is a generic name and clashes with some other libraries. +#if !GTEST_DONT_DEFINE_TEST +# define TEST(test_case_name, test_name) GTEST_TEST(test_case_name, test_name) +#endif + +// Defines a test that uses a test fixture. +// +// The first parameter is the name of the test fixture class, which +// also doubles as the test case name. The second parameter is the +// name of the test within the test case. +// +// A test fixture class must be declared earlier. The user should put +// his test code between braces after using this macro. Example: +// +// class FooTest : public testing::Test { +// protected: +// virtual void SetUp() { b_.AddElement(3); } +// +// Foo a_; +// Foo b_; +// }; +// +// TEST_F(FooTest, InitializesCorrectly) { +// EXPECT_TRUE(a_.StatusIsOK()); +// } +// +// TEST_F(FooTest, ReturnsElementCountCorrectly) { +// EXPECT_EQ(0, a_.size()); +// EXPECT_EQ(1, b_.size()); +// } + +#define TEST_F(test_fixture, test_name)\ + GTEST_TEST_(test_fixture, test_name, test_fixture, \ + ::testing::internal::GetTypeId()) + +// Use this macro in main() to run all tests. It returns 0 if all +// tests are successful, or 1 otherwise. +// +// RUN_ALL_TESTS() should be invoked after the command line has been +// parsed by InitGoogleTest(). + +#define RUN_ALL_TESTS()\ + (::testing::UnitTest::GetInstance()->Run()) + +} // namespace testing + +#endif // GTEST_INCLUDE_GTEST_GTEST_H_ diff --git a/gtest/gtest_main.cc b/gtest/gtest_main.cc new file mode 100644 index 00000000000..0f829af8dfc --- /dev/null +++ b/gtest/gtest_main.cc @@ -0,0 +1,44 @@ +// Copyright 2006, Google Inc. +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include + +#include "glog/logging.h" +#include "gflags/gflags.h" +#include "gtest/gtest.h" + +DEFINE_string(caffe_test_root, "gen/", "The root of the caffe test folder."); + +GTEST_API_ int main(int argc, char **argv) { + // std::cout << "Running main() from gtest_main.cc\n"; + testing::InitGoogleTest(&argc, argv); + google::ParseCommandLineFlags(&argc, &argv, true); + google::InitGoogleLogging(argv[0]); + return RUN_ALL_TESTS(); +} diff --git a/pycaffe2/BREW b/pycaffe2/BREW new file mode 100644 index 00000000000..2a3819b0874 --- /dev/null +++ b/pycaffe2/BREW @@ -0,0 +1,59 @@ +cc_library( + name = "caffe2_python", + srcs = ["caffe2_python.cc"], + deps = [ + "//caffe2/core:core", + "//caffe2/db:db", + "//caffe2/operators:core_ops", + "//caffe2/operators:core_ops_cudnn", + "//caffe2/operators:core_ops_gpu", + "//caffe2/image:image_ops", + "//caffe2/image:image_ops_gpu", + "//caffe2/sgd:sgd_ops", + "//caffe2/sgd:sgd_ops_gpu", + ], + external_libs = Env.PYTHON_LIBS, + shared = True, +) + +py_library( + name = "pycaffe2", + srcs = [ + "__init__.py", + "caffe_translator.py", + "core.py", + "core_gradients.py", + "device_checker.py", + "gradient_checker.py", + "net_drawer.py", + "utils.py", + "visualize.py", + "workspace.py", + ], + deps = [ + ":caffe2_python", + "//caffe/proto:caffe_proto_py", + "//caffe2/proto:caffe2_proto_py", + "//pycaffe2/mint:mint", + ], +) + +py_test( + name = "workspace_test", + srcs = [ + "workspace_test.py", + ], + deps = [ + ":pycaffe2", + ], +) + +py_test( + name = "caffe_translator_test", + srcs = [ + "caffe_translator_test.py", + ], + deps = [ + ":pycaffe2", + ], +) diff --git a/pycaffe2/__init__.py b/pycaffe2/__init__.py new file mode 100644 index 00000000000..2ca08c5c340 --- /dev/null +++ b/pycaffe2/__init__.py @@ -0,0 +1,4 @@ +import atexit + +from . import core, core_gradients, utils, visualize, workspace +from caffe2.proto import caffe2_pb2 diff --git a/pycaffe2/caffe2_python.cc b/pycaffe2/caffe2_python.cc new file mode 100644 index 00000000000..35b7606f20e --- /dev/null +++ b/pycaffe2/caffe2_python.cc @@ -0,0 +1,453 @@ +#include +#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION +#include + +#include +#include +#include +#include + +#include "caffe2/core/context.h" +#include "caffe2/core/context_gpu.h" +#include "caffe2/core/net.h" +#include "caffe2/core/workspace.h" +#include "caffe2/proto/caffe2.pb.h" +#include "glog/logging.h" + +using std::map; +using std::string; +using std::unique_ptr; +using std::vector; +using namespace caffe2; // NOLINT + +// gWorkspaces allows us to define and switch between multiple workspaces in +// Python. +static map > gWorkspaces; +// gWorkspace is the pointer to the current workspace. The ownership is kept +// by the gWorkspaces map. +static Workspace* gWorkspace = nullptr; +static string gCurrentWorkspaceName; + +namespace { + +bool SwitchWorkspaceInternal(const string& name, const bool create_if_missing) { + if (gWorkspaces.count(name)) { + gCurrentWorkspaceName = name; + gWorkspace = gWorkspaces[name].get(); + return true; + } else if (create_if_missing) { + std::unique_ptr new_workspace(new Workspace()); + gWorkspace = new_workspace.get(); + gWorkspaces.insert(std::make_pair(name, std::move(new_workspace))); + gCurrentWorkspaceName = name; + return true; + } else { + return false; + } +} + +inline string PyStringToStdString(PyObject* pystring) { + return string(PyString_AsString(pystring), PyString_Size(pystring)); +} + +inline PyObject* StdStringToPyString(const string& str) { + return PyString_FromStringAndSize(str.c_str(), str.size()); +} + +static_assert(sizeof(int) == sizeof(int32_t), + "Yangqing made a loose assumption that int will always be int32 " + "for numpy type mapping"); + +template struct NumpyTypeWrapper; +template<> struct NumpyTypeWrapper { + static const int type = NPY_FLOAT; +}; +template<> struct NumpyTypeWrapper { + static const int type = NPY_INT32; +}; + +template +PyObject* FetchTensor(const Blob& blob) { + DeviceContext context; + const Tensor& tensor = + blob.Get >(); + CHECK_GT(tensor.size(), 0); + // numpy requires long int as its dims. + vector npy_dims; // NOLINT + for (const int dim : tensor.dims()) { + npy_dims.push_back(dim); + } + PyObject* array = PyArray_SimpleNew( + tensor.ndim(), npy_dims.data(), NumpyTypeWrapper::type); + // Now, copy the data to the tensor. + // TODO(Yangqing): Is there an easier way to convert PyObject to + // PyArrayObject? + context.template Copy( + static_cast(PyArray_DATA(reinterpret_cast(array))), + tensor.data(), tensor.size()); + return array; +} + +template +PyObject* FeedTensor(const DeviceOption& option, PyArrayObject* original_array, + Blob* blob) { + PyArrayObject* array = PyArray_GETCONTIGUOUS(original_array); + DeviceContext context(option); + Tensor* tensor = + blob->GetMutable >(); + // numpy requires long int as its dims. + int ndim = PyArray_NDIM(array); + npy_intp* npy_dims = PyArray_DIMS(array); + vector dims; + for (int i = 0; i < ndim; ++i) { + dims.push_back(npy_dims[i]); + } + tensor->Reshape(dims); + // Now, copy the data to the tensor. + context.template Copy( + tensor->mutable_data(), + static_cast(PyArray_DATA(array)), + tensor->size()); + Py_XDECREF(array); + Py_RETURN_TRUE; +} + +} // namespace + +extern "C" { + +PyObject* SwitchWorkspace(PyObject* self, PyObject* args) { + PyObject* name = nullptr; + PyObject* create_if_missing = nullptr; + if (!PyArg_ParseTuple(args, "S|O", &name, &create_if_missing)) { + PyErr_SetString(PyExc_ValueError, + "SwitchWorkspace takes in a workspace name, and " + "an optional boolean value that specifies whether " + "we want to create the workspace if it is missing."); + return NULL; + } + bool success = SwitchWorkspaceInternal( + PyStringToStdString(name), + (create_if_missing != nullptr) && PyObject_IsTrue(create_if_missing)); + if (!success) { + PyErr_SetString( + PyExc_RuntimeError, + "Workspace of the given name does not exist, and I am not instructed " + "to create it either."); + return NULL; + } + Py_RETURN_TRUE; +} + +PyObject* CurrentWorkspace(PyObject* self, PyObject* args) { + return StdStringToPyString(gCurrentWorkspaceName); +} + +PyObject* Workspaces(PyObject* self, PyObject* args) { + PyObject* list = PyList_New(gWorkspaces.size()); + int i = 0; + for (auto const & it : gWorkspaces) { + CHECK_EQ(PyList_SetItem(list, i, StdStringToPyString(it.first)), 0); + i += 1; + } + return list; +} + +PyObject* ResetWorkspace(PyObject* self, PyObject* args) { + PyObject* root_folder = nullptr; + if (!PyArg_ParseTuple(args, "|S", &root_folder)) { + PyErr_SetString(PyExc_ValueError, + "ResetWorkspace takes in either no argument, or a string " + "specifying the root folder of the workspace."); + return NULL; + } + LOG(INFO) << "Resetting workspace."; + if (root_folder == nullptr) { + gWorkspaces[gCurrentWorkspaceName].reset( + new Workspace()); + } else { + gWorkspaces[gCurrentWorkspaceName].reset( + new Workspace(PyStringToStdString(root_folder))); + } + gWorkspace = gWorkspaces[gCurrentWorkspaceName].get(); + Py_RETURN_TRUE; +} + +PyObject* RootFolder(PyObject* self, PyObject* args) { + return StdStringToPyString(gWorkspace->RootFolder()); +} + +// This function should not be called by the user - only used during the +// destruction of the module. +PyObject* OnModuleExit(PyObject* self, PyObject* args) { + gWorkspaces.clear(); + Py_RETURN_TRUE; +} + +PyObject* Blobs(PyObject* self, PyObject* args) { + vector blob_strings = gWorkspace->Blobs(); + PyObject* list = PyList_New(blob_strings.size()); + for (int i = 0; i < blob_strings.size(); ++i) { + CHECK_EQ(PyList_SetItem(list, i, StdStringToPyString(blob_strings[i])), 0); + } + return list; +} + +PyObject* HasBlob(PyObject* self, PyObject* args) { + char* name; + if (!PyArg_ParseTuple(args, "s", &name)) { + return NULL; + } + if (gWorkspace->HasBlob(string(name))) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } +} + +PyObject* CreateNet(PyObject* self, PyObject* args) { + PyObject* proto_string; + if (!PyArg_ParseTuple(args, "S", &proto_string)) { + return NULL; + } + caffe2::NetDef proto; + if (!proto.ParseFromString(PyStringToStdString(proto_string))) { + PyErr_SetString(PyExc_ValueError, "Cannot parse input net string."); + return NULL; + } + if (!gWorkspace->CreateNet(proto)) { + PyErr_SetString( + PyExc_RuntimeError, + "Cannot create network. See console log for error messages."); + return NULL; + } + Py_RETURN_TRUE; +} + +PyObject* RunNet(PyObject* self, PyObject* args) { + char* name; + if (!PyArg_ParseTuple(args, "s", &name)) { + PyErr_SetString(PyExc_ValueError, + "Incorrect argument. Must pass in a single string."); + return NULL; + } + if (!gWorkspace->RunNet(string(name))) { + PyErr_SetString( + PyExc_RuntimeError, + "Cannot run network. See console log for error messages."); + return NULL; + } + Py_RETURN_TRUE; +} + +PyObject* DeleteNet(PyObject* self, PyObject* args) { + char* name; + if (!PyArg_ParseTuple(args, "s", &name)) { + PyErr_SetString(PyExc_ValueError, + "Incorrect argument. Must pass in a single string."); + return NULL; + } + gWorkspace->DeleteNet(string(name)); + Py_RETURN_TRUE; +} + +PyObject* Nets(PyObject* self, PyObject* args) { + vector net_strings = gWorkspace->Nets(); + PyObject* list = PyList_New(net_strings.size()); + for (int i = 0; i < net_strings.size(); ++i) { + CHECK_EQ(PyList_SetItem(list, i, StdStringToPyString(net_strings[i])), 0); + } + return list; +} + +PyObject* RunOperatorOnce(PyObject* self, PyObject* args) { + PyObject* proto_string; + if (!PyArg_ParseTuple(args, "S", &proto_string)) { + PyErr_SetString(PyExc_ValueError, + "Incorrect argument. Must pass in a single string."); + return NULL; + } + caffe2::OperatorDef proto; + if (!proto.ParseFromString(PyStringToStdString(proto_string))) { + PyErr_SetString(PyExc_ValueError, "Cannot parse input operator proto."); + return NULL; + } + if (!gWorkspace->RunOperatorOnce(proto)) { + PyErr_SetString( + PyExc_RuntimeError, + "Cannot run operator. See console log for error messages."); + return NULL; + } + Py_RETURN_TRUE; +} + +PyObject* RunNetOnce(PyObject* self, PyObject* args) { + PyObject* proto_string; + if (!PyArg_ParseTuple(args, "S", &proto_string)) { + PyErr_SetString(PyExc_ValueError, + "Incorrect argument. Must pass in a single string."); + return NULL; + } + caffe2::NetDef proto; + if (!proto.ParseFromString(PyStringToStdString(proto_string))) { + PyErr_SetString(PyExc_ValueError, "Cannot parse input net proto."); + return NULL; + } + if (!gWorkspace->RunNetOnce(proto)) { + PyErr_SetString( + PyExc_RuntimeError, + "Cannot run net. See console log for error messages."); + return NULL; + } + Py_RETURN_TRUE; +} + +PyObject* RunPlan(PyObject* self, PyObject* args) { + PyObject* proto_string; + if (!PyArg_ParseTuple(args, "S", &proto_string)) { + PyErr_SetString(PyExc_ValueError, + "Incorrect argument. Must pass in a single string."); + return NULL; + } + caffe2::PlanDef proto; + if (!proto.ParseFromString(PyStringToStdString(proto_string))) { + PyErr_SetString(PyExc_ValueError, "Cannot parse input plan proto."); + return NULL; + } + if (!gWorkspace->RunPlan(proto)) { + PyErr_SetString( + PyExc_RuntimeError, + "Cannot run plan. See console log for error messages."); + return NULL; + } + Py_RETURN_TRUE; +} + +PyObject* CreateBlob(PyObject* self, PyObject* args) { + char* name_char; + if (!PyArg_ParseTuple(args, "s", &name_char)) { + PyErr_SetString(PyExc_ValueError, "Incorrect arguments."); + return NULL; + } + string name(name_char); + Blob* blob = gWorkspace->CreateBlob(name); + Py_RETURN_TRUE; +} + +#define RETURN_TENSOR_IF_FORMAT(dtype, context) \ + if (blob.IsType >()) { \ + return FetchTensor(blob); \ + } + +PyObject* FetchBlob(PyObject* self, PyObject* args) { + char* name; + if (!PyArg_ParseTuple(args, "s", &name)) { + PyErr_SetString(PyExc_ValueError, "Incorrect arguments."); + return NULL; + } + if (!gWorkspace->HasBlob(string(name))) { + PyErr_SetString(PyExc_ValueError, "Requested blob does not exist."); + return NULL; + } + const caffe2::Blob& blob = *(gWorkspace->GetBlob(string(name))); + // We only support a subset of exporting capabilities. + RETURN_TENSOR_IF_FORMAT(float, CPUContext) + RETURN_TENSOR_IF_FORMAT(int, CPUContext) + RETURN_TENSOR_IF_FORMAT(float, CUDAContext) + RETURN_TENSOR_IF_FORMAT(int, CUDAContext) + // If all branches failed, we should throw an error. + LOG(ERROR) << "Blob" << string(name) << " has unsupported data type: " + << blob.TypeName(); + PyErr_SetString(PyExc_TypeError, "Unsupported data type."); + return NULL; +} + + +PyObject* FeedBlob(PyObject* self, PyObject* args) { + char* name_char; + PyArrayObject* array = nullptr; + PyObject* device_option_string = nullptr; + if (!PyArg_ParseTuple(args, "sO!|O", &name_char, &PyArray_Type, &array, + &device_option_string)) { + PyErr_SetString(PyExc_ValueError, "Incorrect arguments."); + return NULL; + } + string name(name_char); + DeviceOption option; + if (device_option_string != nullptr) { + // If we have a device option passed in, read it. + if (!option.ParseFromString(PyStringToStdString(device_option_string))) { + PyErr_SetString(PyExc_ValueError, "Cannot parse device option string."); + return NULL; + } + } + Blob* blob = gWorkspace->CreateBlob(name); + int data_type = PyArray_TYPE(array); + + // Since there is really no polymorphism, we will have to do so... + switch (option.device_type()) { + case CPU: + switch (data_type) { + case NPY_INT: + return FeedTensor(option, array, blob); + case NPY_FLOAT: + return FeedTensor(option, array, blob); + default: + PyErr_SetString(PyExc_TypeError, "Unsupported numpy data type."); + return NULL; + } + case CUDA: + switch (data_type) { + case NPY_INT: + return FeedTensor(option, array, blob); + case NPY_FLOAT: + return FeedTensor(option, array, blob); + default: + PyErr_SetString(PyExc_TypeError, "Unsupported numpy data type."); + return NULL; + } + default: + PyErr_SetString(PyExc_TypeError, "Unknown device type."); + return NULL; + } +} + +// A simple macro to avoid writing repeated symbols. +#define _PYNAME(name) {#name, name, METH_VARARGS} + +static PyMethodDef gPycaffe2Methods[] = { + // TODO(Yangqing): write the methods string. + // Note(Yangqing): For any function that we are going to override in the + // python file, we prepend "cc_" here. + _PYNAME(SwitchWorkspace), + _PYNAME(CurrentWorkspace), + _PYNAME(Workspaces), + _PYNAME(ResetWorkspace), + _PYNAME(RootFolder), + _PYNAME(OnModuleExit), + _PYNAME(Blobs), + _PYNAME(HasBlob), + {"cc_CreateNet", CreateNet, METH_VARARGS}, + _PYNAME(RunNet), + _PYNAME(DeleteNet), + _PYNAME(Nets), + {"cc_RunOperatorOnce", RunOperatorOnce, METH_VARARGS}, + {"cc_RunNetOnce", RunNetOnce, METH_VARARGS}, + {"cc_RunPlan", RunPlan, METH_VARARGS}, + _PYNAME(CreateBlob), + _PYNAME(FetchBlob), + {"cc_FeedBlob", FeedBlob, METH_VARARGS}, + {NULL, NULL}, // end of python methods. +}; +#undef _PYNAME + +void initlibcaffe2_python(void) { + (void) Py_InitModule("libcaffe2_python", gPycaffe2Methods); + import_array(); // for numpy + // We will create a default workspace for us to run stuff. + SwitchWorkspaceInternal("default", true); + gCurrentWorkspaceName = "default"; +} + +} // extern "C" + diff --git a/pycaffe2/caffe_translator.py b/pycaffe2/caffe_translator.py new file mode 100644 index 00000000000..4f956b014a4 --- /dev/null +++ b/pycaffe2/caffe_translator.py @@ -0,0 +1,184 @@ +from caffe2.proto import caffe2_pb2 +from caffe.proto import caffe_pb2 +from google.protobuf import text_format +import numpy as np +from pycaffe2 import utils + +MODE_TRAIN = 0 +MODE_TEST = 1 +__TRANSLATE_MODE__ = MODE_TRAIN + +def SetTranslateMode(mode): + global __TRANSLATE_MODE__ + __TRANSLATE_MODE__ = mode + +def IsTraining(): + return (__TRANSLATE_MODE__ == MODE_TRAIN) + +def IsTesting(): + return (__TRANSLATE_MODE__ == MODE_TEST) + + +class CacaRegistry(object): + registry_ = {} + + @classmethod + def Register(cls, op_name): + """A decorator for registering gradient mappings.""" + def Wrapper(func): + cls.registry_[op_name] = func + return func + return Wrapper + + @classmethod + def TranslateLayer(cls, layer, pretrained_blobs): + try: + caffe_ops, params = cls.registry_[layer.type](layer, pretrained_blobs) + except KeyError as err: + raise KeyError('No translator registered for layer: %s' % str(layer)) + if caffe_ops is None: + return [] + if type(caffe_ops) is not list: + caffe_ops = [caffe_ops] + return caffe_ops, params + + @classmethod + def TranslateModel(cls, caffe_net, pretrained_net): + net = caffe2_pb2.NetDef() + net.name = caffe_net.name + net_params = [] + if len(caffe_net.layer) == 0: + raise ValueError('I think something is wrong. This translation script ' + 'only accepts new style layers that are stored in the ' + 'layer field.') + for layer in caffe_net.layer: + print 'Translate layer', layer.name + # Get pretrained one + pretrained_layers = ( + [l for l in pretrained_net.layer if l.name == layer.name] + + [l for l in pretrained_net.layers if l.name == layer.name]) + if len(pretrained_layers) > 1: + raise ValueError('huh? more than one pretrained layer of one name?') + elif len(pretrained_layers) == 1: + pretrained_blobs = [utils.CaffeBlobToNumpyArray(blob) + for blob in pretrained_layers[0].blobs] + else: + # No pretrained layer for the given layer name. We'll just pass no + # parameter blobs. + # print 'No pretrained layer for layer', layer.name + pretrained_blobs = [] + operators, params = cls.TranslateLayer(layer, pretrained_blobs) + net.operators.extend(operators) + net_params.extend(params) + return net, net_params + + +def TranslateModel(caffe_net, pretrained_net): + return CacaRegistry.TranslateModel(caffe_net, pretrained_net) + + +def BaseTranslate(layer, caffe2_type): + caffe2_op = caffe2_pb2.OperatorDef() + caffe2_op.type = caffe2_type + caffe2_op.inputs.extend(layer.bottom) + caffe2_op.outputs.extend(layer.top) + return caffe2_op + + +def AddArgument(op, key, value): + """Makes an argument based on the value type.""" + op.args.extend([utils.MakeArgument(key, value)]) + + +################################################################################ +# Common translators for layers. +################################################################################ + +@CacaRegistry.Register("Convolution") +def TranslateConv(layer, pretrained_blobs): + caffe_op = BaseTranslate(layer, "Conv") + output = caffe_op.outputs[0] + caffe_op.inputs.extend([output + '_w', output + '_b']) + param = layer.convolution_param + AddArgument(caffe_op, "stride", param.stride) + AddArgument(caffe_op, "kernel", param.kernel_size) + AddArgument(caffe_op, "pad", param.pad) + AddArgument(caffe_op, "order", "NCHW") + if param.group > 1: + # Now, if the model is grouped convolution, let's do a backward hack and make + # things working but in an efficient way by inserting zero parameters. + n, c, h, w = pretrained_blobs[0].shape + g = param.group + og = int(n / g) + if (og * g != n): + raise ValueError("This should not happen") + weight = np.zeros((n, c * g, h, w), dtype=np.float32) + for i in range(param.group): + weight[i * og : (i + 1) * og, i * c : (i+1) * c, :, :] = pretrained_blobs[0][i * og : (i + 1) * og] + else: + weight = pretrained_blobs[0] + weight = utils.NumpyArrayToCaffe2Tensor(weight, output + '_w') + bias = utils.NumpyArrayToCaffe2Tensor( + pretrained_blobs[1].flatten(), output + '_b') + # Todo: deal with parameters. + return caffe_op, [weight, bias] + +@CacaRegistry.Register("ReLU") +def TranslateRelu(layer, pretrained_blobs): + return BaseTranslate(layer, "Relu"), [] + +@CacaRegistry.Register("Pooling") +def TranslatePool(layer, pretrained_blobs): + param = layer.pooling_param + if param.pool == caffe_pb2.PoolingParameter.MAX: + caffe_op = BaseTranslate(layer, "MaxPool") + caffe_op.outputs.extend(['_' + caffe_op.outputs[0] + '_maxid']) + elif param.pool == caffe_pb2.PoolingParameter.AVE: + caffe_op = BaseTranslate(layer, "AveragePool") + AddArgument(caffe_op, "stride", int(param.stride)) + AddArgument(caffe_op, "kernel", int(param.kernel_size)) + AddArgument(caffe_op, "pad", int(param.pad)) + AddArgument(caffe_op, "order", "NCHW") + return caffe_op, [] + +@CacaRegistry.Register("LRN") +def TranslateLRN(layer, pretrained_blobs): + caffe_op = BaseTranslate(layer, "LRN") + caffe_op.outputs.extend(['_' + caffe_op.outputs[0] + '_scale']) + param = layer.lrn_param + if param.norm_region != caffe_pb2.LRNParameter.ACROSS_CHANNELS: + raise ValueError("Does not support norm region other than across channels.") + AddArgument(caffe_op, "size", int(param.local_size)) + AddArgument(caffe_op, "alpha", float(param.alpha)) + AddArgument(caffe_op, "beta", float(param.beta)) + AddArgument(caffe_op, "bias", float(param.k)) + AddArgument(caffe_op, "order", "NCHW") + return caffe_op, [] + +@CacaRegistry.Register("InnerProduct") +def TranslateInnerProduct(layer, pretrained_blobs): + caffe_op = BaseTranslate(layer, "FC") + output = caffe_op.outputs[0] + caffe_op.inputs.extend([output + '_w', output + '_b']) + weight = utils.NumpyArrayToCaffe2Tensor( + pretrained_blobs[0][0,0], output + '_w') + bias = utils.NumpyArrayToCaffe2Tensor( + pretrained_blobs[1].flatten(), output + '_b') + return caffe_op, [weight, bias] + +@CacaRegistry.Register("Dropout") +def TranslateDropout(layer, pretrained_blobs): + if IsTraining(): + caffe_op = BaseTranslate(layer, "Dropout") + caffe_op.outputs.extend(['_' + caffe_op.outputs[0] + '_mask']) + param = layer.dropout_param + AddArgument(caffe_op, "ratio", param.dropout_ratio) + return caffe_op, [] + else: + return BaseTranslate(layer, "Alias"), [] + + +@CacaRegistry.Register("Softmax") +def TranslateSoftmax(layer, pretrained_blobs): + caffe_op = BaseTranslate(layer, "Softmax") + return caffe_op, [] diff --git a/pycaffe2/caffe_translator_test.py b/pycaffe2/caffe_translator_test.py new file mode 100644 index 00000000000..e9c0b2b12e7 --- /dev/null +++ b/pycaffe2/caffe_translator_test.py @@ -0,0 +1,56 @@ +# This a large test that goes through the translation of the bvlc caffenet +# model, runs an example through the whole model, and verifies numerically +# that all the results look right. In default, it is disabled unless you +# explicitly want to run it. + +from caffe2.proto import caffe2_pb2 +from caffe.proto import caffe_pb2 +from google.protobuf import text_format +import numpy as np +import os +from pycaffe2 import caffe_translator, utils, workspace +import sys +import unittest + +class TestNumericalEquivalence(unittest.TestCase): + def testBlobs(self): + names = ["conv1", "pool1", "norm1", "conv2", "pool2", "norm2", "conv3", + "conv4", "conv5", "pool5", "fc6", "fc7", "fc8", "prob"] + for name in names: + print 'Verifying ', name + caffe2_result = workspace.FetchBlob(name) + reference = np.load( + 'data/testdata/caffe_translator/' + name + '_dump.npy') + self.assertEqual(caffe2_result.shape, reference.shape) + scale = np.max(caffe2_result) + np.testing.assert_almost_equal(caffe2_result / scale, reference / scale, + decimal=5) + +if __name__ == '__main__': + if len(sys.argv) == 1: + print ('If you do not explicitly ask to run this test, I will not run it. ' + 'Pass in any argument to have the test run for you.') + sys.exit(0) + if not os.path.exists('data/testdata/caffe_translator'): + print 'No testdata existing for the caffe translator test. Exiting.' + sys.exit(0) + # We will do all the computation stuff in the global space. + caffenet = caffe_pb2.NetParameter() + caffenet_pretrained = caffe_pb2.NetParameter() + text_format.Merge(open('data/testdata/caffe_translator/deploy.prototxt').read(), + caffenet) + caffenet_pretrained.ParseFromString( + open('data/testdata/caffe_translator/bvlc_reference_caffenet.caffemodel') + .read()) + caffe_translator.SetTranslateMode(caffe_translator.MODE_TEST) + net, pretrained_params = caffe_translator.TranslateModel( + caffenet, caffenet_pretrained) + + for param in pretrained_params: + workspace.FeedBlob(param.name, utils.Caffe2TensorToNumpyArray(param)) + # Let's also feed in the data from the Caffe test code. + data = np.load('data/testdata/caffe_translator/data_dump.npy').astype(np.float32) + workspace.FeedBlob('data', data) + # Actually running the test. + workspace.RunNetOnce(net.SerializeToString()) + unittest.main() \ No newline at end of file diff --git a/pycaffe2/core.py b/pycaffe2/core.py new file mode 100644 index 00000000000..92a95f9504c --- /dev/null +++ b/pycaffe2/core.py @@ -0,0 +1,292 @@ +from caffe2.proto import caffe2_pb2 +from collections import Counter, defaultdict +from pycaffe2 import utils + +def GetGradientName(name): + """The function that returns the gradient name for a blob.""" + return name + '_grad' + +class BlobReference(object): + """A wrapper around a blob in a net. + + BlobReference gives us a way to refer to the network that the blob is + generated from. Note that blobs are, essentially, just strings in the current + workspace. + """ + def __init__(self, name, net): + self._name = name + self._from_net = net + + def __str__(self): + return self._name + + def Net(self): + return self._from_net + + def Grad(self): + return GetGradientName(self._name) + + def __getattr__(self, op_type): + """A wrapper allowing one to initiate operators from a blob reference. + + Example: for a blob reference b that comes from network n, doing + b.Relu(...) + is equivalent to doing + net.Relu([b], ...) + """ + def _CreateAndAddToNet(inputs=[], *args, **kwargs): + """Internal function that routes the operator generation to the network's + __getattr__ function. + """ + # add self to the input list. + inputs.insert(0, self) + return self._from_net.__getattr__(op_type)(inputs, *args, **kwargs) + return _CreateAndAddToNet + +def CreateOperator(operator_type): + """A function wrapper that allows one to create operators based on the + operator type. The type should be a string corresponding to an operator + registered with Caffe2. + """ + def ReallyCreate(inputs, outputs, name='', device_option=None, + args=None, **kwargs): + operator = caffe2_pb2.OperatorDef() + operator.type = operator_type + operator.name = name + if type(inputs) is str or type(inputs) is BlobReference: + inputs = [inputs] + elif type(inputs) is not list: + raise ValueError("Unknown input format: %s." % str(inputs)) + if type(outputs) is str or type(outputs) is BlobReference: + outputs = [outputs] + elif type(outputs) is not list: + raise ValueError("Unknown output format: %s of type %s." + % (str(outputs), type(outputs))) + operator.inputs.extend([str(i) for i in inputs]) + operator.outputs.extend([str(o) for o in outputs]) + if device_option: + operator.device_option.CopyFrom(device_option) + # random seed is defined in the device option, so we need to do special + # care. + if 'random_seed' in kwargs: + operator.device_option.random_seed = kwargs['random_seed'] + del kwargs['random_seed'] + # Add given arguments that do not need parsing + if args: + operator.args.extend(args) + # Add all other arguments + for key, value in kwargs.iteritems(): + operator.args.add().CopyFrom(utils.MakeArgument(key, value)) + return operator + return ReallyCreate + + +class GradientRegistry(object): + """GradientRegistry holds the mapping from operators to their gradients.""" + gradient_registry_ = {} + + @classmethod + def RegisterGradient(cls, op_type): + """A decorator for registering gradient mappings.""" + def Wrapper(func): + cls.gradient_registry_[op_type] = func + return func + return Wrapper + + @classmethod + def GetGradient(cls, op): + try: + gradient_ops = cls.gradient_registry_[op.type](op) + except KeyError as err: + raise KeyError('No gradient registered for op: %s' % op.type) + if gradient_ops is None: + return [] + if type(gradient_ops) is not list: + gradient_ops = [gradient_ops] + if op.HasField("device_option"): + for gradient_op in gradient_ops: + gradient_op.device_option.CopyFrom(op.device_option) + return gradient_ops + + +class Net(object): + operator_registry_ = {} + + def __init__(self, name): + if type(name) is caffe2_pb2.NetDef: + # We rae initializing a network by a NetDef. In this case, we will + # initialize our network with the given netdef. + self._net = caffe2_pb2.NetDef() + self._net.CopyFrom(name) + # Set the next name index properly. + existing_names = set( + sum([list(op.inputs) for op in self._net.operators], []) + + sum([list(op.outputs) for op in self._net.operators], [])) + prefix_len = len(self._net.name + '_blob_') + autogen_indices = [int(name[prefix_len:]) for name in existing_names + if name.startswith(self._net.name + '_blob_')] + if len(autogen_indices): + self._next_name_index = max(autogen_indices) + 1 + else: + self._next_name_index = 0 + else: + self._net = caffe2_pb2.NetDef() + self._net.name = name + self._next_name_index = 0 + + def __str__(self): + return self._net.name + + def Proto(self): + return self._net + + def NextName(self): + """Returns the next name to be used, if you do not want to explicitly + name your blob.""" + output_name = self._net.name + '_blob_' + str(self._next_name_index) + self._next_name_index += 1 + return str(output_name) + + def AddGradientOperators(self, skip=0): + """Add the gradient for operators in the net. + + Inputs: + skip: skips the first n operators. This is provided mainly because a lot + of nets may use the first few operators for data generation like stuff + which really do not need to have gradients. + + Currently, this is hard-coded for float operators if there are branches + (i.e. a blob is used as input to multiple operators). This is because the + inserted SplitOp is hard-coded for float (its gradient, SumOp, is float + only). Supporting other formats is a todo item. + """ + # (1) Make sure that the network is "legal" in terms of computing gradients: + # for every blob there is only going to be one operator that generates it. + all_outputs = sum([list(op.outputs) for op in self._net.operators], []) + if len(all_outputs) != len(set(all_outputs)): + # There is some output that is produced by multiple operators. This is not + # good. + raise RuntimeError("Some blobs are produced multiple times. A count is " + "as follows: " + str(Counter(all_outputs))) + # (2) For cases when a blob is being used by multiple operators, we will + # need to take special care. Currently, we will ask the operators to compute + # the gradients, and add aggregation operators to get the final gradient. + input_counts = Counter( + sum([list(op.inputs) for op in self._net.operators], [])) + multiple_use_blobs = set( + [key for key in input_counts if input_counts[key] > 1]) + if len(multiple_use_blobs): + # There are some blobs that are used multiple times; As a result, we will + # manually insert split operators and make sure that they are correctly + # dealt with. + new_ops = [] + current_input_id = defaultdict(int) + for op in self._net.operators: + # For the input, if it is one of the mutiple use blobs, change it to + # an autosplit version. + for i, name in enumerate(op.inputs): + if name in multiple_use_blobs: + op.inputs[i] = '_' + name + '_autosplit_%d' % current_input_id[name] + current_input_id[name] += 1 + new_ops.append(op) + # For the output, if it is one of the multiple use blobs, we add a split + # operator after it is created. + for name in op.outputs: + if name in multiple_use_blobs: + new_ops.append(CreateOperator("Split")( + [name], + ['_' + name + '_autosplit_%d' % i + for i in range(input_counts[name])])) + # After we create all the new ops, we write them back to the operators + # that the network currently holds. We have to do this instead of + # inserting things midway because protobuf python only supports appending + # to the end. + del self._net.operators[:] + self._net.operators.extend(new_ops) + # (3) Now that the cleaning has been done, we can simply look into the + # gradient registry and add gradient operators. + for i in xrange(len(self._net.operators) - 1, skip - 1, -1): + gradient_ops = GradientRegistry.GetGradient(self._net.operators[i]) + self._net.operators.extend(gradient_ops) + + def RunAllOnGPU(self, gpu_id=0): + """A convenient function to run everything on the GPU.""" + device_option = caffe2_pb2.DeviceOption() + device_option.device_type = caffe2_pb2.CUDA + device_option.cuda_gpu_id = gpu_id + for op in self._net.operators: + op.device_option.CopyFrom(device_option) + + def __getattr__(self, operator_type): + if operator_type in self.__class__.operator_registry_: + # Not finished. Operator registry allows one to define custon functions, + # but so far that functionality is not complete. + return self.__class__.operator_registry_ + def _CreateAndAddToSelf(inputs, outputs=None, **kwargs): + if outputs is None: + # If we do not specify an output, we will assume that this operator + # produces one output in this case. + outputs = self.NextName() + elif type(outputs) is int: + # In this case, we will auto-fill the given number of outputs with + # auto-generated names. + outputs = [self.NextName() for i in range(outputs)] + op = CreateOperator(operator_type)(inputs, outputs, **kwargs) + self._net.operators.extend([op]) + if len(op.outputs) == 0: + return + elif len(op.outputs) == 1: + return BlobReference(str(op.outputs[0]), self) + else: + return tuple(BlobReference(str(o), self) for o in op.outputs) + return _CreateAndAddToSelf + + +class ExecutionStep(object): + def __init__(self, name): + self._step = caffe2_pb2.ExecutionStep() + self._step.name = name + + def __init__(self, name, nets, iterations=None): + self._step = caffe2_pb2.ExecutionStep() + self._step.name = name + if type(nets) is Net: + nets = [nets] + self._step.networks.extend([str(n) for n in nets]) + if iterations: + self._step.iterations = iterations + + def __str__(self): + return self._step.name + + def Proto(self): + return self._step + + def SetIter(self, iterations): + self._step.iterations = iterations + + def AddSubstep(self, substep): + self._step.substeps.add().CopyFrom(substep) + + def AddNet(self, net): + self._step.networks.add(str(net)) + + +class Plan(object): + def __init__(self, name): + self._plan = caffe2_pb2.PlanDef() + self._plan.name = name + + def __str__(self): + return self._plan.name + + def Proto(self): + return self._plan + + def AddNets(self, nets): + for net in nets: + self._plan.networks.add().CopyFrom(net.Proto()) + + def AddStep(self, step): + self._plan.execution_steps.add().CopyFrom(step.Proto()) + diff --git a/pycaffe2/core_gradients.py b/pycaffe2/core_gradients.py new file mode 100644 index 00000000000..783a66df873 --- /dev/null +++ b/pycaffe2/core_gradients.py @@ -0,0 +1,117 @@ +from caffe2.proto import caffe2_pb2 +from pycaffe2.core import * # I know, I know... will fix later + +@GradientRegistry.RegisterGradient('FC') +def AddFCGradient(op): + return CreateOperator('FCGradient')( + list(op.inputs) + [GetGradientName(op.outputs[0])], + [GetGradientName(name) for name in + [op.inputs[1], op.inputs[2], op.inputs[0]]]) + +@GradientRegistry.RegisterGradient('SquaredL2Distance') +def AddSquaredL2DistanceGradient(op): + return CreateOperator('SquaredL2DistanceGradient')( + list(op.inputs) + [GetGradientName(op.outputs[0])], + [GetGradientName(name) for name in op.inputs]) + +@GradientRegistry.RegisterGradient("LabelCrossEntropy") +def AddLabelCrossEntropyGradient(op): + return CreateOperator('LabelCrossEntropyGradient')( + list(op.inputs) + [GetGradientName(op.outputs[0])], + [GetGradientName(op.inputs[0])]) + +@GradientRegistry.RegisterGradient("Softmax") +def AddSoftmaxGradient(op): + return CreateOperator('SoftmaxGradient')( + [op.outputs[0], GetGradientName(op.outputs[0])], + [GetGradientName(op.inputs[0])]) + +@GradientRegistry.RegisterGradient("Flatten") +def AddFlattenGradient(op): + return CreateOperator('ReshapeLike')( + [GetGradientName(op.outputs[0]), op.inputs[0]], + [GetGradientName(op.inputs[0])]) + +@GradientRegistry.RegisterGradient("AveragedLoss") +def CheckAveragedLossNaming(op): + if op.outputs[1] != GetGradientName(op.inputs[0]): + raise ValueError( + "AveragedLoss output[1] should be named as the gradient of input[0]. " + "Please name your output[1] to %s.", GetGradientName(op.inputs[0])) + return + + +@GradientRegistry.RegisterGradient("TensorProtosDBInput") +@GradientRegistry.RegisterGradient("GaussianFill") +def NoGradientToCompute(op): + return + +@GradientRegistry.RegisterGradient("Accuracy") +@GradientRegistry.RegisterGradient("Print") +def UtilityOperatorsShouldNotBeAddedBeforeGradients(op): + raise RuntimeError("Utility operators should be added after you add " + "gradient operators to a net.") + + +@GradientRegistry.RegisterGradient("Relu") +def AddReluGradient(op): + return CreateOperator("ReluGradient")( + [op.inputs[0], GetGradientName(op.outputs[0])], + [GetGradientName(op.inputs[0])]) + +@GradientRegistry.RegisterGradient("MaxPool") +def AddMaxPoolGradient(op): + return CreateOperator("MaxPoolGradient")( + [op.inputs[0], GetGradientName(op.outputs[0]), op.outputs[1]], + [GetGradientName(op.inputs[0])], args=op.args) + + +@GradientRegistry.RegisterGradient("AveragePool") +def AddAveragePoolGradient(op): + return CreateOperator("AveragePoolGradient")( + [op.inputs[0], GetGradientName(op.outputs[0])], + [GetGradientName(op.inputs[0])], args=op.args) + +@GradientRegistry.RegisterGradient('Conv') +def AddFCGradient(op): + return CreateOperator('ConvGradient')( + list(op.inputs) + [GetGradientName(op.outputs[0])], + [GetGradientName(name) for name in + [op.inputs[1], op.inputs[2], op.inputs[0]]], + args=op.args) + + +@GradientRegistry.RegisterGradient('DepthSplit') +def AddDepthSplitGradient(op): + return CreateOperator('DepthConcat')( + [GetGradientName(name) for name in op.outputs], + [GetGradientName(op.inputs[0]), '_' + GetGradientName(op.inputs[0]) + '_dims'], + args = op.args) + +@GradientRegistry.RegisterGradient('DepthConcat') +def AddDepthConcatGradient(op): + return CreateOperator('DepthSplit')( + [GetGradientName(op.outputs[0]), op.outputs[1]], + [GetGradientName(name) for name in op.inputs], + args = op.args) + +@GradientRegistry.RegisterGradient('Dropout') +def AddDropoutGradient(op): + return CreateOperator('DropoutGrad')( + [GetGradientName(op.outputs[0]), op.outputs[1]], + [GetGradientName(op.inputs[0])], + args = op.args) + +@GradientRegistry.RegisterGradient('LRN') +def AddLRNGradient(op): + return CreateOperator('LRNGradient')( + [op.inputs[0], op.outputs[0], op.outputs[1], + GetGradientName(op.outputs[0])], + [GetGradientName(op.inputs[0])], + args = op.args) + +@GradientRegistry.RegisterGradient('Split') +def AddSplitGradient(op): + return CreateOperator('Sum')( + [GetGradientName(name) for name in op.outputs], + [GetGradientName(op.inputs[0])]) \ No newline at end of file diff --git a/pycaffe2/device_checker.py b/pycaffe2/device_checker.py new file mode 100644 index 00000000000..cd87f21326b --- /dev/null +++ b/pycaffe2/device_checker.py @@ -0,0 +1,91 @@ +import numpy as np +from pycaffe2 import core, workspace + +class DeviceChecker(object): + """A gradient checker in Python. + + This is not the most efficient way to check gradients, as the Python interface + will involve a lot of copy back and forth operations. Use at your own risk. + """ + def __init__(self, threshold, device_options): + self._threshold = threshold + self._device_options = device_options + + def CheckSimple(self, op, inputs, outputs_to_check): + """Checks the operator in a very simple fashion by stacking a sum of squares + on the top. + + Inputs: + op: the operator to be checked. + inputs: the input data in numpy arrays. + input_to_check: an index specifying which input blob we should + check. + outputs_with_grads: indices specifying which output blobs will we + need to check gradients with. For these outputs, we will collect a + squared sum and also feed in their gradients. + grad_operator: the gradient operator. If not given, we will get the + gradient operator from the gradient registry. + Outputs: + boolean: True if it passes, False if it does not pass. + """ + # Entering the checker workspace + old_ws_name = workspace.CurrentWorkspace() + results = [] + workspace.SwitchWorkspace("_device_check_", True) + for i, device_option in enumerate(self._device_options): + for i, arr in enumerate(inputs): + workspace.FeedBlob(op.inputs[i], arr, device_option) + op.device_option.CopyFrom(device_option) + workspace.RunOperatorOnce(op) + results.append( + [workspace.FetchBlob(op.outputs[idx]) for idx in outputs_to_check]) + # Everything is done, reset the workspace. + workspace.ResetWorkspace() + # After running on all devices, check correctness + success = True + for i in range(1, len(self._device_options)): + for j in range(len(outputs_to_check)): + x = results[i][j] + y = results[0][j] + if np.any(np.abs(x - y) > self._threshold): + print 'Failure in checking device option', i, 'and output ', + print op.outputs[j], '. The outputs are:' + print x.flatten() + print y.flatten() + success = False + continue + workspace.SwitchWorkspace(old_ws_name) + return success + + def CheckNet(self, net, inputs={}, ignore=set()): + """Checks a network by inspecting all of its intermediate results, and see + if things match. + """ + old_ws_name = workspace.CurrentWorkspace() + results = [] + blobs_to_check = sum([list(op.outputs) for op in net.operators], []) + blobs_to_check = [b for b in blobs_to_check if b not in ignore] + workspace.SwitchWorkspace("_device_check_", True) + for i, device_option in enumerate(self._device_options): + for name, arr in inputs.iteritems(): + workspace.FeedBlob(name, arr, device_option) + for op in net.operators: + op.device_option.CopyFrom(device_option) + workspace.RunNetOnce(net) + results.append( + [workspace.FetchBlob(name) for name in blobs_to_check]) + # After running on all devices, check correctness + success = True + for i in range(1, len(results)): + for j in range(len(blobs_to_check)): + x = results[i][j] + y = results[0][j] + if np.any(np.abs(x - y) > self._threshold): + print 'Failure in checking device option', i, 'and blob ', + print blobs_to_check[j], '. The outputs are:' + print x.flatten() + print y.flatten() + success = False + continue + workspace.SwitchWorkspace(old_ws_name) + return success diff --git a/pycaffe2/gradient_checker.py b/pycaffe2/gradient_checker.py new file mode 100644 index 00000000000..c8c8347fbd4 --- /dev/null +++ b/pycaffe2/gradient_checker.py @@ -0,0 +1,106 @@ +import numpy as np +from pycaffe2 import core, workspace +from caffe2.proto import caffe2_pb2 + +class GradientChecker: + """A gradient checker in Python. + + This is not the most efficient way to check gradients, as the Python interface + will involve a lot of copy back and forth operations. Use at your own risk. + """ + def __init__(self, stepsize, threshold, + device_option=caffe2_pb2.DeviceOption(), + workspace_name="gradient_check"): + self._stepsize = stepsize + self._threshold = threshold + self._device_option = device_option + self._workspace_name = workspace_name + + def GetLossAndGrad(self, op, grad_ops, x, input_name, outputs_with_grads): + # First, feed in the current input. Note that we are not changing anything + # else, so we don't need to feed in others. + workspace.FeedBlob(input_name, x, self._device_option) + # Run. + workspace.RunOperatorOnce(op) + loss = 0. + # Get Loss and feed in the gradients, run gradient ops. + for idx in outputs_with_grads: + name = op.outputs[idx] + arr = workspace.FetchBlob(name) + loss += (arr ** 2).sum() + workspace.FeedBlob(core.GetGradientName(name), arr, self._device_option) + loss /= 2. + # Run gradient ops + workspace.RunOperatorsOnce(grad_ops) + # Get gradients + grad = workspace.FetchBlob(core.GetGradientName(input_name)) + return loss, grad + + + def CheckSimple(self, op, inputs, input_to_check, + outputs_with_grads, grad_ops=None): + """Checks the operator in a very simple fashion by stacking a sum of squares + on the top. + + Inputs: + op: the operator to be checked. + inputs: the input data in numpy arrays. + input_to_check: an index specifying which input blob we should + check. + outputs_with_grads: indices specifying which output blobs will we + need to check gradients with. For these outputs, we will collect a + squared sum and also feed in their gradients. + grad_operator: the gradient operator. If not given, we will get the + gradient operator from the gradient registry. + Outputs: + boolean: True if it passes, False if it does not pass. + """ + # Entering the checker workspace + old_ws_name = workspace.CurrentWorkspace() + if self._workspace_name != old_ws_name: + workspace.SwitchWorkspace(self._workspace_name, True) + + op.device_option.CopyFrom(self._device_option) + if grad_ops is None: + grad_ops = core.GradientRegistry.GetGradient(op) + + dims_to_check = inputs[input_to_check].size + # First, feed in the input. + for i, arr in enumerate(inputs): + workspace.FeedBlob(op.inputs[i], arr, self._device_option) + + # Get the loss and gradient for the original. + input_name = op.inputs[input_to_check] + loss, grad = self.GetLossAndGrad(op, grad_ops, inputs[input_to_check], + input_name, outputs_with_grads) + grad_estimate = np.zeros_like(inputs[input_to_check]) + for current_dim in range(dims_to_check): + # Positive gradient + inputs[input_to_check].flat[current_dim] += self._stepsize + pos_loss, _ = self.GetLossAndGrad(op, grad_ops, inputs[input_to_check], + input_name, outputs_with_grads) + # Negative gradient + inputs[input_to_check].flat[current_dim] -= self._stepsize * 2 + neg_loss, _ = self.GetLossAndGrad(op, grad_ops, inputs[input_to_check], + input_name, outputs_with_grads) + # Recover the value + inputs[input_to_check].flat[current_dim] += self._stepsize + grad_estimate.flat[current_dim] = (pos_loss - neg_loss) / self._stepsize / 2 + # Now, check correctness + scale = np.maximum(np.maximum(np.abs(grad), np.abs(grad_estimate)), 1) + fail_mat = (np.abs(grad - grad_estimate) > scale * self._threshold) + if np.any(fail_mat): + idx = np.flatnonzero(fail_mat) + #print 'Failed. [idx, grad, grad_estimate] are:' + #print np.vstack([idx, grad.flat[idx], grad_estimate.flat[idx]]).T + ret = False + else: + ret = True + # After finishing, cleaning up things. + if self._workspace_name != old_ws_name: + # We reset the workspace to make sure everything intermediate is cleaned + # up. Note that there is no need to delete a workspace - when empty it + # takes a very limited amount of memory. + workspace.ResetWorkspace() + workspace.SwitchWorkspace(old_ws_name) + return ret, grad, grad_estimate \ No newline at end of file diff --git a/pycaffe2/mint/BREW b/pycaffe2/mint/BREW new file mode 100644 index 00000000000..24d72811a69 --- /dev/null +++ b/pycaffe2/mint/BREW @@ -0,0 +1,9 @@ +py_library( + name = "mint", + srcs = [ + "__init__.py", + "app.py", + "static/css/simple-sidebar.css", + "templates/index.html", + ], +) diff --git a/pycaffe2/mint/__init__.py b/pycaffe2/mint/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/pycaffe2/mint/app.py b/pycaffe2/mint/app.py new file mode 100644 index 00000000000..bf2353ee740 --- /dev/null +++ b/pycaffe2/mint/app.py @@ -0,0 +1,133 @@ +import argparse +import flask +import glob +import numpy as np +import nvd3 +import os +import sys +import tornado.httpserver +import tornado.wsgi + +__folder__ = os.path.abspath(os.path.dirname(__file__)) + +app = flask.Flask(__name__, + template_folder=os.path.join(__folder__, "templates"), + static_folder=os.path.join(__folder__, "static")) +args = None + +def jsonify_nvd3(chart): + chart.buildcontent() + # Note(Yangqing): python-nvd3 does not seem to separate the built HTML part + # and the script part. Luckily, it seems to be the case that the HTML part is + # only a
, which can be accessed by chart.container; the script part, + # while the script part occupies the rest of the html content, which we can + # then find by chart.htmlcontent.find['') + return flask.jsonify( + result=chart.container, + script=chart.htmlcontent[script_start:script_end].strip()) + +def visualize_summary(filename): + try: + data = np.loadtxt(filename) + except Error as e: + return 'Cannot load file {}: {}'.format(filename, str(e)) + chart_name = os.path.splitext(os.path.basename(filename))[0] + chart = nvd3.lineChart(name=chart_name + '_summary_chart', + height=args.chart_height, + y_axis_format='.03g') + if args.sample < 0: + step = max(data.shape[0] / -args.sample, 1) + else: + step = args.sample + xdata = np.arange(0, data.shape[0], step) + # data should have 4 dimensions. + chart.add_serie(x=xdata, y=data[xdata, 0], name='min') + chart.add_serie(x=xdata, y=data[xdata, 1], name='max') + chart.add_serie(x=xdata, y=data[xdata, 2], name='mean') + chart.add_serie(x=xdata, y=data[xdata, 2] + data[xdata, 3], name='m+std') + chart.add_serie(x=xdata, y=data[xdata, 2] - data[xdata, 3], name='m-std') + return jsonify_nvd3(chart) + +def visualize_print_log(filename): + try: + data = np.loadtxt(filename) + if data.ndim == 1: + data = data[:, np.newaxis] + except Error as e: + return 'Cannot load file {}: {}'.format(filename, str(e)) + chart_name = os.path.splitext(os.path.basename(filename))[0] + chart = nvd3.lineChart(name=chart_name + '_log_chart', + height=args.chart_height, + y_axis_format='.03g') + if args.sample < 0: + step = max(data.shape[0] / -args.sample, 1) + else: + step = args.sample + xdata = np.arange(0, data.shape[0], step) + # if there is only one curve, we also show the running min and max + if data.shape[1] == 1: + # We also print the running min and max for the steps. + trunc_size = data.shape[0] / step + running_mat = data[:trunc_size * step].reshape((trunc_size, step)) + chart.add_serie(x=xdata[:trunc_size], y=running_mat.min(axis=1), + name='running_min') + chart.add_serie(x=xdata[:trunc_size], y=running_mat.max(axis=1), + name='running_max') + chart.add_serie(x=xdata, y=data[xdata, 0], name=chart_name) + else: + for i in range(0, min(data.shape[1], args.max_curves)): + # data should have 4 dimensions. + chart.add_serie(x=xdata, y=data[xdata, i], + name='{}[{}]'.format(chart_name, i)) + + return jsonify_nvd3(chart) + +def visualize_file(filename): + fullname = os.path.join(args.root, filename) + if filename.endswith('summary'): + return visualize_summary(fullname) + elif filename.endswith('log'): + return visualize_print_log(fullname) + else: + return flask.jsonify(result='Unsupport file: {}'.format(filename), + script='') + +@app.route('/') +def index(): + files = glob.glob(os.path.join(args.root, "*.*")) + files.sort() + names = [os.path.basename(f) for f in files] + return flask.render_template( + 'index.html', root=args.root, names=names, debug_messages=names) + +@app.route('/visualization/') +def visualization(name): + ret = visualize_file(name) + print 'debug:', ret + return ret + +def main(argv): + parser = argparse.ArgumentParser("The mint visualizer.") + parser.add_argument('-p', '--port', type=int, default=5000, + help="The flask port to use.") + parser.add_argument('-r', '--root', type=str, default='.', + help="The root folder to read files for visualization.") + parser.add_argument('--max_curves', type=int, default=5, + help="The max number of curves to show in a dump tensor.") + parser.add_argument('--chart_height', type=int, default=300, + help="The chart height for nvd3.") + parser.add_argument('-s', '--sample', type=int, default=-200, + help="Sample every given number of data points. A negative " + "number means the total points we will sample on the " + "whole curve. Default 100 points.") + global args + args = parser.parse_args(argv) + server = tornado.httpserver.HTTPServer(tornado.wsgi.WSGIContainer(app)) + server.listen(args.port) + print "Tornado server starting on port {}.".format(args.port) + tornado.ioloop.IOLoop.instance().start() + +if __name__ == '__main__': + main(sys.argv[1:]) \ No newline at end of file diff --git a/pycaffe2/mint/static/css/simple-sidebar.css b/pycaffe2/mint/static/css/simple-sidebar.css new file mode 100644 index 00000000000..6bb18e9f930 --- /dev/null +++ b/pycaffe2/mint/static/css/simple-sidebar.css @@ -0,0 +1,125 @@ +/*! + * Start Bootstrap - Simple Sidebar HTML Template (http://startbootstrap.com) + * Code licensed under the Apache License v2.0. + * For details, see http://www.apache.org/licenses/LICENSE-2.0. + */ + +/* Toggle Styles */ + +#wrapper { + padding-left: 0; + -webkit-transition: all 0.5s ease; + -moz-transition: all 0.5s ease; + -o-transition: all 0.5s ease; + transition: all 0.5s ease; +} + +#wrapper.toggled { + padding-left: 250px; +} + +#sidebar-wrapper { + z-index: 1000; + position: fixed; + left: 250px; + width: 0; + height: 100%; + margin-left: -250px; + overflow-y: auto; + background: rgb(193,237,201); + -webkit-transition: all 0.5s ease; + -moz-transition: all 0.5s ease; + -o-transition: all 0.5s ease; + transition: all 0.5s ease; +} + +#wrapper.toggled #sidebar-wrapper { + width: 250px; +} + +#page-content-wrapper { + width: 100%; + position: absolute; + padding: 15px; +} + +#wrapper.toggled #page-content-wrapper { + position: absolute; + margin-right: -250px; +} + +/* Sidebar Styles */ + +.sidebar-nav { + position: absolute; + top: 0; + width: 250px; + margin-bottom: 40px; + padding: 0; + list-style: none; +} + +.sidebar-nav li { + text-indent: 20px; + line-height: 30px; +} + +.sidebar-nav li a { + display: block; + text-decoration: none; + color: #999999; +} + +.sidebar-nav li a:hover { + text-decoration: none; + color: #000; + background: rgba(255,255,255,0.8); +} + +.sidebar-nav li a:active, +.sidebar-nav li a:focus { + text-decoration: none; +} + +.sidebar-nav > .sidebar-brand { + height: 65px; + font-size: 18px; + line-height: 60px; +} + +.sidebar-nav > .sidebar-brand a { + color: #999999; +} + +.sidebar-nav > .sidebar-brand a:hover { + color: #fff; + background: none; +} + +@media(min-width:768px) { + #wrapper { + padding-left: 250px; + } + + #wrapper.toggled { + padding-left: 0; + } + + #sidebar-wrapper { + width: 250px; + } + + #wrapper.toggled #sidebar-wrapper { + width: 0; + } + + #page-content-wrapper { + padding: 20px; + position: relative; + } + + #wrapper.toggled #page-content-wrapper { + position: relative; + margin-right: 0; + } +} \ No newline at end of file diff --git a/pycaffe2/mint/templates/index.html b/pycaffe2/mint/templates/index.html new file mode 100644 index 00000000000..506f9698f23 --- /dev/null +++ b/pycaffe2/mint/templates/index.html @@ -0,0 +1,134 @@ + + + + + + + Mint + + + + + + + + + + + + + +
+ + + +
+

+ Visualizing folder: {{ root }}.
+ Toggle sidebar + Refresh all +

+
+ + + + +
+

+
+ {% for name in names %} +
+
+ {{ name }} + + + + Top +
+
+
Loading...
+

Last updated: NA

+
+
+ {% endfor %} +
+
+
    + {% for message in debug_messages %} +
  • {{ message }}
  • + {% endfor %} +
+
+
+
+ + +
+ +
+
+ + + + + + + + + + + \ No newline at end of file diff --git a/pycaffe2/net_drawer.py b/pycaffe2/net_drawer.py new file mode 100644 index 00000000000..2c1c767d286 --- /dev/null +++ b/pycaffe2/net_drawer.py @@ -0,0 +1,93 @@ +from collections import defaultdict +from pycaffe2 import utils +import sys +import subprocess + +try: + import pydot +except ImportError: + print ('Cannot import pydot, which is required for drawing a network. This ' + 'can usually be installed in python with "pip install pydot". Also, ' + 'pydot requires graphviz to convert dot files to pdf: in ubuntu, this ' + 'can usually be installed with "sudo apt-get install graphviz".') + print ('net_drawer will now exit. Please install the correct dependencies.') + sys.exit(1) + +from caffe2.proto import caffe2_pb2 +from google.protobuf import text_format + +OP_STYLE = {'shape': 'box', 'color': '#0F9D58', 'style': 'filled', + 'fontcolor': '#FFFFFF'} +BLOB_STYLE = {'shape': 'octagon'} + +def GetPydotGraph(operators, name): + graph = pydot.Dot(name, rankdir='LR') + pydot_nodes = {} + pydot_node_counts = defaultdict(int) + node_id = 0 + for op_id, op in enumerate(operators): + if op.name: + op_node = pydot.Node( + '%s/%s (op#%d)' % (op.name, op.type, op_id), **OP_STYLE) + else: + op_node = pydot.Node( + '%s (op#%d)' % (op.type, op_id), **OP_STYLE) + graph.add_node(op_node) + # print 'Op: %s' % op.name + # print 'inputs: %s' % str(op.inputs) + # print 'outputs: %s' % str(op.outputs) + for input_name in op.inputs: + if input_name not in pydot_nodes: + input_node = pydot.Node( + input_name + str(pydot_node_counts[input_name]), + label=input_name, **BLOB_STYLE) + pydot_nodes[input_name] = input_node + else: + input_node = pydot_nodes[input_name] + graph.add_node(input_node) + graph.add_edge(pydot.Edge(input_node, op_node)) + for output_name in op.outputs: + if output_name in pydot_nodes: + # we are overwriting an existing blob. need to updat the count. + pydot_node_counts[output_name] += 1 + output_node = pydot.Node( + output_name + str(pydot_node_counts[output_name]), + label=output_name, **BLOB_STYLE) + pydot_nodes[output_name] = output_node + graph.add_node(output_node) + graph.add_edge(pydot.Edge(op_node, output_node)) + return graph + +def GetOperatorMapForPlan(plan_def): + graphs = {} + for net_id, net in enumerate(plan_def.networks): + if net.HasField('name'): + graphs[plan_def.name + "_" + net.name] = net.operators + else: + graphs[plan_def.name + "_network_%d" % net_id] = net.operators + return graphs + +def main(): + with open(sys.argv[1], 'r') as fid: + content = fid.read() + graphs = utils.GetContentFromProtoString( + content,{ + caffe2_pb2.PlanDef: lambda x: GetOperatorMapForPlan(x), + caffe2_pb2.NetDef: lambda x: {x.name: x.operators}, + }) + for key, operators in graphs.iteritems(): + graph = GetPydotGraph(operators, key) + filename = graph.get_name() + '.dot' + graph.write(filename, format='raw') + pdf_filename = filename[:-3] + 'pdf' + with open(pdf_filename, 'w') as fid: + try: + subprocess.call(['dot', '-Tpdf', filename], stdout=fid) + except OSError: + print ('pydot requires graphviz to convert dot files to pdf: in ubuntu ' + 'this can usually be installed with "sudo apt-get install ' + 'graphviz". We have generated the .dot file but will not ' + 'generate pdf file for now due to missing graphviz binaries.') + +if __name__ == '__main__': + main() diff --git a/pycaffe2/utils.py b/pycaffe2/utils.py new file mode 100644 index 00000000000..7b01c065ac7 --- /dev/null +++ b/pycaffe2/utils.py @@ -0,0 +1,83 @@ +from caffe2.proto import caffe2_pb2 +from caffe.proto import caffe_pb2 +from google.protobuf.message import DecodeError, Message +from google.protobuf import text_format +import numpy as np + +def CaffeBlobToNumpyArray(blob): + return np.asarray(blob.data, dtype=np.float32).reshape( + blob.num, blob.channels, blob.height, blob.width) + +def Caffe2TensorToNumpyArray(tensor): + return np.asarray(tensor.float_data, dtype=np.float32).reshape(tensor.dims) + +def NumpyArrayToCaffe2Tensor(arr, name): + tensor = caffe2_pb2.TensorProto() + tensor.data_type = caffe2_pb2.TensorProto.FLOAT + tensor.name = name + tensor.dims.extend(arr.shape) + tensor.float_data.extend(list(arr.flatten().astype(float))) + return tensor + +def MakeArgument(key, value): + """Makes an argument based on the value type.""" + argument = caffe2_pb2.Argument() + argument.name = key + if type(value) is float: + argument.f = value + elif type(value) is int: + argument.i = value + elif type(value) is str: + argument.s = value + elif type(value) is Message: + argument.s = value.SerializeToString() + elif all(type(v) is float for v in value): + argument.floats.extend(value) + elif all(type(v) is int for v in value): + argument.ints.extend(value) + elif all(type(v) is str for v in value): + argument.strings.extend(value) + elif all(type(v) is Message for v in value): + argument.strings.extend([v.SerializeToString() for v in values]) + else: + raise ValueError("Unknown argument type: key=%s value=%s, value type=%s" % + (key, str(value), str(type(value)))) + return argument + +def TryReadProtoWithClass(cls, s): + """Reads a protobuffer with the given proto class. + + Inputs: + cls: a protobuffer class. + s: a string of either binary or text protobuffer content. + + Outputs: + proto: the protobuffer of cls + + Throws: + google.protobuf.message.DecodeError: if we cannot decode the message. + """ + obj = cls() + try: + text_format.Parse(s, obj) + return obj + except text_format.ParseError as e: + obj.ParseFromString(s) + return obj + +def GetContentFromProto(obj, function_map): + """Gets a specific field from a protocol buffer that matches the given class. + """ + for cls, func in function_map.iteritems(): + if type(obj) is cls: + return func(obj) + +def GetContentFromProtoString(s, function_map): + for cls, func in function_map.iteritems(): + try: + obj = TryReadProtoWithClass(cls, s) + return func(obj) + except DecodeError: + continue + else: + raise DecodeError("Cannot find a fit protobuffer class.") \ No newline at end of file diff --git a/pycaffe2/visualize.py b/pycaffe2/visualize.py new file mode 100644 index 00000000000..ec442924dcf --- /dev/null +++ b/pycaffe2/visualize.py @@ -0,0 +1,144 @@ +"""Functions that could be used to visualize Tensors. + +This is adapted from the old-time iceberk package that Yangqing wrote... Oh gold +memories. Before decaf and caffe. Why iceberk? Because I was at Berkeley, +bears are vegetarian, and iceberg lettuce has layers of leaves. + +(This joke is so lame.) +""" + +import numpy as np +from matplotlib import cm, pyplot + +def ChannelFirst(arr): + """Convert a HWC array to CHW.""" + ndim = arr.ndim + return arr.swapaxes(ndim-1, ndim-2).swapaxes(ndim-2, ndim-3) + +def ChannelLast(arr): + """Convert a CHW array to HWC.""" + ndim = arr.ndim + return arr.swapaxes(ndim-3, ndim-2).swapaxes(ndim-2, ndim-1) + +class PatchVisualizer(object): + """PatchVisualizer visualizes patches. + """ + def __init__(self, gap=1): + self.gap = gap + + def ShowSingle(self, patch, cmap=None): + """Visualizes one single patch. + + The input patch could be a vector (in which case we try to infer the shape + of the patch), a 2-D matrix, or a 3-D matrix whose 3rd dimension has 3 + channels. + """ + if len(patch.shape) == 1: + patch = patch.reshape(self.get_patch_shape(patch)) + elif len(patch.shape) > 2 and patch.shape[2] != 3: + raise ValueError("The input patch shape isn't correct.") + # determine color + if len(patch.shape) == 2 and cmap is None: + cmap = cm.gray + pyplot.imshow(patch, cmap=cmap) + return patch + + def ShowMultiple(self, patches, ncols=None, cmap=None, bg_func=np.mean): + """Visualize multiple patches. + + In the passed in patches matrix, each row is a patch, in the shape of either + n*n, n*n*1 or n*n*3, either in a flattened format (so patches would be a + 2-D array), or a multi-dimensional tensor. We will try our best to figure + out automatically the patch size. + """ + num_patches = patches.shape[0] + if ncols is None: + ncols = int(np.ceil(np.sqrt(num_patches))) + nrows = int(np.ceil(num_patches / float(ncols))) + if len(patches.shape) == 2: + patches = patches.reshape((patches.shape[0],) + + self.get_patch_shape(patches[0])) + patch_size_expand = np.array(patches.shape[1:3]) + self.gap + image_size = patch_size_expand * np.array([nrows, ncols]) - self.gap + if len(patches.shape) == 4: + if patches.shape[3] == 1: + # gray patches + patches = patches.reshape(patches.shape[:-1]) + image_shape = tuple(image_size) + if cmap is None: + cmap = cm.gray + elif patches.shape[3] == 3: + # color patches + image_shape = tuple(image_size) + (3,) + else: + raise ValueError, "The input patch shape isn't expected." + else: + image_shape = tuple(image_size) + if cmap is None: + cmap = cm.gray + image = np.ones(image_shape) * bg_func(patches) + for pid in range(num_patches): + row = pid / ncols * patch_size_expand[0] + col = pid % ncols * patch_size_expand[1] + image[row:row+patches.shape[1], col:col+patches.shape[2]] = \ + patches[pid] + pyplot.imshow(image, cmap=cmap, interpolation='nearest') + pyplot.axis('off') + return image + + def ShowImages(self, patches, *args, **kwargs): + """Similar to ShowMultiple, but always normalize the values between 0 and 1 + for better visualization of image-type data. + """ + patches = patches - np.min(patches) + patches /= np.max(patches) + np.finfo(np.float64).eps + return self.ShowMultiple(patches, *args, **kwargs) + + def ShowChannels(self, patch, cmap=None, bg_func=np.mean): + """ This function shows the channels of a patch. + + The incoming patch should have shape [w, h, num_channels], and each channel + will be visualized as a separate gray patch. + """ + if len(patch.shape) != 3: + raise ValueError, "The input patch shape isn't correct." + patch_reordered = np.swapaxes(patch.T, 1, 2) + return self.ShowMultiple(patch_reordered, cmap=cmap, bg_func=bg_func) + + def get_patch_shape(self, patch): + """Gets the shape of a single patch. + + Basically it tries to interprete the patch as a square, and also check if it + is in color (3 channels) + """ + edgeLen = np.sqrt(patch.size) + if edgeLen != np.floor(edgeLen): + # we are given color patches + edgeLen = np.sqrt(patch.size / 3.) + if edgeLen != np.floor(edgeLen): + raise ValueError, "I can't figure out the patch shape." + return (edgeLen, edgeLen, 3) + else: + edgeLen = int(edgeLen) + return (edgeLen, edgeLen) + +_default_visualizer = PatchVisualizer() + +"""Utility functions that directly point to functions in the default visualizer. + +These functions don't return anything, so you won't see annoying printouts of +the visualized images. If you want to save the images for example, you should +explicitly instantiate a patch visualizer, and call those functions. +""" + +def ShowSingle(*args, **kwargs): + _default_visualizer.ShowSingle(*args, **kwargs) + +def ShowMultiple(*args, **kwargs): + _default_visualizer.ShowMultiple(*args, **kwargs) + +def ShowImages(*args, **kwargs): + _default_visualizer.ShowImages(*args, **kwargs) + +def ShowChannels(*args, **kwargs): + _default_visualizer.ShowChannels(*args, **kwargs) diff --git a/pycaffe2/workspace.py b/pycaffe2/workspace.py new file mode 100644 index 00000000000..7d17275a312 --- /dev/null +++ b/pycaffe2/workspace.py @@ -0,0 +1,96 @@ +import atexit +from multiprocessing import Process +import socket + +from .libcaffe2_python import * +# libcaffe2_python contains a global Workspace that we need to properly delete +# when exiting. Otherwise, cudart will cause segfaults sometimes. +atexit.register(OnModuleExit) + +try: + import pycaffe2.mint.app + _has_mint = True +except ImportError as err: + print 'Mint is not available, possibly due to some downstream dependencies.' + _has_mint = False + +def _GetFreeFlaskPort(): + """Get a free flask port.""" + # We will prefer to use 5000. If not, we will then pick a random port. + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + result = sock.connect_ex(('127.0.0.1',5000)) + if result == 0: + return 5000 + else: + s = socket.socket() + s.bind(('', 0)) + port = s.getsockname()[1] + s.close() + # Race condition: between the interval we close the socket and actually + # start a mint process, another process might have occupied the port. We + # don't do much here as this is mostly for convenience in research rather + # than 24x7 service. + return port + +def StartMint(root_folder=None, port=None): + """Start a mint instance.""" + if not _has_mint: + print 'Mint is not available. Not starting the server.' + return None + if root_folder is None: + root_folder = RootFolder() + if port is None: + port = _GetFreeFlaskPort() + process = Process(target=pycaffe2.mint.app.main, args=( + ['-p', str(port), '-r', root_folder],)) + process.start() + print 'Mint running at http://{}:{}'.format(socket.getfqdn(), port) + return process + +def StringfyProto(obj): + """Stringfy a protocol buffer object. + + Inputs: + obj: a protocol buffer object, or a Pycaffe2 object that has a Proto() + function. + Outputs: + string: the output protobuf string. + Raises: + AttributeError: if the passed in object does not have the right attribute. + """ + if type(obj) is str: + return obj + else: + try: + # First, see if this object is a protocol buffer, which we can simply + # serialize with the SerializeToString() call. + return obj.SerializeToString() + except AttributeError: + # Secind, see if this is an object defined in Pycaffe2, which exposes a + # Proto() function that gives you the protocol buffer. + return obj.Proto().SerializeToString() + +def CreateNet(net): + return cc_CreateNet(StringfyProto(net)) + +def RunOperatorOnce(operator): + return cc_RunOperatorOnce(StringfyProto(operator)) + +def RunOperatorsOnce(operators): + for op in operators: + success = RunOperatorOnce(op) + if not success: + return False + return True + +def RunNetOnce(net): + return cc_RunNetOnce(StringfyProto(net)) + +def RunPlan(plan): + return cc_RunPlan(StringfyProto(plan)) + +def FeedBlob(name, arr, device_option=None): + if device_option is not None: + return cc_FeedBlob(name, arr, StringfyProto(device_option)) + else: + return cc_FeedBlob(name, arr) \ No newline at end of file diff --git a/pycaffe2/workspace_test.py b/pycaffe2/workspace_test.py new file mode 100644 index 00000000000..0a4f50ae775 --- /dev/null +++ b/pycaffe2/workspace_test.py @@ -0,0 +1,114 @@ +import numpy as np +import unittest + +from caffe2.proto import caffe2_pb2 +from pycaffe2 import core, workspace + +class TestWorkspace(unittest.TestCase): + def setUp(self): + self.net = core.Net("test-net") + self.net.ConstantFill([], "testblob", shape=[1, 2, 3, 4], value=1.0) + workspace.ResetWorkspace() + + def testRootFolder(self): + self.assertEqual(workspace.ResetWorkspace(), True) + self.assertEqual(workspace.RootFolder(), ".") + self.assertEqual(workspace.ResetWorkspace("/home/test"), True) + self.assertEqual(workspace.RootFolder(), "/home/test") + + def testWorkspaceHasBlobWithNonexistingName(self): + self.assertEqual(workspace.HasBlob("non-existing"), False) + + def testRunOperatorOnce(self): + self.assertEqual( + workspace.RunOperatorOnce( + self.net.Proto().operators[0].SerializeToString()), + True) + self.assertEqual(workspace.HasBlob("testblob"), True) + blobs = workspace.Blobs() + self.assertEqual(len(blobs), 1) + self.assertEqual(blobs[0], "testblob") + + def testRunNetOnce(self): + self.assertEqual(workspace.RunNetOnce(self.net.Proto().SerializeToString()), True) + self.assertEqual(workspace.HasBlob("testblob"), True) + + def testRunPlan(self): + plan = core.Plan("test-plan") + plan.AddNets([self.net]) + plan.AddStep(core.ExecutionStep("test-step", self.net)) + self.assertEqual(workspace.RunPlan(plan.Proto().SerializeToString()), True); + self.assertEqual(workspace.HasBlob("testblob"), True) + + def testResetWorkspace(self): + self.assertEqual(workspace.RunNetOnce(self.net.Proto().SerializeToString()), True) + self.assertEqual(workspace.HasBlob("testblob"), True) + self.assertEqual(workspace.ResetWorkspace(), True) + self.assertEqual(workspace.HasBlob("testblob"), False) + + def testFetchFeedBlob(self): + self.assertEqual(workspace.RunNetOnce(self.net.Proto().SerializeToString()), True) + fetched = workspace.FetchBlob("testblob") + # check if fetched is correct. + self.assertEqual(fetched.shape, (1, 2, 3, 4)) + np.testing.assert_array_equal(fetched, 1.0) + fetched[:] = 2.0 + self.assertEqual(workspace.FeedBlob("testblob", fetched), True) + fetched_again = workspace.FetchBlob("testblob") + self.assertEqual(fetched_again.shape, (1, 2, 3, 4)) + np.testing.assert_array_equal(fetched_again, 2.0) + +class TestWorkspaceGPU(unittest.TestCase): + def setUp(self): + self.net = core.Net("test-net") + self.net.ConstantFill([], "testblob", shape=[1, 2, 3, 4], value=1.0) + self.net.RunAllOnGPU() + + def testFetchBlobGPU(self): + self.assertEqual(workspace.RunNetOnce(self.net.Proto().SerializeToString()), True) + fetched = workspace.FetchBlob("testblob") + # check if fetched is correct. + self.assertEqual(fetched.shape, (1, 2, 3, 4)) + np.testing.assert_array_equal(fetched, 1.0) + fetched[:] = 2.0 + self.assertEqual(workspace.FeedBlob("testblob", fetched), True) + fetched_again = workspace.FetchBlob("testblob") + self.assertEqual(fetched_again.shape, (1, 2, 3, 4)) + np.testing.assert_array_equal(fetched_again, 2.0) + + +class TestMultiWorkspaces(unittest.TestCase): + def setUp(self): + workspace.SwitchWorkspace("default") + workspace.ResetWorkspace() + + def testCreateWorkspace(self): + workspaces = workspace.Workspaces() + self.assertEqual(len(workspaces), 1) + self.assertEqual(workspaces[0], "default") + self.net = core.Net("test-net") + self.net.ConstantFill([], "testblob", shape=[1, 2, 3, 4], value=1.0) + self.assertEqual( + workspace.RunNetOnce(self.net.Proto().SerializeToString()), True) + self.assertEqual(workspace.HasBlob("testblob"), True) + self.assertEqual(workspace.SwitchWorkspace("test", True), True) + self.assertEqual(workspace.HasBlob("testblob"), False) + self.assertEqual(workspace.SwitchWorkspace("default"), True) + self.assertEqual(workspace.HasBlob("testblob"), True) + + try: + # The following should raise an error. + workspace.SwitchWorkspace("non-existing") + # so this should never happen. + self.assertEqual(True, False) + except RuntimeError: + pass + + workspaces = workspace.Workspaces() + self.assertEqual(len(workspaces), 2) + workspaces.sort() + self.assertEqual(workspaces[0], "default") + self.assertEqual(workspaces[1], "test") + +if __name__ == '__main__': + unittest.main() diff --git a/third_party/README b/third_party/README new file mode 100644 index 00000000000..ad7794bf4e1 --- /dev/null +++ b/third_party/README @@ -0,0 +1,2 @@ +Subfolders in the third_party folder are used to help install things more easily +and you should pre-install them on your machine. \ No newline at end of file diff --git a/third_party/cudnn/BREW b/third_party/cudnn/BREW new file mode 100644 index 00000000000..25a9d7647b0 --- /dev/null +++ b/third_party/cudnn/BREW @@ -0,0 +1,7 @@ +cc_thirdparty_target( + name = "cudnn", + srcs = ["BREW"], + commands=[], + cc_obj_files = ["-lcudnn"], +) + diff --git a/third_party/eigen3/BREW b/third_party/eigen3/BREW new file mode 100644 index 00000000000..c96f731fbed --- /dev/null +++ b/third_party/eigen3/BREW @@ -0,0 +1,7 @@ +cc_thirdparty_target( + name = "eigen", + srcs = ["BREW"], + commands=[], + # Eigen is a header-only library so there is no cc_obj_files. + cc_obj_files = [], +) diff --git a/third_party/gflags/BREW b/third_party/gflags/BREW new file mode 100644 index 00000000000..cc9ab8323b6 --- /dev/null +++ b/third_party/gflags/BREW @@ -0,0 +1,6 @@ +cc_thirdparty_target( + name="gflags", + srcs=["BREW"], + commands=[], + cc_obj_files=["-lgflags"], +) diff --git a/third_party/glog/BREW b/third_party/glog/BREW new file mode 100644 index 00000000000..3180b5fdfc9 --- /dev/null +++ b/third_party/glog/BREW @@ -0,0 +1,6 @@ +cc_thirdparty_target( + name="glog", + srcs=["BREW"], + commands=[], + cc_obj_files=["-lglog"], +) diff --git a/third_party/leveldb/BREW b/third_party/leveldb/BREW new file mode 100644 index 00000000000..3f21da13ae0 --- /dev/null +++ b/third_party/leveldb/BREW @@ -0,0 +1,7 @@ +cc_thirdparty_target( + name = "leveldb", + srcs = ["BREW"], + commands=[], + deps = ["//third_party/snappy:snappy"], + cc_obj_files = [ "-lleveldb" ], +) diff --git a/third_party/liblmdb/BREW b/third_party/liblmdb/BREW new file mode 100644 index 00000000000..0f2b4d59274 --- /dev/null +++ b/third_party/liblmdb/BREW @@ -0,0 +1,8 @@ +cc_thirdparty_target( + name = "lmdb", + srcs = ["BREW"], + commands=[], + cc_obj_files = [ + "-llmdb" + ], +) diff --git a/third_party/libzmq/BREW b/third_party/libzmq/BREW new file mode 100644 index 00000000000..3e99bb80155 --- /dev/null +++ b/third_party/libzmq/BREW @@ -0,0 +1,6 @@ +cc_thirdparty_target( + name = "libzmq", + srcs = ["BREW"], + commands=[], + cc_obj_files = [ "-lzmq" ], +) \ No newline at end of file diff --git a/third_party/snappy/BREW b/third_party/snappy/BREW new file mode 100644 index 00000000000..4820567cdaf --- /dev/null +++ b/third_party/snappy/BREW @@ -0,0 +1,8 @@ +cc_thirdparty_target( + name = "snappy", + srcs = ["BREW"], + commands=[], + cc_obj_files = [ + "-lsnappy" + ], +)