[BE][MPS] Apply clang-format to mps headers (#140906)

It was a mistake to amiss them in the past

All changes in this PR except ones to .lintrunner.toml are generated by running
`lintrunner -a --take CLANGFORMAT --all-files`
Pull Request resolved: https://github.com/pytorch/pytorch/pull/140906
Approved by: https://github.com/Skylion007
This commit is contained in:
Nikita Shulga 2024-11-17 10:13:01 -08:00 committed by PyTorch MergeBot
parent 5a7e147ef3
commit 99014a297c
24 changed files with 896 additions and 787 deletions

View File

@ -56,10 +56,12 @@ code = 'CLANGFORMAT'
include_patterns = [
'aten/src/ATen/*.h',
'aten/src/ATen/mps/**/*.mm',
'aten/src/ATen/mps/**/*.h',
'aten/src/ATen/xpu/**/*.h',
'aten/src/ATen/xpu/**/*.cpp',
'aten/src/ATen/native/mps/**/*.metal',
'aten/src/ATen/native/mps/**/*.mm',
'aten/src/ATen/native/mps/**/*.h',
'aten/src/ATen/native/vulkan/**/*.h',
'aten/src/ATen/native/vulkan/**/*.cpp',
'aten/src/ATen/native/cuda/MultiTensorApply.cuh',

View File

@ -12,8 +12,7 @@ C10_EXPORT TensorBase empty_mps(
std::optional<Device> device_opt,
std::optional<bool> pin_memory_opt,
std::optional<c10::MemoryFormat> memory_format_opt);
C10_EXPORT TensorBase empty_mps(
IntArrayRef size, const TensorOptions &options);
C10_EXPORT TensorBase empty_mps(IntArrayRef size, const TensorOptions& options);
C10_EXPORT TensorBase empty_strided_mps(
IntArrayRef size,

View File

@ -6,12 +6,12 @@
#include <ATen/mps/MPSEvent.h>
#include <ATen/mps/MPSStream.h>
#include <c10/util/flat_hash_map.h>
#include <mach/vm_page_size.h>
#include <cstdio>
#include <mutex>
#include <set>
#include <unordered_set>
#include <mach/vm_page_size.h>
#include <c10/util/flat_hash_map.h>
// this implementation is based on CUDACachingAllocator.
// It utilizes Metal Heaps to improve the performance with buffer allocation.
@ -24,8 +24,10 @@ static const size_t kMinLargeAlloc = MB(10); // allocations between 1 and 10 M
static const size_t kRoundLarge = MB(2); // round up large allocations to 2 MiB
static const size_t kSmallHeap = MB(8); // "small" allocations are packed in 8 MiB heaps
static const size_t kLargeHeap = MB(32); // "large" allocations may be packed in 32 MiB heaps
static const size_t kXLargeHeapD = MB(128); // "extra large" allocations on Discrete devices may be packed in 128 MiB heaps
static const size_t kXLargeHeapU = MB(1024); // "extra large" allocations on Unified devices may be packed in 1 GiB heaps
static const size_t kXLargeHeapD =
MB(128); // "extra large" allocations on Discrete devices may be packed in 128 MiB heaps
static const size_t kXLargeHeapU =
MB(1024); // "extra large" allocations on Unified devices may be packed in 1 GiB heaps
static const size_t kMaxScalarAlloc = (sizeof(int64_t)); // largest "scalar" allocation
// buffer pools could be customized with a combination of usage flags
@ -67,10 +69,8 @@ struct BufferBlock {
// Metal events used to sync GPU/CPU operations on the shared-storage buffers
MPSEventPtr event;
BufferBlock(size_t Size, size_t RequestedSize = 0, const id<MTLBuffer> Buffer = nullptr,
HeapBlock* Heap = nullptr) :
buffer(Buffer), size(Size), requested_size(RequestedSize),
heap(Heap), buf_id(Buffer ? ++buffer_counter : 0) { }
BufferBlock(size_t Size, size_t RequestedSize = 0, const id<MTLBuffer> Buffer = nullptr, HeapBlock* Heap = nullptr)
: buffer(Buffer), size(Size), requested_size(RequestedSize), heap(Heap), buf_id(Buffer ? ++buffer_counter : 0) {}
static bool Comparator(const BufferBlock* a, const BufferBlock* b) {
return (a->size != b->size) ? a->size < b->size : (uintptr_t)a->buffer < (uintptr_t)b->buffer;
@ -79,15 +79,19 @@ struct BufferBlock {
assert(((Alignment - 1) & Alignment) == 0);
return ((Size + Alignment - 1) & ~(Alignment - 1));
}
uint32_t retainCount() const { return [buffer retainCount]; }
uint32_t retainCount() const {
return [buffer retainCount];
}
};
typedef bool (*BufferComparison)(const BufferBlock*, const BufferBlock*);
struct BufferPool;
struct AllocParams {
AllocParams(size_t Alloc_Size, size_t Requested_Size, BufferPool* Pool) :
search_key(Alloc_Size), pool(Pool), requested_size(Requested_Size) { }
size_t size() const { return search_key.size; }
AllocParams(size_t Alloc_Size, size_t Requested_Size, BufferPool* Pool)
: search_key(Alloc_Size), pool(Pool), requested_size(Requested_Size) {}
size_t size() const {
return search_key.size;
}
BufferBlock search_key;
BufferPool* pool;
@ -102,7 +106,9 @@ struct AllocParams {
struct HeapBlock {
id<MTLHeap> heap;
struct { size_t total, available; } size;
struct {
size_t total, available;
} size;
BufferPool* pool;
unsigned int n_buffers = 0;
id_t heap_id;
@ -111,9 +117,12 @@ struct HeapBlock {
// counter to assign unique ids to heap blocks
static uint64_t heap_counter;
HeapBlock(size_t Size, const id<MTLHeap> Heap = nullptr, BufferPool *Pool = nullptr) :
heap(Heap), size({.total = Size, .available = Size}), pool(Pool),
heap_id(Heap ? ++heap_counter : 0), is_split(true) { }
HeapBlock(size_t Size, const id<MTLHeap> Heap = nullptr, BufferPool* Pool = nullptr)
: heap(Heap),
size({.total = Size, .available = Size}),
pool(Pool),
heap_id(Heap ? ++heap_counter : 0),
is_split(true) {}
static MTLResourceOptions getOptions(uint32_t usage) {
// TODO: check the caching performance of write-combined mode
@ -126,7 +135,8 @@ struct HeapBlock {
else
options |= MTLResourceStorageModePrivate;
options |= (usage & UsageFlags::HAZARD) ? MTLResourceHazardTrackingModeTracked : MTLResourceHazardTrackingModeUntracked;
options |=
(usage & UsageFlags::HAZARD) ? MTLResourceHazardTrackingModeTracked : MTLResourceHazardTrackingModeUntracked;
return options;
}
@ -152,7 +162,8 @@ struct HeapBlock {
d.cpuCacheMode = MTLCPUCacheModeDefaultCache;
// this automatically handles Metal buffer access synchronizations at the
// cost of slightly lower performance.
d.hazardTrackingMode = (usage & UsageFlags::HAZARD) ? MTLHazardTrackingModeTracked : MTLHazardTrackingModeUntracked;
d.hazardTrackingMode =
(usage & UsageFlags::HAZARD) ? MTLHazardTrackingModeTracked : MTLHazardTrackingModeUntracked;
d.resourceOptions = getOptions(usage);
d.type = MTLHeapTypeAutomatic;
id<MTLHeap> heap = [device newHeapWithDescriptor:d];
@ -169,8 +180,8 @@ struct HeapBlock {
return heapBlock;
}
static bool Comparator(const HeapBlock* a, const HeapBlock* b) {
return (a->size.available != b->size.available) ? a->size.available < b->size.available :
(uintptr_t)a->heap < (uintptr_t)b->heap;
return (a->size.available != b->size.available) ? a->size.available < b->size.available
: (uintptr_t)a->heap < (uintptr_t)b->heap;
}
static NSUInteger heapAvailableSize(id<MTLHeap> heap, size_t Alignment = vm_page_size) {
return [heap maxAvailableSizeWithAlignment:Alignment];
@ -205,8 +216,12 @@ struct HeapBlock {
size.available = 0;
return retainCount;
}
uint32_t retainCount() const { return [heap retainCount]; }
void updateAvailableSize() { size.available = heapAvailableSize(heap); }
uint32_t retainCount() const {
return [heap retainCount];
}
void updateAvailableSize() {
size.available = heapAvailableSize(heap);
}
};
typedef bool (*HeapComparison)(const HeapBlock*, const HeapBlock*);
@ -219,9 +234,8 @@ struct BufferPool {
SCALAR,
};
BufferPool(const id<MTLDevice> Device, uint32_t Usage) :
device(Device), usage(Usage),
heaps(HeapBlock::Comparator), available_buffers(BufferBlock::Comparator) { }
BufferPool(const id<MTLDevice> Device, uint32_t Usage)
: device(Device), usage(Usage), heaps(HeapBlock::Comparator), available_buffers(BufferBlock::Comparator) {}
const id<MTLDevice> device;
// usage flags to customize the pool for various purposes (see UsageFlags enum)
@ -249,8 +263,8 @@ struct BufferPool {
class MPSHeapAllocatorImpl {
public:
explicit MPSHeapAllocatorImpl() :
m_device(at::mps::MPSDevice::getInstance()->device()),
explicit MPSHeapAllocatorImpl()
: m_device(at::mps::MPSDevice::getInstance()->device()),
m_max_buffer_size([m_device maxBufferLength]),
m_stream(getDefaultMPSStream()),
m_event_pool(getMPSEventPool()) {
@ -298,22 +312,38 @@ public:
// (see m_high_watermark_ratio for description)
void setHighWatermarkRatio(double ratio);
// (see m_low_watermark_limit for description)
size_t getLowWatermarkLimit() const { return m_low_watermark_limit; }
size_t getLowWatermarkLimit() const {
return m_low_watermark_limit;
}
// (see m_max_total_allowed_size for description)
size_t getHighWatermarkLimit() const { return m_max_total_allowed_size; }
size_t getHighWatermarkLimit() const {
return m_max_total_allowed_size;
}
// (see m_total_allocated_memory for description)
size_t getTotalAllocatedMemory() const { return m_total_allocated_memory; }
size_t getTotalAllocatedMemory() const {
return m_total_allocated_memory;
}
// (see m_current_allocated_memory for description)
size_t getCurrentAllocatedMemory() const { return m_current_allocated_memory; }
size_t getCurrentAllocatedMemory() const {
return m_current_allocated_memory;
}
// total GPU memory allocated in the process by Metal driver; including
// implicit allocations from MPS/MPSGraph frameworks and MPSHeapAllocatorImpl.
size_t getDriverAllocatedMemory() const { return current_allocated_size(); }
size_t getDriverAllocatedMemory() const {
return current_allocated_size();
}
// recommended Max memory for Metal
size_t getRecommendedMaxMemory() const { return max_device_size(); }
size_t getRecommendedMaxMemory() const {
return max_device_size();
}
// (see enum DebugVerbosity for description)
uint32_t getDebugVerbosity() const { return m_debug_verbosity; }
uint32_t getDebugVerbosity() const {
return m_debug_verbosity;
}
// returns the device that we allocate from
inline id<MTLDevice> Device() const { return m_device; }
inline id<MTLDevice> Device() const {
return m_device;
}
// TODO: make a common function to do size unit conversions in PyTorch.
inline std::string format_size(uint64_t size) const;
@ -387,14 +417,19 @@ private:
size_t get_allocation_size(size_t size, uint32_t usage) const;
// maximum size of device memory available for allocation in current process
// Note: the recommendedMaxWorkingSetSize is typically 75% of the total system memory.
size_t max_device_size() const { return [m_device recommendedMaxWorkingSetSize]; }
size_t max_device_size() const {
return [m_device recommendedMaxWorkingSetSize];
}
// there are implicit allocations from MPS backend, so we need to query the 'device' for
// total allocated size instead of manually tracking in MPSAllocator
size_t current_allocated_size() const { return [m_device currentAllocatedSize]; }
size_t current_allocated_size() const {
return [m_device currentAllocatedSize];
}
bool trigger_memory_callbacks(BufferBlock* buffer_block, IMpsAllocatorCallback::EventType event) const {
for (const auto& name : MPSAllocatorCallbacksRegistry()->Keys()) {
MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(buffer_block ? buffer_block->buffer : nullptr, event);
MPSAllocatorCallbacksRegistry()->Create(name)->executeMPSAllocatorCallback(
buffer_block ? buffer_block->buffer : nullptr, event);
}
return true;
}

View File

@ -2,9 +2,9 @@
#pragma once
#include <ATen/core/ATen_fwd.h>
#include <c10/core/Allocator.h>
#include <c10/util/Registry.h>
#include <ATen/core/ATen_fwd.h>
#define MB(x) (x * 1048576UL)
@ -20,10 +20,12 @@ public:
virtual ssize_t getUnalignedBufferSize(const void* ptr) const = 0;
virtual IntArrayRef getBufferShape(const void* ptr) const = 0;
virtual id_t getBufferId(const void* ptr) const = 0;
virtual void setBufferShape(const void* ptr, const IntArrayRef& shape) const = 0;
virtual void setBufferShape(const void* ptr, const IntArrayRef& shape)
const = 0;
virtual bool isSharedBuffer(const void* ptr) const = 0;
virtual bool isSharedStorageSupported() const = 0;
virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size) const = 0;
virtual c10::DataPtr allocScalarBufferWithValue(void* value, size_t size)
const = 0;
virtual std::string formatSize(size_t size) const = 0;
virtual void setLowWatermarkRatio(double ratio) const = 0;
virtual void setHighWatermarkRatio(double ratio) const = 0;
@ -34,7 +36,8 @@ public:
virtual size_t getCurrentAllocatedMemory() const = 0;
virtual size_t getDriverAllocatedMemory() const = 0;
virtual size_t getRecommendedMaxMemory() const = 0;
virtual std::pair<const void*, uint32_t> getSharedBufferPtr(const void* ptr) const = 0;
virtual std::pair<const void*, uint32_t> getSharedBufferPtr(
const void* ptr) const = 0;
virtual bool recordEvents(c10::ArrayRef<const void*> buffers) const = 0;
virtual bool waitForEvents(c10::ArrayRef<const void*> buffers) const = 0;
};
@ -52,7 +55,8 @@ class IMpsAllocatorCallback {
virtual void executeMPSAllocatorCallback(void* ptr, EventType event) = 0;
};
// MPS allocator will execute every registered callback when a block of memory is freed.
// MPS allocator will execute every registered callback when a block of memory
// is freed.
TORCH_DECLARE_REGISTRY(MPSAllocatorCallbacksRegistry, IMpsAllocatorCallback);
#define REGISTER_MPS_ALLOCATOR_CALLBACK(name, ...) \
C10_REGISTER_CLASS(MPSAllocatorCallbacksRegistry, name, __VA_ARGS__)

View File

@ -5,7 +5,6 @@
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#ifdef __OBJC__
#include <Foundation/Foundation.h>
#include <Metal/Metal.h>

View File

@ -26,12 +26,17 @@ public:
// blocks the CPU thread until all the GPU work that were scheduled
// prior to recording this event are completed.
bool synchronize();
// resets this event with new parameters in case it gets reused from the event pool
// resets this event with new parameters in case it gets reused from the event
// pool
void reset(MPSStream* stream, bool enable_timing);
// returns the unique ID of the event instance
id_t getID() const { return m_id; }
id_t getID() const {
return m_id;
}
// returns the completion timestamp of the event
uint64_t getCompletionTime() const { return m_completion_time; }
uint64_t getCompletionTime() const {
return m_completion_time;
}
// if already recorded, waits for cpu_sync_cv to be signaled
void waitForCpuSync();

View File

@ -17,7 +17,8 @@ struct rng_data_pod {
};
TORCH_API const Generator& getDefaultMPSGenerator();
TORCH_API Generator createMPSGenerator(uint64_t seed_val = default_rng_seed_val);
TORCH_API Generator
createMPSGenerator(uint64_t seed_val = default_rng_seed_val);
} // namespace mps::detail
@ -37,10 +38,18 @@ struct TORCH_API MPSGeneratorImpl : public c10::GeneratorImpl {
c10::intrusive_ptr<c10::TensorImpl> get_state() const override;
void update_philox_counters();
void set_engine(at::Philox4_32 engine) { engine_ = engine; }
at::Philox4_32 engine() { return engine_; }
uint32_t* state_data() { return data_.state.data(); }
static DeviceType device_type() { return DeviceType::MPS; }
void set_engine(at::Philox4_32 engine) {
engine_ = engine;
}
at::Philox4_32 engine() {
return engine_;
}
uint32_t* state_data() {
return data_.state.data();
}
static DeviceType device_type() {
return DeviceType::MPS;
}
private:
mps::detail::rng_data_pod data_;

View File

@ -1,12 +1,12 @@
// Copyright © 2022 Apple Inc.
#pragma once
#include <ATen/Context.h>
#include <ATen/mps/MPSEvent.h>
#include <ATen/mps/MPSStream.h>
#include <c10/core/impl/DeviceGuardImplInterface.h>
#include <c10/macros/Macros.h>
#include <c10/util/Exception.h>
#include <ATen/Context.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/mps/MPSEvent.h>
#ifdef __OBJC__
#include <Foundation/Foundation.h>
@ -18,11 +18,10 @@
#include <c10/core/MemoryFormat.h>
#include <c10/core/Storage.h>
#include <c10/core/TensorImpl.h>
#include <sys/_types/_size_t.h>
#include <memory>
#include <c10/core/UndefinedTensorImpl.h>
#include <c10/util/intrusive_ptr.h>
#include <sys/_types/_size_t.h>
#include <memory>
namespace at::mps {
@ -30,7 +29,8 @@ typedef MPSEvent* mpsEvent_t;
// TODO: Move the MPSGuardImpl to inherit from NoOpDeviceGuardImpl
// https://github.com/pytorch/pytorch/issues/77170
struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface {
struct TORCH_API MPSGuardImpl final
: public c10::impl::DeviceGuardImplInterface {
static constexpr c10::DeviceType static_type = c10::DeviceType::MPS;
// constructor
@ -91,13 +91,10 @@ struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface
}
// Event-related functions
void createEvent(
mpsEvent_t* event,
const EventFlag flag) const;
void createEvent(mpsEvent_t* event, const EventFlag flag) const;
void destroyEvent(
void* event,
const DeviceIndex device_index) const noexcept override;
void destroyEvent(void* event, const DeviceIndex device_index)
const noexcept override;
void record(
void** event,
@ -105,14 +102,11 @@ struct TORCH_API MPSGuardImpl final : public c10::impl::DeviceGuardImplInterface
const DeviceIndex device_index,
const EventFlag flag) const override;
void block(
void* event,
const Stream& stream) const override;
void block(void* event, const Stream& stream) const override;
bool queryEvent(void* event) const override;
void synchronizeDevice(const DeviceIndex device_index) const override;
};
/// A variant of OptionalDeviceGuard that is specialized for MPS.
@ -175,7 +169,6 @@ struct OptionalMPSGuard {
c10::impl::InlineOptionalDeviceGuard<MPSGuardImpl> guard_;
};
C10_REGISTER_GUARD_IMPL(MPS, MPSGuardImpl)
} // namespace at::mps

View File

@ -2,8 +2,8 @@
#pragma once
#include <ATen/detail/MPSHooksInterface.h>
#include <ATen/Generator.h>
#include <ATen/detail/MPSHooksInterface.h>
#include <ATen/mps/MPSEvent.h>
#include <optional>
@ -38,7 +38,8 @@ struct MPSHooks : public at::MPSHooksInterface {
Allocator* getPinnedMemoryAllocator() const override;
// MPSProfiler interface
void profilerStartTrace(const std::string& mode, bool waitUntilCompleted) const override;
void profilerStartTrace(const std::string& mode, bool waitUntilCompleted)
const override;
void profilerStopTrace() const override;
// MPSEvent interface
@ -48,7 +49,8 @@ struct MPSHooks : public at::MPSHooksInterface {
void waitForEvent(uint32_t event_id) const override;
void synchronizeEvent(uint32_t event_id) const override;
bool queryEvent(uint32_t event_id) const override;
double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id) const override;
double elapsedTimeOfEvents(uint32_t start_event_id, uint32_t end_event_id)
const override;
// Compatibility with Accelerator API
bool hasPrimaryContext(DeviceIndex device_index) const override {

View File

@ -3,11 +3,11 @@
#pragma once
#include <ATen/Tensor.h>
#include <ATen/mps/MPSStream.h>
#include <ATen/mps/MPSAllocatorInterface.h>
#include <ATen/mps/MPSStream.h>
#include <os/signpost.h>
#include <os/log.h>
#include <os/signpost.h>
#include <atomic>
#include <ctime>
@ -29,8 +29,8 @@ struct BaseInfo {
CPU_FALLBACK,
};
BaseInfo(Type infoType, uint64_t Id, const uintptr_t Handle) :
type(infoType), profileId(Id), handle(Handle) { }
BaseInfo(Type infoType, uint64_t Id, const uintptr_t Handle)
: type(infoType), profileId(Id), handle(Handle) {}
virtual ~BaseInfo() = default;
// type of profiling info
@ -41,30 +41,36 @@ struct BaseInfo {
// since it's possible to use event and interval-based signposts at the
// same time, we need separate IDs for each.
os_signpost_id_t eventSignpostId = 0, intervalSignpostId = 0;
// accumulated GPU time in ms (obtained from CompletionHandler's "GPUEndTime - GPUStartTime")
// accumulated GPU time in ms (obtained from CompletionHandler's "GPUEndTime -
// GPUStartTime")
std::atomic<double> totalGpuTime{0.0};
// accumulated Scheduling time in ms (obtained from CompletionHandler's "KernelEndTime - KernelStartTime")
// accumulated Scheduling time in ms (obtained from CompletionHandler's
// "KernelEndTime - KernelStartTime")
std::atomic<double> totalSchedulingTime{0.0};
// indicates if the operation or copy execution has completed
std::atomic_bool completed{false};
// handle used to identify the profile info's instance (usually the pointer)
const uintptr_t handle;
virtual const std::string toString(double gpuTime = 0, double schedulingTime = 0) const;
virtual const std::string toString(
double gpuTime = 0,
double schedulingTime = 0) const;
// builds a string for a tensor (format: Device:ScalarType[tensor.sizes()])
static std::string buildTensorString(const Tensor& tensor, bool includeBufferId = false) {
static std::string buildTensorString(
const Tensor& tensor,
bool includeBufferId = false) {
if (tensor.defined()) {
std::stringstream tensorStr;
auto deviceType = tensor.device().type();
tensorStr << c10::DeviceTypeName(deviceType);
// see comments for INCLUDE_BUFFER_ID
if (includeBufferId && deviceType == at::kMPS) {
id<MTLBuffer> buffer = __builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer))
<< ":" << buffer.retainCount << ")";
id<MTLBuffer> buffer =
__builtin_bit_cast(id<MTLBuffer>, tensor.storage().data());
tensorStr << "(buf#" << (getIMPSAllocator()->getBufferId(buffer)) << ":"
<< buffer.retainCount << ")";
}
tensorStr << ":"
<< tensor.scalar_type() << tensor.sizes();
tensorStr << ":" << tensor.scalar_type() << tensor.sizes();
return tensorStr.str();
} else {
return "undefined";
@ -76,16 +82,23 @@ struct BaseInfo {
};
struct OperationInfo : BaseInfo {
OperationInfo(const void* Handle, bool IsGraph, uint64_t Id, const std::string& StrKey) :
BaseInfo(IsGraph ? Type::GRAPH : Type::KERNEL, Id, uintptr_t(Handle)), strKey(StrKey) { }
OperationInfo(
const void* Handle,
bool IsGraph,
uint64_t Id,
const std::string& StrKey)
: BaseInfo(IsGraph ? Type::GRAPH : Type::KERNEL, Id, uintptr_t(Handle)),
strKey(StrKey) {}
uint64_t runCount = 0;
std::string strKey;
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
const std::string toString(double gpuTime = 0, double schedulingTime = 0)
const override;
// builds a string for a kernel
static std::string buildKernelString(const std::string& kernelName,
static std::string buildKernelString(
const std::string& kernelName,
const TensorList& tensors,
bool includeBufferId = false) {
std::stringstream kernelStr;
@ -98,19 +111,20 @@ struct OperationInfo : BaseInfo {
};
struct CpuFbInfo : BaseInfo {
CpuFbInfo(uint64_t Id, const std::string& OpName) :
BaseInfo(Type::CPU_FALLBACK, Id, 0), opName(OpName) { }
CpuFbInfo(uint64_t Id, const std::string& OpName)
: BaseInfo(Type::CPU_FALLBACK, Id, 0), opName(OpName) {}
uint64_t runCount = 0;
// the current and total overhead of copies in bytes required to convert the Op's
// input tensors from MPS to CPU and then output from CPU back to MPS
// the current and total overhead of copies in bytes required to convert the
// Op's input tensors from MPS to CPU and then output from CPU back to MPS
size_t currentCopyOverhead = 0;
size_t totalCopyOverhead = 0;
std::string opName;
std::string strKey;
uint64_t startTime = 0;
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
const std::string toString(double gpuTime = 0, double schedulingTime = 0)
const override;
void updateCopyOverhead(const TensorList& tensors) {
currentCopyOverhead = 0;
@ -130,9 +144,17 @@ struct CopyInfo : BaseInfo {
CPU_TO_MPS,
};
CopyInfo(const void* Handle, size_t Length, uint64_t Id, bool IsNonBlocking, bool UsesBlitter) :
BaseInfo(Type::COPY, Id, uintptr_t(Handle)), kind(Kind::MPS_TO_MPS),
length(Length), isNonBlocking(IsNonBlocking), usesBlitter(UsesBlitter) { }
CopyInfo(
const void* Handle,
size_t Length,
uint64_t Id,
bool IsNonBlocking,
bool UsesBlitter)
: BaseInfo(Type::COPY, Id, uintptr_t(Handle)),
kind(Kind::MPS_TO_MPS),
length(Length),
isNonBlocking(IsNonBlocking),
usesBlitter(UsesBlitter) {}
Kind kind;
size_t length;
@ -143,11 +165,17 @@ struct CopyInfo : BaseInfo {
// for copies that don't use blitters, we measure CPU time
uint64_t startTime = 0;
const std::string toString(double gpuTime = 0, double schedulingTime = 0) const override;
const std::string toString(double gpuTime = 0, double schedulingTime = 0)
const override;
static std::string buildTensorString(const void* buffer, const OptionalTensorRef tensor, bool includeBufferId = false);
static std::string buildTensorString(
const void* buffer,
const OptionalTensorRef tensor,
bool includeBufferId = false);
static bool isStorageOnMPS(const void* buffer, const OptionalTensorRef tensor) {
static bool isStorageOnMPS(
const void* buffer,
const OptionalTensorRef tensor) {
if (tensor.has_value()) {
return tensor->device().type() == at::kMPS;
}
@ -156,8 +184,11 @@ struct CopyInfo : BaseInfo {
return getIMPSAllocator()->getUnalignedBufferSize(buffer) >= 0;
}
static Kind getCopyKind(const void* srcBuffer, const void* dstBuffer,
const OptionalTensorRef srcTensor, const OptionalTensorRef dstTensor) {
static Kind getCopyKind(
const void* srcBuffer,
const void* dstBuffer,
const OptionalTensorRef srcTensor,
const OptionalTensorRef dstTensor) {
const bool isSrcOnMPS = isStorageOnMPS(srcBuffer, srcTensor);
const bool isDstOnMPS = isStorageOnMPS(dstBuffer, dstTensor);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(isSrcOnMPS || isDstOnMPS);
@ -171,8 +202,9 @@ struct CopyInfo : BaseInfo {
};
struct CopyStat : CopyInfo {
explicit CopyStat(std::string CopyKindStr) :
CopyInfo(nullptr, 0, 0, false, false), kindStr(std::move(CopyKindStr)) {}
explicit CopyStat(std::string CopyKindStr)
: CopyInfo(nullptr, 0, 0, false, false),
kindStr(std::move(CopyKindStr)) {}
// total number of copies
size_t totalCount = 0;
// number of Scalar copies (i.e., less than sizeof(int64))
@ -192,8 +224,8 @@ public:
// lower 16 bits used for profiler options
enum ProfileOptions : uint32_t {
OPTIONS_NONE = 0,
// ALL_* means, all signpost types (RUN_OPERATION|BLIT_COPY|CPU_FALLBACK, etc.)
// (used for convenience to not compute bit flags by OR-ing manually)
// ALL_* means, all signpost types (RUN_OPERATION|BLIT_COPY|CPU_FALLBACK,
// etc.) (used for convenience to not compute bit flags by OR-ing manually)
// trace all signpost types using events
ALL_SIGNPOST_EVENTS = (1 << 0),
// trace all signpost types using intervals
@ -206,8 +238,8 @@ public:
// and not schedule time.
INCLUDE_SCHEDULE_INTERVAL = (1 << 3),
// use these if you need to trace signposts types individually (rarely required)
// trace signpost using intervals
// use these if you need to trace signposts types individually (rarely
// required) trace signpost using intervals
USE_INTERVALS = (1 << 4),
// trace signpost by emitting events
USE_EVENTS = (1 << 5),
@ -238,17 +270,21 @@ public:
OPERATION_INFO = (1 << 0),
// prints copy info (src/dst tensors/buffers, size, etc.) during execution
COPY_INFO = (1 << 1),
// prints CPU Fallback info (id/runCount/opName/copyOverhead) during execution
// prints CPU Fallback info (id/runCount/opName/copyOverhead) during
// execution
CPU_FALLBACK_INFO = (1 << 2),
// Profiling Statistics logging options when process terminates
// ------------------------------------------------------------
// prints all stats (OPERATION_STATS, COPY_STATS, CPU_FALLBACK_STATS) before process terminates
// this is convenient to not combine following stats bit flags manually
// prints all stats (OPERATION_STATS, COPY_STATS, CPU_FALLBACK_STATS) before
// process terminates this is convenient to not combine following stats bit
// flags manually
ALL_STATS = (1 << 3),
// prints operation stats (GPU times, run count, etc.) before process terminates
// prints operation stats (GPU times, run count, etc.) before process
// terminates
OPERATION_STATS = (1 << 4),
// prints copies stats (GPU times, copy kinds, sizes, etc.) before process terminates
// prints copies stats (GPU times, copy kinds, sizes, etc.) before process
// terminates
COPY_STATS = (1 << 5),
// prints CPU Fallback stats (CPU times, run times, size of MPS<->CPU copies
// for tensors, etc.) before process terminates
@ -256,17 +292,18 @@ public:
// Metadata format options when logging the info
// ---------------------------------------------
// if enabled, includes GPU run time in metadata (i.e., GPUEndTime-GPUStartTime
// from Metal Command Buffers) (e.g., [GPU=0.324 ms])
// if enabled, includes GPU run time in metadata (i.e.,
// GPUEndTime-GPUStartTime from Metal Command Buffers) (e.g., [GPU=0.324
// ms])
INCLUDE_GPU_TIME = (1 << 7),
// if enabled, includes GPU scheduling time in metadata separately
// (i.e., KernelEndTime-KernelStartTime from Metal Command Buffers)
// e.g., [GPU=0.324 ms, KRNL=0.036 ms]
INCLUDE_KERNEL_TIME = (1 << 8),
// if enabled, includes the unique buffer ID in metadata for the storage
// of a tensor that was allocated on MPSAllocator. This is useful (along with
// the EV "PYTORCH_DEBUG_MPS_ALLOCATOR") to identify buffers that are involved
// with various operations.
// of a tensor that was allocated on MPSAllocator. This is useful (along
// with the EV "PYTORCH_DEBUG_MPS_ALLOCATOR") to identify buffers that are
// involved with various operations.
INCLUDE_BUFFER_ID = (1 << 9),
// used for sanity check (Change this when new option added)
@ -276,15 +313,28 @@ public:
explicit MPSProfiler();
~MPSProfiler();
// the handle is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal Kernels
// the beginProfile*() functions return a profileId which is unique per graph/kernel/copy
uint64_t beginProfileKernel(const void* handle, const std::string& strKey, bool isGraph);
uint64_t beginProfileKernel(const void* handle, const std::string& kernelName, const TensorList& tensors);
uint64_t beginProfileCopy(const void* srcBuffer, const void* dstBuffer,
// the handle is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal
// Kernels the beginProfile*() functions return a profileId which is unique
// per graph/kernel/copy
uint64_t beginProfileKernel(
const void* handle,
const std::string& strKey,
bool isGraph);
uint64_t beginProfileKernel(
const void* handle,
const std::string& kernelName,
const TensorList& tensors);
uint64_t beginProfileCopy(
const void* srcBuffer,
const void* dstBuffer,
const OptionalTensorRef srcTensor,
const OptionalTensorRef dstTensor,
size_t length, bool isNonBlocking, bool usesBlitter = true);
uint64_t beginProfileCPUFallback(const std::string& opName, const TensorList& tensors);
size_t length,
bool isNonBlocking,
bool usesBlitter = true);
uint64_t beginProfileCPUFallback(
const std::string& opName,
const TensorList& tensors);
void beginProfileGPUInterval(const void* handle);
void endProfileCopy(uint64_t profileId, SyncType syncType);
@ -309,7 +359,8 @@ public:
// logging are enabled for the SignpostTypes
bool isOperationProfilingEnabled() const {
return (m_signpost_types & SignpostTypes::RUN_OPERATION) ||
(m_log_options & (LogOptions::OPERATION_INFO | LogOptions::OPERATION_STATS));
(m_log_options &
(LogOptions::OPERATION_INFO | LogOptions::OPERATION_STATS));
}
bool isCopyProfilingEnabled() const {
return (m_signpost_types & SignpostTypes::BLIT_COPY) ||
@ -317,14 +368,16 @@ public:
}
bool isCPUFallbackProfilingEnabled() const {
return (m_signpost_types & SignpostTypes::CPU_FALLBACK) ||
(m_log_options & (LogOptions::CPU_FALLBACK_INFO | LogOptions::CPU_FALLBACK_STATS));
(m_log_options &
(LogOptions::CPU_FALLBACK_INFO | LogOptions::CPU_FALLBACK_STATS));
}
bool isSignpostTracingEnabled() const {
return (m_signpost_types != SignpostTypes::SIGNPOST_NONE);
}
private:
// indicates what type of signpost types are enabled and traced by MPS profiler.
// indicates what type of signpost types are enabled and traced by MPS
// profiler.
uint32_t m_signpost_types = 0;
uint32_t m_profile_options = 0;
uint32_t m_log_options = 0;
@ -332,14 +385,15 @@ public:
uint64_t m_graph_counter = 0;
uint64_t m_cpu_fb_counter = 0;
uint64_t m_copy_counter = 0;
// technically, it's possible to trace both events and intervals at the same time
// so we use separate os_log categories for them
// technically, it's possible to trace both events and intervals at the same
// time so we use separate os_log categories for them
os_log_t m_os_log_events;
os_log_t m_os_log_intervals;
// stats logging could run either from destructor or signal handler
// so this is used to check if logging has already started.
std::atomic_bool hasLoggedStats{false};
// indicates there are pending completionHandler callbacks that haven't been called yet.
// indicates there are pending completionHandler callbacks that haven't been
// called yet.
std::atomic_bool hasPendingCompletionHandlers{false};
// used to capture sigint signal to log profiling stats
static struct sigaction currentSigint, previousSigint;
@ -347,40 +401,62 @@ public:
// We use the following lists for two reasons:
// 1- for interval-based signposts the "begin" point won't be in same function
// as the "end" point where we need to be able to retrieve signpost's info
// 2- if Operations info need to be logged when process ends using LogOptions::OPERATION_INFO.
// 2- if Operations info need to be logged when process ends using
// LogOptions::OPERATION_INFO.
// the pointer key for this map is either "MPSGraph*" or "id<MTLComputePipelineState>" for Metal Kernels
// this list is retained and could be logged along with aggregate profiling numbers when the process ends.
std::unordered_map<uintptr_t, std::unique_ptr<OperationInfo>> m_op_info_list{};
// the string key for this map is the op name that we fall back to execute on CPU
// this list is retained and could be logged along with aggregate profiling numbers when the process ends.
std::unordered_map<std::string, std::unique_ptr<CpuFbInfo>> m_cpu_fb_info_list{};
// the pointer key for this map is either "MPSGraph*" or
// "id<MTLComputePipelineState>" for Metal Kernels this list is retained and
// could be logged along with aggregate profiling numbers when the process
// ends.
std::unordered_map<uintptr_t, std::unique_ptr<OperationInfo>>
m_op_info_list{};
// the string key for this map is the op name that we fall back to execute on
// CPU this list is retained and could be logged along with aggregate
// profiling numbers when the process ends.
std::unordered_map<std::string, std::unique_ptr<CpuFbInfo>>
m_cpu_fb_info_list{};
// this list contains the info for copies, and its key is the unique profileId
// which is generated from m_copy_counter
// The copyInfo list is not retained.
std::unordered_map<uint64_t, std::unique_ptr<CopyInfo>> m_copy_info_list{};
// a short list that contains copy stats
std::unordered_map<CopyInfo::Kind, std::unique_ptr<CopyStat>> m_copy_stat_list{};
std::unordered_map<CopyInfo::Kind, std::unique_ptr<CopyStat>>
m_copy_stat_list{};
mutable MTLCaptureManager* captureManager = nil;
unsigned captureCount = 0;
void initialize();
void beginProfileExecution(BaseInfo& info, bool cpuExecution = false);
void endProfileExecution(BaseInfo& info, os_signpost_id_t event_signpost_id,
void endProfileExecution(
BaseInfo& info,
os_signpost_id_t event_signpost_id,
os_signpost_id_t interval_signpost_id,
double gpuTime, double schedulingTime);
double gpuTime,
double schedulingTime);
void addProfilerScheduledHandler(BaseInfo& info);
void addProfilerCompletedHandler(BaseInfo& info, SyncType syncType);
void emitSignpostEvent(SignpostTypes signpost_type, os_signpost_id_t signpost_id,
void emitSignpostEvent(
SignpostTypes signpost_type,
os_signpost_id_t signpost_id,
const std::string& msg) const;
void beginSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id,
void beginSignpostInterval(
SignpostTypes signpost_type,
os_signpost_id_t signpost_id,
const std::string& msg) const;
void endSignpostInterval(SignpostTypes signpost_type, os_signpost_id_t signpost_id) const;
void endSignpostInterval(
SignpostTypes signpost_type,
os_signpost_id_t signpost_id) const;
void updateCopyStats(const CopyInfo& copyInfo, double gpuTime, double schedulingTime);
// returns true if logging the profiling info "during the execution" is enabled
bool isProfileInfoLoggingEnabled(BaseInfo::Type infoType, bool isExecutionEnded);
void updateCopyStats(
const CopyInfo& copyInfo,
double gpuTime,
double schedulingTime);
// returns true if logging the profiling info "during the execution" is
// enabled
bool isProfileInfoLoggingEnabled(
BaseInfo::Type infoType,
bool isExecutionEnded);
// logs all the profiling stats that are enabled
void logProfilingStats();
// logs kernel profiling stats when the process ends.
@ -390,7 +466,9 @@ public:
// logs copy profiling stats when the process ends.
void logCopyProfilingStats(std::FILE* f) const;
os_signpost_id_t generateSignpostId(os_signpost_type_t signpostType, const void* ptr = nullptr);
os_signpost_id_t generateSignpostId(
os_signpost_type_t signpostType,
const void* ptr = nullptr);
static SignpostTypes getSignpostType(BaseInfo::Type infoType);
static void handleIntSignal(int signal);
};

View File

@ -5,10 +5,10 @@
#include <cstdint>
#include <utility>
#include <c10/core/DeviceGuard.h>
#include <c10/util/Exception.h>
#include <c10/core/Stream.h>
#include <ATen/mps/MPSDevice.h>
#include <c10/core/DeviceGuard.h>
#include <c10/core/Stream.h>
#include <c10/util/Exception.h>
#ifdef __OBJC__
#include <Foundation/Foundation.h>
@ -32,7 +32,6 @@ typedef void* MTLDevice_t;
#define nil NULL;
#endif
namespace at::mps {
//-----------------------------------------------------------------
@ -47,8 +46,7 @@ enum class SyncType {
COMMIT_ADAPTIVE, // commit adaptively based on available memory
};
class TORCH_API MPSStream
{
class TORCH_API MPSStream {
public:
enum Unchecked { UNCHECKED };
@ -57,32 +55,55 @@ public:
explicit MPSStream(Stream stream);
~MPSStream();
MTLCommandQueue_t commandQueue() const { return _commandQueue; };
dispatch_queue_t queue() const { return _serialQueue; }
MTLCommandQueue_t commandQueue() const {
return _commandQueue;
};
dispatch_queue_t queue() const {
return _serialQueue;
}
MPSCommandBuffer* commandBuffer();
MTLComputeCommandEncoder_t commandEncoder();
void endKernelCoalescing();
void synchronize(SyncType syncType);
void fill(id<MTLBuffer> buffer, uint8_t value, size_t length, size_t offset, SyncType syncType = SyncType::NONE);
void copy(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
size_t length, size_t srcOffset, size_t dstOffset,
uint64_t profileId, SyncType syncType = SyncType::NONE);
void copy_and_sync(id<MTLBuffer> srcBuffer, id<MTLBuffer> dstBuffer,
size_t length, size_t srcOffset, size_t dstOffset,
bool non_blocking, uint64_t profileId);
void executeMPSGraph(MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results, SyncType syncType = SyncType::NONE);
void copy(id<MTLBuffer> srcBuffer,
id<MTLBuffer> dstBuffer,
size_t length,
size_t srcOffset,
size_t dstOffset,
uint64_t profileId,
SyncType syncType = SyncType::NONE);
void copy_and_sync(id<MTLBuffer> srcBuffer,
id<MTLBuffer> dstBuffer,
size_t length,
size_t srcOffset,
size_t dstOffset,
bool non_blocking,
uint64_t profileId);
void executeMPSGraph(MPSGraph* mpsGraph,
NSDictionary* feeds,
NSDictionary* results,
SyncType syncType = SyncType::NONE);
void addCompletedHandler(MTLCommandBufferHandler block);
/// Get the MPS device index that this stream is associated with.
c10::DeviceIndex device_index() const { return _stream.device_index(); }
c10::DeviceIndex device_index() const {
return _stream.device_index();
}
MTLCommandQueue_t stream() const { return _commandQueue; };
MTLCommandQueue_t stream() const {
return _commandQueue;
};
MTLDevice_t device() const { return [_commandQueue device];}
MTLDevice_t device() const {
return [_commandQueue device];
}
/// Explicit conversion to Stream.
Stream unwrap() const { return _stream; }
Stream unwrap() const {
return _stream;
}
private:
Stream _stream;
@ -117,8 +138,7 @@ TORCH_API MPSStream* getDefaultMPSStream();
// MPSStreamImpl
//-----------------------------------------------------------------
class TORCH_API MPSStreamImpl
{
class TORCH_API MPSStreamImpl {
public:
/**
* Gets single instance of the MPSStream.

View File

@ -2,8 +2,7 @@
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
#if !defined(__MAC_15_0) && \
(!defined(MAC_OS_X_VERSION_15_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_15_0))
#if !defined(__MAC_15_0) && (!defined(MAC_OS_X_VERSION_15_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_15_0))
@interface MPSNDArrayIdentity : MPSNDArrayUnaryKernel
- (MPSNDArray* __nullable)reshapeWithCommandBuffer:(__nullable id<MTLCommandBuffer>)cmdBuf
@ -20,19 +19,16 @@
- (nonnull instancetype)initWithBuffer:(id<MTLBuffer> _Nonnull)buffer
offset:(NSUInteger)offset
descriptor:(MPSNDArrayDescriptor* _Nonnull)descriptor;
-(MPSNDArray * __nullable) arrayViewWithShape:(MPSShape * _Nullable) shape
strides:(MPSShape * _Nonnull) strides;
- (MPSNDArray* __nullable)arrayViewWithShape:(MPSShape* _Nullable)shape strides:(MPSShape* _Nonnull)strides;
@end
typedef NS_ENUM(NSInteger, MTLMathMode)
{
typedef NS_ENUM(NSInteger, MTLMathMode) {
MTLMathModeSafe = 0,
MTLMathModeRelaxed = 1,
MTLMathModeFast = 2,
};
typedef NS_ENUM(NSInteger, MTLMathFloatingPointFunctions)
{
typedef NS_ENUM(NSInteger, MTLMathFloatingPointFunctions) {
MTLMathFloatingPointFunctionsFast = 0,
MTLMathFloatingPointFunctionsPrecise = 1,
};

View File

@ -2,11 +2,9 @@
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
#if !defined(__MAC_14_0) && \
(!defined(MAC_OS_X_VERSION_14_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_14_0))
#if !defined(__MAC_14_0) && (!defined(MAC_OS_X_VERSION_14_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_14_0))
typedef NS_ENUM(NSUInteger, MPSGraphFFTScalingMode)
{
typedef NS_ENUM(NSUInteger, MPSGraphFFTScalingMode) {
MPSGraphFFTScalingModeNone = 0L,
MPSGraphFFTScalingModeSize = 1L,
MPSGraphFFTScalingModeUnitary = 2L,
@ -22,12 +20,9 @@ typedef NS_ENUM(NSUInteger, MPSGraphFFTScalingMode)
@compatibility_alias MPSGraphFFTDescriptor FakeMPSGraphFFTDescriptor;
@interface MPSGraph (SonomaOps)
-(MPSGraphTensor * _Nonnull) conjugateWithTensor:(MPSGraphTensor * _Nonnull) tensor
name:(NSString * _Nullable) name;
-(MPSGraphTensor * _Nonnull) realPartOfTensor:(MPSGraphTensor * _Nonnull) tensor
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)conjugateWithTensor:(MPSGraphTensor* _Nonnull)tensor name:(NSString* _Nullable)name;
- (MPSGraphTensor* _Nonnull)realPartOfTensor:(MPSGraphTensor* _Nonnull)tensor name:(NSString* _Nullable)name;
- (MPSGraphTensor* _Nonnull)fastFourierTransformWithTensor:(MPSGraphTensor* _Nonnull)tensor
axes:(NSArray<NSNumber*>* _Nonnull)axes

View File

@ -2,8 +2,7 @@
#include <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
// TODO: Remove me when moved to MacOS 13
#if !defined(__MAC_13_2) && \
(!defined(MAC_OS_X_VERSION_13_2) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_2))
#if !defined(__MAC_13_2) && (!defined(MAC_OS_X_VERSION_13_2) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_2))
@interface FakeMPSGraphConvolution3DOpDescriptor : NSObject<NSCopying>
@ -35,11 +34,9 @@
@interface MPSGraph (VenturaOps)
#if !defined(__MAC_13_0) && \
(!defined(MAC_OS_X_VERSION_13_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_0))
#if !defined(__MAC_13_0) && (!defined(MAC_OS_X_VERSION_13_0) || (MAC_OS_X_VERSION_MIN_REQUIRED < MAC_OS_X_VERSION_13_0))
typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode)
{
typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode) {
MPSGraphResizeNearestRoundingModeRoundPreferCeil = 0L,
MPSGraphResizeNearestRoundingModeRoundPreferFloor = 1L,
MPSGraphResizeNearestRoundingModeCeil = 2L,
@ -59,16 +56,20 @@ typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode)
descriptor:(MPSGraphConvolution3DOpDescriptor* _Nonnull)descriptor
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) convolution3DDataGradientWithIncomingGradientTensor:(MPSGraphTensor * _Nonnull) incomingGradient
- (MPSGraphTensor* _Nonnull)
convolution3DDataGradientWithIncomingGradientTensor:(MPSGraphTensor* _Nonnull)incomingGradient
weightsTensor:(MPSGraphTensor* _Nonnull)weights
outputShape:(MPSShape* _Nonnull)outputShape
forwardConvolutionDescriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) forwardConvolutionDescriptor
forwardConvolutionDescriptor:
(MPSGraphConvolution3DOpDescriptor* _Nonnull)forwardConvolutionDescriptor
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) convolution3DWeightsGradientWithIncomingGradientTensor:(MPSGraphTensor * _Nonnull) incomingGradient
- (MPSGraphTensor* _Nonnull)
convolution3DWeightsGradientWithIncomingGradientTensor:(MPSGraphTensor* _Nonnull)incomingGradient
sourceTensor:(MPSGraphTensor* _Nonnull)source
outputShape:(MPSShape* _Nonnull)outputShape
forwardConvolutionDescriptor:(MPSGraphConvolution3DOpDescriptor * _Nonnull) forwardConvolutionDescriptor
forwardConvolutionDescriptor:
(MPSGraphConvolution3DOpDescriptor* _Nonnull)forwardConvolutionDescriptor
name:(NSString* _Nullable)name;
- (MPSGraphTensor* _Nonnull)cumulativeSumWithTensor:(MPSGraphTensor* _Nonnull)tensor
@ -111,8 +112,7 @@ typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode)
axisTensor:(MPSGraphTensor* _Nonnull)axisTensor
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull)inverseOfTensor:(MPSGraphTensor * _Nonnull) inputTensor
name:(NSString * _Nullable)name;
- (MPSGraphTensor* _Nonnull)inverseOfTensor:(MPSGraphTensor* _Nonnull)inputTensor name:(NSString* _Nullable)name;
- (MPSGraphTensor* _Nonnull)resizeNearestWithTensor:(MPSGraphTensor* _Nonnull)imagesTensor
sizeTensor:(MPSGraphTensor* _Nonnull)size
@ -191,7 +191,6 @@ typedef NS_ENUM(NSUInteger, MPSGraphResizeNearestRoundingMode)
nearestRoundingMode:(MPSGraphResizeNearestRoundingMode)nearestRoundingMode
constantValue:(double)constantValue
name:(NSString* _Nullable)name;
- (MPSGraphTensor * _Nonnull) truncateWithTensor:(MPSGraphTensor * _Nonnull) tensor
name:(NSString * _Nullable) name;
- (MPSGraphTensor* _Nonnull)truncateWithTensor:(MPSGraphTensor* _Nonnull)tensor name:(NSString* _Nullable)name;
@end

View File

@ -35,7 +35,9 @@ namespace at::native::mps {
void dispatch_sync_with_rethrow(dispatch_queue_t queue, void (^block)());
struct MPSScalar {
id<MTLBuffer> getMTLBuffer() const { return __builtin_bit_cast(id<MTLBuffer>, buffer.get()); }
id<MTLBuffer> getMTLBuffer() const {
return __builtin_bit_cast(id<MTLBuffer>, buffer.get());
}
size_t size = 0;
ScalarType type = ScalarType::Undefined;
@ -51,10 +53,7 @@ struct MPSScalar {
} value{};
};
void runMPSGraph(MPSStream* mpsStream,
MPSGraph* mpsGraph,
NSDictionary* feeds,
NSDictionary* results);
void runMPSGraph(MPSStream* mpsStream, MPSGraph* mpsGraph, NSDictionary* feeds, NSDictionary* results);
MPSDataType getMPSDataType(ScalarType scalar_type);
static inline MPSDataType getMPSDataType(const TensorBase& t) {
@ -82,9 +81,17 @@ std::string getArrayRefString(const IntArrayRef s);
Tensor gatherViewTensor(const Tensor& src, Tensor& dst);
Tensor& scatterViewTensor(const Tensor& src, Tensor& output);
bool canSliceViewTensor(const TensorBase& src, MPSShape* mpsShape);
MPSGraphTensorData* getMPSGraphTensorDataForView(const TensorBase& src, MPSShape *mpsShape, const MPSDataType mpsDataType);
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input, bool includesInt64 = false);
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor, const TensorBase& input, bool includesInt64 = false);
MPSGraphTensorData* getMPSGraphTensorDataForView(const TensorBase& src,
MPSShape* mpsShape,
const MPSDataType mpsDataType);
MPSGraphTensor* castToIHFTypes(MPSGraph* mpsGraph,
MPSGraphTensor* inputTensor,
const TensorBase& input,
bool includesInt64 = false);
MPSGraphTensor* castFromIHFTypes(MPSGraph* mpsGraph,
MPSGraphTensor* inputTensor,
const TensorBase& input,
bool includesInt64 = false);
MPSNDArray* getMPSNDArray(const TensorBase& t, const IntArrayRef& sizes = {}, const IntArrayRef& strides = {});
MPSNDArray* getMPSNDArray(const TensorBase& t, MPSShape* sizes = nil, MPSShape* strides = nil);
@ -102,8 +109,12 @@ class Placeholder {
Placeholder() : _placeholder(nullptr), _value(nullptr), _tensor(Tensor()) {}
Placeholder(MPSGraphTensor* mpsGraphTensor) : _placeholder(mpsGraphTensor), _value(nullptr), _tensor(Tensor()) {}
Placeholder(MPSGraphTensor* mpsGraphTensor, MPSNDArray* mpsNDArray);
Placeholder(MPSGraphTensor* mpsGraphTensor, const Tensor& self, MPSShape *mpsShape = nullptr,
bool gatherTensorData = true, MPSDataType dataType = MPSDataTypeInvalid, bool useMPSStridedAPI = true);
Placeholder(MPSGraphTensor* mpsGraphTensor,
const Tensor& self,
MPSShape* mpsShape = nullptr,
bool gatherTensorData = true,
MPSDataType dataType = MPSDataTypeInvalid,
bool useMPSStridedAPI = true);
MPSGraphTensor* getMPSGraphTensor() {
return _placeholder;
}
@ -145,8 +156,7 @@ using MPSCacheKey = uint64_t;
// derive this class to cache a graph and its inputs/outputs
// can be used to store any NSObject
struct MPSCachedGraph
{
struct MPSCachedGraph {
MPSCachedGraph(NSObject* object) : _object([object retain]) {}
virtual ~MPSCachedGraph() {
[_object release];
@ -158,21 +168,24 @@ struct MPSCachedGraph
return static_cast<T*>(this);
}
MPSGraph *graph() const { return (MPSGraph *)_object; }
NSObject *object() const { return _object; }
MPSGraph* graph() const {
return (MPSGraph*)_object;
}
NSObject* object() const {
return _object;
}
private:
NSObject* _object = nullptr;
};
struct MPSUnaryCachedGraph : public MPSCachedGraph
{
struct MPSUnaryCachedGraph : public MPSCachedGraph {
MPSUnaryCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
struct MPSUnaryGradCachedGraph : public MPSCachedGraph
{
struct MPSUnaryGradCachedGraph : public MPSCachedGraph {
MPSUnaryGradCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* gradOutputTensor_ = nil;
MPSGraphTensor* inputTensor_ = nil;
@ -180,16 +193,14 @@ struct MPSUnaryGradCachedGraph : public MPSCachedGraph
MPSGraphTensor* gradInputTensor_ = nil;
};
struct MPSBinaryCachedGraph : public MPSCachedGraph
{
struct MPSBinaryCachedGraph : public MPSCachedGraph {
MPSBinaryCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* inputTensor_ = nil;
MPSGraphTensor* otherTensor_ = nil;
MPSGraphTensor* outputTensor_ = nil;
};
struct MPSBinaryGradCachedGraph : public MPSCachedGraph
{
struct MPSBinaryGradCachedGraph : public MPSCachedGraph {
MPSBinaryGradCachedGraph(MPSGraph* graph) : MPSCachedGraph(graph) {}
MPSGraphTensor* gradOutputTensor_ = nil;
MPSGraphTensor* inputTensor_ = nil;
@ -200,8 +211,7 @@ struct MPSBinaryGradCachedGraph : public MPSCachedGraph
// TODO: Improve the overall design of MPSGraphCache.
// https://github.com/pytorch/pytorch/issues/77176
// Cache holding various keys mapped to graphs
struct MPSGraphCache
{
struct MPSGraphCache {
typedef MPSCachedGraph* (^CreateCachedGraphBlock)();
struct CacheEntry {
@ -211,7 +221,6 @@ struct MPSGraphCache
};
public:
static MPSGraphCache* getInstance() {
if (_instance_cache == nullptr) {
_instance_cache = new MPSGraphCache();
@ -232,7 +241,6 @@ struct MPSGraphCache
void operator=(const MPSGraphCache&) = delete;
MPSCachedGraph* CreateCachedGraph(const std::string& key, CreateCachedGraphBlock createCacheBlock) {
__block MPSCachedGraph* cachedGraph = nil;
MPSCacheKey hash = std::hash<std::string>{}(key);
@ -259,13 +267,11 @@ struct MPSGraphCache
}
MPSCachedGraph* LookUp(const std::string& key) const {
__block MPSCachedGraph* cachedGraph = nullptr;
MPSCacheKey hash = std::hash<std::string>{}(key);
dispatch_sync(serialQueue_, ^() {
if (cache_.count(hash) != 0) {
auto& entry = cache_.at(hash);
TORCH_INTERNAL_ASSERT_DEBUG_ONLY(key == entry.key_, "Key collision in the MPS cached graph!\n");
@ -292,7 +298,6 @@ struct MPSGraphCache
static MPSGraphCache* _instance_cache;
std::unordered_map<MPSCacheKey, CacheEntry> cache_;
dispatch_queue_t serialQueue_ = nullptr;
};
// Common template for creating graph with a specified cache if missing
@ -319,7 +324,9 @@ MPSGraphTensor* log1p(MPSGraph* mpsGraph, MPSGraphTensor* inputTensor);
#define MPS_CHECK_INT64_OP_SUPPORTED(input_tensor, mac_os_13_3_plus, op_name) \
if (!mac_os_13_3_plus && input_tensor.scalar_type() == kLong) { \
TORCH_WARN_ONCE("MPS: no support for int64 for ", op_name, \
TORCH_WARN_ONCE( \
"MPS: no support for int64 for ", \
op_name, \
", downcasting to a smaller data type (int32/float32). Native support for int64 has been added in macOS 13.3."); \
}
@ -335,17 +342,19 @@ inline bool is_dense_in_storage(const TensorBase& t) {
return compute_storage_numel_distance(t) == static_cast<size_t>(t.numel());
}
class MetalShaderLibrary {
public:
MetalShaderLibrary(const std::string& src) : shaderSource(src), nparams(0), compile_options(nullptr) {}
MetalShaderLibrary(const std::string& src, unsigned nparams_): shaderSource(src), nparams(nparams_), compile_options(nullptr){}
MetalShaderLibrary(const std::string& src, unsigned nparams_, MTLCompileOptions* compile_options_): shaderSource(src), nparams(nparams_), compile_options(compile_options_) {}
MetalShaderLibrary(const std::string& src, unsigned nparams_)
: shaderSource(src), nparams(nparams_), compile_options(nullptr) {}
MetalShaderLibrary(const std::string& src, unsigned nparams_, MTLCompileOptions* compile_options_)
: shaderSource(src), nparams(nparams_), compile_options(compile_options_) {}
MetalShaderLibrary(const MetalShaderLibrary&) = delete;
inline id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname) {
return getLibraryPipelineState(getLibrary(), fname).first;
}
id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname, const std::initializer_list<std::string>& params) {
id<MTLComputePipelineState> getPipelineStateForFunc(const std::string& fname,
const std::initializer_list<std::string>& params) {
return getLibraryPipelineState(getLibrary(params), fname).first;
}
inline id<MTLFunction> getMTLFunction(const std::string& fname) {
@ -355,12 +364,15 @@ public:
return getLibraryPipelineState(getLibrary(params), fname).second;
}
static MetalShaderLibrary& getBundledLibrary();
protected:
virtual id<MTLLibrary> getLibrary();
virtual id<MTLLibrary> getLibrary(const std::initializer_list<std::string>& params);
id<MTLLibrary> library = nil;
private:
std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getLibraryPipelineState(id<MTLLibrary> lib, const std::string& fname);
std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getLibraryPipelineState(id<MTLLibrary> lib,
const std::string& fname);
id<MTLLibrary> compileLibrary(const std::string& src);
std::string shaderSource;
@ -371,21 +383,18 @@ private:
};
template <typename encoder_t,
typename = std::enable_if_t<std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t> || std::is_same_v<id<MTLArgumentEncoder>, encoder_t>>>
typename = std::enable_if_t<std::is_same_v<id<MTLComputeCommandEncoder>, encoder_t> ||
std::is_same_v<id<MTLArgumentEncoder>, encoder_t>>>
static inline void mtl_setBuffer(encoder_t encoder, const TensorBase& t, unsigned idx) {
[encoder setBuffer:getMTLBufferStorage(t)
offset:t.storage_offset() * t.element_size()
atIndex:idx];
[encoder setBuffer:getMTLBufferStorage(t) offset:t.storage_offset() * t.element_size() atIndex:idx];
}
template<typename T,
typename = std::enable_if_t<std::is_integral_v<T> || std::is_same_v<T, float>>>
template <typename T, typename = std::enable_if_t<std::is_integral_v<T> || std::is_same_v<T, float>>>
static inline void mtl_setBytes(id<MTLComputeCommandEncoder> encoder, const T val, unsigned idx) {
[encoder setBytes:&val length:sizeof(T) atIndex:idx];
}
template<typename Container,
typename = std::enable_if_t<std::is_integral_v<typename Container::size_type>>>
template <typename Container, typename = std::enable_if_t<std::is_integral_v<typename Container::size_type>>>
static inline void mtl_setBytes(id<MTLComputeCommandEncoder> encoder, const Container& values, unsigned idx) {
[encoder setBytes:values.data() length:sizeof(typename Container::value_type) * values.size() atIndex:idx];
}
@ -400,7 +409,9 @@ static inline void mtl_dispatch1DJob(id<MTLComputeCommandEncoder> encoder,
[encoder dispatchThreads:size threadsPerThreadgroup:threadGroupSize];
}
id<MTLBuffer> generateKernelDataOffsets(id<MTLComputeCommandEncoder> commandEncoder, const TensorIteratorBase& iter, bool use_64bit_index = false);
id<MTLBuffer> generateKernelDataOffsets(id<MTLComputeCommandEncoder> commandEncoder,
const TensorIteratorBase& iter,
bool use_64bit_index = false);
inline NSDictionary* dictionaryFromPlaceholders(Placeholder& p1) {
return @{p1.getMPSGraphTensor() : p1.getMPSGraphTensorData()};

View File

@ -2,10 +2,11 @@
#define AT_DISPATCH_MPS_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
TYPE, \
NAME, \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) AT_DISPATCH_CASE( \
at::ScalarType::Half, \
__VA_ARGS__) AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \

View File

@ -1,5 +1,8 @@
#pragma once
namespace at::native::mps {
void complex_mul_out(const Tensor& input, const Tensor& other, const Tensor& output);
void complex_mul_out(
const Tensor& input,
const Tensor& other,
const Tensor& output);
}

View File

@ -17,8 +17,7 @@ void _fused_adam_amsgrad_mps_impl_(
const double eps,
const bool maximize,
const std::optional<Tensor>& grad_scale,
const std::optional<Tensor>& found_inf
);
const std::optional<Tensor>& found_inf);
void _fused_adam_amsgrad_mps_impl_(
TensorList params,
@ -34,7 +33,6 @@ void _fused_adam_amsgrad_mps_impl_(
const double eps,
const bool maximize,
const std::optional<at::Tensor>& grad_scale,
const std::optional<at::Tensor>& found_inf
);
const std::optional<at::Tensor>& found_inf);
} // namespace at::native::mps

View File

@ -16,8 +16,7 @@ void _fused_adam_mps_impl_(
const double eps,
const bool maximize,
const std::optional<Tensor>& grad_scale,
const std::optional<Tensor>& found_inf
);
const std::optional<Tensor>& found_inf);
void _fused_adam_mps_impl_(
TensorList params,
@ -32,6 +31,5 @@ void _fused_adam_mps_impl_(
const double eps,
const bool maximize,
const std::optional<Tensor>& grad_scale,
const std::optional<Tensor>& found_inf
);
const std::optional<Tensor>& found_inf);
} // namespace at::native::mps

View File

@ -17,8 +17,7 @@ void _fused_adamw_amsgrad_mps_impl_(
const double eps,
const bool maximize,
const std::optional<Tensor>& grad_scale,
const std::optional<Tensor>& found_inf
);
const std::optional<Tensor>& found_inf);
void _fused_adamw_amsgrad_mps_impl_(
TensorList params,
@ -34,6 +33,5 @@ void _fused_adamw_amsgrad_mps_impl_(
const double eps,
const bool maximize,
const std::optional<Tensor>& grad_scale,
const std::optional<Tensor>& found_inf
);
const std::optional<Tensor>& found_inf);
} // namespace at::native::mps

View File

@ -16,8 +16,7 @@ void _fused_adamw_mps_impl_(
const double eps,
const bool maximize,
const std::optional<Tensor>& grad_scale,
const std::optional<Tensor>& found_inf
);
const std::optional<Tensor>& found_inf);
void _fused_adamw_mps_impl_(
TensorList params,
@ -32,7 +31,6 @@ void _fused_adamw_mps_impl_(
const double eps,
const bool maximize,
const std::optional<Tensor>& grad_scale,
const std::optional<Tensor>& found_inf
);
const std::optional<Tensor>& found_inf);
} // namespace at::native::mps

View File

@ -436,9 +436,11 @@ REGISTER_FUSED_SGD_MOMENTUM_OP(half);
)METAL";
static std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getCPLState(const std::string& fname) {
static std::pair<id<MTLComputePipelineState>, id<MTLFunction>> getCPLState(
const std::string& fname) {
static MetalShaderLibrary lib(FUSED_ADAM_OPS, 0);
return std::make_pair(lib.getPipelineStateForFunc(fname), lib.getMTLFunction(fname));
return std::make_pair(
lib.getPipelineStateForFunc(fname), lib.getMTLFunction(fname));
}
} // namespace mps

View File

@ -16,8 +16,7 @@ struct MetadataArguments { // the size of this struct must be less than 4 bytes
};
struct FusedAdamEncodingFunctor {
void operator()(
id<MTLComputeCommandEncoder>& computeEncoder,
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
id<MTLBuffer>& tensorArgumentBuffer,
const MetadataArguments& metadata_arguments,
const double lr,
@ -25,9 +24,7 @@ struct FusedAdamEncodingFunctor {
const double beta2,
const double weight_decay,
const double eps,
const bool maximize
) const {
const bool maximize) const {
float lr_lv = lr;
float beta1_lv = beta1;
float beta2_lv = beta2;
@ -35,12 +32,8 @@ struct FusedAdamEncodingFunctor {
float eps_lv = eps;
uint8_t maximize_lv = maximize;
[computeEncoder setBuffer:tensorArgumentBuffer
offset:0
atIndex:0];
[computeEncoder setBytes:&metadata_arguments
length:sizeof(MetadataArguments)
atIndex:1];
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
mtl_setBytes(computeEncoder, lr_lv, 2);
mtl_setBytes(computeEncoder, beta1_lv, 3);
mtl_setBytes(computeEncoder, beta2_lv, 4);
@ -49,8 +42,7 @@ struct FusedAdamEncodingFunctor {
mtl_setBytes(computeEncoder, maximize_lv, 7);
}
void operator()(
id<MTLComputeCommandEncoder>& computeEncoder,
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
id<MTLBuffer>& tensorArgumentBuffer,
const MetadataArguments& metadata_arguments,
const at::Tensor& lr,
@ -58,20 +50,15 @@ struct FusedAdamEncodingFunctor {
const double beta2,
const double weight_decay,
const double eps,
const bool maximize
) const {
const bool maximize) const {
float beta1_lv = beta1;
float beta2_lv = beta2;
float weight_decay_lv = weight_decay;
float eps_lv = eps;
uint8_t maximize_lv = maximize;
[computeEncoder setBuffer:tensorArgumentBuffer
offset:0
atIndex:0];
[computeEncoder setBytes:&metadata_arguments
length:sizeof(MetadataArguments)
atIndex:1];
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
mtl_setBuffer(computeEncoder, lr, 2);
mtl_setBytes(computeEncoder, beta1_lv, 3);
mtl_setBytes(computeEncoder, beta2_lv, 4);
@ -86,8 +73,7 @@ struct FusedSgdEncodingFunctor {};
template <>
struct FusedSgdEncodingFunctor<true> {
void operator()(
id<MTLComputeCommandEncoder>& computeEncoder,
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
id<MTLBuffer>& tensorArgumentBuffer,
const MetadataArguments& metadata_arguments,
const double weight_decay,
@ -96,8 +82,7 @@ struct FusedSgdEncodingFunctor<true> {
const double dampening,
const bool nesterov,
const bool maximize,
const bool is_first_step
) const {
const bool is_first_step) const {
float weight_decay_lv = weight_decay;
float momentum_lv = momentum;
float lr_lv = lr;
@ -106,12 +91,8 @@ struct FusedSgdEncodingFunctor<true> {
uint8_t maximize_lv = maximize;
uint8_t is_first_step_lv = is_first_step;
[computeEncoder setBuffer:tensorArgumentBuffer
offset:0
atIndex:0];
[computeEncoder setBytes:&metadata_arguments
length:sizeof(MetadataArguments)
atIndex:1];
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
mtl_setBytes(computeEncoder, momentum_lv, 3);
mtl_setBytes(computeEncoder, lr_lv, 4);
@ -121,8 +102,7 @@ struct FusedSgdEncodingFunctor<true> {
mtl_setBytes(computeEncoder, is_first_step_lv, 8);
}
void operator()(
id<MTLComputeCommandEncoder>& computeEncoder,
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
id<MTLBuffer>& tensorArgumentBuffer,
const MetadataArguments& metadata_arguments,
const double weight_decay,
@ -131,8 +111,7 @@ struct FusedSgdEncodingFunctor<true> {
const double dampening,
const bool nesterov,
const bool maximize,
const bool is_first_step
) const {
const bool is_first_step) const {
float weight_decay_lv = weight_decay;
float momentum_lv = momentum;
float dampening_lv = dampening;
@ -140,12 +119,8 @@ struct FusedSgdEncodingFunctor<true> {
uint8_t maximize_lv = maximize;
uint8_t is_first_step_lv = is_first_step;
[computeEncoder setBuffer:tensorArgumentBuffer
offset:0
atIndex:0];
[computeEncoder setBytes:&metadata_arguments
length:sizeof(MetadataArguments)
atIndex:1];
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
mtl_setBytes(computeEncoder, momentum_lv, 3);
mtl_setBuffer(computeEncoder, lr, 4);
@ -158,46 +133,34 @@ struct FusedSgdEncodingFunctor<true> {
template <>
struct FusedSgdEncodingFunctor<false> {
void operator()(
id<MTLComputeCommandEncoder>& computeEncoder,
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
id<MTLBuffer>& tensorArgumentBuffer,
const MetadataArguments& metadata_arguments,
const double weight_decay,
const double lr,
const bool maximize
) const {
const bool maximize) const {
float weight_decay_lv = weight_decay;
float lr_lv = lr;
uint8_t maximize_lv = maximize;
[computeEncoder setBuffer:tensorArgumentBuffer
offset:0
atIndex:0];
[computeEncoder setBytes:&metadata_arguments
length:sizeof(MetadataArguments)
atIndex:1];
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
mtl_setBytes(computeEncoder, lr_lv, 3);
mtl_setBytes(computeEncoder, maximize_lv, 4);
}
void operator()(
id<MTLComputeCommandEncoder>& computeEncoder,
void operator()(id<MTLComputeCommandEncoder>& computeEncoder,
id<MTLBuffer>& tensorArgumentBuffer,
const MetadataArguments& metadata_arguments,
const double weight_decay,
const at::Tensor& lr,
const bool maximize
) const {
const bool maximize) const {
float weight_decay_lv = weight_decay;
uint8_t maximize_lv = maximize;
[computeEncoder setBuffer:tensorArgumentBuffer
offset:0
atIndex:0];
[computeEncoder setBytes:&metadata_arguments
length:sizeof(MetadataArguments)
atIndex:1];
[computeEncoder setBuffer:tensorArgumentBuffer offset:0 atIndex:0];
[computeEncoder setBytes:&metadata_arguments length:sizeof(MetadataArguments) atIndex:1];
mtl_setBytes(computeEncoder, weight_decay_lv, 2);
mtl_setBuffer(computeEncoder, lr, 3);
mtl_setBytes(computeEncoder, maximize_lv, 4);
@ -205,25 +168,22 @@ struct FusedSgdEncodingFunctor<false> {
};
template <int depth, uint32_t kThreadGroupSize, typename encoder_func_t, typename... ArgTypes>
static void multi_tensor_apply_for_fused_optimizer(
const std::string& kernel_name,
static void multi_tensor_apply_for_fused_optimizer(const std::string& kernel_name,
std::vector<std::vector<at::Tensor>>& tensor_lists,
at::TensorList state_steps,
encoder_func_t encode,
ArgTypes... args
) {
ArgTypes... args) {
const auto num_tensors = tensor_lists[0].size();
if (num_tensors == 0) {
return;
}
TORCH_CHECK(
tensor_lists.size() == depth,
"Number of tensor lists has to match the depth");
TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth");
for (const auto& d : c10::irange(depth)) {
TORCH_CHECK(
tensor_lists[d][0].scalar_type() == at::ScalarType::Float || tensor_lists[d][0].scalar_type() == at::ScalarType::Half, "Only float and half are supported");
TORCH_CHECK(tensor_lists[d][0].scalar_type() == at::ScalarType::Float ||
tensor_lists[d][0].scalar_type() == at::ScalarType::Half,
"Only float and half are supported");
}
id<MTLDevice> device = MPSDevice::getInstance()->device();
@ -251,7 +211,8 @@ static void multi_tensor_apply_for_fused_optimizer(
// BufferIndex is the index in the kernel function
auto tensorArgumentEncoder = [[fusedOptimizerFunc newArgumentEncoderWithBufferIndex:0] autorelease];
id<MTLBuffer> tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
id<MTLBuffer> tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength
options:0] autorelease];
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
int64_t tensor_loc = 0;
@ -266,7 +227,8 @@ static void multi_tensor_apply_for_fused_optimizer(
for (const auto& d : c10::irange(depth)) {
mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors + tensor_loc);
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageRead | MTLResourceUsageWrite];
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index])
usage:MTLResourceUsageRead | MTLResourceUsageWrite];
}
if (state_steps.size() > 0) {
mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors + tensor_loc);
@ -302,18 +264,21 @@ static void multi_tensor_apply_for_fused_optimizer(
if (chunk == chunks - 1) {
// last chunk
tensor_loc = 0;
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength
options:0] autorelease];
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
} else {
// reuse the current tensor since the current one isn't done.
metadata_arguments.numels[0] = metadata_arguments.numels[tensor_loc - 1];
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength options:0] autorelease];
tensorArgumentBuffer = [[device newBufferWithLength:tensorArgumentEncoder.encodedLength
options:0] autorelease];
[tensorArgumentEncoder setArgumentBuffer:tensorArgumentBuffer offset:0];
for (const auto& d : c10::irange(depth)) {
mtl_setBuffer(tensorArgumentEncoder, tensor_lists[d][tensor_index], d * kmaxTensors);
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index]) usage:MTLResourceUsageWrite | MTLResourceUsageRead];
[computeEncoder useResource:getMTLBufferStorage(tensor_lists[d][tensor_index])
usage:MTLResourceUsageWrite | MTLResourceUsageRead];
}
if (state_steps.size() > 0) {
mtl_setBuffer(tensorArgumentEncoder, state_steps[tensor_index], depth * kmaxTensors);
@ -334,7 +299,6 @@ static void multi_tensor_apply_for_fused_optimizer(
}
getMPSProfiler().endProfileKernel(fusedOptimizerPSO);
}
});
}