pytorch/torch/csrc/jit/codegen/cuda/utils.h
jiej ac146c4820 [nvFuser] Switching to CudaFusionGuard from BailOut for nvfuser - update 2 (#46452)
Summary:
1. Added CudaFusionGuard as the custom TypeCheck for nvfuser; enabled dynamic shape support with profiling executor;
2. dropped support for legacy fuser;
3. re-enabled nvfuser tests;
4. added registration for profiling record to allow profiling on user specified nodes.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/46452

Reviewed By: zou3519, anjali411

Differential Revision: D24364642

Pulled By: ngimel

fbshipit-source-id: daf53a9a6b6636e1ede420a3a6d0397d4a8b450b
2020-10-19 15:44:31 -07:00

81 lines
1.7 KiB
C++

#pragma once
#include <c10/util/Exception.h>
namespace torch {
namespace jit {
namespace fuser {
namespace cuda {
// Common Functions
constexpr int64_t ceilDiv(int64_t a, int64_t b) {
return (a + b - 1) / b;
}
// Simple mixin for suppressing copy & move operations, ex:
//
// class Foo : public NonCopyable {
// ...
// };
//
class NonCopyable {
public:
NonCopyable() = default;
// No copy/move semantics
NonCopyable(const NonCopyable&) = delete;
NonCopyable& operator=(const NonCopyable&) = delete;
};
// A generic root for a hierarchy of polymorphic classes:
// - It ensures virtual destructors
// - Provides the base->as<Derived>() and node->isA<T>() notation
class PolymorphicBase {
public:
virtual ~PolymorphicBase() = default;
// Replacement for static_cast<T*>(ptr): ptr->as<T>()
// (checked in DEBUG builds)
template <class T>
T* as() {
#ifdef NDEBUG
auto downcast_ptr = static_cast<T*>(this);
#else
auto downcast_ptr = dynamic_cast<T*>(this);
TORCH_INTERNAL_ASSERT(downcast_ptr != nullptr);
#endif
return downcast_ptr;
}
template <class T>
const T* as() const {
#ifdef NDEBUG
auto downcast_ptr = static_cast<const T*>(this);
#else
auto downcast_ptr = dynamic_cast<const T*>(this);
TORCH_INTERNAL_ASSERT(downcast_ptr != nullptr);
#endif
return downcast_ptr;
}
// Check if the runtime time is T (or derived from T)
//
// NOTE: Don't use this for conditional casts. Use:
//
// if (auto t = dynamic_cast<T>(p)) { ... }
//
// instead of:
//
// if (p->isA<T>()) { auto t = p->as<T>(); ... }
//
template <class T>
bool isA() const {
return dynamic_cast<const T*>(this) != nullptr;
}
};
} // namespace cuda
} // namespace fuser
} // namespace jit
} // namespace torch