Skip to content

Commit e07b1ad

Browse files
committed
[Metal Shader Converter] Use explicit layout via IRRootSignature
1 parent 112914d commit e07b1ad

File tree

8 files changed

+181
-54
lines changed

8 files changed

+181
-54
lines changed

src/FlyCube/BindingSet/MTBindingSet.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,6 @@ class MTBindingSet : public BindingSetBase {
2525
MTDevice& device_;
2626
std::set<BindKey> bindless_bind_keys_;
2727
std::map<BindKey, std::shared_ptr<View>> direct_bindings_;
28-
std::map<std::pair<Shader*, ShaderType>, id<MTLBuffer>> argument_buffers_;
28+
id<MTLBuffer> argument_buffer_;
29+
std::map<uint32_t, id<MTLBuffer>> bindings_by_space_;
2930
};

src/FlyCube/BindingSet/MTBindingSet.mm

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,63 @@
2222
}
2323
}
2424

25+
#if defined(USE_METAL_SHADER_CONVERTER)
26+
uint32_t spaces = 0;
27+
std::map<uint32_t, uint32_t> slots;
28+
for (const auto& bind_key : layout->GetBindKeys()) {
29+
spaces = std::max(spaces, bind_key.space + 1);
30+
if (bind_key.count != kBindlessCount) {
31+
slots[bind_key.space] = std::max(slots[bind_key.space], bind_key.slot + 1);
32+
}
33+
}
34+
for (const auto& [bind_key, _] : layout->GetConstants()) {
35+
DCHECK(bind_key.count == 1);
36+
spaces = std::max(spaces, bind_key.space + 1);
37+
slots[bind_key.space] = std::max(slots[bind_key.space], bind_key.slot + 1);
38+
}
39+
40+
if (spaces > 0) {
41+
argument_buffer_ = [device_.GetDevice() newBufferWithLength:spaces * sizeof(uint64_t)
42+
options:MTLResourceStorageModeShared];
43+
uint64_t* argument_buffer_data = reinterpret_cast<uint64_t*>(argument_buffer_.contents);
44+
for (size_t i = 0; i < spaces; ++i) {
45+
if (slots[i] == 0) {
46+
continue;
47+
}
48+
bindings_by_space_[i] = [device_.GetDevice() newBufferWithLength:slots[i] * sizeof(IRDescriptorTableEntry)
49+
options:MTLResourceStorageModeShared];
50+
argument_buffer_data[i] = bindings_by_space_[i].gpuAddress;
51+
}
52+
53+
id<MTLBuffer> buffer = device_.GetBindlessArgumentBuffer().GetArgumentBuffer();
54+
for (const auto& bind_key : bindless_bind_keys_) {
55+
argument_buffer_data[bind_key.space] = buffer.gpuAddress;
56+
}
57+
}
58+
#endif
59+
2560
CreateConstantsFallbackBuffer(device_, layout->GetConstants());
61+
std::vector<BindingDesc> fallback_constants_bindings;
62+
fallback_constants_bindings.reserve(fallback_constants_buffer_views_.size());
2663
for (const auto& [bind_key, view] : fallback_constants_buffer_views_) {
27-
direct_bindings_.insert_or_assign(bind_key, view);
64+
fallback_constants_bindings.emplace_back(bind_key, view);
2865
}
66+
WriteBindings({ .bindings = fallback_constants_bindings });
2967
}
3068

3169
void MTBindingSet::WriteBindings(const WriteBindingsDesc& desc)
3270
{
71+
#if defined(USE_METAL_SHADER_CONVERTER)
3372
for (const auto& [bind_key, view] : desc.bindings) {
34-
assert(bind_key.count != kBindlessCount);
73+
IRDescriptorTableEntry* entries =
74+
static_cast<IRDescriptorTableEntry*>(bindings_by_space_[bind_key.space].contents);
75+
auto* mt_view = CastToImpl<MTView>(view);
76+
mt_view->BindView(&entries[bind_key.slot]);
77+
}
78+
#endif
79+
80+
for (const auto& [bind_key, view] : desc.bindings) {
81+
DCHECK(bind_key.count != kBindlessCount);
3582
direct_bindings_.insert_or_assign(bind_key, view);
3683
}
3784
for (const auto& [bind_key, data] : desc.constants) {
@@ -45,36 +92,22 @@
4592
id<MTLResidencySet> residency_set)
4693
{
4794
#if defined(USE_METAL_SHADER_CONVERTER)
48-
std::map<std::pair<Shader*, ShaderType>, uint64_t> argument_buffers_size;
49-
for (const auto& [bind_key, view] : direct_bindings_) {
50-
decltype(auto) shader = CastToImpl<MTPipeline>(pipeline)->GetShader(bind_key.shader_type);
51-
uint32_t offset = CastToImpl<MTShader>(shader)->GetBindingOffset({ bind_key.slot, bind_key.space });
52-
argument_buffers_size[{ shader.get(), bind_key.shader_type }] = std::max<uint64_t>(
53-
argument_buffers_size[{ shader.get(), bind_key.shader_type }], offset + 3 * sizeof(uint64_t));
54-
}
55-
for (const auto& [shader_type, size] : argument_buffers_size) {
56-
if (!argument_buffers_[shader_type]) {
57-
argument_buffers_[shader_type] = [device_.GetDevice() newBufferWithLength:size
58-
options:MTLResourceStorageModeShared];
59-
}
60-
}
6195
for (const auto& [bind_key, view] : direct_bindings_) {
6296
auto* mt_view = CastToImpl<MTView>(view);
63-
DCHECK(mt_view);
64-
decltype(auto) shader = CastToImpl<MTPipeline>(pipeline)->GetShader(bind_key.shader_type);
65-
uint32_t offset = CastToImpl<MTShader>(shader)->GetBindingOffset({ bind_key.slot, bind_key.space });
66-
auto* ptr = static_cast<uint8_t*>(argument_buffers_[{ shader.get(), bind_key.shader_type }].contents);
67-
IRDescriptorTableEntry* entry = reinterpret_cast<IRDescriptorTableEntry*>(ptr + offset);
68-
mt_view->BindView(entry);
6997
id<MTLResource> allocation = mt_view->GetAllocation();
7098
if (allocation) {
7199
[residency_set addAllocation:allocation];
72100
}
101+
102+
decltype(auto) argument_table = argument_tables.at(bind_key.shader_type);
103+
[argument_table setAddress:argument_buffer_.gpuAddress atIndex:kIRArgumentBufferBindPoint];
104+
[residency_set addAllocation:argument_buffer_];
105+
[residency_set addAllocation:bindings_by_space_[bind_key.space]];
73106
}
74-
for (const auto& [shader_type, size] : argument_buffers_size) {
75-
decltype(auto) argument_table = argument_tables.at(shader_type.second);
76-
[argument_table setAddress:argument_buffers_[shader_type].gpuAddress atIndex:kIRArgumentBufferBindPoint];
77-
[residency_set addAllocation:argument_buffers_[shader_type]];
107+
108+
if (!bindless_bind_keys_.empty()) {
109+
id<MTLBuffer> buffer = device_.GetBindlessArgumentBuffer().GetArgumentBuffer();
110+
[residency_set addAllocation:buffer];
78111
}
79112
#else
80113
for (const auto& [bind_key, view] : direct_bindings_) {

src/FlyCube/BindlessTypedViewPool/MTBindlessTypedViewPool.mm

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,13 @@
55
#include "Utilities/Check.h"
66
#include "View/MTView.h"
77

8+
#if defined(USE_METAL_SHADER_CONVERTER)
9+
#include <metal_irconverter_runtime.h>
10+
using EntryType = IRDescriptorTableEntry;
11+
#else
12+
using EntryType = uint64_t;
13+
#endif
14+
815
MTBindlessTypedViewPool::MTBindlessTypedViewPool(MTDevice& device, ViewType view_type, uint32_t view_count)
916
: view_count_(view_count)
1017
{
@@ -31,8 +38,12 @@
3138
{
3239
DCHECK(index < view_count_);
3340
auto* mt_view = CastToImpl<MTView>(view);
34-
uint64_t* arguments = static_cast<uint64_t*>(range_->GetArgumentBuffer().contents);
41+
EntryType* arguments = static_cast<EntryType*>(range_->GetArgumentBuffer().contents);
3542
const uint32_t offset = range_->GetOffset() + index;
43+
#if defined(USE_METAL_SHADER_CONVERTER)
44+
mt_view->BindView(&arguments[offset]);
45+
#else
3646
arguments[offset] = mt_view->GetGpuAddress();
47+
#endif
3748
range_->AddAllocation(offset, mt_view->GetAllocation());
3849
}

src/FlyCube/GPUDescriptorPool/MTGPUBindlessArgumentBuffer.mm

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,13 @@
22

33
#include "Device/MTDevice.h"
44

5+
#if defined(USE_METAL_SHADER_CONVERTER)
6+
#include <metal_irconverter_runtime.h>
7+
using EntryType = IRDescriptorTableEntry;
8+
#else
9+
using EntryType = uint64_t;
10+
#endif
11+
512
MTGPUBindlessArgumentBuffer::MTGPUBindlessArgumentBuffer(MTDevice& device)
613
: device_(device)
714
, residency_set_(device.CreateResidencySet())
@@ -14,10 +21,10 @@
1421
return;
1522
}
1623

17-
id<MTLBuffer> buffer = [device_.GetDevice() newBufferWithLength:req_size * sizeof(uint64_t)
24+
id<MTLBuffer> buffer = [device_.GetDevice() newBufferWithLength:req_size * sizeof(EntryType)
1825
options:MTLResourceStorageModeShared];
1926
if (size_ && buffer_) {
20-
memcpy(buffer.contents, buffer_.contents, size_ * sizeof(uint64_t));
27+
memcpy(buffer.contents, buffer_.contents, size_ * sizeof(EntryType));
2128
}
2229

2330
size_ = req_size;

src/FlyCube/HLSLCompiler/MetalShaderConverter.cpp

Lines changed: 95 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#include "HLSLCompiler/MetalShaderConverter.h"
22

3+
#include "ShaderReflection/ShaderReflection.h"
34
#include "Utilities/Logging.h"
45
#include "Utilities/NotReached.h"
56

@@ -27,12 +28,88 @@ IRShaderStage GetShaderStage(ShaderType type)
2728
}
2829
}
2930

31+
IRDescriptorRangeType GetRangeType(ViewType view_type)
32+
{
33+
switch (view_type) {
34+
case ViewType::kTexture:
35+
case ViewType::kBuffer:
36+
case ViewType::kStructuredBuffer:
37+
case ViewType::kByteAddressBuffer:
38+
case ViewType::kAccelerationStructure:
39+
return IRDescriptorRangeTypeSRV;
40+
case ViewType::kRWTexture:
41+
case ViewType::kRWBuffer:
42+
case ViewType::kRWStructuredBuffer:
43+
case ViewType::kRWByteAddressBuffer:
44+
return IRDescriptorRangeTypeUAV;
45+
case ViewType::kConstantBuffer:
46+
return IRDescriptorRangeTypeCBV;
47+
case ViewType::kSampler:
48+
return IRDescriptorRangeTypeSampler;
49+
default:
50+
NOTREACHED();
51+
}
52+
}
53+
54+
IRRootSignature* CreateIRRootSignature(const std::vector<BindKey>& bind_keys)
55+
{
56+
uint32_t spaces = 0;
57+
for (const auto& bind_key : bind_keys) {
58+
spaces = std::max(spaces, bind_key.space + 1);
59+
}
60+
61+
std::vector<std::vector<IRDescriptorRange1>> descriptor_table_ranges(spaces);
62+
for (const auto& bind_key : bind_keys) {
63+
auto& range = descriptor_table_ranges[bind_key.space].emplace_back();
64+
range.RangeType = GetRangeType(bind_key.view_type);
65+
range.NumDescriptors = bind_key.count;
66+
range.BaseShaderRegister = bind_key.slot;
67+
range.RegisterSpace = bind_key.space;
68+
range.Flags = IRDescriptorRangeFlagNone;
69+
range.OffsetInDescriptorsFromTableStart = bind_key.slot;
70+
}
71+
72+
std::vector<IRRootParameter1> root_parameters;
73+
auto add_root_table = [&](size_t range_count, IRDescriptorRange1* ranges) {
74+
IRRootDescriptorTable1 descriptor_table = {};
75+
descriptor_table.NumDescriptorRanges = range_count;
76+
descriptor_table.pDescriptorRanges = ranges;
77+
78+
IRRootParameter1& root_parameter = root_parameters.emplace_back();
79+
root_parameter.ParameterType = IRRootParameterTypeDescriptorTable;
80+
root_parameter.DescriptorTable = descriptor_table;
81+
root_parameter.ShaderVisibility = IRShaderVisibilityAll;
82+
};
83+
84+
for (auto& descriptor_table_range : descriptor_table_ranges) {
85+
add_root_table(descriptor_table_range.size(), descriptor_table_range.data());
86+
}
87+
88+
IRRootSignatureFlags root_signature_flags = static_cast<IRRootSignatureFlags>(
89+
IRRootSignatureFlagAllowInputAssemblerInputLayout | IRRootSignatureFlagDenyHullShaderRootAccess |
90+
IRRootSignatureFlagDenyDomainShaderRootAccess);
91+
92+
IRVersionedRootSignatureDescriptor root_signature_desc = {};
93+
root_signature_desc.version = IRRootSignatureVersion_1_1;
94+
root_signature_desc.desc_1_1.Flags = root_signature_flags;
95+
root_signature_desc.desc_1_1.NumParameters = root_parameters.size();
96+
root_signature_desc.desc_1_1.pParameters = root_parameters.data();
97+
98+
IRError* error = nullptr;
99+
IRRootSignature* root_signature = IRRootSignatureCreateFromDescriptor(&root_signature_desc, &error);
100+
if (!root_signature) {
101+
Logging::Println("IRRootSignatureCreateFromDescriptor failed: {}", IRErrorGetCode(error));
102+
IRErrorDestroy(error);
103+
}
104+
105+
return root_signature;
106+
}
107+
30108
} // namespace
31109

32110
std::vector<uint8_t> ConvertToMetalLibBytecode(ShaderType shader_type,
33111
const std::vector<uint8_t>& blob,
34-
std::string& entry_point,
35-
std::map<std::pair<uint32_t, uint32_t>, uint32_t>& binding_offsets)
112+
std::string& entry_point)
36113
{
37114
IRCompiler* compiler = IRCompilerCreate();
38115
IRObject* dxil_obj = IRObjectCreateFromDXIL(blob.data(), blob.size(), IRBytecodeOwnershipNone);
@@ -41,6 +118,21 @@ std::vector<uint8_t> ConvertToMetalLibBytecode(ShaderType shader_type,
41118
IRCompilerSetStageInGenerationMode(compiler, IRStageInCodeGenerationModeUseMetalVertexFetch);
42119
}
43120

121+
std::vector<BindKey> bind_keys;
122+
auto dxil_reflection = CreateShaderReflection(ShaderBlobType::kDXIL, blob.data(), blob.size());
123+
for (const auto& binding : dxil_reflection->GetBindings()) {
124+
BindKey bind_key = {
125+
.shader_type = shader_type,
126+
.view_type = binding.type,
127+
.slot = binding.slot,
128+
.space = binding.space,
129+
.count = binding.count,
130+
};
131+
bind_keys.push_back(bind_key);
132+
}
133+
IRRootSignature* root_signature = CreateIRRootSignature(bind_keys);
134+
IRCompilerSetGlobalRootSignature(compiler, root_signature);
135+
44136
IRError* error = nullptr;
45137
IRObject* metal_ir = IRCompilerAllocCompileAndLink(compiler, nullptr, dxil_obj, &error);
46138
if (!metal_ir) {
@@ -62,16 +154,9 @@ std::vector<uint8_t> ConvertToMetalLibBytecode(ShaderType shader_type,
62154
IRObjectGetReflection(metal_ir, GetShaderStage(shader_type), reflection);
63155
entry_point = IRShaderReflectionGetEntryPointFunctionName(reflection);
64156

65-
size_t resource_count = IRShaderReflectionGetResourceCount(reflection);
66-
std::vector<IRResourceLocation> resources(resource_count);
67-
IRShaderReflectionGetResourceLocations(reflection, resources.data());
68-
for (const auto& resource : resources) {
69-
binding_offsets[{ resource.slot, resource.space }] = resource.topLevelOffset;
70-
assert(resource.sizeBytes == 3 * sizeof(uint64_t));
71-
}
72-
73157
IRShaderReflectionDestroy(reflection);
74158
IRMetalLibBinaryDestroy(metal_lib);
159+
IRRootSignatureDestroy(root_signature);
75160
IRObjectDestroy(metal_ir);
76161
IRObjectDestroy(dxil_obj);
77162
IRCompilerDestroy(compiler);

src/FlyCube/HLSLCompiler/MetalShaderConverter.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,4 @@
33

44
std::vector<uint8_t> ConvertToMetalLibBytecode(ShaderType shader_type,
55
const std::vector<uint8_t>& blob,
6-
std::string& entry_point,
7-
std::map<std::pair<uint32_t, uint32_t>, uint32_t>& binding_offsets);
6+
std::string& entry_point);

src/FlyCube/Shader/MTShader.h

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,15 @@ class MTShader : public ShaderBase {
1313
public:
1414
MTShader(MTDevice& device, const std::vector<uint8_t>& blob, ShaderBlobType blob_type, ShaderType shader_type);
1515

16-
#if defined(USE_METAL_SHADER_CONVERTER)
17-
uint32_t GetBindingOffset(const std::pair<uint32_t, uint32_t>& slot_space) const;
18-
#else
16+
#if !defined(USE_METAL_SHADER_CONVERTER)
1917
uint32_t GetIndex(BindKey bind_key) const;
2018
#endif
2119

2220
MTL4LibraryFunctionDescriptor* GetFunctionDescriptor();
2321

2422
private:
2523
MTL4LibraryFunctionDescriptor* function_descriptor_ = nullptr;
26-
#if defined(USE_METAL_SHADER_CONVERTER)
27-
std::map<std::pair<uint32_t, uint32_t>, uint32_t> binding_offsets_;
28-
#else
24+
#if !defined(USE_METAL_SHADER_CONVERTER)
2925
std::map<BindKey, uint32_t> slot_remapping_;
3026
#endif
3127
};

src/FlyCube/Shader/MTShader.mm

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
{
1515
#if defined(USE_METAL_SHADER_CONVERTER)
1616
std::string entry_point;
17-
auto metal_lib_bytecode = ConvertToMetalLibBytecode(shader_type, blob, entry_point, binding_offsets_);
17+
auto metal_lib_bytecode = ConvertToMetalLibBytecode(shader_type, blob, entry_point);
1818
dispatch_data_t metal_lib_data = dispatch_data_create(metal_lib_bytecode.data(), metal_lib_bytecode.size(), nullptr,
1919
DISPATCH_DATA_DESTRUCTOR_DEFAULT);
2020
NSError* error = nullptr;
@@ -40,12 +40,7 @@
4040
function_descriptor_.name = [NSString stringWithUTF8String:entry_point.c_str()];
4141
}
4242

43-
#if defined(USE_METAL_SHADER_CONVERTER)
44-
uint32_t MTShader::GetBindingOffset(const std::pair<uint32_t, uint32_t>& slot_space) const
45-
{
46-
return binding_offsets_.at(slot_space);
47-
}
48-
#else
43+
#if !defined(USE_METAL_SHADER_CONVERTER)
4944
uint32_t MTShader::GetIndex(BindKey bind_key) const
5045
{
5146
return slot_remapping_.at(bind_key);

0 commit comments

Comments
 (0)