Commit Graph

51 Commits

Author SHA1 Message Date
Joel Schlosser
a0309f89f4 Initial ModuleInfo implementation (#61935)
Summary:
This PR contains the initial version of `ModuleInfo` for use in testing modules. The design philosophy taken here is to start small and simple and build out / refactor as needed when more test coverage or `ModuleInfo` entries are added. As such, it's not intended for general usage yet. The PR contains the following:

* (new file) `torch/testing/_internal/common_modules.py`
  * `ModuleInfo` definition - metadata for each module to use in testing
  * `module_db` - the actual `ModuleInfo` database; currently contains entries for two modules
  * `ModuleInput` - analogous to `SampleInput` from OpInfo; contains `FunctionInput`s for both constructor and forward pass inputs
      * Constructor and forward pass inputs are tied together within a `ModuleInput` because they are likely correlated
  * `FunctionInput` - just contains args and kwargs to pass to a function (is there a nicer way to do this?)
  * `modules` decorator - analogous to `ops`; specifies a set of modules to run a test over
  * Some constants used to keep track of all modules under torch.nn:
      * `MODULE_NAMESPACES` - list of all namespaces containing modules
      * `MODULE_CLASSES` - list of all module class objects
      * `MODULE_CLASS_NAMES` - dict from module class object to nice name (e.g. torch.nn.Linear -> "nn.Linear")
* (new file) `test/test_modules.py`
    * Uses the above to define tests over modules
    * Currently, there is one test for demonstration, `test_forward`, which instantiates a module, runs its forward pass, and compares it to a reference, if one is defined

Pull Request resolved: https://github.com/pytorch/pytorch/pull/61935

Reviewed By: mruberry

Differential Revision: D29881832

Pulled By: jbschlosser

fbshipit-source-id: cc05c7d85f190a3aa42d55d4c8b01847d1efd57f
2021-07-27 07:42:07 -07:00