mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-07 12:21:27 +01:00
Summary: 1. Added CudaFusionGuard as the custom TypeCheck for nvfuser; enabled dynamic shape support with profiling executor; 2. dropped support for legacy fuser; 3. re-enabled nvfuser tests; 4. added registration for profiling record to allow profiling on user specified nodes. Pull Request resolved: https://github.com/pytorch/pytorch/pull/46452 Reviewed By: zou3519, anjali411 Differential Revision: D24364642 Pulled By: ngimel fbshipit-source-id: daf53a9a6b6636e1ede420a3a6d0397d4a8b450b
81 lines
1.7 KiB
C++
81 lines
1.7 KiB
C++
#pragma once
|
|
|
|
#include <c10/util/Exception.h>
|
|
|
|
namespace torch {
|
|
namespace jit {
|
|
namespace fuser {
|
|
namespace cuda {
|
|
|
|
// Common Functions
|
|
constexpr int64_t ceilDiv(int64_t a, int64_t b) {
|
|
return (a + b - 1) / b;
|
|
}
|
|
|
|
// Simple mixin for suppressing copy & move operations, ex:
|
|
//
|
|
// class Foo : public NonCopyable {
|
|
// ...
|
|
// };
|
|
//
|
|
class NonCopyable {
|
|
public:
|
|
NonCopyable() = default;
|
|
|
|
// No copy/move semantics
|
|
NonCopyable(const NonCopyable&) = delete;
|
|
NonCopyable& operator=(const NonCopyable&) = delete;
|
|
};
|
|
|
|
// A generic root for a hierarchy of polymorphic classes:
|
|
// - It ensures virtual destructors
|
|
// - Provides the base->as<Derived>() and node->isA<T>() notation
|
|
class PolymorphicBase {
|
|
public:
|
|
virtual ~PolymorphicBase() = default;
|
|
|
|
// Replacement for static_cast<T*>(ptr): ptr->as<T>()
|
|
// (checked in DEBUG builds)
|
|
template <class T>
|
|
T* as() {
|
|
#ifdef NDEBUG
|
|
auto downcast_ptr = static_cast<T*>(this);
|
|
#else
|
|
auto downcast_ptr = dynamic_cast<T*>(this);
|
|
TORCH_INTERNAL_ASSERT(downcast_ptr != nullptr);
|
|
#endif
|
|
return downcast_ptr;
|
|
}
|
|
|
|
template <class T>
|
|
const T* as() const {
|
|
#ifdef NDEBUG
|
|
auto downcast_ptr = static_cast<const T*>(this);
|
|
#else
|
|
auto downcast_ptr = dynamic_cast<const T*>(this);
|
|
TORCH_INTERNAL_ASSERT(downcast_ptr != nullptr);
|
|
#endif
|
|
return downcast_ptr;
|
|
}
|
|
|
|
// Check if the runtime time is T (or derived from T)
|
|
//
|
|
// NOTE: Don't use this for conditional casts. Use:
|
|
//
|
|
// if (auto t = dynamic_cast<T>(p)) { ... }
|
|
//
|
|
// instead of:
|
|
//
|
|
// if (p->isA<T>()) { auto t = p->as<T>(); ... }
|
|
//
|
|
template <class T>
|
|
bool isA() const {
|
|
return dynamic_cast<const T*>(this) != nullptr;
|
|
}
|
|
};
|
|
|
|
} // namespace cuda
|
|
} // namespace fuser
|
|
} // namespace jit
|
|
} // namespace torch
|