Skip to content

Commit cda46bd

Browse files
committed
Add ByteAddressBuffer/RWByteAddressBuffer support
1 parent d3f8eab commit cda46bd

File tree

13 files changed

+105
-16
lines changed

13 files changed

+105
-16
lines changed

assets/shaders/BufferViewTest/PixelShader.hlsl

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ struct VsOutput {
44
};
55

66
static const uint kIndexCount = 64;
7-
static const uint kRows = 4;
7+
static const uint kRows = 6;
88
static const uint kColumns = 4;
99
static const float4 kPass = float4(0.0, 1.0, 0.0, 1.0);
1010
static const float4 kFail = float4(1.0, 0.0, 0.0, 1.0);
@@ -180,6 +180,28 @@ float4 RunRWStructuredBufferTest4() {
180180
return kPass;
181181
}
182182

183+
ByteAddressBuffer byte_address_buffer;
184+
185+
float4 RunByteAddressBufferTest() {
186+
for (uint u = 0; u < kIndexCount; ++u) {
187+
if (byte_address_buffer.Load(u * 4) != MakeU32(u)) {
188+
return kFail;
189+
}
190+
}
191+
return kPass;
192+
}
193+
194+
RWByteAddressBuffer rwbyte_address_buffer;
195+
196+
float4 RunRWByteAddressBufferTest() {
197+
for (uint u = 0; u < kIndexCount; ++u) {
198+
if (rwbyte_address_buffer.Load(u * 4) != MakeU32(u)) {
199+
return kFail;
200+
}
201+
}
202+
return kPass;
203+
}
204+
183205
float4 RunTest(uint index) {
184206
switch (index) {
185207
case 0:
@@ -214,6 +236,16 @@ float4 RunTest(uint index) {
214236
return RunRWStructuredBufferTest3();
215237
case 15:
216238
return RunRWStructuredBufferTest4();
239+
case 16:
240+
case 17:
241+
case 18:
242+
case 19:
243+
return RunByteAddressBufferTest();
244+
case 20:
245+
case 21:
246+
case 22:
247+
case 23:
248+
return RunRWByteAddressBufferTest();
217249
default:
218250
return kIgnore;
219251
}

src/Apps/BufferViewTest/main.cpp

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
#include "AppSettings/ArgsParser.h"
44
#include "Instance/Instance.h"
55
#include "Utilities/Asset.h"
6+
#include "Utilities/Check.h"
67
#include "Utilities/NotReached.h"
78

89
#include <array>
@@ -33,13 +34,20 @@ std::string GetBufferPrefix(ViewType view_type)
3334
return "structured_buffer";
3435
case ViewType::kRWStructuredBuffer:
3536
return "rwstructured_buffer";
37+
case ViewType::kByteAddressBuffer:
38+
return "byte_address_buffer";
39+
case ViewType::kRWByteAddressBuffer:
40+
return "rwbyte_address_buffer";
3641
default:
3742
NOTREACHED();
3843
}
3944
}
4045

4146
bool IsSupported(ViewType view_type, gli::format format)
4247
{
48+
if (view_type == ViewType::kByteAddressBuffer || view_type == ViewType::kRWByteAddressBuffer) {
49+
return format == gli::format::FORMAT_R32_UINT_PACK32;
50+
}
4351
if (format == gli::format::FORMAT_RGB32_UINT_PACK32) {
4452
return view_type != ViewType::kBuffer && view_type != ViewType::kRWBuffer;
4553
}
@@ -110,6 +118,8 @@ BufferViewTestRenderer::BufferViewTestRenderer(const Settings& settings)
110118
ViewType::kRWBuffer,
111119
ViewType::kStructuredBuffer,
112120
ViewType::kRWStructuredBuffer,
121+
ViewType::kByteAddressBuffer,
122+
ViewType::kRWByteAddressBuffer,
113123
});
114124
const auto formats = std::to_array({
115125
gli::format::FORMAT_R32_UINT_PACK32,
@@ -131,7 +141,8 @@ BufferViewTestRenderer::BufferViewTestRenderer(const Settings& settings)
131141
if (buffer_desc.size % structure_stride != 0) {
132142
buffer_desc.size += structure_stride - (buffer_desc.size % structure_stride);
133143
}
134-
if (view_type == ViewType::kBuffer || view_type == ViewType::kStructuredBuffer) {
144+
if (view_type == ViewType::kBuffer || view_type == ViewType::kStructuredBuffer ||
145+
view_type == ViewType::kByteAddressBuffer) {
135146
buffer_desc.usage = BindFlag::kShaderResource;
136147
} else {
137148
buffer_desc.usage = BindFlag::kCopySource;
@@ -144,7 +155,8 @@ BufferViewTestRenderer::BufferViewTestRenderer(const Settings& settings)
144155
}
145156
buffer->Unmap();
146157

147-
if (view_type == ViewType::kRWBuffer || view_type == ViewType::kRWStructuredBuffer) {
158+
if (view_type == ViewType::kRWBuffer || view_type == ViewType::kRWStructuredBuffer ||
159+
view_type == ViewType::kRWByteAddressBuffer) {
148160
upload_buffers_.push_back(std::move(buffer));
149161
buffer_desc.usage = BindFlag::kUnorderedAccess | BindFlag::kCopyDest;
150162
buffer = device_->CreateBuffer(MemoryType::kDefault, buffer_desc);
@@ -168,9 +180,12 @@ BufferViewTestRenderer::BufferViewTestRenderer(const Settings& settings)
168180
}
169181
std::shared_ptr<View> buffer_view = device_->CreateView(buffer, buffer_view_desc);
170182

171-
std::string bind_key_name = GetBufferPrefix(view_type) + "_uint";
172-
bind_key_name += std::to_string(structure_stride / 4);
183+
std::string bind_key_name = GetBufferPrefix(view_type);
184+
if (view_type != ViewType::kByteAddressBuffer && view_type != ViewType::kRWByteAddressBuffer) {
185+
bind_key_name += "_uint" + std::to_string(structure_stride / 4);
186+
}
173187
BindKey bind_key = pixel_shader_->GetBindKey(bind_key_name);
188+
DCHECK(bind_key.view_type == view_type);
174189
binding_set_layout_desc.bind_keys.push_back(bind_key);
175190
write_bindings_desc.bindings.emplace_back(bind_key, buffer_view);
176191

src/FlyCube/BindingSet/MTBindingSet.mm

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ void SetView(id<MTL4ArgumentTable> argument_table, const std::shared_ptr<MTView>
3737
case ViewType::kConstantBuffer:
3838
case ViewType::kStructuredBuffer:
3939
case ViewType::kRWStructuredBuffer:
40+
case ViewType::kByteAddressBuffer:
41+
case ViewType::kRWByteAddressBuffer:
4042
SetBuffer(argument_table, view->GetBuffer(), view->GetViewDesc().offset, index);
4143
break;
4244
case ViewType::kSampler:

src/FlyCube/BindingSetLayout/DXBindingSetLayout.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,11 +45,13 @@ D3D12_DESCRIPTOR_RANGE_TYPE GetRangeType(ViewType view_type)
4545
case ViewType::kTexture:
4646
case ViewType::kBuffer:
4747
case ViewType::kStructuredBuffer:
48+
case ViewType::kByteAddressBuffer:
4849
case ViewType::kAccelerationStructure:
4950
return D3D12_DESCRIPTOR_RANGE_TYPE_SRV;
5051
case ViewType::kRWTexture:
5152
case ViewType::kRWBuffer:
5253
case ViewType::kRWStructuredBuffer:
54+
case ViewType::kRWByteAddressBuffer:
5355
return D3D12_DESCRIPTOR_RANGE_TYPE_UAV;
5456
case ViewType::kConstantBuffer:
5557
return D3D12_DESCRIPTOR_RANGE_TYPE_CBV;

src/FlyCube/BindingSetLayout/VKBindingSetLayout.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,10 @@ vk::DescriptorType GetDescriptorType(ViewType view_type)
2424
return vk::DescriptorType::eStorageBuffer;
2525
case ViewType::kRWStructuredBuffer:
2626
return vk::DescriptorType::eStorageBuffer;
27+
case ViewType::kByteAddressBuffer:
28+
return vk::DescriptorType::eStorageBuffer;
29+
case ViewType::kRWByteAddressBuffer:
30+
return vk::DescriptorType::eStorageBuffer;
2731
case ViewType::kAccelerationStructure:
2832
return vk::DescriptorType::eAccelerationStructureKHR;
2933
default:

src/FlyCube/CPUDescriptorPool/DXCPUDescriptorPool.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ DXCPUDescriptorPoolTyped& DXCPUDescriptorPool::SelectHeap(ViewType view_type)
3131
case ViewType::kRWBuffer:
3232
case ViewType::kStructuredBuffer:
3333
case ViewType::kRWStructuredBuffer:
34+
case ViewType::kByteAddressBuffer:
35+
case ViewType::kRWByteAddressBuffer:
3436
return resource_;
3537
case ViewType::kSampler:
3638
return sampler_;

src/FlyCube/Instance/BaseTypes.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ enum class ViewType {
9898
kRWBuffer,
9999
kStructuredBuffer,
100100
kRWStructuredBuffer,
101+
kByteAddressBuffer,
102+
kRWByteAddressBuffer,
101103
kAccelerationStructure,
102104
kShadingRateSource,
103105
kRenderTarget,

src/FlyCube/ShaderReflection/DXReflection.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,10 @@ ViewType GetViewType(const D3D12_SHADER_INPUT_BIND_DESC& bind_desc)
6262
return ViewType::kRWTexture;
6363
}
6464
}
65+
case D3D_SIT_BYTEADDRESS:
66+
return ViewType::kByteAddressBuffer;
67+
case D3D_SIT_UAV_RWBYTEADDRESS:
68+
return ViewType::kRWByteAddressBuffer;
6569
default:
6670
NOTREACHED();
6771
}

src/FlyCube/ShaderReflection/SPIRVReflection.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,18 @@ ViewType GetViewType(const spirv_cross::Compiler& compiler, const spirv_cross::S
156156
if (type.storage == spv::StorageClassStorageBuffer) {
157157
spirv_cross::Bitset flags = compiler.get_buffer_block_flags(resource_id);
158158
bool is_readonly = flags.get(spv::DecorationNonWritable);
159-
if (is_readonly) {
160-
return ViewType::kStructuredBuffer;
159+
bool is_byte_address_buffer = false;
160+
if (compiler.has_decoration(resource_id, spv::DecorationUserTypeGOOGLE)) {
161+
decltype(auto) user_type = compiler.get_decoration_string(resource_id, spv::DecorationUserTypeGOOGLE);
162+
is_byte_address_buffer = user_type.find("byteaddressbuffer") != std::string::npos;
161163
} else {
162-
return ViewType::kRWStructuredBuffer;
164+
is_byte_address_buffer =
165+
compiler.get_name(type.parent_type).find("ByteAddressBuffer") != std::string::npos;
166+
}
167+
if (is_byte_address_buffer) {
168+
return is_readonly ? ViewType::kByteAddressBuffer : ViewType::kRWByteAddressBuffer;
163169
}
170+
return is_readonly ? ViewType::kStructuredBuffer : ViewType::kRWStructuredBuffer;
164171
} else if (type.storage == spv::StorageClassPushConstant || type.storage == spv::StorageClassUniform) {
165172
return ViewType::kConstantBuffer;
166173
}

src/FlyCube/View/DXBindlessTypedViewPool.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ DXBindlessTypedViewPool::DXBindlessTypedViewPool(DXDevice& device, ViewType view
1616
case ViewType::kRWBuffer:
1717
case ViewType::kStructuredBuffer:
1818
case ViewType::kRWStructuredBuffer:
19+
case ViewType::kByteAddressBuffer:
20+
case ViewType::kRWByteAddressBuffer:
1921
case ViewType::kAccelerationStructure: {
2022
range_ = std::make_shared<DXGPUDescriptorPoolRange>(
2123
device.GetGPUDescriptorPool().Allocate(D3D12_DESCRIPTOR_HEAP_TYPE_CBV_SRV_UAV, view_count));

0 commit comments

Comments
 (0)