Skip to content

Commit a5387ca

Browse files
authored
Merge pull request #676 from JuliaParallel/jps/stencil-reflect
stencils: Add Reflect boundary condition
2 parents d0081e7 + 951f661 commit a5387ca

File tree

3 files changed

+234
-6
lines changed

3 files changed

+234
-6
lines changed

docs/src/stencils.md

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ The fundamental structure of a `@stencil` block involves iterating over an impli
88

99
```julia
1010
using Dagger
11-
import Dagger: @stencil, Wrap, Pad
11+
import Dagger: @stencil, Wrap, Pad, Reflect
1212

1313
# Initialize a DArray
1414
A = zeros(Blocks(2, 2), Int, 4, 4)
@@ -35,6 +35,9 @@ The true power of stencils comes from accessing neighboring elements. The `@neig
3535
- `boundary_condition`: Defines how to handle accesses beyond the array boundaries. Available conditions are:
3636
- `Wrap()`: Wraps around to the other side of the array.
3737
- `Pad(value)`: Pads with a specified `value`.
38+
- `Reflect(symmetric)`: Reflects values back into the array at boundaries. The `symmetric` boolean controls whether the edge element is included in the reflection:
39+
- `Reflect(true)` (symmetric): Edge element IS repeated. For array `[a,b,c,d]`, extends as `[...,c,b,a,a,b,c,d,d,c,b,...]`.
40+
- `Reflect(false)` (mirror): Edge element NOT repeated. For array `[a,b,c,d]`, extends as `[...,d,c,b,a,b,c,d,c,b,a,...]`.
3841

3942
### Example: Averaging Neighbors with `Wrap`
4043

@@ -93,6 +96,59 @@ expected_B_padded = [
9396
@assert collect(B) == expected_B_padded
9497
```
9598

99+
### Example: Smoothing with `Reflect`
100+
101+
The `Reflect` boundary condition mirrors values at the edges, which is useful for operations like smoothing or image processing where you want to avoid artificial discontinuities at boundaries.
102+
103+
#### Symmetric Reflection (`Reflect(true)`)
104+
105+
With symmetric reflection, the edge element is repeated in the reflection:
106+
107+
```julia
108+
import Dagger: Reflect
109+
110+
# Simple 1D example to illustrate symmetric reflection
111+
# Array [1, 2, 3, 4] extends as [..., 3, 2, 1, 1, 2, 3, 4, 4, 3, 2, ...]
112+
# ^edge^ ^edge^
113+
A = DArray([1, 2, 3, 4], Blocks(2))
114+
B = zeros(Blocks(2), Int, 4)
115+
116+
Dagger.spawn_datadeps() do
117+
@stencil begin
118+
B[idx] = sum(@neighbors(A[idx], 1, Reflect(true)))
119+
end
120+
end
121+
122+
# B[1]: indices 0,1,2 -> 0 reflects to 1, so [1,1,2] = 4
123+
# B[2]: indices 1,2,3 -> all in bounds, [1,2,3] = 6
124+
# B[3]: indices 2,3,4 -> all in bounds, [2,3,4] = 9
125+
# B[4]: indices 3,4,5 -> 5 reflects to 4, so [3,4,4] = 11
126+
@assert collect(B) == [4, 6, 9, 11]
127+
```
128+
129+
#### Mirror Reflection (`Reflect(false)`)
130+
131+
With mirror reflection, the edge element is NOT repeated:
132+
133+
```julia
134+
# Array [1, 2, 3, 4] extends as [..., 4, 3, 2, 1, 2, 3, 4, 3, 2, 1, ...]
135+
# ^edge^ ^edge^
136+
A = DArray([1, 2, 3, 4], Blocks(2))
137+
B = zeros(Blocks(2), Int, 4)
138+
139+
Dagger.spawn_datadeps() do
140+
@stencil begin
141+
B[idx] = sum(@neighbors(A[idx], 1, Reflect(false)))
142+
end
143+
end
144+
145+
# B[1]: indices 0,1,2 -> 0 reflects to 2, so [2,1,2] = 5
146+
# B[2]: indices 1,2,3 -> all in bounds, [1,2,3] = 6
147+
# B[3]: indices 2,3,4 -> all in bounds, [2,3,4] = 9
148+
# B[4]: indices 3,4,5 -> 5 reflects to 3, so [3,4,3] = 10
149+
@assert collect(B) == [5, 6, 9, 10]
150+
```
151+
96152
## Sequential Semantics
97153

98154
Expressions within a `@stencil` block are executed sequentially in terms of their effect on the data. This means that the result of one statement is visible to the subsequent statements, as if they were applied "all at once" across all indices before the next statement begins.

src/stencil.jl

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@ end
2121
get_neigh_dist(neigh_dist::Integer, i::Int) = neigh_dist
2222
get_neigh_dist(neigh_dist::Tuple, i::Int) = neigh_dist[i]
2323

24-
2524
# Load a halo region from a neighboring chunk
2625
# region_code: N-tuple where each element is -1 (low), 0 (full extent), or +1 (high)
2726
# For dimensions with code 0, we take the full extent of the array
@@ -72,13 +71,17 @@ function select_neighborhood_chunks(chunks, idx, neigh_dist, boundary)
7271
new_idx = idx + chunk_offset
7372

7473
if is_past_boundary(size(chunks), new_idx)
74+
# Compute which dimensions are actually past boundary
75+
boundary_dims = ntuple(N) do d
76+
new_idx[d] < 1 || new_idx[d] > size(chunks)[d]
77+
end
7578
if boundary_has_transition(boundary)
7679
new_idx = boundary_transition(boundary, new_idx, size(chunks))
7780
else
7881
new_idx = idx
7982
end
8083
chunk = chunks[new_idx]
81-
push!(accesses, Dagger.@spawn load_boundary_region(boundary, chunk, region_code, neigh_dist))
84+
push!(accesses, Dagger.@spawn load_boundary_region(boundary, chunk, region_code, neigh_dist, boundary_dims))
8285
else
8386
chunk = chunks[new_idx]
8487
push!(accesses, Dagger.@spawn load_neighbor_region(chunk, region_code, neigh_dist))
@@ -113,17 +116,23 @@ end
113116

114117
is_past_boundary(size, idx) = any(ntuple(i -> idx[i] < 1 || idx[i] > size[i], length(size)))
115118

119+
"""
120+
Wrap boundary condition. Non-local accesses wrap around to the other side of the array.
121+
"""
116122
struct Wrap end
117123
boundary_has_transition(::Wrap) = true
118124
boundary_transition(::Wrap, idx, size) =
119125
CartesianIndex(ntuple(i -> mod1(idx[i], size[i]), length(size)))
120-
load_boundary_region(::Wrap, arr, region_code, neigh_dist) = load_neighbor_region(arr, region_code, neigh_dist)
126+
load_boundary_region(::Wrap, arr, region_code, neigh_dist, boundary_dims) = load_neighbor_region(arr, region_code, neigh_dist)
121127

128+
"""
129+
Pad boundary condition. Non-local accesses are padded with a specified value.
130+
"""
122131
struct Pad{T}
123132
padval::T
124133
end
125134
boundary_has_transition(::Pad) = false
126-
function load_boundary_region(pad::Pad, arr, region_code::NTuple{N,Int}, neigh_dist) where N
135+
function load_boundary_region(pad::Pad, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where N
127136
# Compute the size of this halo region
128137
# For dimensions with code 0, use full array size
129138
# For dimensions with code -1 or +1, use neigh_dist
@@ -134,6 +143,76 @@ function load_boundary_region(pad::Pad, arr, region_code::NTuple{N,Int}, neigh_d
134143
return move(task_processor(), fill(pad.padval, region_size))
135144
end
136145

146+
"""
147+
Reflect boundary condition. Non-local accesses are reflected back into the array.
148+
If `symm` is true, the reflected values include the nearest center elements.
149+
If `symm` is false, the reflected values do not include the nearest center elements.
150+
"""
151+
struct Reflect{Symmetric} end
152+
Reflect(symm::Bool) = Reflect{symm}()
153+
boundary_has_transition(::Reflect) = true
154+
# Clamp to valid chunk indices - we stay at the boundary chunk
155+
boundary_transition(::Reflect, idx, size) =
156+
CartesianIndex(ntuple(i -> clamp(idx[i], 1, size[i]), length(size)))
157+
function load_boundary_region(::Reflect{Symm}, arr, region_code::NTuple{N,Int}, neigh_dist, boundary_dims::NTuple{N,Bool}) where {N, Symm}
158+
# Only flip region_code for dimensions that are BOTH:
159+
# 1. Non-zero in region_code (we're accessing a neighbor in that dimension)
160+
# 2. Actually past boundary (boundary_dims[i] is true)
161+
# For dimensions not past boundary, keep the original region_code behavior
162+
flipped_code = ntuple(N) do i
163+
if region_code[i] != 0 && boundary_dims[i]
164+
# This dimension needs reflection - flip the code
165+
-region_code[i]
166+
else
167+
# Keep original code (either 0, or not past boundary)
168+
region_code[i]
169+
end
170+
end
171+
172+
# For non-symmetric (mirror), skip 1 element to exclude the edge
173+
# For symmetric, include the edge element (skip = 0)
174+
# Only apply skip to dimensions that are being reflected
175+
skip = Symm ? 0 : 1
176+
177+
# Compute region indices
178+
start_idx = CartesianIndex(ntuple(N) do i
179+
needs_skip = boundary_dims[i] && region_code[i] != 0
180+
actual_skip = needs_skip ? skip : 0
181+
if flipped_code[i] == -1
182+
# Taking from end (high side)
183+
lastindex(arr, i) - get_neigh_dist(neigh_dist, i) + 1 - actual_skip
184+
elseif flipped_code[i] == +1
185+
# Taking from start (low side)
186+
firstindex(arr, i) + actual_skip
187+
else
188+
firstindex(arr, i)
189+
end
190+
end)
191+
stop_idx = CartesianIndex(ntuple(N) do i
192+
needs_skip = boundary_dims[i] && region_code[i] != 0
193+
actual_skip = needs_skip ? skip : 0
194+
if flipped_code[i] == +1
195+
firstindex(arr, i) + get_neigh_dist(neigh_dist, i) - 1 + actual_skip
196+
elseif flipped_code[i] == -1
197+
lastindex(arr, i) - actual_skip
198+
else
199+
lastindex(arr, i)
200+
end
201+
end)
202+
203+
region = move(task_processor(), collect(@view arr[start_idx:stop_idx]))
204+
205+
# Reverse only along dimensions that are actually being reflected
206+
# (both non-zero in region_code AND past boundary)
207+
for i in 1:N
208+
if region_code[i] != 0 && boundary_dims[i]
209+
region = reverse(region, dims=i)
210+
end
211+
end
212+
213+
return region
214+
end
215+
137216
"""
138217
@stencil begin body end
139218

test/array/stencil.jl

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import Dagger: @stencil, Wrap, Pad
1+
import Dagger: @stencil, Wrap, Pad, Reflect
22

33
function test_stencil()
44
@testset "Simple assignment" begin
@@ -67,6 +67,99 @@ function test_stencil()
6767
@test collect(B) == expected_B_pad
6868
end
6969

70+
@testset "Reflect boundary (symmetric)" begin
71+
# Test symmetric reflection (edge element IS included/repeated)
72+
# For A = [1, 2, 3, 4] with Reflect(true):
73+
# idx=0 → 1, idx=-1 → 2 (reflection includes edge)
74+
# idx=5 → 4, idx=6 → 3 (reflection includes edge)
75+
A = DArray([1, 2, 3, 4], Blocks(2))
76+
B = zeros(Blocks(2), Int, 4)
77+
Dagger.spawn_datadeps() do
78+
@stencil begin
79+
B[idx] = sum(@neighbors(A[idx], 1, Reflect(true)))
80+
end
81+
end
82+
# B[1]: neighbors at indices 0, 1, 2 -> reflected 0 becomes 1, so [1, 1, 2] = 4
83+
# B[2]: neighbors at indices 1, 2, 3 -> [1, 2, 3] = 6
84+
# B[3]: neighbors at indices 2, 3, 4 -> [2, 3, 4] = 9
85+
# B[4]: neighbors at indices 3, 4, 5 -> reflected 5 becomes 4, so [3, 4, 4] = 11
86+
expected_B_symm = [4, 6, 9, 11]
87+
@test collect(B) == expected_B_symm
88+
end
89+
90+
@testset "Reflect boundary (mirror)" begin
91+
# Test mirror reflection (edge element NOT included/repeated)
92+
# For A = [1, 2, 3, 4] with Reflect(false):
93+
# idx=0 → 2, idx=-1 → 3 (reflection skips edge)
94+
# idx=5 → 3, idx=6 → 2 (reflection skips edge)
95+
A = DArray([1, 2, 3, 4], Blocks(2))
96+
B = zeros(Blocks(2), Int, 4)
97+
Dagger.spawn_datadeps() do
98+
@stencil begin
99+
B[idx] = sum(@neighbors(A[idx], 1, Reflect(false)))
100+
end
101+
end
102+
# B[1]: neighbors at indices 0, 1, 2 -> reflected 0 becomes 2, so [2, 1, 2] = 5
103+
# B[2]: neighbors at indices 1, 2, 3 -> [1, 2, 3] = 6
104+
# B[3]: neighbors at indices 2, 3, 4 -> [2, 3, 4] = 9
105+
# B[4]: neighbors at indices 3, 4, 5 -> reflected 5 becomes 3, so [3, 4, 3] = 10
106+
expected_B_mirror = [5, 6, 9, 10]
107+
@test collect(B) == expected_B_mirror
108+
end
109+
110+
@testset "Reflect boundary 2D (symmetric)" begin
111+
# Test 2D symmetric reflection with a gradient pattern
112+
A = DArray(reshape(1:16, 4, 4), Blocks(2, 2))
113+
B = zeros(Blocks(2, 2), Int, 4, 4)
114+
Dagger.spawn_datadeps() do
115+
@stencil begin
116+
B[idx] = sum(@neighbors(A[idx], 1, Reflect(true)))
117+
end
118+
end
119+
# Symmetric: idx < 1 → 1 - idx, idx > size → 2*size + 1 - idx
120+
A_collected = collect(A)
121+
expected_B_symm = zeros(Int, 4, 4)
122+
for i in 1:4, j in 1:4
123+
sum_val = 0
124+
for di in -1:1, dj in -1:1
125+
ni, nj = i + di, j + dj
126+
# Apply symmetric reflection logic
127+
# For symmetric: idx < 1 → 1 - idx, idx > size → 2*size + 1 - idx
128+
ni = ni < 1 ? 1 - ni : (ni > 4 ? 2*4 + 1 - ni : ni)
129+
nj = nj < 1 ? 1 - nj : (nj > 4 ? 2*4 + 1 - nj : nj)
130+
sum_val += A_collected[ni, nj]
131+
end
132+
expected_B_symm[i, j] = sum_val
133+
end
134+
@test collect(B) == expected_B_symm
135+
end
136+
137+
@testset "Reflect boundary 2D (mirror)" begin
138+
# Test 2D mirror reflection with a gradient pattern
139+
A = DArray(reshape(1:16, 4, 4), Blocks(2, 2))
140+
B = zeros(Blocks(2, 2), Int, 4, 4)
141+
Dagger.spawn_datadeps() do
142+
@stencil begin
143+
B[idx] = sum(@neighbors(A[idx], 1, Reflect(false)))
144+
end
145+
end
146+
# Mirror: idx < 1 → 2 - idx, idx > size → 2*size - idx
147+
A_collected = collect(A)
148+
expected_B_mirror = zeros(Int, 4, 4)
149+
for i in 1:4, j in 1:4
150+
sum_val = 0
151+
for di in -1:1, dj in -1:1
152+
ni, nj = i + di, j + dj
153+
# Apply mirror reflection logic
154+
ni = ni < 1 ? 2 - ni : (ni > 4 ? 2*4 - ni : ni)
155+
nj = nj < 1 ? 2 - nj : (nj > 4 ? 2*4 - nj : nj)
156+
sum_val += A_collected[ni, nj]
157+
end
158+
expected_B_mirror[i, j] = sum_val
159+
end
160+
@test collect(B) == expected_B_mirror
161+
end
162+
70163
@testset "Multiple expressions" begin
71164
A = zeros(Blocks(2, 2), Int, 4, 4)
72165
B = zeros(Blocks(2, 2), Int, 4, 4)

0 commit comments

Comments
 (0)