Skip to content

Commit 87bcb17

Browse files
committed
fixup! stencils: Support mixed boundaries, add Clamp and LinearExtrapolate
1 parent c5f5df7 commit 87bcb17

File tree

2 files changed

+100
-86
lines changed

2 files changed

+100
-86
lines changed

src/Dagger.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,8 @@ end
4646

4747
import MacroTools: @capture, prewalk
4848

49+
import KernelAbstractions, Adapt
50+
4951
include("lib/util.jl")
5052
include("utils/dagdebug.jl")
5153

@@ -122,8 +124,6 @@ include("array/mul.jl")
122124
include("array/cholesky.jl")
123125
include("array/lu.jl")
124126

125-
import KernelAbstractions, Adapt
126-
127127
# GPU
128128
include("gpu.jl")
129129

src/stencil.jl

Lines changed: 98 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -180,6 +180,34 @@ boundary_has_transition(::Clamp) = true
180180
boundary_transition(::Clamp, idx, size) =
181181
CartesianIndex(ntuple(i -> clamp(idx[i], 1, size[i]), length(size)))
182182

183+
KernelAbstractions.@kernel function load_boundary_region_kernel(::Clamp, result, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where N
184+
raw_idx = KernelAbstractions.@index(Global)
185+
186+
# Convert linear index to Cartesian index
187+
idx = CartesianIndices(result)[raw_idx]
188+
189+
# Compute source index for each dimension
190+
src_idx = CartesianIndex(ntuple(N) do i
191+
nd = get_neigh_dist(neigh_dist, i)
192+
if boundary_dims[i] && region_code[i] == -1
193+
# Low boundary - clamp to first element
194+
firstindex(arr, i)
195+
elseif boundary_dims[i] && region_code[i] == +1
196+
# High boundary - clamp to last element
197+
lastindex(arr, i)
198+
elseif region_code[i] == -1
199+
# Not at boundary but loading from low side of neighbor
200+
lastindex(arr, i) - nd + idx[i]
201+
elseif region_code[i] == +1
202+
# Not at boundary but loading from high side of neighbor
203+
firstindex(arr, i) + idx[i] - 1
204+
else
205+
# Full extent
206+
idx[i]
207+
end
208+
end)
209+
result[idx] = arr[src_idx]
210+
end
183211
function load_boundary_region(::Clamp, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where N
184212
# Compute the size of this halo region
185213
region_size = ntuple(N) do i
@@ -188,29 +216,7 @@ function load_boundary_region(::Clamp, arr, region_code::NTuple{N,Int}, neigh_di
188216

189217
result = similar(arr, region_size)
190218

191-
for idx in CartesianIndices(result)
192-
# Compute source index for each dimension
193-
src_idx = CartesianIndex(ntuple(N) do i
194-
nd = get_neigh_dist(neigh_dist, i)
195-
if boundary_dims[i] && region_code[i] == -1
196-
# Low boundary - clamp to first element
197-
firstindex(arr, i)
198-
elseif boundary_dims[i] && region_code[i] == +1
199-
# High boundary - clamp to last element
200-
lastindex(arr, i)
201-
elseif region_code[i] == -1
202-
# Not at boundary but loading from low side of neighbor
203-
lastindex(arr, i) - nd + idx[i]
204-
elseif region_code[i] == +1
205-
# Not at boundary but loading from high side of neighbor
206-
firstindex(arr, i) + idx[i] - 1
207-
else
208-
# Full extent
209-
idx[i]
210-
end
211-
end)
212-
result[idx] = arr[src_idx]
213-
end
219+
Kernel(load_boundary_region_kernel)(Clamp(), result, arr, region_code, neigh_dist, boundary_dims; ndrange=size(result))
214220

215221
return move(task_processor(), result)
216222
end
@@ -244,6 +250,66 @@ boundary_has_transition(::LinearExtrapolate) = true
244250
boundary_transition(::LinearExtrapolate, idx, size) =
245251
CartesianIndex(ntuple(i -> clamp(idx[i], 1, size[i]), length(size)))
246252

253+
KernelAbstractions.@kernel function load_boundary_region_kernel(::LinearExtrapolate, result, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}, ::Val{extrap_dim}, ::Val{nd}) where {N,extrap_dim,nd}
254+
raw_idx = KernelAbstractions.@index(Global)
255+
256+
# Convert linear index to Cartesian index
257+
idx = CartesianIndices(result)[raw_idx]
258+
259+
if extrap_dim == 0
260+
# No boundary dimensions - normal neighbor access
261+
src_idx = CartesianIndex(ntuple(Val(N)) do i
262+
ndi = get_neigh_dist(neigh_dist, i)::Int
263+
if region_code[i] == -1
264+
lastindex(arr, i) - ndi + idx[i]
265+
elseif region_code[i] == +1
266+
firstindex(arr, i) + idx[i] - 1
267+
else
268+
idx[i]
269+
end
270+
end)
271+
result[idx] = arr[src_idx]
272+
else
273+
# Extrapolate along extrap_dim, clamp other boundary dimensions
274+
#nd = get_neigh_dist(neigh_dist, extrap_dim)::Int
275+
276+
# Compute base index (for other dimensions, clamp if at boundary)
277+
base_idx = ntuple(Val(N)) do i
278+
ndi = get_neigh_dist(neigh_dist, i)
279+
if i == extrap_dim
280+
# Will be set for slope computation
281+
region_code[i] == -1 ? firstindex(arr, i) : lastindex(arr, i)
282+
elseif boundary_dims[i] && region_code[i] == -1
283+
firstindex(arr, i)
284+
elseif boundary_dims[i] && region_code[i] == +1
285+
lastindex(arr, i)
286+
elseif region_code[i] == -1
287+
lastindex(arr, i) - ndi + idx[i]
288+
elseif region_code[i] == +1
289+
firstindex(arr, i) + idx[i] - 1
290+
else
291+
idx[i]
292+
end
293+
end
294+
295+
# Compute slope at boundary
296+
if region_code[extrap_dim] == -1
297+
# Low boundary: slope = arr[2] - arr[1]
298+
idx1 = ntuple(i -> i == extrap_dim ? firstindex(arr, i) : base_idx[i], Val(N))
299+
idx2 = ntuple(i -> i == extrap_dim ? firstindex(arr, i) + 1 : base_idx[i], Val(N))
300+
slope = arr[CartesianIndex(idx2)] - arr[CartesianIndex(idx1)]
301+
dist = -(nd - idx[extrap_dim] + 1)
302+
result[idx] = arr[CartesianIndex(idx1)] + slope * dist
303+
else
304+
# High boundary: slope = arr[end] - arr[end-1]
305+
idx1 = ntuple(i -> i == extrap_dim ? lastindex(arr, i) - 1 : base_idx[i], Val(N))
306+
idx2 = ntuple(i -> i == extrap_dim ? lastindex(arr, i) : base_idx[i], Val(N))
307+
slope = arr[CartesianIndex(idx2)] - arr[CartesianIndex(idx1)]
308+
dist = idx[extrap_dim]
309+
result[idx] = arr[CartesianIndex(idx2)] + slope * dist
310+
end
311+
end
312+
end
247313
function load_boundary_region(::LinearExtrapolate, arr::AbstractArray{T}, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where {T<:Real,N}
248314
# Compute the size of this halo region
249315
region_size = ntuple(N) do i
@@ -252,70 +318,18 @@ function load_boundary_region(::LinearExtrapolate, arr::AbstractArray{T}, region
252318

253319
result = similar(arr, region_size)
254320

255-
for idx in CartesianIndices(result)
256-
# Find the first boundary dimension that needs extrapolation
257-
extrap_dim = 0
258-
for d in 1:N
259-
if boundary_dims[d] && region_code[d] != 0
260-
extrap_dim = d
261-
break
262-
end
321+
# Find the first boundary dimension that needs extrapolation
322+
extrap_dim = 0
323+
for d in 1:N
324+
if boundary_dims[d] && region_code[d] != 0
325+
extrap_dim = d
326+
break
263327
end
328+
end
264329

265-
if extrap_dim == 0
266-
# No boundary dimensions - normal neighbor access
267-
src_idx = CartesianIndex(ntuple(N) do i
268-
nd = get_neigh_dist(neigh_dist, i)
269-
if region_code[i] == -1
270-
lastindex(arr, i) - nd + idx[i]
271-
elseif region_code[i] == +1
272-
firstindex(arr, i) + idx[i] - 1
273-
else
274-
idx[i]
275-
end
276-
end)
277-
result[idx] = arr[src_idx]
278-
else
279-
# Extrapolate along extrap_dim, clamp other boundary dimensions
280-
nd = get_neigh_dist(neigh_dist, extrap_dim)
281-
282-
# Compute base index (for other dimensions, clamp if at boundary)
283-
base_idx = ntuple(N) do i
284-
ndi = get_neigh_dist(neigh_dist, i)
285-
if i == extrap_dim
286-
# Will be set for slope computation
287-
region_code[i] == -1 ? firstindex(arr, i) : lastindex(arr, i)
288-
elseif boundary_dims[i] && region_code[i] == -1
289-
firstindex(arr, i)
290-
elseif boundary_dims[i] && region_code[i] == +1
291-
lastindex(arr, i)
292-
elseif region_code[i] == -1
293-
lastindex(arr, i) - ndi + idx[i]
294-
elseif region_code[i] == +1
295-
firstindex(arr, i) + idx[i] - 1
296-
else
297-
idx[i]
298-
end
299-
end
330+
nd = get_neigh_dist(neigh_dist, extrap_dim)
300331

301-
# Compute slope at boundary
302-
if region_code[extrap_dim] == -1
303-
# Low boundary: slope = arr[2] - arr[1]
304-
idx1 = ntuple(i -> i == extrap_dim ? firstindex(arr, i) : base_idx[i], N)
305-
idx2 = ntuple(i -> i == extrap_dim ? firstindex(arr, i) + 1 : base_idx[i], N)
306-
slope = arr[CartesianIndex(idx2)] - arr[CartesianIndex(idx1)]
307-
dist = -(nd - idx[extrap_dim] + 1)
308-
result[idx] = arr[CartesianIndex(idx1)] + slope * dist
309-
else
310-
# High boundary: slope = arr[end] - arr[end-1]
311-
idx1 = ntuple(i -> i == extrap_dim ? lastindex(arr, i) - 1 : base_idx[i], N)
312-
idx2 = ntuple(i -> i == extrap_dim ? lastindex(arr, i) : base_idx[i], N)
313-
slope = arr[CartesianIndex(idx2)] - arr[CartesianIndex(idx1)]
314-
dist = idx[extrap_dim]
315-
result[idx] = arr[CartesianIndex(idx2)] + slope * dist
316-
end
317-
end
318-
end
332+
Kernel(load_boundary_region_kernel)(LinearExtrapolate(), result, arr, region_code, neigh_dist, boundary_dims, Val(extrap_dim), Val(nd); ndrange=size(result))
319333

320334
return move(task_processor(), result)
321335
end

0 commit comments

Comments
 (0)