mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 00:21:07 +01:00
Summary: PR https://github.com/pytorch/pytorch/pull/30523 attempted to fix https://github.com/pytorch/pytorch/issues/30508 and https://github.com/pytorch/pytorch/issues/30462, but the fix wasn't complete. This PR makes the following improvements: 1. Fixes https://github.com/pytorch/pytorch/issues/30508 and https://github.com/pytorch/pytorch/issues/30462 properly by excluding undefined tensors in the result of `Module::parameters()` / `named_parameters()` / `buffers()` / `named_buffers()`, which mirrors the Python API behavior. 2. Audits all use sites of `Module::parameters_` / `buffers_` and change them to `Module::named_parameters(/*recurse=*/false)` / `named_buffers(/*recurse=*/false)` when appropriate, so that use sites of module parameters / buffers never need to worry about undefined tensors. Pull Request resolved: https://github.com/pytorch/pytorch/pull/30626 Differential Revision: D18777507 Pulled By: yf225 fbshipit-source-id: 55b64b69779e1186342efd3c44857f416334ed6b
85 lines
2.1 KiB
C++
85 lines
2.1 KiB
C++
#pragma once
|
|
|
|
#include <test/cpp/common/support.h>
|
|
|
|
#include <gtest/gtest.h>
|
|
|
|
#include <torch/nn/cloneable.h>
|
|
#include <torch/types.h>
|
|
#include <torch/utils.h>
|
|
|
|
#include <string>
|
|
#include <utility>
|
|
|
|
namespace torch {
|
|
namespace test {
|
|
|
|
// Lets you use a container without making a new class,
|
|
// for experimental implementations
|
|
class SimpleContainer : public nn::Cloneable<SimpleContainer> {
|
|
public:
|
|
void reset() override {}
|
|
|
|
template <typename ModuleHolder>
|
|
ModuleHolder add(
|
|
ModuleHolder module_holder,
|
|
std::string name = std::string()) {
|
|
return Module::register_module(std::move(name), module_holder);
|
|
}
|
|
};
|
|
|
|
struct SeedingFixture : public ::testing::Test {
|
|
SeedingFixture() {
|
|
torch::manual_seed(0);
|
|
}
|
|
};
|
|
|
|
struct CerrRedirect {
|
|
CerrRedirect(std::streambuf * new_buffer) : prev_buffer(std::cerr.rdbuf(new_buffer)) {}
|
|
|
|
~CerrRedirect( ) {
|
|
std::cerr.rdbuf(prev_buffer);
|
|
}
|
|
|
|
private:
|
|
std::streambuf * prev_buffer;
|
|
};
|
|
|
|
inline bool pointer_equal(at::Tensor first, at::Tensor second) {
|
|
return first.data_ptr() == second.data_ptr();
|
|
}
|
|
|
|
inline int count_substr_occurrences(const std::string& str, const std::string& substr) {
|
|
int count = 0;
|
|
size_t pos = str.find(substr);
|
|
|
|
while (pos != std::string::npos) {
|
|
count++;
|
|
pos = str.find(substr, pos + substr.size());
|
|
}
|
|
|
|
return count;
|
|
}
|
|
|
|
// A RAII, thread local (!) guard that changes default dtype upon
|
|
// construction, and sets it back to the original dtype upon destruction.
|
|
//
|
|
// Usage of this guard is synchronized across threads, so that at any given time,
|
|
// only one guard can take effect.
|
|
struct AutoDefaultDtypeMode {
|
|
static std::mutex default_dtype_mutex;
|
|
|
|
AutoDefaultDtypeMode(c10::ScalarType default_dtype) : prev_default_dtype(torch::typeMetaToScalarType(torch::get_default_dtype())) {
|
|
default_dtype_mutex.lock();
|
|
torch::set_default_dtype(torch::scalarTypeToTypeMeta(default_dtype));
|
|
}
|
|
~AutoDefaultDtypeMode() {
|
|
default_dtype_mutex.unlock();
|
|
torch::set_default_dtype(torch::scalarTypeToTypeMeta(prev_default_dtype));
|
|
}
|
|
c10::ScalarType prev_default_dtype;
|
|
};
|
|
|
|
} // namespace test
|
|
} // namespace torch
|