mirror of
https://github.com/zebrajr/pytorch.git
synced 2025-12-06 12:20:52 +01:00
[BE][MPS] Apply clang-format to mps headers (#140906)
It was a mistake to amiss them in the past All changes in this PR except ones to .lintrunner.toml are generated by running `lintrunner -a --take CLANGFORMAT --all-files` Pull Request resolved: https://github.com/pytorch/pytorch/pull/140906 Approved by: https://github.com/Skylion007
This commit is contained in:
parent
5a7e147ef3
commit
99014a297c
|
|
@ -56,10 +56,12 @@ code = 'CLANGFORMAT'
|
|||
include_patterns = [
|
||||
'aten/src/ATen/*.h',
|
||||
'aten/src/ATen/mps/**/*.mm',
|
||||
'aten/src/ATen/mps/**/*.h',
|
||||
'aten/src/ATen/xpu/**/*.h',
|
||||
'aten/src/ATen/xpu/**/*.cpp',
|
||||
'aten/src/ATen/native/mps/**/*.metal',
|
||||
'aten/src/ATen/native/mps/**/*.mm',
|
||||
'aten/src/ATen/native/mps/**/*.h',
|
||||
'aten/src/ATen/native/vulkan/**/*.h',
|
||||
'aten/src/ATen/native/vulkan/**/*.cpp',
|
||||
'aten/src/ATen/native/cuda/MultiTensorApply.cuh',
|
||||
|
|
|
|||
|
|
@ -12,8 +12,7 @@ C10_EXPORT TensorBase empty_mps(
|
|||
std::optional<Device> device_opt,
|
||||
std::optional<bool> pin_memory_opt,
|
||||
std::optional<c10::MemoryFormat> memory_format_opt);
|
||||
C10_EXPORT TensorBase empty_mps(
|
||||
IntArrayRef size, const TensorOptions &options);
|
||||
C10_EXPORT TensorBase empty_mps(IntArrayRef size, const TensorOptions& options);
|
||||
|
||||
C10_EXPORT TensorBase empty_strided_mps(
|
||||
IntArrayRef size,
|
||||
|
|
|
|||
|
|
@ -6,12 +6,12 @@
|
|||
#include <ATen/mps/MPSEvent.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
#include <mach/vm_page_size.h>
|
||||
#include <cstdio>
|
||||
#include <mutex>
|
||||
#include <set>
|
||||
#include <unordered_set>
|
||||
#include <mach/vm_page_size.h>
|
||||
#include <c10/util/flat_hash_map.h>
|
||||
|
||||
// this implementation is based on CUDACachingAllocator.
|
||||
// It utilizes Metal Heaps to improve the performance with buffer allocation.
|
||||
|
|
@ -24,8 +24,10 @@ static const size_t kMinLargeAlloc = MB(10); // allocations between 1 and 10 M
|
|||
static const size_t kRoundLarge = MB(2); // round up large allocations to 2 MiB
|
||||
static const size_t kSmallHeap = MB(8); // "small" allocations are packed in 8 MiB heaps
|
||||
static const size_t kLargeHeap = MB(32); // "large" allocations may be packed in 32 MiB heaps
|
||||
static const size_t kXLargeHeapD = MB(128); // "extra large" allocations on Discrete devices may be packed in 128 MiB heaps
|
||||
static const size_t kXLargeHeapU = MB(1024); // "extra large" allocations on Unified devices may be packed in 1 GiB heaps
|
||||
static const size_t kXLargeHeapD =
|
||||
MB(128); // "extra large" allocations on Discrete devices may be packed in 128 MiB heaps
|
||||
static const size_t kXLargeHeapU =
|
||||
MB(1024); // "extra large" allocations on Unified devices may be packed in 1 GiB heaps
|
||||
static const size_t kMaxScalarAlloc = (sizeof(int64_t)); // largest "scalar" allocation
|
||||
|
||||
// buffer pools could be customized with a combination of usage flags
|
||||
|
|
@ -67,10 +69,8 @@ struct BufferBlock {
|
|||
// Metal events used to sync GPU/CPU operations on the shared-storage buffers
|
||||
MPSEventPtr event;
|
||||
|
||||
BufferBlock(size_t Size, size_t RequestedSize = 0, const id<MTLBuffer> Buffer = nullptr,
|
||||
HeapBlock* Heap = nullptr) :
|
||||
buffer(Buffer), size(Size), requested_size(RequestedSize),
|
||||
heap(Heap), buf_id(Buffer ? ++buffer_counter : 0) { }
|
||||
BufferBlock(size_t Size, size_t RequestedSize = 0, const id<MTLBuffer> Buffer = nullptr, HeapBlock* Heap = nullptr)
|
||||
: buffer(Buffer), size(Size), requested_size(RequestedSize), heap(Heap), buf_id(Buffer ? ++buffer_counter : 0) {}
|
||||
|
||||
static bool Comparator(const BufferBlock* a, const BufferBlock* b) {
|
||||
return (a->size != b->size) ? a->size < b->size : (uintptr_t)a->buffer < (uintptr_t)b->buffer;
|
||||
|
|
@ -79,15 +79,19 @@ struct BufferBlock {
|
|||
assert(((Alignment - 1) & Alignment) == 0);
|
||||
return ((Size + Alignment - 1) & ~(Alignment - 1));
|
||||
}
|
||||
uint32_t retainCount() const { return [buffer retainCount]; }
|
||||
uint32_t retainCount() const {
|
||||
return [buffer retainCount];
|
||||
}
|
||||
};
|
||||
typedef bool (*BufferComparison)(const BufferBlock*, const BufferBlock*);
|
||||
|
||||
struct BufferPool;
|
||||
struct AllocParams {
|
||||
AllocParams(size_t Alloc_Size, size_t Requested_Size, BufferPool* Pool) :
|
||||
search_key(Alloc_Size), pool(Pool), requested_size(Requested_Size) { }
|
||||
size_t size() const { return search_key.size; }
|
||||
AllocParams(size_t Alloc_Size, size_t Requested_Size, BufferPool* Pool)
|
||||
: search_key(Alloc_Size), pool(Pool), requested_size(Requested_Size) {}
|
||||
size_t size() const {
|
||||
return search_key.size;
|
||||
}
|
||||
|
||||
BufferBlock search_key;
|
||||
BufferPool* pool;
|
||||
|
|
@ -102,7 +106,9 @@ struct AllocParams {
|
|||
|
||||
struct HeapBlock {
|
||||
id<MTLHeap> heap;
|
||||
struct { size_t total, available; } size;
|
||||
struct {
|
||||
size_t total, available;
|
||||
} size;
|
||||
BufferPool* pool;
|
||||
unsigned int n_buffers = 0;
|
||||
id_t heap_id;
|
||||
|
|
@ -111,9 +117,12 @@ struct HeapBlock {
|
|||
// counter to assign unique ids to heap blocks
|
||||
static uint64_t heap_counter;
|
||||
|
||||
HeapBlock(size_t Size, const id<MTLHeap> Heap = nullptr, BufferPool *Pool = nullptr) :
|
||||
heap(Heap), size({.total = Size, .available = Size}), pool(Pool),
|
||||
heap_id(Heap ? ++heap_counter : 0), is_split(true) { }
|
||||
HeapBlock(size_t Size, const id<MTLHeap> Heap = nullptr, BufferPool* Pool = nullptr)
|
||||
: heap(Heap),
|
||||
size({.total = Size, .available = Size}),
|
||||
pool(Pool),
|
||||
heap_id(Heap ? ++heap_counter : 0),
|
||||
is_split(true) {}
|
||||
|
||||
static MTLResourceOptions getOptions(uint32_t usage) {
|
||||
// TODO: check the caching performance of write-combined mode
|
||||
|
|
@ -126,7 +135,8 @@ struct HeapBlock {
|
|||
else
|
||||
options |= MTLResourceStorageModePrivate;
|
||||
|
||||
options |= (usage & UsageFlags::HAZARD) ? MTLResourceHazardTrackingModeTracked : MTLResourceHazardTrackingModeUntracked;
|
||||
options |=
|
||||
(usage & UsageFlags::HAZARD) ? MTLResourceHazardTrackingModeTracked : MTLResourceHazardTrackingModeUntracked;
|
||||
|
||||
return options;
|
||||
}
|
||||
|
|
@ -152,7 +162,8 @@ struct HeapBlock {
|
|||
d.cpuCacheMode = MTLCPUCacheModeDefaultCache;
|
||||
// this automatically handles Metal buffer access synchronizations at the
|
||||
// cost of slightly lower performance.
|
||||
d.hazardTrackingMode = (usage & UsageFlags::HAZARD) ? MTLHazardTrackingModeTracked : MTLHazardTrackingModeUntracked;
|
||||
d.hazardTrackingMode =
|
||||
(usage & UsageFlags::HAZARD) ? MTLHazardTrackingModeTracked : MTLHazardTrackingModeUntracked;
|
||||
d.resourceOptions = getOptions(usage);
|
||||
d.type = MTLHeapTypeAutomatic;
|
||||
id<MTLHeap> heap = [device newHeapWithDescriptor:d];
|
||||
|
|
@ -169,8 +180,8 @@ struct HeapBlock {
|
|||
return heapBlock;
|
||||
}
|
||||
static bool Comparator(const HeapBlock* a, const HeapBlock* b) {
|
||||
return (a->size.available != b->size.available) ? a->size.available < b->size.available :
|
||||
(uintptr_t)a->heap < (uintptr_t)b->heap;
|
||||
return (a->size.available != b->size.available) ? a->size.available < b->size.available
|
||||
: (uintptr_t)a->heap < (uintptr_t)b->heap;
|
||||
}
|
||||
static NSUInteger heapAvailableSize(id<MTLHeap> heap, size_t Alignment = vm_page_size) {
|
||||
return [heap maxAvailableSizeWithAlignment:Alignment];
|
||||
|
|
@ -205,8 +216,12 @@ struct HeapBlock {
|
|||
size.available = 0;
|
||||
return retainCount;
|
||||
}
|
||||
uint32_t retainCount() const { return [heap retainCount]; }
|
||||
void updateAvailableSize() { size.available = heapAvailableSize(heap); }
|
||||
uint32_t retainCount() const {
|
||||
return [heap retainCount];
|
||||
}
|
||||
void updateAvailableSize() {
|
||||
size.available = heapAvailableSize(heap);
|
||||
}
|
||||
};
|
||||
typedef bool (*HeapComparison)(const HeapBlock*, const HeapBlock*);
|
||||
|
||||
|
|
@ -219,9 +234,8 @@ struct BufferPool {
|
|||
SCALAR,
|
||||
};
|
||||
|
||||
BufferPool(const id<MTLDevice> Device, uint32_t Usage) :
|
||||
device(Device), usage(Usage),
|
||||
heaps(HeapBlock::Comparator), available_buffers(BufferBlock::Comparator) { }
|
||||
BufferPool(const id<MTLDevice> Device, uint32_t Usage)
|
||||
: device(Device), usage(Usage), heaps(HeapBlock::Comparator), available_buffers(BufferBlock::Comparator) {}
|
||||
|
||||
const id<MTLDevice> device;
|
||||
// usage flags to customize the pool for various purposes (see UsageFlags enum)
|
||||
|
|
@ -249,8 +263,8 @@ struct BufferPool {
|
|||
|
||||
class MPSHeapAllocatorImpl {
|
||||
public:
|
||||
explicit MPSHeapAllocatorImpl() :
|
||||
m_device(at::mps::MPSDevice::getInstance()->device()),
|
||||
explicit MPSHeapAllocatorImpl()
|
||||
: m_device(at::mps::MPSDevice::getInstance()->device()),
|
||||
m_max_buffer_size([m_device maxBufferLength]),
|
||||
m_stream(getDefaultMPSStream()),
|
||||
m_event_pool(getMPSEventPool()) {
|
||||
|
|
@ -298,22 +312,38 @@ public:
|
|||
// (see m_high_watermark_ratio for description)
|
||||
void setHighWatermarkRatio(double ratio);
|
||||
// (see m_low_watermark_limit for description)
|
||||
size_t getLowWatermarkLimit() const { return m_low_watermark_limit; }
|
||||
size_t getLowWatermarkLimit() const {
|
||||
return m_low_watermark_limit;
|
||||
}
|
||||
// (see m_max_total_allowed_size for description)
|
||||
size_t getHighWatermarkLimit() const { return m_max_total_allowed_size; }
|
||||
size_t getHighWatermarkLimit() const {
|
||||
return m_max_total_allowed_size;
|
||||
}
|
||||
// (see m_total_allocated_memory for description)
|
||||
size_t getTotalAllocatedMemory() const { return m_total_allocated_memory; }
|
||||
size_t getTotalAllocatedMemory() const {
|
||||
return m_total_allocated_memory;
|
||||
}
|
||||
// (see m_current_allocated_memory for description)
|
||||
size_t getCurrentAllocatedMemory() const { return m_current_allocated_memory; }
|
||||
size_t getCurrentAllocatedMemory() const {
|
||||
return m_current_allocated_memory;
|
||||
}
|
||||
// total GPU memory allocated in the process by Metal driver; including
|
||||
// implicit allocations from MPS/MPSGraph frameworks and MPSHeapAllocatorImpl.
|
||||
size_t getDriverAllocatedMemory() const { return current_allocated_size(); }
|
||||
size_t getDriverAllocatedMemory() const {
|
||||
return current_allocated_size();
|
||||
}
|
||||
// recommended Max memory for Metal
|
||||
size_t getRecommendedMaxMemory() const { return max_device_size(); }
|
||||
size_t getRecommendedMaxMemory() const {
|
||||
return max_device_size();
|
||||
}
|
||||
// (see enum DebugVerbosity for description)
|
||||
uint32_t getDebugVerbosity() const { return m_debug_verbosity; }
|
||||
uint32_t getDebugVerbosity() const {
|
||||
return m_debug_verbosity;
|
||||
}
|
||||
// returns the device that we allocate from
|
||||
inline id<MTLDevice> Device() const { return m_device; }
|
||||
inline id<MTLDevice> Device() const {
|
||||
return m_device;
|
||||
}
|
||||
|
||||
// TODO: make a common function to do size unit conversions in PyTorch.
|
||||
inline std::string format_size(uint64_t size) const;
|
||||
|
|
@ -387,14 +417,19 @@ private:
|
|||
size_t get_allocation_size(size_t size, uint32_t usage) const;
|
||||
// maximum size of device memory available for allocation in current process
|
||||
// Note: the recommendedMaxWorkingSetSize is typically 75% of the total system memory.
|
||||
size_t max_device_size() const { return [m_device recommendedMaxWorkingSetSize]; }
|
||||
size_t max_device_size() const {
|
||||
return [m_device recommendedMaxWorkingSetSize];
|
||||
}
|
||||
// there are implicit allocations from MPS backend, so we need to query the 'device' for
|
||||
// total allocated size instead of manually tracking in MPSAllocator
|
||||
size_t current_allocated_size() const { return [m_device currentAllocatedSize]; }
|
||||
size_t current_allocated_size() const {
|
||||
return [m_device currentAllocatedSize];
|
||||
}
|
||||
|
||||
bool trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) const {
|
||||
for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) {
|
||||
MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(buffer_block ? buffer_block->buffer : nullptr, event);
|
||||
MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(
|
||||
buffer_block ? buffer_block->buffer : nullptr, event);
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
#include <c10/core/Allocator.h>
|
||||
#include <c10/util/Registry.h>
|
||||
#include <ATen/core/ATen_fwd.h>
|
||||
|
||||
#define MB(x) (x * 1048576UL)
|
||||
|
||||
|
|
@ -20,10 +20,12 @@ public:
|
|||
virtual ssize_t getUnalignedBufferSize(const void* ptr) const = 0;
|
||||
virtual IntArrayRef getBufferShape(const void* ptr) const = 0;
|
||||
virtual id_t getBufferId(const void* ptr) const = 0;
|
||||
virtual void setBufferShape(const void* ptr, const IntArrayRef& shape) const = 0;
|
||||
virtual void setBufferShape(const void* ptr, const IntArrayRef& shape)
|
||||
const = 0;
|
||||
virtual bool isSharedBuffer(const void* ptr) const = 0;
|
||||
virtual bool isSharedStorageSupported() const = 0;
|
||||
virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size) const = 0;
|
||||
virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size)
|
||||
const = 0;
|
||||
virtual std::string formatSize(size_t size) const = 0;
|
||||
virtual void setLowWatermarkRatio(double ratio) const = 0;
|
||||
virtual void setHighWatermarkRatio(double ratio) const = 0;
|
||||
|
|
@ -34,7 +36,8 @@ public:
|
|||
virtual size_t getCurrentAllocatedMemory() const = 0;
|
||||
virtual size_t getDriverAllocatedMemory() const = 0;
|
||||
virtual size_t getRecommendedMaxMemory() const = 0;
|
||||
virtual std::pair<const void*, uint32_t> getSharedBufferPtr(const void* ptr) const = 0;
|
||||
virtual std::pair<const void*, uint32_t> getSharedBufferPtr(
|
||||
const void* ptr) const = 0;
|
||||
virtual bool recordEvents(c10::ArrayRef<const void*> buffers) const = 0;
|
||||
virtual bool waitForEvents(c10::ArrayRef<const void*> buffers) const = 0;
|
||||
};
|
||||
|
|
@ -52,7 +55,8 @@ class IMpsAllocatorCallback {
|
|||
virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0;
|
||||
};
|
||||
|
||||
// MPS allocator will execute every registered callback when a block of memory is freed.
|
||||
// MPS allocator will execute every registered callback when a block of memory
|
||||
// is freed.
|
||||
TORCH_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
|
||||
#define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \
|
||||
C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__)
|
||||
|
|
|
|||
|
|
@ -5,7 +5,6 @@
|
|||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
|
||||
#ifdef __OBJC__
|
||||
#include <Foundation/Foundation.h>
|
||||
#include <Metal/Metal.h>
|
||||
|
|
|
|||
|
|
@ -26,12 +26,17 @@ public:
|
|||
// blocks the CPU thread until all the GPU work that were scheduled
|
||||
// prior to recording this event are completed.
|
||||
bool synchronize();
|
||||
// resets this event with new parameters in case it gets reused from the event pool
|
||||
// resets this event with new parameters in case it gets reused from the event
|
||||
// pool
|
||||
void reset(MPSStream* stream, bool enable_timing);
|
||||
// returns the unique ID of the event instance
|
||||
id_t getID() const { return m_id; }
|
||||
id_t getID() const {
|
||||
return m_id;
|
||||
}
|
||||
// returns the completion timestamp of the event
|
||||
uint64_t getCompletionTime() const { return m_completion_time; }
|
||||
uint64_t getCompletionTime() const {
|
||||
return m_completion_time;
|
||||
}
|
||||
// if already recorded, waits for cpu_sync_cv to be signaled
|
||||
void waitForCpuSync();
|
||||
|
||||
|
|
|
|||
|
|
@ -17,7 +17,8 @@ struct rng_data_pod {
|
|||
};
|
||||
|
||||
TORCH_API const Generator& getDefaultMPSGenerator();
|
||||
TORCH_API Generator createMPSGenerator(uint64_t seed_val = default_rng_seed_val);
|
||||
TORCH_API Generator
|
||||
createMPSGenerator(uint64_t seed_val = default_rng_seed_val);
|
||||
|
||||
} // namespace mps::detail
|
||||
|
||||
|
|
@ -37,10 +38,18 @@ struct TORCH_API MPSGeneratorImpl : public c10::GeneratorImpl {
|
|||
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
|
||||
void update_philox_counters();
|
||||
|
||||
void set_engine(at::Philox4_32 engine) { engine_ = engine; }
|
||||
at::Philox4_32 engine() { return engine_; }
|
||||
uint32_t* state_data() { return data_.state.data(); }
|
||||
static DeviceType device_type() { return DeviceType::MPS; }
|
||||
void set_engine(at::Philox4_32 engine) {
|
||||
engine_ = engine;
|
||||
}
|
||||
at::Philox4_32 engine() {
|
||||
return engine_;
|
||||
}
|
||||
uint32_t* state_data() {
|
||||
return data_.state.data();
|
||||
}
|
||||
static DeviceType device_type() {
|
||||
return DeviceType::MPS;
|
||||
}
|
||||
|
||||
private:
|
||||
mps::detail::rng_data_pod data_;
|
||||
|
|
|
|||
|
|
@ -1,12 +1,12 @@
|
|||
// Copyright © 2022 Apple Inc.
|
||||
|
||||
#pragma once
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/mps/MPSEvent.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <c10/core/impl/DeviceGuardImplInterface.h>
|
||||
#include <c10/macros/Macros.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <ATen/Context.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/mps/MPSEvent.h>
|
||||
|
||||
#ifdef __OBJC__
|
||||
#include <Foundation/Foundation.h>
|
||||
|
|
@ -18,11 +18,10 @@
|
|||
#include <c10/core/MemoryFormat.h>
|
||||
#include <c10/core/Storage.h>
|
||||
#include <c10/core/TensorImpl.h>
|
||||
#include <sys/_types/_size_t.h>
|
||||
#include <memory>
|
||||
#include <c10/core/UndefinedTensorImpl.h>
|
||||
#include <c10/util/intrusive_ptr.h>
|
||||
|
||||
#include <sys/_types/_size_t.h>
|
||||
#include <memory>
|
||||
|
||||
namespace at::mps {
|
||||
|
||||
|
|
@ -30,7 +29,8 @@ typedef MPSEvent* mpsEvent_t;
|
|||
|
||||
// TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl
|
||||
// https://github.com/pytorch/pytorch/issues/77170
|
||||
struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface {
|
||||
struct TORCH_API MPSGuardImpl final
|
||||
: public c10::impl::DeviceGuardImplInterface {
|
||||
static constexpr c10::DeviceType static_type = c10::DeviceType::MPS;
|
||||
|
||||
// constructor
|
||||
|
|
@ -91,13 +91,10 @@ struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface
|
|||
}
|
||||
|
||||
// Event-related functions
|
||||
void createEvent(
|
||||
mpsEvent_t* event,
|
||||
const EventFlag flag) const;
|
||||
void createEvent(mpsEvent_t* event, const EventFlag flag) const;
|
||||
|
||||
void destroyEvent(
|
||||
void* event,
|
||||
const DeviceIndex device_index) const noexcept override;
|
||||
void destroyEvent(void* event, const DeviceIndex device_index)
|
||||
const noexcept override;
|
||||
|
||||
void record(
|
||||
void** event,
|
||||
|
|
@ -105,14 +102,11 @@ struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface
|
|||
const DeviceIndex device_index,
|
||||
const EventFlag flag) const override;
|
||||
|
||||
void block(
|
||||
void* event,
|
||||
const Stream& stream) const override;
|
||||
void block(void* event, const Stream& stream) const override;
|
||||
|
||||
bool queryEvent(void* event) const override;
|
||||
|
||||
void synchronizeDevice(const DeviceIndex device_index) const override;
|
||||
|
||||
};
|
||||
|
||||
/// A variant of OptionalDeviceGuard that is specialized for MPS.
|
||||
|
|
@ -175,7 +169,6 @@ struct OptionalMPSGuard {
|
|||
c10::impl::InlineOptionalDeviceGuard<MPSGuardImpl> guard_;
|
||||
};
|
||||
|
||||
|
||||
C10_REGISTER_GUARD_IMPL(MPS, MPSGuardImpl)
|
||||
|
||||
} // namespace at::mps
|
||||
|
|
|
|||
|
|
@ -2,8 +2,8 @@
|
|||
|
||||
#pragma once
|
||||
|
||||
#include <ATen/detail/MPSHooksInterface.h>
|
||||
#include <ATen/Generator.h>
|
||||
#include <ATen/detail/MPSHooksInterface.h>
|
||||
#include <ATen/mps/MPSEvent.h>
|
||||
#include <optional>
|
||||
|
||||
|
|
@ -38,7 +38,8 @@ struct MPSHooks : public at::MPSHooksInterface {
|
|||
Allocator* getPinnedMemoryAllocator() const override;
|
||||
|
||||
// MPSProfiler interface
|
||||
void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const override;
|
||||
void profilerStartTrace(const std::string& mode, bool waitUntilCompleted)
|
||||
const override;
|
||||
void profilerStopTrace() const override;
|
||||
|
||||
// MPSEvent interface
|
||||
|
|
@ -48,7 +49,8 @@ struct MPSHooks : public at::MPSHooksInterface {
|
|||
void waitForEvent(uint32_t event_id) const override;
|
||||
void synchronizeEvent(uint32_t event_id) const override;
|
||||
bool queryEvent(uint32_t event_id) const override;
|
||||
double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const override;
|
||||
double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id)
|
||||
const override;
|
||||
|
||||
// Compatibility with Accelerator API
|
||||
bool hasPrimaryContext(DeviceIndex device_index) const override {
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@
|
|||
#pragma once
|
||||
|
||||
#include <ATen/Tensor.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
#include <ATen/mps/MPSAllocatorInterface.h>
|
||||
#include <ATen/mps/MPSStream.h>
|
||||
|
||||
#include <os/signpost.h>
|
||||
#include <os/log.h>
|
||||
#include <os/signpost.h>
|
||||
|
||||
#include <atomic>
|
||||
#include <ctime>
|
||||
|
|
@ -29,8 +29,8 @@ struct BaseInfo {
|
|||
CPU_FALLBACK,
|
||||
};
|
||||
|
||||
BaseInfo(Type infoType, uint64_t Id, const uintptr_t Handle) :
|
||||
type(infoType), profileId(Id), handle(Handle) { }
|
||||
BaseInfo(Type infoType, uint64_t Id, const uintptr_t Handle)
|
||||
: type(infoType), profileId(Id), handle(Handle) {}
|
||||
virtual ~BaseInfo() = default;
|
||||
|
||||
// type of profiling info
|
||||
|
|
@ -41,30 +41,36 @@ struct BaseInfo {
|
|||
// since it's possible to use event and interval-based signposts at the
|
||||
// same time, we need separate IDs for each.
|
||||
os_signpost_id_t eventSignpostId = 0, intervalSignpostId = 0;
|
||||
// accumulated GPU time in ms (obtained from CompletionHandler's "GPUEndTime - GPUStartTime")
|
||||
// accumulated GPU time in ms (obtained from CompletionHandler's "GPUEndTime -
|
||||
// GPUStartTime")
|
||||
std::atomic<double> totalGpuTime{0.0};
|
||||
// accumulated Scheduling time in ms (obtained from CompletionHandler's "KernelEndTime - KernelStartTime")
|
||||
// accumulated Scheduling time in ms (obtained from CompletionHandler's
|
||||
// "KernelEndTime - KernelStartTime")
|
||||
std::atomic<double> totalSchedulingTime{0.0};
|
||||
// indicates if the operation or copy execution has completed
|
||||
std::atomic_bool completed{false};
|
||||
// handle used to identify the profile info's instance (usually the pointer)
|
||||
const uintptr_t handle;
|
||||
|
||||
virtual const std::string toString(double gpuTime = 0, double schedulingTime = 0) const;
|
||||
virtual const std::string toString(
|
||||
double gpuTime = 0,
|
||||
double schedulingTime = 0) const;
|
||||
// builds a string for a tensor (format: Device:ScalarType[tensor.sizes()])
|
||||
static std::string buildTensorString(const Tensor& tensor, bool includeBufferId = false) {
|
||||
static std::string buildTensorString(
|
||||
const Tensor& tensor,
|
||||
bool includeBufferId = false) {
|
||||
if (tensor.defined()) {
|
||||
std::stringstream tensorStr;
|
||||
auto deviceType = tensor.device().type();
|
||||
tensorStr << c10::DeviceTypeName(deviceType);
|
||||
// see comments for INCLUDE_BUFFER_ID
|
||||
if (includeBufferId && deviceType == at::kMPS) {
|
||||
id<MTLBuffer> buffer = __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
||||
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer))
|
||||
<< ":" << buffer.retainCount << ")";
|
||||
id<MTLBuffer> buffer =
|
||||
__builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
|
||||
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer)) << ":"
|
||||
<< buffer.retainCount << ")";
|
||||
}
|
||||
tensorStr << ":"
|
||||
<< tensor.scalar_type() << tensor.sizes();
|
||||
tensorStr << ":" << tensor.scalar_type() << tensor.sizes();
|
||||
return tensorStr.str();
|
||||
} else {
|
||||
return "undefined";
|
||||
|
|
@ -76,16 +82,23 @@ struct BaseInfo {
|
|||
};
|
||||
|
||||
struct OperationInfo : BaseInfo {
|
||||
OperationInfo(const void* Handle, bool IsGraph, uint64_t Id, const std::string& StrKey) :
|
||||
BaseInfo(IsGraph ? Type::GRAPH : Type::KERNEL, Id, uintptr_t(Handle)), strKey(StrKey) { }
|
||||
OperationInfo(
|
||||
const void* Handle,
|
||||
bool IsGraph,
|
||||
uint64_t Id,
|
||||
const std::string& StrKey)
|
||||
: BaseInfo(IsGraph ? Type::GRAPH : Type::KERNEL, Id, uintptr_t(Handle)),
|
||||
strKey(StrKey) {}
|
||||
|
||||
uint64_t runCount = 0;
|
||||
std::string strKey;
|
||||
|
||||
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
|
||||
const std::string toString(double gpuTime = 0, double schedulingTime = 0)
|
||||
const override;
|
||||
|
||||
// builds a string for a kernel
|
||||
static std::string buildKernelString(const std::string& kernelName,
|
||||
static std::string buildKernelString(
|
||||
const std::string& kernelName,
|
||||
const TensorList& tensors,
|
||||
bool includeBufferId = false) {
|
||||
std::stringstream kernelStr;
|
||||
|
|
@ -98,19 +111,20 @@ struct OperationInfo : BaseInfo {
|
|||
};
|
||||
|
||||
struct CpuFbInfo : BaseInfo {
|
||||
CpuFbInfo(uint64_t Id, const std::string& OpName) :
|
||||
BaseInfo(Type::CPU_FALLBACK, Id, 0), opName(OpName) { }
|
||||
CpuFbInfo(uint64_t Id, const std::string& OpName)
|
||||
: BaseInfo(Type::CPU_FALLBACK, Id, 0), opName(OpName) {}
|
||||
|
||||
uint64_t runCount = 0;
|
||||
// the current and total overhead of copies in bytes required to convert the Op's
|
||||
// input tensors from MPS to CPU and then output from CPU back to MPS
|
||||
// the current and total overhead of copies in bytes required to convert the
|
||||
// Op's input tensors from MPS to CPU and then output from CPU back to MPS
|
||||
size_t currentCopyOverhead = 0;
|
||||
size_t totalCopyOverhead = 0;
|
||||
std::string opName;
|
||||
std::string strKey;
|
||||
uint64_t startTime = 0;
|
||||
|
||||
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
|
||||
const std::string toString(double gpuTime = 0, double schedulingTime = 0)
|
||||
const override;
|
||||
|
||||
void updateCopyOverhead(const TensorList& tensors) {
|
||||
currentCopyOverhead = 0;
|
||||
|
|
@ -130,9 +144,17 @@ struct CopyInfo : BaseInfo {
|
|||
CPU_TO_MPS,
|
||||
};
|
||||
|
||||
CopyInfo(const void* Handle, size_t Length, uint64_t Id, bool IsNonBlocking, bool UsesBlitter) :
|
||||
BaseInfo(Type::COPY, Id, uintptr_t(Handle)), kind(Kind::MPS_TO_MPS),
|
||||
length(Length), isNonBlocking(IsNonBlocking), usesBlitter(UsesBlitter) { }
|
||||
CopyInfo(
|
||||
const void* Handle,
|
||||
size_t Length,
|
||||
uint64_t Id,
|
||||
bool IsNonBlocking,
|
||||
bool UsesBlitter)
|
||||
: BaseInfo(Type::COPY, Id, uintptr_t(Handle)),
|
||||
kind(Kind::MPS_TO_MPS),
|
||||
length(Length),
|
||||
isNonBlocking(IsNonBlocking),
|
||||
usesBlitter(UsesBlitter) {}
|
||||
|
||||
Kind kind;
|
||||
size_t length;
|
||||
|
|
@ -143,11 +165,17 @@ struct CopyInfo : BaseInfo {
|
|||
// for copies that don't use blitters, we measure CPU time
|
||||
uint64_t startTime = 0;
|
||||
|
||||
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
|
||||
const std::string toString(double gpuTime = 0, double schedulingTime = 0)
|
||||
const override;
|
||||
|
||||
static std::string buildTensorString(const void* buffer, const OptionalTensorRef tensor, bool includeBufferId = false);
|
||||
static std::string buildTensorString(
|
||||
const void* buffer,
|
||||
const OptionalTensorRef tensor,
|
||||
bool includeBufferId = false);
|
||||
|
||||
static bool isStorageOnMPS(const void* buffer, const OptionalTensorRef tensor) {
|
||||
static bool isStorageOnMPS(
|
||||
const void* buffer,
|
||||
const OptionalTensorRef tensor) {
|
||||
if (tensor.has_value()) {
|
||||
return tensor->device().type() == at::kMPS;
|
||||
}
|
||||
|
|
@ -156,8 +184,11 @@ struct CopyInfo : BaseInfo {
|
|||
return getIMPSAllocator()->getUnalignedBufferSize(buffer) >= 0;
|
||||
}
|
||||
|
||||
static Kind getCopyKind(const void* srcBuffer, const void* dstBuffer,
|
||||
const OptionalTensorRef srcTensor, const OptionalTensorRef dstTensor) {
|
||||
static Kind getCopyKind(
|
||||
const void* srcBuffer,
|
||||
const void* dstBuffer,
|
||||
const OptionalTensorRef srcTensor,
|
||||
const OptionalTensorRef dstTensor) {
|
||||
const bool isSrcOnMPS = isStorageOnMPS(srcBuffer, srcTensor);
|
||||
const bool isDstOnMPS = isStorageOnMPS(dstBuffer, dstTensor);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isSrcOnMPS || isDstOnMPS);
|
||||
|
|
@ -171,8 +202,9 @@ struct CopyInfo : BaseInfo {
|
|||
};
|
||||
|
||||
struct CopyStat : CopyInfo {
|
||||
explicit CopyStat(std::string CopyKindStr) :
|
||||
CopyInfo(nullptr, 0, 0, false, false), kindStr(std::move(CopyKindStr)) {}
|
||||
explicit CopyStat(std::string CopyKindStr)
|
||||
: CopyInfo(nullptr, 0, 0, false, false),
|
||||
kindStr(std::move(CopyKindStr)) {}
|
||||
// total number of copies
|
||||
size_t totalCount = 0;
|
||||
// number of Scalar copies (i.e., less than sizeof(int64))
|
||||
|
|
@ -192,8 +224,8 @@ public:
|
|||
// lower 16 bits used for profiler options
|
||||
enum ProfileOptions : uint32_t {
|
||||
OPTIONS_NONE = 0,
|
||||
// ALL_* means, all signpost types (RUN_OPERATION|BLIT_COPY|CPU_FALLBACK, etc.)
|
||||
// (used for convenience to not compute bit flags by OR-ing manually)
|
||||
// ALL_* means, all signpost types (RUN_OPERATION|BLIT_COPY|CPU_FALLBACK,
|
||||
// etc.) (used for convenience to not compute bit flags by OR-ing manually)
|
||||
// trace all signpost types using events
|
||||
ALL_SIGNPOST_EVENTS = (1 << 0),
|
||||
// trace all signpost types using intervals
|
||||
|
|
@ -206,8 +238,8 @@ public:
|
|||
// and not schedule time.
|
||||
INCLUDE_SCHEDULE_INTERVAL = (1 << 3),
|
||||
|
||||
// use these if you need to trace signposts types individually (rarely required)
|
||||
// trace signpost using intervals
|
||||
// use these if you need to trace signposts types individually (rarely
|
||||
// required) trace signpost using intervals
|
||||
USE_INTERVALS = (1 << 4),
|
||||
// trace signpost by emitting events
|
||||
USE_EVENTS = (1 << 5),
|
||||
|
|
@ -238,17 +270,21 @@ public:
|
|||
OPERATION_INFO = (1 << 0),
|
||||
// prints copy info (src/dst tensors/buffers, size, etc.) during execution
|
||||
COPY_INFO = (1 << 1),
|
||||
// prints CPU Fallback info (id/runCount/opName/copyOverhead) during execution
|
||||
// prints CPU Fallback info (id/runCount/opName/copyOverhead) during
|
||||
// execution
|
||||
CPU_FALLBACK_INFO = (1 << 2),
|
||||
|
||||
// Profiling Statistics logging options when process terminates
|
||||
// ------------------------------------------------------------
|
||||
// prints all stats (OPERATION_STATS, COPY_STATS, CPU_FALLBACK_STATS) before process terminates
|
||||
// this is convenient to not combine following stats bit flags manually
|
||||
// prints all stats (OPERATION_STATS, COPY_STATS, CPU_FALLBACK_STATS) before
|
||||
// process terminates this is convenient to not combine following stats bit
|
||||
// flags manually
|
||||
ALL_STATS = (1 << 3),
|
||||
// prints operation stats (GPU times, run count, etc.) before process terminates
|
||||
// prints operation stats (GPU times, run count, etc.) before process
|
||||
// terminates
|
||||
OPERATION_STATS = (1 << 4),
|
||||
// prints copies stats (GPU times, copy kinds, sizes, etc.) before process terminates
|
||||
// prints copies stats (GPU times, copy kinds, sizes, etc.) before process
|
||||
// terminates
|
||||
COPY_STATS = (1 << 5),
|
||||
// prints CPU Fallback stats (CPU times, run times, size of MPS<->CPU copies
|
||||
// for tensors, etc.) before process terminates
|
||||
|
|
@ -256,17 +292,18 @@ public:
|
|||
|
||||
// Metadata format options when logging the info
|
||||
// ---------------------------------------------
|
||||
// if enabled, includes GPU run time in metadata (i.e., GPUEndTime-GPUStartTime
|
||||
// from Metal Command Buffers) (e.g., [GPU=0.324 ms])
|
||||
// if enabled, includes GPU run time in metadata (i.e.,
|
||||
// GPUEndTime-GPUStartTime from Metal Command Buffers) (e.g., [GPU=0.324
|
||||
// ms])
|
||||
INCLUDE_GPU_TIME = (1 << 7),
|
||||
// if enabled, includes GPU scheduling time in metadata separately
|
||||
// (i.e., KernelEndTime-KernelStartTime from Metal Command Buffers)
|
||||
// e.g., [GPU=0.324 ms, KRNL=0.036 ms]
|
||||
INCLUDE_KERNEL_TIME = (1 << 8),
|
||||
// if enabled, includes the unique buffer ID in metadata for the storage
|
||||
// of a tensor that was allocated on MPSAllocator. This is useful (along with
|
||||
// the EV "PYTORCH_DEBUG_MPS_ALLOCATOR") to identify buffers that are involved
|
||||
// with various operations.
|
||||
// of a tensor that was allocated on MPSAllocator. This is useful (along
|
||||
// with the EV "PYTORCH_DEBUG_MPS_ALLOCATOR") to identify buffers that are
|
||||
// involved with various operations.
|
||||
INCLUDE_BUFFER_ID = (1 << 9),
|
||||
|
||||
// used for sanity check (Change this when new option added)
|
||||
|
|
@ -276,15 +313,28 @@ public:
|
|||
explicit MPSProfiler();
|
||||
~MPSProfiler();
|
||||
|
||||
// the handle is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal Kernels
|
||||
// the beginProfile*() functions return a profileId which is unique per graph/kernel/copy
|
||||
uint64_t beginProfileKernel(const void* handle, const std::string& strKey, bool isGraph);
|
||||
uint64_t beginProfileKernel(const void* handle, const std::string& kernelName, const TensorList& tensors);
|
||||
uint64_t beginProfileCopy(const void* srcBuffer, const void* dstBuffer,
|
||||
// the handle is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal
|
||||
// Kernels the beginProfile*() functions return a profileId which is unique
|
||||
// per graph/kernel/copy
|
||||
uint64_t beginProfileKernel(
|
||||
const void* handle,
|
||||
const std::string& strKey,
|
||||
bool isGraph);
|
||||
uint64_t beginProfileKernel(
|
||||
const void* handle,
|
||||
const std::string& kernelName,
|
||||
const TensorList& tensors);
|
||||
uint64_t beginProfileCopy(
|
||||
const void* srcBuffer,
|
||||
const void* dstBuffer,
|
||||
const OptionalTensorRef srcTensor,
|
||||
const OptionalTensorRef dstTensor,
|
||||
size_t length, bool isNonBlocking, bool usesBlitter = true);
|
||||
uint64_t beginProfileCPUFallback(const std::string& opName, const TensorList& tensors);
|
||||
size_t length,
|
||||
bool isNonBlocking,
|
||||
bool usesBlitter = true);
|
||||
uint64_t beginProfileCPUFallback(
|
||||
const std::string& opName,
|
||||
const TensorList& tensors);
|
||||
void beginProfileGPUInterval(const void* handle);
|
||||
|
||||
void endProfileCopy(uint64_t profileId, SyncType syncType);
|
||||
|
|
@ -309,7 +359,8 @@ public:
|
|||
// logging are enabled for the SignpostTypes
|
||||
bool isOperationProfilingEnabled() const {
|
||||
return (m_signpost_types & SignpostTypes::RUN_OPERATION) ||
|
||||
(m_log_options & (LogOptions::OPERATION_INFO | LogOptions::OPERATION_STATS));
|
||||
(m_log_options &
|
||||
(LogOptions::OPERATION_INFO | LogOptions::OPERATION_STATS));
|
||||
}
|
||||
bool isCopyProfilingEnabled() const {
|
||||
return (m_signpost_types & SignpostTypes::BLIT_COPY) ||
|
||||
|
|
@ -317,14 +368,16 @@ public:
|
|||
}
|
||||
bool isCPUFallbackProfilingEnabled() const {
|
||||
return (m_signpost_types & SignpostTypes::CPU_FALLBACK) ||
|
||||
(m_log_options & (LogOptions::CPU_FALLBACK_INFO | LogOptions::CPU_FALLBACK_STATS));
|
||||
(m_log_options &
|
||||
(LogOptions::CPU_FALLBACK_INFO | LogOptions::CPU_FALLBACK_STATS));
|
||||
}
|
||||
bool isSignpostTracingEnabled() const {
|
||||
return (m_signpost_types != SignpostTypes::SIGNPOST_NONE);
|
||||
}
|
||||
|
||||
private:
|
||||
// indicates what type of signpost types are enabled and traced by MPS profiler.
|
||||
// indicates what type of signpost types are enabled and traced by MPS
|
||||
// profiler.
|
||||
uint32_t m_signpost_types = 0;
|
||||
uint32_t m_profile_options = 0;
|
||||
uint32_t m_log_options = 0;
|
||||
|
|
@ -332,14 +385,15 @@ public:
|
|||
uint64_t m_graph_counter = 0;
|
||||
uint64_t m_cpu_fb_counter = 0;
|
||||
uint64_t m_copy_counter = 0;
|
||||
// technically, it's possible to trace both events and intervals at the same time
|
||||
// so we use separate os_log categories for them
|
||||
// technically, it's possible to trace both events and intervals at the same
|
||||
// time so we use separate os_log categories for them
|
||||
os_log_t m_os_log_events;
|
||||
os_log_t m_os_log_intervals;
|
||||
// stats logging could run either from destructor or signal handler
|
||||
// so this is used to check if logging has already started.
|
||||
std::atomic_bool hasLoggedStats{false};
|
||||
// indicates there are pending completionHandler callbacks that haven't been called yet.
|
||||
// indicates there are pending completionHandler callbacks that haven't been
|
||||
// called yet.
|
||||
std::atomic_bool hasPendingCompletionHandlers{false};
|
||||
// used to capture sigint signal to log profiling stats
|
||||
static struct sigaction currentSigint, previousSigint;
|
||||
|
|
@ -347,40 +401,62 @@ public:
|
|||
// We use the following lists for two reasons:
|
||||
// 1- for interval-based signposts the "begin" point won't be in same function
|
||||
// as the "end" point where we need to be able to retrieve signpost's info
|
||||
// 2- if Operations info need to be logged when process ends using LogOptions::OPERATION_INFO.
|
||||
// 2- if Operations info need to be logged when process ends using
|
||||
// LogOptions::OPERATION_INFO.
|
||||
|
||||
// the pointer key for this map is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal Kernels
|
||||
// this list is retained and could be logged along with aggregate profiling numbers when the process ends.
|
||||
std::unordered_map<uintptr_t, std::unique_ptr<OperationInfo>> m_op_info_list{};
|
||||
// the string key for this map is the op name that we fall back to execute on CPU
|
||||
// this list is retained and could be logged along with aggregate profiling numbers when the process ends.
|
||||
std::unordered_map<std::string, std::unique_ptr<CpuFbInfo>> m_cpu_fb_info_list{};
|
||||
// the pointer key for this map is either "MPSGraph*" or
|
||||
// "id<MTLComputePipelineState>" for Metal Kernels this list is retained and
|
||||
// could be logged along with aggregate profiling numbers when the process
|
||||
// ends.
|
||||
std::unordered_map<uintptr_t, std::unique_ptr<OperationInfo>>
|
||||
m_op_info_list{};
|
||||
// the string key for this map is the op name that we fall back to execute on
|
||||
// CPU this list is retained and could be logged along with aggregate
|
||||
// profiling numbers when the process ends.
|
||||
std::unordered_map<std::string, std::unique_ptr<CpuFbInfo>>
|
||||
m_cpu_fb_info_list{};
|
||||
// this list contains the info for copies, and its key is the unique profileId
|
||||
// which is generated from m_copy_counter
|
||||
// The copyInfo list is not retained.
|
||||
std::unordered_map<uint64_t, std::unique_ptr<CopyInfo>> m_copy_info_list{};
|
||||
// a short list that contains copy stats
|
||||
std::unordered_map<CopyInfo::Kind, std::unique_ptr<CopyStat>> m_copy_stat_list{};
|
||||
std::unordered_map<CopyInfo::Kind, std::unique_ptr<CopyStat>>
|
||||
m_copy_stat_list{};
|
||||
|
||||
mutable MTLCaptureManager* captureManager = nil;
|
||||
unsigned captureCount = 0;
|
||||
|
||||
void initialize();
|
||||
void beginProfileExecution(BaseInfo& info, bool cpuExecution = false);
|
||||
void endProfileExecution(BaseInfo& info, os_signpost_id_t event_signpost_id,
|
||||
void endProfileExecution(
|
||||
BaseInfo& info,
|
||||
os_signpost_id_t event_signpost_id,
|
||||
os_signpost_id_t interval_signpost_id,
|
||||
double gpuTime, double schedulingTime);
|
||||
double gpuTime,
|
||||
double schedulingTime);
|
||||
void addProfilerScheduledHandler(BaseInfo& info);
|
||||
void addProfilerCompletedHandler(BaseInfo& info, SyncType syncType);
|
||||
void emitSignpostEvent(SignpostTypes signpost_type, os_signpost_id_t signpost_id,
|
||||
void emitSignpostEvent(
|
||||
SignpostTypes signpost_type,
|
||||
os_signpost_id_t signpost_id,
|
||||
const std::string& msg) const;
|
||||
void beginSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id,
|
||||
void beginSignpostInterval(
|
||||
SignpostTypes signpost_type,
|
||||
os_signpost_id_t signpost_id,
|
||||
const std::string& msg) const;
|
||||
void endSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id) const;
|
||||
void endSignpostInterval(
|
||||
SignpostTypes signpost_type,
|
||||
os_signpost_id_t signpost_id) const;
|
||||
|
||||
void updateCopyStats(const CopyInfo& copyInfo, double gpuTime, double schedulingTime);
|
||||
// returns true if logging the profiling info "during the execution" is enabled
|
||||
bool isProfileInfoLoggingEnabled(BaseInfo::Type infoType, bool isExecutionEnded);
|
||||
void updateCopyStats(
|
||||
const CopyInfo& copyInfo,
|
||||
double gpuTime,
|
||||
double schedulingTime);
|
||||
// returns true if logging the profiling info "during the execution" is
|
||||
// enabled
|
||||
bool isProfileInfoLoggingEnabled(
|
||||
BaseInfo::Type infoType,
|
||||
bool isExecutionEnded);
|
||||
// logs all the profiling stats that are enabled
|
||||
void logProfilingStats();
|
||||
// logs kernel profiling stats when the process ends.
|
||||
|
|
@ -390,7 +466,9 @@ public:
|
|||
// logs copy profiling stats when the process ends.
|
||||
void logCopyProfilingStats(std::FILE* f) const;
|
||||
|
||||
os_signpost_id_t generateSignpostId(os_signpost_type_t signpostType, const void* ptr = nullptr);
|
||||
os_signpost_id_t generateSignpostId(
|
||||
os_signpost_type_t signpostType,
|
||||
const void* ptr = nullptr);
|
||||
static SignpostTypes getSignpostType(BaseInfo::Type infoType);
|
||||
static void handleIntSignal(int signal);
|
||||
};
|
||||
|
|
|
|||
|
|
@ -5,10 +5,10 @@
|
|||
#include <cstdint>
|
||||
#include <utility>
|
||||
|
||||
#include <c10/core/DeviceGuard.h>
|
||||
#include <c10/util/Exception.h>
|
||||
#include <c10/core/Stream.h>
|
||||
#include <ATen/mps/MPSDevice.h>
|
||||
#include <c10/core/DeviceGuard.h>
|
||||
#include <c10/core/Stream.h>
|
||||
#include <c10/util/Exception.h>
|
||||
|
||||
#ifdef __OBJC__
|
||||
#include <Foundation/Foundation.h>
|
||||
|
|
@ -32,7 +32,6 @@ typedef void* MTLDevice_t;
|
|||
#define nil NULL;
|
||||
#endif
|
||||
|
||||
|
||||
namespace at::mps {
|
||||
|
||||
//-----------------------------------------------------------------
|
||||
|
|
@ -47,8 +46,7 @@ enum class SyncType {
|
|||
COMMIT_ADAPTIVE, // commit adaptively based on available memory
|
||||
};
|
||||
|
||||
class TORCH_API MPSStream
|
||||
{
|
||||
class TORCH_API MPSStream {
|
||||
public:
|
||||
enum Unchecked { UNCHECKED };
|
||||
|
||||
|
|
@ -57,32 +55,55 @@ public:
|
|||
explicit MPSStream(Stream stream);
|
||||
|
||||
~MPSStream();
|
||||
MTLCommandQueue_t commandQueue() const { return _commandQueue; };
|
||||
dispatch_queue_t queue() const { return _serialQueue; }
|
||||
MTLCommandQueue_t commandQueue() const {
|
||||
return _commandQueue;
|
||||
};
|
||||
dispatch_queue_t queue() const {
|
||||
return _serialQueue;
|
||||
}
|
||||
|
||||
MPSCommandBuffer* commandBuffer();
|
||||
MTLComputeCommandEncoder_t commandEncoder();
|
||||
void endKernelCoalescing();
|
||||
void synchronize(SyncType syncType);
|
||||
void fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE);
|
||||
void copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
|
||||
size_t length, size_t srcOffset, size_t dstOffset,
|
||||
uint64_t profileId, SyncType syncType = SyncType::NONE);
|
||||
void copy_and_sync(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
|
||||
size_t length, size_t srcOffset, size_t dstOffset,
|
||||
bool non_blocking, uint64_t profileId);
|
||||
void executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType = SyncType::NONE);
|
||||
void copy(id<MTLBuffer> srcBuffer,
|
||||
id<MTLBuffer> dstBuffer,
|
||||
size_t length,
|
||||
size_t srcOffset,
|
||||
size_t dstOffset,
|
||||
uint64_t profileId,
|
||||
SyncType syncType = SyncType::NONE);
|
||||
void copy_and_sync(id<MTLBuffer> srcBuffer,
|
||||
id<MTLBuffer> dstBuffer,
|
||||
size_t length,
|
||||
size_t srcOffset,
|
||||
size_t dstOffset,
|
||||
bool non_blocking,
|
||||
uint64_t profileId);
|
||||
void executeMPSGraph(MPSGraph* mpsGraph,
|
||||
NSDictionary* feeds,
|
||||
NSDictionary* results,
|
||||
SyncType syncType = SyncType::NONE);
|
||||
void addCompletedHandler(MTLCommandBufferHandler block);
|
||||
|
||||
/// Get the MPS device index that this stream is associated with.
|
||||
c10::DeviceIndex device_index() const { return _stream.device_index(); }
|
||||
c10::DeviceIndex device_index() const {
|
||||
return _stream.device_index();
|
||||
}
|
||||
|
||||
MTLCommandQueue_t stream() const { return _commandQueue; };
|
||||
MTLCommandQueue_t stream() const {
|
||||
return _commandQueue;
|
||||
};
|
||||
|
||||
MTLDevice_t device() const { return [_commandQueue device];}
|
||||
MTLDevice_t device() const {
|
||||
return [_commandQueue device];
|
||||
}
|
||||
|
||||
/// Explicit conversion to Stream.
|
||||
Stream unwrap() const { return _stream; }
|
||||
Stream unwrap() const {
|
||||
return _stream;
|
||||
}
|
||||
|
||||
private:
|
||||
Stream _stream;
|
||||
|
|
@ -117,8 +138,7 @@ TORCH_API MPSStream* getDefaultMPSStream();
|
|||
// MPSStreamImpl
|
||||
//-----------------------------------------------------------------
|
||||
|
||||
class TORCH_API MPSStreamImpl
|
||||
{
|
||||
class TORCH_API MPSStreamImpl {
|
||||
public:
|
||||
/**
|
||||
* Gets single instance of the MPSStream.
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@
|
|||
|
||||
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
|
||||
|
||||
#if !defined(__MAC_15_0) && \
|
||||
(!defined(MAC_OS_X_VERSION_15_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_15_0))
|
||||
#if !defined(__MAC_15_0) && (!defined(MAC_OS_X_VERSION_15_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_15_0))
|
||||
|
||||
@interface MPSNDArrayIdentity : MPSNDArrayUnaryKernel
|
||||
- (MPSNDArray* __nullable)reshapeWithCommandBuffer:(__nullable id<MTLCommandBuffer>)cmdBuf
|
||||
|
|
@ -20,19 +19,16 @@
|
|||
- (nonnull instancetype)initWithBuffer:(id<MTLBuffer> _Nonnull)buffer
|
||||
offset:(NSUInteger)offset
|
||||
descriptor:(MPSNDArrayDescriptor* _Nonnull)descriptor;
|
||||
-(MPSNDArray * __nullable) arrayViewWithShape:(MPSShape * _Nullable) shape
|
||||
strides:(MPSShape * _Nonnull) strides;
|
||||
- (MPSNDArray* __nullable)arrayViewWithShape:(MPSShape* _Nullable)shape strides:(MPSShape* _Nonnull)strides;
|
||||
@end
|
||||
|
||||
typedef NS_ENUM(NSInteger, MTLMathMode)
|
||||
{
|
||||
typedef NS_ENUM(NSInteger, MTLMathMode) {
|
||||
MTLMathModeSafe = 0,
|
||||
MTLMathModeRelaxed = 1,
|
||||
MTLMathModeFast = 2,
|
||||
};
|
||||
|
||||
typedef NS_ENUM(NSInteger, MTLMathFloatingPointFunctions)
|
||||
{
|
||||
typedef NS_ENUM(NSInteger, MTLMathFloatingPointFunctions) {
|
||||
MTLMathFloatingPointFunctionsFast = 0,
|
||||
MTLMathFloatingPointFunctionsPrecise = 1,
|
||||
};
|
||||
|
|
|
|||
|
|
@ -2,11 +2,9 @@
|
|||
|
||||
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
|
||||
|
||||
#if !defined(__MAC_14_0) && \
|
||||
(!defined(MAC_OS_X_VERSION_14_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_14_0))
|
||||
#if !defined(__MAC_14_0) && (!defined(MAC_OS_X_VERSION_14_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_14_0))
|
||||
|
||||
typedef NS_ENUM(NSUInteger, MPSGraphFFTScalingMode)
|
||||
{
|
||||
typedef NS_ENUM(NSUInteger, MPSGraphFFTScalingMode) {
|
||||
MPSGraphFFTScalingModeNone = 0L,
|
||||
MPSGraphFFTScalingModeSize = 1L,
|
||||
MPSGraphFFTScalingModeUnitary = 2L,
|
||||
|
|
@ -22,12 +20,9 @@ typedef NS_ENUM(NSUInteger, MPSGraphFFTScalingMode)
|
|||
@compatibility_alias MPSGraphFFTDescriptor FakeMPSGraphFFTDescriptor;
|
||||
|
||||
@interface MPSGraph (SonomaOps)
|
||||
-(MPSGraphTensor * _Nonnull) conjugateWithTensor:(MPSGraphTensor * _Nonnull) tensor
|
||||
name:(NSString * _Nullable) name;
|
||||
|
||||
-(MPSGraphTensor * _Nonnull) realPartOfTensor:(MPSGraphTensor * _Nonnull) tensor
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)conjugateWithTensor:(MPSGraphTensor* _Nonnull)tensor name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor* _Nonnull)realPartOfTensor:(MPSGraphTensor* _Nonnull)tensor name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor* _Nonnull)fastFourierTransformWithTensor:(MPSGraphTensor* _Nonnull)tensor
|
||||
axes:(NSArray<NSNumber*>* _Nonnull)axes
|
||||
|
|
|
|||
|
|
@ -2,8 +2,7 @@
|
|||
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
|
||||
|
||||
// TODO: Remove me when moved to MacOS 13
|
||||
#if !defined(__MAC_13_2) && \
|
||||
(!defined(MAC_OS_X_VERSION_13_2) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_2))
|
||||
#if !defined(__MAC_13_2) && (!defined(MAC_OS_X_VERSION_13_2) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_2))
|
||||
|
||||
@interface FakeMPSGraphConvolution3DOpDescriptor : NSObject<NSCopying>
|
||||
|
||||
|
|
@ -35,11 +34,9 @@
|
|||
|
||||
@interface MPSGraph (VenturaOps)
|
||||
|
||||
#if !defined(__MAC_13_0) && \
|
||||
(!defined(MAC_OS_X_VERSION_13_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_0))
|
||||
#if !defined(__MAC_13_0) && (!defined(MAC_OS_X_VERSION_13_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_0))
|
||||
|
||||
typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode)
|
||||
{
|
||||
typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode) {
|
||||
MPSGraphResizeNearestRoundingModeRoundPreferCeil = 0L,
|
||||
MPSGraphResizeNearestRoundingModeRoundPreferFloor = 1L,
|
||||
MPSGraphResizeNearestRoundingModeCeil = 2L,
|
||||
|
|
@ -59,16 +56,20 @@ typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode)
|
|||
descriptor:(MPSGraphConvolution3DOpDescriptor* _Nonnull)descriptor
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) convolution3DDataGradientWithIncomingGradientTensor:(MPSGraphTensor * _Nonnull) incomingGradient
|
||||
- (MPSGraphTensor* _Nonnull)
|
||||
convolution3DDataGradientWithIncomingGradientTensor:(MPSGraphTensor* _Nonnull)incomingGradient
|
||||
weightsTensor:(MPSGraphTensor* _Nonnull)weights
|
||||
outputShape:(MPSShape* _Nonnull)outputShape
|
||||
forwardConvolutionDescriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) forwardConvolutionDescriptor
|
||||
forwardConvolutionDescriptor:
|
||||
(MPSGraphConvolution3DOpDescriptor* _Nonnull)forwardConvolutionDescriptor
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull) convolution3DWeightsGradientWithIncomingGradientTensor:(MPSGraphTensor * _Nonnull) incomingGradient
|
||||
- (MPSGraphTensor* _Nonnull)
|
||||
convolution3DWeightsGradientWithIncomingGradientTensor:(MPSGraphTensor* _Nonnull)incomingGradient
|
||||
sourceTensor:(MPSGraphTensor* _Nonnull)source
|
||||
outputShape:(MPSShape* _Nonnull)outputShape
|
||||
forwardConvolutionDescriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) forwardConvolutionDescriptor
|
||||
forwardConvolutionDescriptor:
|
||||
(MPSGraphConvolution3DOpDescriptor* _Nonnull)forwardConvolutionDescriptor
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor* _Nonnull)cumulativeSumWithTensor:(MPSGraphTensor* _Nonnull)tensor
|
||||
|
|
@ -111,8 +112,7 @@ typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode)
|
|||
axisTensor:(MPSGraphTensor* _Nonnull)axisTensor
|
||||
name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor * _Nonnull)inverseOfTensor:(MPSGraphTensor * _Nonnull) inputTensor
|
||||
name:(NSString * _Nullable)name;
|
||||
- (MPSGraphTensor* _Nonnull)inverseOfTensor:(MPSGraphTensor* _Nonnull)inputTensor name:(NSString* _Nullable)name;
|
||||
|
||||
- (MPSGraphTensor* _Nonnull)resizeNearestWithTensor:(MPSGraphTensor* _Nonnull)imagesTensor
|
||||
sizeTensor:(MPSGraphTensor* _Nonnull)size
|
||||
|
|
@ -191,7 +191,6 @@ typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode)
|
|||
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
|
||||
constantValue:(double)constantValue
|
||||
name:(NSString* _Nullable)name;
|
||||
- (MPSGraphTensor * _Nonnull) truncateWithTensor:(MPSGraphTensor * _Nonnull) tensor
|
||||
name:(NSString * _Nullable) name;
|
||||
- (MPSGraphTensor* _Nonnull)truncateWithTensor:(MPSGraphTensor* _Nonnull)tensor name:(NSString* _Nullable)name;
|
||||
|
||||
@end
|
||||
|
|
|
|||
|
|
@ -35,7 +35,9 @@ namespace at::native::mps {
|
|||
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
|
||||
|
||||
struct MPSScalar {
|
||||
id<MTLBuffer> getMTLBuffer() const { return __builtin_bit_cast(id<MTLBuffer>, buffer.get()); }
|
||||
id<MTLBuffer> getMTLBuffer() const {
|
||||
return __builtin_bit_cast(id<MTLBuffer>, buffer.get());
|
||||
}
|
||||
|
||||
size_t size = 0;
|
||||
ScalarType type = ScalarType::Undefined;
|
||||
|
|
@ -51,10 +53,7 @@ struct MPSScalar {
|
|||
} value{};
|
||||
};
|
||||
|
||||
void runMPSGraph(MPSStream* mpsStream,
|
||||
MPSGraph* mpsGraph,
|
||||
NSDictionary* feeds,
|
||||
NSDictionary* results);
|
||||
void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results);
|
||||
|
||||
MPSDataType getMPSDataType(ScalarType scalar_type);
|
||||
static inline MPSDataType getMPSDataType(const TensorBase& t) {
|
||||
|
|
@ -82,9 +81,17 @@ std::string getArrayRefString(const IntArrayRef s);
|
|||
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);
|
||||
Tensor& scatterViewTensor(const Tensor& src, Tensor& output);
|
||||
bool canSliceViewTensor(const TensorBase& src, MPSShape* mpsShape);
|
||||
MPSGraphTensorData* getMPSGraphTensorDataForView(const TensorBase& src, MPSShape *mpsShape, const MPSDataType mpsDataType);
|
||||
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input, bool includesInt64 = false);
|
||||
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input, bool includesInt64 = false);
|
||||
MPSGraphTensorData* getMPSGraphTensorDataForView(const TensorBase& src,
|
||||
MPSShape* mpsShape,
|
||||
const MPSDataType mpsDataType);
|
||||
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
const TensorBase& input,
|
||||
bool includesInt64 = false);
|
||||
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph,
|
||||
MPSGraphTensor* inputTensor,
|
||||
const TensorBase& input,
|
||||
bool includesInt64 = false);
|
||||
|
||||
MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {});
|
||||
MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes = nil, MPSShape* strides = nil);
|
||||
|
|
@ -102,8 +109,12 @@ class Placeholder {
|
|||
Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {}
|
||||
Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr), _tensor(Tensor()) {}
|
||||
Placeholder(MPSGraphTensor* mpsGraphTensor, MPSNDArray* mpsNDArray);
|
||||
Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& self, MPSShape *mpsShape = nullptr,
|
||||
bool gatherTensorData = true, MPSDataType dataType = MPSDataTypeInvalid, bool useMPSStridedAPI = true);
|
||||
Placeholder(MPSGraphTensor* mpsGraphTensor,
|
||||
const Tensor& self,
|
||||
MPSShape* mpsShape = nullptr,
|
||||
bool gatherTensorData = true,
|
||||
MPSDataType dataType = MPSDataTypeInvalid,
|
||||
bool useMPSStridedAPI = true);
|
||||
MPSGraphTensor* getMPSGraphTensor() {
|
||||
return _placeholder;
|
||||
}
|
||||
|
|
@ -145,8 +156,7 @@ using MPSCacheKey = uint64_t;
|
|||
|
||||
// derive this class to cache a graph and its inputs/outputs
|
||||
// can be used to store any NSObject
|
||||
struct MPSCachedGraph
|
||||
{
|
||||
struct MPSCachedGraph {
|
||||
MPSCachedGraph(NSObject* object) : _object([object retain]) {}
|
||||
virtual ~MPSCachedGraph() {
|
||||
[_object release];
|
||||
|
|
@ -158,21 +168,24 @@ struct MPSCachedGraph
|
|||
return static_cast<T*>(this);
|
||||
}
|
||||
|
||||
MPSGraph *graph() const { return (MPSGraph *)_object; }
|
||||
NSObject *object() const { return _object; }
|
||||
MPSGraph* graph() const {
|
||||
return (MPSGraph*)_object;
|
||||
}
|
||||
NSObject* object() const {
|
||||
return _object;
|
||||
}
|
||||
|
||||
private:
|
||||
NSObject* _object = nullptr;
|
||||
};
|
||||
|
||||
struct MPSUnaryCachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct MPSUnaryCachedGraph : public MPSCachedGraph {
|
||||
MPSUnaryCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
struct MPSUnaryGradCachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct MPSUnaryGradCachedGraph : public MPSCachedGraph {
|
||||
MPSUnaryGradCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* gradOutputTensor_ = nil;
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
|
|
@ -180,16 +193,14 @@ struct MPSUnaryGradCachedGraph : public MPSCachedGraph
|
|||
MPSGraphTensor* gradInputTensor_ = nil;
|
||||
};
|
||||
|
||||
struct MPSBinaryCachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct MPSBinaryCachedGraph : public MPSCachedGraph {
|
||||
MPSBinaryCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
MPSGraphTensor* otherTensor_ = nil;
|
||||
MPSGraphTensor* outputTensor_ = nil;
|
||||
};
|
||||
|
||||
struct MPSBinaryGradCachedGraph : public MPSCachedGraph
|
||||
{
|
||||
struct MPSBinaryGradCachedGraph : public MPSCachedGraph {
|
||||
MPSBinaryGradCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
|
||||
MPSGraphTensor* gradOutputTensor_ = nil;
|
||||
MPSGraphTensor* inputTensor_ = nil;
|
||||
|
|
@ -200,8 +211,7 @@ struct MPSBinaryGradCachedGraph : public MPSCachedGraph
|
|||
// TODO: Improve the overall design of MPSGraphCache.
|
||||
// https://github.com/pytorch/pytorch/issues/77176
|
||||
// Cache holding various keys mapped to graphs
|
||||
struct MPSGraphCache
|
||||
{
|
||||
struct MPSGraphCache {
|
||||
typedef MPSCachedGraph* (^CreateCachedGraphBlock)();
|
||||
|
||||
struct CacheEntry {
|
||||
|
|
@ -211,7 +221,6 @@ struct MPSGraphCache
|
|||
};
|
||||
|
||||
public:
|
||||
|
||||
static MPSGraphCache* getInstance() {
|
||||
if (_instance_cache == nullptr) {
|
||||
_instance_cache = new MPSGraphCache();
|
||||
|
|
@ -232,7 +241,6 @@ struct MPSGraphCache
|
|||
void operator=(const MPSGraphCache&) = delete;
|
||||
|
||||
MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock) {
|
||||
|
||||
__block MPSCachedGraph* cachedGraph = nil;
|
||||
|
||||
MPSCacheKey hash = std::hash<std::string>{}(key);
|
||||
|
|
@ -259,13 +267,11 @@ struct MPSGraphCache
|
|||
}
|
||||
|
||||
MPSCachedGraph* LookUp(const std::string& key) const {
|
||||
|
||||
__block MPSCachedGraph* cachedGraph = nullptr;
|
||||
|
||||
MPSCacheKey hash = std::hash<std::string>{}(key);
|
||||
|
||||
dispatch_sync(serialQueue_, ^() {
|
||||
|
||||
if (cache_.count(hash) != 0) {
|
||||
auto& entry = cache_.at(hash);
|
||||
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n");
|
||||
|
|
@ -292,7 +298,6 @@ struct MPSGraphCache
|
|||
static MPSGraphCache* _instance_cache;
|
||||
std::unordered_map<MPSCacheKey, CacheEntry> cache_;
|
||||
dispatch_queue_t serialQueue_ = nullptr;
|
||||
|
||||
};
|
||||
|
||||
// Common template for creating graph with a specified cache if missing
|
||||
|
|
@ -319,7 +324,9 @@ MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
|
|||
|
||||
#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \
|
||||
if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \
|
||||
TORCH_WARN_ONCE("MPS: no support for int64 for ", op_name, \
|
||||
TORCH_WARN_ONCE( \
|
||||
"MPS: no support for int64 for ", \
|
||||
op_name, \
|
||||
", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \
|
||||
}
|
||||
|
||||
|
|
@ -335,17 +342,19 @@ inline bool is_dense_in_storage(const TensorBase& t) {
|
|||
return compute_storage_numel_distance(t) == static_cast<size_t>(t.numel());
|
||||
}
|
||||
|
||||
|
||||
class MetalShaderLibrary {
|
||||
public:
|
||||
MetalShaderLibrary(const std::string& src) : shaderSource(src), nparams(0), compile_options(nullptr) {}
|
||||
MetalShaderLibrary(const std::string& src, unsigned nparams_): shaderSource(src), nparams(nparams_), compile_options(nullptr){}
|
||||
MetalShaderLibrary(const std::string& src, unsigned nparams_, MTLCompileOptions* compile_options_): shaderSource(src), nparams(nparams_), compile_options(compile_options_) {}
|
||||
MetalShaderLibrary(const std::string& src, unsigned nparams_)
|
||||
: shaderSource(src), nparams(nparams_), compile_options(nullptr) {}
|
||||
MetalShaderLibrary(const std::string& src, unsigned nparams_, MTLCompileOptions* compile_options_)
|
||||
: shaderSource(src), nparams(nparams_), compile_options(compile_options_) {}
|
||||
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
|
||||
inline id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname) {
|
||||
return getLibraryPipelineState(getLibrary(), fname).first;
|
||||
}
|
||||
id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname, const std::initializer_list<std::string>& params) {
|
||||
id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname,
|
||||
const std::initializer_list<std::string>& params) {
|
||||
return getLibraryPipelineState(getLibrary(params), fname).first;
|
||||
}
|
||||
inline id<MTLFunction> getMTLFunction(const std::string& fname) {
|
||||
|
|
@ -355,12 +364,15 @@ public:
|
|||
return getLibraryPipelineState(getLibrary(params), fname).second;
|
||||
}
|
||||
static MetalShaderLibrary& getBundledLibrary();
|
||||
|
||||
protected:
|
||||
virtual id<MTLLibrary> getLibrary();
|
||||
virtual id<MTLLibrary> getLibrary(const std::initializer_list<std::string>& params);
|
||||
id<MTLLibrary> library = nil;
|
||||
|
||||
private:
|
||||
std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname);
|
||||
std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getLibraryPipelineState(id<MTLLibrary> lib,
|
||||
const std::string& fname);
|
||||
|
||||
id<MTLLibrary> compileLibrary(const std::string& src);
|
||||
std::string shaderSource;
|
||||
|
|
@ -371,21 +383,18 @@ private:
|
|||
};
|
||||
|
||||
template <typename encoder_t,
|
||||
typename = std::enable_if_t<std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t> || std::is_same_v<id<MTLArgumentEncoder>, encoder_t>>>
|
||||
typename = std::enable_if_t<std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t> ||
|
||||
std::is_same_v<id<MTLArgumentEncoder>, encoder_t>>>
|
||||
static inline void mtl_setBuffer(encoder_t encoder, const TensorBase& t, unsigned idx) {
|
||||
[encoder setBuffer:getMTLBufferStorage(t)
|
||||
offset:t.storage_offset() * t.element_size()
|
||||
atIndex:idx];
|
||||
[encoder setBuffer:getMTLBufferStorage(t) offset:t.storage_offset() * t.element_size() atIndex:idx];
|
||||
}
|
||||
|
||||
template<typename T,
|
||||
typename = std::enable_if_t<std::is_integral_v<T> || std::is_same_v<T, float>>>
|
||||
template <typename T, typename = std::enable_if_t<std::is_integral_v<T> || std::is_same_v<T, float>>>
|
||||
static inline void mtl_setBytes(id<MTLComputeCommandEncoder> encoder, const T val, unsigned idx) {
|
||||
[encoder setBytes:&val length:sizeof(T) atIndex:idx];
|
||||
}
|
||||
|
||||
template<typename Container,
|
||||
typename = std::enable_if_t<std::is_integral_v<typename Container::size_type>>>
|
||||
template <typename Container, typename = std::enable_if_t<std::is_integral_v<typename Container::size_type>>>
|
||||
static inline void mtl_setBytes(id<MTLComputeCommandEncoder> encoder, const Container& values, unsigned idx) {
|
||||
[encoder setBytes:values.data() length:sizeof(typename Container::value_type) * values.size() atIndex:idx];
|
||||
}
|
||||
|
|
@ -400,7 +409,9 @@ static inline void mtl_dispatch1DJob(id<MTLComputeCommandEncoder> encoder,
|
|||
[encoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize];
|
||||
}
|
||||
|
||||
id<MTLBuffer> generateKernelDataOffsets(id<MTLComputeCommandEncoder> commandEncoder, const TensorIteratorBase& iter, bool use_64bit_index = false);
|
||||
id<MTLBuffer> generateKernelDataOffsets(id<MTLComputeCommandEncoder> commandEncoder,
|
||||
const TensorIteratorBase& iter,
|
||||
bool use_64bit_index = false);
|
||||
|
||||
inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1) {
|
||||
return @{p1.getMPSGraphTensor() : p1.getMPSGraphTensorData()};
|
||||
|
|
|
|||
|
|
@ -2,10 +2,11 @@
|
|||
|
||||
#define AT_DISPATCH_MPS_TYPES(TYPE, NAME, ...) \
|
||||
AT_DISPATCH_SWITCH( \
|
||||
TYPE, NAME, \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
TYPE, \
|
||||
NAME, \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) AT_DISPATCH_CASE( \
|
||||
at::ScalarType::Half, \
|
||||
__VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
|
||||
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
|
||||
|
|
|
|||
|
|
@ -1,5 +1,8 @@
|
|||
#pragma once
|
||||
|
||||
namespace at::native::mps {
|
||||
void complex_mul_out(const Tensor& input, const Tensor& other, const Tensor& output);
|
||||
void complex_mul_out(
|
||||
const Tensor& input,
|
||||
const Tensor& other,
|
||||
const Tensor& output);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -17,8 +17,7 @@ void _fused_adam_amsgrad_mps_impl_(
|
|||
const double eps,
|
||||
const bool maximize,
|
||||
const std::optional<Tensor>& grad_scale,
|
||||
const std::optional<Tensor>& found_inf
|
||||
);
|
||||
const std::optional<Tensor>& found_inf);
|
||||
|
||||
void _fused_adam_amsgrad_mps_impl_(
|
||||
TensorList params,
|
||||
|
|
@ -34,7 +33,6 @@ void _fused_adam_amsgrad_mps_impl_(
|
|||
const double eps,
|
||||
const bool maximize,
|
||||
const std::optional<at::Tensor>& grad_scale,
|
||||
const std::optional<at::Tensor>& found_inf
|
||||
);
|
||||
const std::optional<at::Tensor>& found_inf);
|
||||
|
||||
} // namespace at::native::mps
|
||||
|
|
|
|||
|
|
@ -16,8 +16,7 @@ void _fused_adam_mps_impl_(
|
|||
const double eps,
|
||||
const bool maximize,
|
||||
const std::optional<Tensor>& grad_scale,
|
||||
const std::optional<Tensor>& found_inf
|
||||
);
|
||||
const std::optional<Tensor>& found_inf);
|
||||
|
||||
void _fused_adam_mps_impl_(
|
||||
TensorList params,
|
||||
|
|
@ -32,6 +31,5 @@ void _fused_adam_mps_impl_(
|
|||
const double eps,
|
||||
const bool maximize,
|
||||
const std::optional<Tensor>& grad_scale,
|
||||
const std::optional<Tensor>& found_inf
|
||||
);
|
||||
const std::optional<Tensor>& found_inf);
|
||||
} // namespace at::native::mps
|
||||
|
|
|
|||
|
|
@ -17,8 +17,7 @@ void _fused_adamw_amsgrad_mps_impl_(
|
|||
const double eps,
|
||||
const bool maximize,
|
||||
const std::optional<Tensor>& grad_scale,
|
||||
const std::optional<Tensor>& found_inf
|
||||
);
|
||||
const std::optional<Tensor>& found_inf);
|
||||
|
||||
void _fused_adamw_amsgrad_mps_impl_(
|
||||
TensorList params,
|
||||
|
|
@ -34,6 +33,5 @@ void _fused_adamw_amsgrad_mps_impl_(
|
|||
const double eps,
|
||||
const bool maximize,
|
||||
const std::optional<Tensor>& grad_scale,
|
||||
const std::optional<Tensor>& found_inf
|
||||
);
|
||||
const std::optional<Tensor>& found_inf);
|
||||
} // namespace at::native::mps
|
||||
|
|
|
|||
|
|
@ -16,8 +16,7 @@ void _fused_adamw_mps_impl_(
|
|||
const double eps,
|
||||
const bool maximize,
|
||||
const std::optional<Tensor>& grad_scale,
|
||||
const std::optional<Tensor>& found_inf
|
||||
);
|
||||
const std::optional<Tensor>& found_inf);
|
||||
|
||||
void _fused_adamw_mps_impl_(
|
||||
TensorList params,
|
||||
|
|
@ -32,7 +31,6 @@ void _fused_adamw_mps_impl_(
|
|||
const double eps,
|
||||
const bool maximize,
|
||||
const std::optional<Tensor>& grad_scale,
|
||||
const std::optional<Tensor>& found_inf
|
||||
);
|
||||
const std::optional<Tensor>& found_inf);
|
||||
|
||||
} // namespace at::native::mps
|
||||
|
|
|
|||
|
|
@ -436,9 +436,11 @@ REGISTER_FUSED_SGD_MOMENTUM_OP(half);
|
|||
|
||||
)METAL";
|
||||
|
||||
static std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getCPLState(const std::string& fname) {
|
||||
static std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getCPLState(
|
||||
const std::string& fname) {
|
||||
static MetalShaderLibrary lib(FUSED_ADAM_OPS, 0);
|
||||
return std::make_pair(lib.getPipelineStateForFunc(fname), lib.getMTLFunction(fname));
|
||||
return std::make_pair(
|
||||
lib.getPipelineStateForFunc(fname), lib.getMTLFunction(fname));
|
||||
}
|
||||
|
||||
} // namespace mps
|
||||
|
|
|
|||
|
|
@ -16,8 +16,7 @@ struct MetadataArguments { // the size of this struct must be less than 4 bytes
|
|||
};
|
||||
|
||||
struct FusedAdamEncodingFunctor {
|
||||
void operator()(
|
||||
id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
id<MTLBuffer>& tensorArgumentBuffer,
|
||||
const MetadataArguments& metadata_arguments,
|
||||
const double lr,
|
||||
|
|
@ -25,9 +24,7 @@ struct FusedAdamEncodingFunctor {
|
|||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize
|
||||
) const {
|
||||
|
||||
const bool maximize) const {
|
||||
float lr_lv = lr;
|
||||
float beta1_lv = beta1;
|
||||
float beta2_lv = beta2;
|
||||
|
|
@ -35,12 +32,8 @@ struct FusedAdamEncodingFunctor {
|
|||
float eps_lv = eps;
|
||||
uint8_t maximize_lv = maximize;
|
||||
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer
|
||||
offset:0
|
||||
atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments
|
||||
length:sizeof(MetadataArguments)
|
||||
atIndex:1];
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
|
||||
mtl_setBytes(computeEncoder, lr_lv, 2);
|
||||
mtl_setBytes(computeEncoder, beta1_lv, 3);
|
||||
mtl_setBytes(computeEncoder, beta2_lv, 4);
|
||||
|
|
@ -49,8 +42,7 @@ struct FusedAdamEncodingFunctor {
|
|||
mtl_setBytes(computeEncoder, maximize_lv, 7);
|
||||
}
|
||||
|
||||
void operator()(
|
||||
id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
id<MTLBuffer>& tensorArgumentBuffer,
|
||||
const MetadataArguments& metadata_arguments,
|
||||
const at::Tensor& lr,
|
||||
|
|
@ -58,20 +50,15 @@ struct FusedAdamEncodingFunctor {
|
|||
const double beta2,
|
||||
const double weight_decay,
|
||||
const double eps,
|
||||
const bool maximize
|
||||
) const {
|
||||
const bool maximize) const {
|
||||
float beta1_lv = beta1;
|
||||
float beta2_lv = beta2;
|
||||
float weight_decay_lv = weight_decay;
|
||||
float eps_lv = eps;
|
||||
uint8_t maximize_lv = maximize;
|
||||
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer
|
||||
offset:0
|
||||
atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments
|
||||
length:sizeof(MetadataArguments)
|
||||
atIndex:1];
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
|
||||
mtl_setBuffer(computeEncoder, lr, 2);
|
||||
mtl_setBytes(computeEncoder, beta1_lv, 3);
|
||||
mtl_setBytes(computeEncoder, beta2_lv, 4);
|
||||
|
|
@ -86,8 +73,7 @@ struct FusedSgdEncodingFunctor {};
|
|||
|
||||
template <>
|
||||
struct FusedSgdEncodingFunctor<true> {
|
||||
void operator()(
|
||||
id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
id<MTLBuffer>& tensorArgumentBuffer,
|
||||
const MetadataArguments& metadata_arguments,
|
||||
const double weight_decay,
|
||||
|
|
@ -96,8 +82,7 @@ struct FusedSgdEncodingFunctor<true> {
|
|||
const double dampening,
|
||||
const bool nesterov,
|
||||
const bool maximize,
|
||||
const bool is_first_step
|
||||
) const {
|
||||
const bool is_first_step) const {
|
||||
float weight_decay_lv = weight_decay;
|
||||
float momentum_lv = momentum;
|
||||
float lr_lv = lr;
|
||||
|
|
@ -106,12 +91,8 @@ struct FusedSgdEncodingFunctor<true> {
|
|||
uint8_t maximize_lv = maximize;
|
||||
uint8_t is_first_step_lv = is_first_step;
|
||||
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer
|
||||
offset:0
|
||||
atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments
|
||||
length:sizeof(MetadataArguments)
|
||||
atIndex:1];
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
|
||||
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
|
||||
mtl_setBytes(computeEncoder, momentum_lv, 3);
|
||||
mtl_setBytes(computeEncoder, lr_lv, 4);
|
||||
|
|
@ -121,8 +102,7 @@ struct FusedSgdEncodingFunctor<true> {
|
|||
mtl_setBytes(computeEncoder, is_first_step_lv, 8);
|
||||
}
|
||||
|
||||
void operator()(
|
||||
id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
id<MTLBuffer>& tensorArgumentBuffer,
|
||||
const MetadataArguments& metadata_arguments,
|
||||
const double weight_decay,
|
||||
|
|
@ -131,8 +111,7 @@ struct FusedSgdEncodingFunctor<true> {
|
|||
const double dampening,
|
||||
const bool nesterov,
|
||||
const bool maximize,
|
||||
const bool is_first_step
|
||||
) const {
|
||||
const bool is_first_step) const {
|
||||
float weight_decay_lv = weight_decay;
|
||||
float momentum_lv = momentum;
|
||||
float dampening_lv = dampening;
|
||||
|
|
@ -140,12 +119,8 @@ struct FusedSgdEncodingFunctor<true> {
|
|||
uint8_t maximize_lv = maximize;
|
||||
uint8_t is_first_step_lv = is_first_step;
|
||||
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer
|
||||
offset:0
|
||||
atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments
|
||||
length:sizeof(MetadataArguments)
|
||||
atIndex:1];
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
|
||||
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
|
||||
mtl_setBytes(computeEncoder, momentum_lv, 3);
|
||||
mtl_setBuffer(computeEncoder, lr, 4);
|
||||
|
|
@ -158,46 +133,34 @@ struct FusedSgdEncodingFunctor<true> {
|
|||
|
||||
template <>
|
||||
struct FusedSgdEncodingFunctor<false> {
|
||||
void operator()(
|
||||
id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
id<MTLBuffer>& tensorArgumentBuffer,
|
||||
const MetadataArguments& metadata_arguments,
|
||||
const double weight_decay,
|
||||
const double lr,
|
||||
const bool maximize
|
||||
) const {
|
||||
const bool maximize) const {
|
||||
float weight_decay_lv = weight_decay;
|
||||
float lr_lv = lr;
|
||||
uint8_t maximize_lv = maximize;
|
||||
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer
|
||||
offset:0
|
||||
atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments
|
||||
length:sizeof(MetadataArguments)
|
||||
atIndex:1];
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
|
||||
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
|
||||
mtl_setBytes(computeEncoder, lr_lv, 3);
|
||||
mtl_setBytes(computeEncoder, maximize_lv, 4);
|
||||
}
|
||||
|
||||
void operator()(
|
||||
id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
|
||||
id<MTLBuffer>& tensorArgumentBuffer,
|
||||
const MetadataArguments& metadata_arguments,
|
||||
const double weight_decay,
|
||||
const at::Tensor& lr,
|
||||
const bool maximize
|
||||
) const {
|
||||
const bool maximize) const {
|
||||
float weight_decay_lv = weight_decay;
|
||||
uint8_t maximize_lv = maximize;
|
||||
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer
|
||||
offset:0
|
||||
atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments
|
||||
length:sizeof(MetadataArguments)
|
||||
atIndex:1];
|
||||
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
|
||||
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
|
||||
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
|
||||
mtl_setBuffer(computeEncoder, lr, 3);
|
||||
mtl_setBytes(computeEncoder, maximize_lv, 4);
|
||||
|
|
@ -205,25 +168,22 @@ struct FusedSgdEncodingFunctor<false> {
|
|||
};
|
||||
|
||||
template <int depth, uint32_t kThreadGroupSize, typename encoder_func_t, typename... ArgTypes>
|
||||
static void multi_tensor_apply_for_fused_optimizer(
|
||||
const std::string& kernel_name,
|
||||
static void multi_tensor_apply_for_fused_optimizer(const std::string& kernel_name,
|
||||
std::vector<std::vector<at::Tensor>>& tensor_lists,
|
||||
at::TensorList state_steps,
|
||||
encoder_func_t encode,
|
||||
ArgTypes... args
|
||||
) {
|
||||
ArgTypes... args) {
|
||||
const auto num_tensors = tensor_lists[0].size();
|
||||
|
||||
if (num_tensors == 0) {
|
||||
return;
|
||||
}
|
||||
|
||||
TORCH_CHECK(
|
||||
tensor_lists.size() == depth,
|
||||
"Number of tensor lists has to match the depth");
|
||||
TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth");
|
||||
for (const auto& d : c10::irange(depth)) {
|
||||
TORCH_CHECK(
|
||||
tensor_lists[d][0].scalar_type() == at::ScalarType::Float || tensor_lists[d][0].scalar_type() == at::ScalarType::Half, "Only float and half are supported");
|
||||
TORCH_CHECK(tensor_lists[d][0].scalar_type() == at::ScalarType::Float ||
|
||||
tensor_lists[d][0].scalar_type() == at::ScalarType::Half,
|
||||
"Only float and half are supported");
|
||||
}
|
||||
|
||||
id<MTLDevice> device = MPSDevice::getInstance()->device();
|
||||
|
|
@ -251,7 +211,8 @@ static void multi_tensor_apply_for_fused_optimizer(
|
|||
|
||||
// BufferIndex is the index in the kernel function
|
||||
auto tensorArgumentEncoder = [[fusedOptimizerFunc newArgumentEncoderWithBufferIndex:0] autorelease];
|
||||
id<MTLBuffer> tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
|
||||
id<MTLBuffer> tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength
|
||||
options:0] autorelease];
|
||||
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
|
||||
|
||||
int64_t tensor_loc = 0;
|
||||
|
|
@ -266,7 +227,8 @@ static void multi_tensor_apply_for_fused_optimizer(
|
|||
|
||||
for (const auto& d : c10::irange(depth)) {
|
||||
mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors + tensor_loc);
|
||||
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageRead | MTLResourceUsageWrite];
|
||||
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index])
|
||||
usage:MTLResourceUsageRead | MTLResourceUsageWrite];
|
||||
}
|
||||
if (state_steps.size() > 0) {
|
||||
mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors + tensor_loc);
|
||||
|
|
@ -302,18 +264,21 @@ static void multi_tensor_apply_for_fused_optimizer(
|
|||
if (chunk == chunks - 1) {
|
||||
// last chunk
|
||||
tensor_loc = 0;
|
||||
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
|
||||
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength
|
||||
options:0] autorelease];
|
||||
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
|
||||
} else {
|
||||
// reuse the current tensor since the current one isn't done.
|
||||
metadata_arguments.numels[0] = metadata_arguments.numels[tensor_loc - 1];
|
||||
|
||||
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
|
||||
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength
|
||||
options:0] autorelease];
|
||||
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
|
||||
|
||||
for (const auto& d : c10::irange(depth)) {
|
||||
mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors);
|
||||
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageWrite | MTLResourceUsageRead];
|
||||
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index])
|
||||
usage:MTLResourceUsageWrite | MTLResourceUsageRead];
|
||||
}
|
||||
if (state_steps.size() > 0) {
|
||||
mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors);
|
||||
|
|
@ -334,7 +299,6 @@ static void multi_tensor_apply_for_fused_optimizer(
|
|||
}
|
||||
|
||||
getMPSProfiler().endProfileKernel(fusedOptimizerPSO);
|
||||
|
||||
}
|
||||
});
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue
Block a user