11#include " HLSLCompiler/MetalShaderConverter.h"
22
3+ #include " ShaderReflection/ShaderReflection.h"
34#include " Utilities/Logging.h"
45#include " Utilities/NotReached.h"
56
@@ -27,12 +28,88 @@ IRShaderStage GetShaderStage(ShaderType type)
2728 }
2829}
2930
31+ IRDescriptorRangeType GetRangeType (ViewType view_type)
32+ {
33+ switch (view_type) {
34+ case ViewType::kTexture :
35+ case ViewType::kBuffer :
36+ case ViewType::kStructuredBuffer :
37+ case ViewType::kByteAddressBuffer :
38+ case ViewType::kAccelerationStructure :
39+ return IRDescriptorRangeTypeSRV;
40+ case ViewType::kRWTexture :
41+ case ViewType::kRWBuffer :
42+ case ViewType::kRWStructuredBuffer :
43+ case ViewType::kRWByteAddressBuffer :
44+ return IRDescriptorRangeTypeUAV;
45+ case ViewType::kConstantBuffer :
46+ return IRDescriptorRangeTypeCBV;
47+ case ViewType::kSampler :
48+ return IRDescriptorRangeTypeSampler;
49+ default :
50+ NOTREACHED ();
51+ }
52+ }
53+
54+ IRRootSignature* CreateIRRootSignature (const std::vector<BindKey>& bind_keys)
55+ {
56+ uint32_t spaces = 0 ;
57+ for (const auto & bind_key : bind_keys) {
58+ spaces = std::max (spaces, bind_key.space + 1 );
59+ }
60+
61+ std::vector<std::vector<IRDescriptorRange1>> descriptor_table_ranges (spaces);
62+ for (const auto & bind_key : bind_keys) {
63+ auto & range = descriptor_table_ranges[bind_key.space ].emplace_back ();
64+ range.RangeType = GetRangeType (bind_key.view_type );
65+ range.NumDescriptors = bind_key.count ;
66+ range.BaseShaderRegister = bind_key.slot ;
67+ range.RegisterSpace = bind_key.space ;
68+ range.Flags = IRDescriptorRangeFlagNone;
69+ range.OffsetInDescriptorsFromTableStart = bind_key.slot ;
70+ }
71+
72+ std::vector<IRRootParameter1> root_parameters;
73+ auto add_root_table = [&](size_t range_count, IRDescriptorRange1* ranges) {
74+ IRRootDescriptorTable1 descriptor_table = {};
75+ descriptor_table.NumDescriptorRanges = range_count;
76+ descriptor_table.pDescriptorRanges = ranges;
77+
78+ IRRootParameter1& root_parameter = root_parameters.emplace_back ();
79+ root_parameter.ParameterType = IRRootParameterTypeDescriptorTable;
80+ root_parameter.DescriptorTable = descriptor_table;
81+ root_parameter.ShaderVisibility = IRShaderVisibilityAll;
82+ };
83+
84+ for (auto & descriptor_table_range : descriptor_table_ranges) {
85+ add_root_table (descriptor_table_range.size (), descriptor_table_range.data ());
86+ }
87+
88+ IRRootSignatureFlags root_signature_flags = static_cast <IRRootSignatureFlags>(
89+ IRRootSignatureFlagAllowInputAssemblerInputLayout | IRRootSignatureFlagDenyHullShaderRootAccess |
90+ IRRootSignatureFlagDenyDomainShaderRootAccess);
91+
92+ IRVersionedRootSignatureDescriptor root_signature_desc = {};
93+ root_signature_desc.version = IRRootSignatureVersion_1_1;
94+ root_signature_desc.desc_1_1 .Flags = root_signature_flags;
95+ root_signature_desc.desc_1_1 .NumParameters = root_parameters.size ();
96+ root_signature_desc.desc_1_1 .pParameters = root_parameters.data ();
97+
98+ IRError* error = nullptr ;
99+ IRRootSignature* root_signature = IRRootSignatureCreateFromDescriptor (&root_signature_desc, &error);
100+ if (!root_signature) {
101+ Logging::Println (" IRRootSignatureCreateFromDescriptor failed: {}" , IRErrorGetCode (error));
102+ IRErrorDestroy (error);
103+ }
104+
105+ return root_signature;
106+ }
107+
30108} // namespace
31109
32110std::vector<uint8_t > ConvertToMetalLibBytecode (ShaderType shader_type,
33111 const std::vector<uint8_t >& blob,
34- std::string& entry_point,
35- std::map<std::pair<uint32_t , uint32_t >, uint32_t >& binding_offsets)
112+ std::string& entry_point)
36113{
37114 IRCompiler* compiler = IRCompilerCreate ();
38115 IRObject* dxil_obj = IRObjectCreateFromDXIL (blob.data (), blob.size (), IRBytecodeOwnershipNone);
@@ -41,6 +118,21 @@ std::vector<uint8_t> ConvertToMetalLibBytecode(ShaderType shader_type,
41118 IRCompilerSetStageInGenerationMode (compiler, IRStageInCodeGenerationModeUseMetalVertexFetch);
42119 }
43120
121+ std::vector<BindKey> bind_keys;
122+ auto dxil_reflection = CreateShaderReflection (ShaderBlobType::kDXIL , blob.data (), blob.size ());
123+ for (const auto & binding : dxil_reflection->GetBindings ()) {
124+ BindKey bind_key = {
125+ .shader_type = shader_type,
126+ .view_type = binding.type ,
127+ .slot = binding.slot ,
128+ .space = binding.space ,
129+ .count = binding.count ,
130+ };
131+ bind_keys.push_back (bind_key);
132+ }
133+ IRRootSignature* root_signature = CreateIRRootSignature (bind_keys);
134+ IRCompilerSetGlobalRootSignature (compiler, root_signature);
135+
44136 IRError* error = nullptr ;
45137 IRObject* metal_ir = IRCompilerAllocCompileAndLink (compiler, nullptr , dxil_obj, &error);
46138 if (!metal_ir) {
@@ -62,16 +154,9 @@ std::vector<uint8_t> ConvertToMetalLibBytecode(ShaderType shader_type,
62154 IRObjectGetReflection (metal_ir, GetShaderStage (shader_type), reflection);
63155 entry_point = IRShaderReflectionGetEntryPointFunctionName (reflection);
64156
65- size_t resource_count = IRShaderReflectionGetResourceCount (reflection);
66- std::vector<IRResourceLocation> resources (resource_count);
67- IRShaderReflectionGetResourceLocations (reflection, resources.data ());
68- for (const auto & resource : resources) {
69- binding_offsets[{ resource.slot , resource.space }] = resource.topLevelOffset ;
70- assert (resource.sizeBytes == 3 * sizeof (uint64_t ));
71- }
72-
73157 IRShaderReflectionDestroy (reflection);
74158 IRMetalLibBinaryDestroy (metal_lib);
159+ IRRootSignatureDestroy (root_signature);
75160 IRObjectDestroy (metal_ir);
76161 IRObjectDestroy (dxil_obj);
77162 IRCompilerDestroy (compiler);
0 commit comments