Skip to content

Commit 29625e8

Browse files
Metal backend: split et_metal_ops.mm (#17353)
Split et_metal_ops.mm into multiple files under runtime/ops.
1 parent 429925d commit 29625e8

File tree

10 files changed

+3355
-3306
lines changed

10 files changed

+3355
-3306
lines changed

backends/apple/metal/CMakeLists.txt

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,15 @@ set(_aoti_metal_sources
3939
runtime/stats.cpp
4040
runtime/shims/memory.cpp
4141
runtime/shims/et_metal.mm
42-
runtime/shims/et_metal_ops.mm
4342
runtime/shims/shim_mps.mm
4443
runtime/shims/tensor_attribute.cpp
4544
runtime/shims/utils.cpp
45+
runtime/ops/common.mm
46+
runtime/ops/op_bmm.mm
47+
runtime/ops/op_convolution.mm
48+
runtime/ops/op_linear_4bit.mm
49+
runtime/ops/op_mm.mm
50+
runtime/ops/op_sdpa.mm
4651
)
4752

4853
add_library(metal_backend STATIC ${_aoti_metal_sources})
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#pragma once
10+
11+
#ifdef __OBJC__
12+
#import <Foundation/Foundation.h>
13+
#import <Metal/Metal.h>
14+
#import <MetalPerformanceShaders/MetalPerformanceShaders.h>
15+
#import <MetalPerformanceShadersGraph/MetalPerformanceShadersGraph.h>
16+
typedef id<MTLBuffer> MTLBuffer_t;
17+
typedef MPSGraph* MPSGraph_t;
18+
typedef MPSGraphTensor* MPSGraphTensor_t;
19+
typedef void (^dispatch_block_t)();
20+
#else
21+
typedef void* MTLBuffer_t;
22+
typedef void* MPSGraph_t;
23+
typedef void* MPSGraphTensor_t;
24+
typedef void* dispatch_block_t;
25+
#endif
26+
27+
#include <executorch/backends/apple/metal/runtime/shims/et_metal.h>
28+
#include <executorch/backends/apple/metal/runtime/shims/memory.h>
29+
#include <executorch/backends/apple/metal/runtime/shims/shim_mps.h>
30+
#include <executorch/backends/apple/metal/runtime/shims/utils.h>
31+
#include <executorch/runtime/core/exec_aten/exec_aten.h>
32+
#include <executorch/runtime/platform/log.h>
33+
#include <functional>
34+
#include <memory>
35+
#include <unordered_map>
36+
37+
namespace executorch {
38+
namespace backends {
39+
namespace metal {
40+
41+
using executorch::runtime::etensor::Tensor;
42+
43+
void dispatch_sync_with_rethrow(dispatch_queue_t queue, dispatch_block_t block);
44+
45+
extern std::unordered_map<void*, MTLBuffer_t> ptr_to_mtl_buffer;
46+
47+
struct GraphCacheKey {
48+
std::string op_name;
49+
std::vector<int64_t> shape_params;
50+
int32_t dtype;
51+
bool transpose_flag;
52+
53+
bool operator==(const GraphCacheKey& other) const {
54+
return op_name == other.op_name && shape_params == other.shape_params &&
55+
dtype == other.dtype && transpose_flag == other.transpose_flag;
56+
}
57+
};
58+
59+
struct GraphCacheKeyHash {
60+
std::size_t operator()(const GraphCacheKey& key) const {
61+
std::size_t hash = std::hash<std::string>{}(key.op_name);
62+
for (auto val : key.shape_params) {
63+
hash ^=
64+
std::hash<int64_t>{}(val) + 0x9e3779b9 + (hash << 6) + (hash >> 2);
65+
}
66+
hash ^= std::hash<int32_t>{}(key.dtype) + 0x9e3779b9 + (hash << 6) +
67+
(hash >> 2);
68+
hash ^= std::hash<bool>{}(key.transpose_flag) + 0x9e3779b9 + (hash << 6) +
69+
(hash >> 2);
70+
return hash;
71+
}
72+
};
73+
74+
struct CachedGraph {
75+
MPSGraph_t graph;
76+
MPSGraphTensor_t input1;
77+
MPSGraphTensor_t input2;
78+
MPSGraphTensor_t input3;
79+
MPSGraphTensor_t output;
80+
};
81+
82+
struct CacheStats {
83+
size_t hits = 0;
84+
size_t misses = 0;
85+
86+
void logStats() {
87+
if ((hits + misses) % 100 == 0 && (hits + misses) > 0) {
88+
double hit_rate = 100.0 * hits / (hits + misses);
89+
ET_LOG(
90+
Debug,
91+
"MPSGraph cache stats: %zu hits, %zu misses (%.1f%% hit rate)",
92+
hits,
93+
misses,
94+
hit_rate);
95+
}
96+
}
97+
};
98+
99+
extern std::unordered_map<GraphCacheKey, CachedGraph, GraphCacheKeyHash>
100+
graph_cache;
101+
extern CacheStats cache_stats;
102+
103+
MTLBuffer_t
104+
get_mtl_buffer(Tensor* tensor, const char* op_name, const char* tensor_name);
105+
MTLBuffer_t allocate_mtl_buffer(void** data_ptr, size_t size_bytes);
106+
107+
} // namespace metal
108+
} // namespace backends
109+
} // namespace executorch
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
#include <executorch/backends/apple/metal/runtime/ops/common.h>
10+
11+
namespace executorch {
12+
namespace backends {
13+
namespace metal {
14+
15+
std::unordered_map<GraphCacheKey, CachedGraph, GraphCacheKeyHash> graph_cache;
16+
CacheStats cache_stats;
17+
18+
id<MTLBuffer> get_mtl_buffer(Tensor* tensor, const char* op_name, const char* tensor_name) {
19+
void* data_ptr = tensor->mutable_data_ptr();
20+
auto it = ptr_to_mtl_buffer.find(data_ptr);
21+
if (it == ptr_to_mtl_buffer.end()) {
22+
ET_LOG(Error, "%s: %s tensor not found in Metal buffer mapping", op_name, tensor_name);
23+
throw std::runtime_error(std::string(tensor_name) + " tensor not found in Metal buffer mapping");
24+
}
25+
return it->second;
26+
}
27+
28+
id<MTLBuffer> allocate_mtl_buffer(void** data_ptr, size_t size_bytes) {
29+
AOTITorchError malloc_err = aoti_torch_mps_malloc(data_ptr, size_bytes);
30+
if (malloc_err != Error::Ok) {
31+
ET_LOG(Error, "allocate_and_register_mtl_buffer: Failed to allocate Metal buffer via aoti_torch_mps_malloc");
32+
throw std::runtime_error("Failed to allocate output Metal buffer");
33+
}
34+
35+
auto it = ptr_to_mtl_buffer.find(*data_ptr);
36+
if (it == ptr_to_mtl_buffer.end()) {
37+
ET_LOG(Error, "allocate_and_register_mtl_buffer: aoti_torch_mps_malloc did not register buffer in map");
38+
throw std::runtime_error("Failed to look up allocated Metal buffer");
39+
}
40+
return it->second;
41+
}
42+
43+
} // namespace metal
44+
} // namespace backends
45+
} // namespace executorch

0 commit comments

Comments
 (0)