Skip to content

Commit c8c9599

Browse files
committed
datadeps: Make schedulers programmable with dispatch
1 parent fd13b3a commit c8c9599

File tree

6 files changed

+249
-184
lines changed

6 files changed

+249
-184
lines changed

docs/src/datadeps.md

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,66 @@ function Dagger.move!(dep_mod::Any, from_space::Dagger.MemorySpace, to_space::Da
221221
end
222222
```
223223

224+
## Custom Schedulers
225+
226+
The `spawn_datadeps` function accepts an optional `scheduler` keyword argument that controls how tasks are assigned to processors. By default, `spawn_datadeps` uses `RoundRobinScheduler()`, which cycles through available processors in a round-robin fashion.
227+
228+
### Built-in Schedulers
229+
230+
- **`RoundRobinScheduler()`** (default): Assigns tasks to processors in round-robin order. This is a simple and effective scheduler for most use cases.
231+
- **`NaiveScheduler()`**: Uses the main Dagger scheduler's cost estimation to select processors. (Currently experimental)
232+
- **`UltraScheduler()`**: An advanced scheduler that tracks task completion times and tries to minimize overall execution time. (Currently experimental)
233+
234+
### Using a Different Scheduler
235+
236+
You can pass a scheduler to `spawn_datadeps` like so:
237+
238+
```julia
239+
Dagger.spawn_datadeps(; scheduler=Dagger.RoundRobinScheduler()) do
240+
Dagger.@spawn my_task!(InOut(A))
241+
Dagger.@spawn another_task!(In(B))
242+
end
243+
```
244+
245+
### Writing Your Own Scheduler
246+
247+
You can implement a custom scheduler by:
248+
1. Defining a struct that subtypes `Dagger.DataDepsScheduler`
249+
2. Implementing the `Dagger.datadeps_schedule_task` method for your scheduler
250+
251+
The scheduler's job is to select which processor should execute a given task. Here's a simple example that randomly selects a processor:
252+
253+
```julia
254+
# Define the scheduler type
255+
struct RandomScheduler <: Dagger.DataDepsScheduler end
256+
257+
# Implement the scheduling function
258+
function Dagger.datadeps_schedule_task(::RandomScheduler, state, all_procs, all_scope, task_scope, spec, task)
259+
# Reduce the available processors to the ones that are compatible with the task scope
260+
compatible_procs = filter(proc->proc_in_scope(proc, task_scope), all_procs)
261+
if isempty(compatible_procs)
262+
throw(SchedulingException("No processors available for task $(task.uid) with scope $(task_scope)"))
263+
end
264+
# Simply pick a random processor from the compatible ones
265+
return rand(compatible_procs)
266+
end
267+
268+
# Use it
269+
Dagger.spawn_datadeps(; scheduler=RandomScheduler()) do
270+
Dagger.@spawn my_task!(InOut(A))
271+
end
272+
```
273+
274+
The `datadeps_schedule_task` function receives:
275+
- `state`: Internal datadeps state (typically not needed for simple schedulers)
276+
- `all_procs`: Vector of all available processors
277+
- `all_scope`: The combined scope of all processors
278+
- `task_scope`: The scope constraint for this specific task
279+
- `spec`: The task specification
280+
- `task`: The DTask being scheduled
281+
282+
The function must return a processor from `all_procs` that is compatible with `task_scope`.
283+
224284
## Chunk and DTask slicing with `view`
225285

226286
The `view` function allows you to efficiently create a "view" of a `Chunk` or `DTask` that contains an array. This enables operations on specific parts of your distributed data using standard Julia array slicing, without needing to materialize the entire array.

src/Dagger.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ include("sch/Sch.jl"); using .Sch
9191
include("datadeps/aliasing.jl")
9292
include("datadeps/chunkview.jl")
9393
include("datadeps/remainders.jl")
94+
include("datadeps/scheduling.jl")
9495
include("datadeps/queue.jl")
9596

9697
# Stencils

src/datadeps/aliasing.jl

Lines changed: 1 addition & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -393,11 +393,7 @@ struct DataDepsState
393393
ainfos_owner::Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}
394394
ainfos_readers::Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}
395395

396-
function DataDepsState(aliasing::Bool)
397-
if !aliasing
398-
@warn "aliasing=false is no longer supported, aliasing is now always enabled" maxlog=1
399-
end
400-
396+
function DataDepsState()
401397
arg_to_chunk = IdDict{Any,Chunk}()
402398
arg_origin = IdDict{Any,MemorySpace}()
403399
remote_args = Dict{MemorySpace,IdDict{Any,Any}}()
@@ -832,21 +828,3 @@ move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor,
832828
move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Symbol) = x
833829
move_rewrap(cache::AliasedObjectCache, from_proc::Processor, to_proc::Processor, from_space::MemorySpace, to_space::MemorySpace, x::Type) = x
834830
=#
835-
836-
struct DataDepsSchedulerState
837-
task_to_spec::Dict{DTask,DTaskSpec}
838-
assignments::Dict{DTask,MemorySpace}
839-
dependencies::Dict{DTask,Set{DTask}}
840-
task_completions::Dict{DTask,UInt64}
841-
space_completions::Dict{MemorySpace,UInt64}
842-
capacities::Dict{MemorySpace,Int}
843-
844-
function DataDepsSchedulerState()
845-
return new(Dict{DTask,DTaskSpec}(),
846-
Dict{DTask,MemorySpace}(),
847-
Dict{DTask,Set{DTask}}(),
848-
Dict{DTask,UInt64}(),
849-
Dict{MemorySpace,UInt64}(),
850-
Dict{MemorySpace,Int}())
851-
end
852-
end

src/datadeps/queue.jl

Lines changed: 15 additions & 161 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
struct DataDepsTaskQueue <: AbstractTaskQueue
1+
struct DataDepsTaskQueue{Scheduler<:DataDepsScheduler} <: AbstractTaskQueue
22
# The queue above us
33
upper_queue::AbstractTaskQueue
44
# The set of tasks that have already been seen
@@ -8,13 +8,13 @@ struct DataDepsTaskQueue <: AbstractTaskQueue
88
# The mapping from task to graph ID
99
task_to_id::Union{Dict{DTask,Int},Nothing}
1010
# Which scheduler to use to assign tasks to processors
11-
scheduler::Symbol
11+
scheduler::Scheduler
1212

13-
function DataDepsTaskQueue(upper_queue; scheduler::Symbol)
13+
function DataDepsTaskQueue(upper_queue; scheduler::DataDepsScheduler)
1414
seen_tasks = DTaskPair[]
1515
g = SimpleDiGraph()
1616
task_to_id = Dict{DTask,Int}()
17-
return new(upper_queue, seen_tasks, g, task_to_id, scheduler)
17+
return new{typeof(scheduler)}(upper_queue, seen_tasks, g, task_to_id, scheduler)
1818
end
1919
end
2020

@@ -56,7 +56,7 @@ returned from `spawn_datadeps`.
5656
"""
5757
function spawn_datadeps(f::Base.Callable; static::Bool=true,
5858
traversal::Symbol=:inorder,
59-
scheduler::Union{Symbol,Nothing}=nothing,
59+
scheduler::Union{DataDepsScheduler,Nothing}=nothing,
6060
aliasing::Bool=true,
6161
launch_wait::Union{Bool,Nothing}=nothing)
6262
if !static
@@ -69,7 +69,7 @@ function spawn_datadeps(f::Base.Callable; static::Bool=true,
6969
throw(ArgumentError("Aliasing analysis is no longer optional"))
7070
end
7171
wait_all(; check_errors=true) do
72-
scheduler = something(scheduler, DATADEPS_SCHEDULER[], :roundrobin)::Symbol
72+
scheduler = something(scheduler, DATADEPS_SCHEDULER[], RoundRobinScheduler())
7373
launch_wait = something(launch_wait, DATADEPS_LAUNCH_WAIT[], false)::Bool
7474
if launch_wait
7575
result = spawn_bulk() do
@@ -85,7 +85,7 @@ function spawn_datadeps(f::Base.Callable; static::Bool=true,
8585
return result
8686
end
8787
end
88-
const DATADEPS_SCHEDULER = ScopedValue{Union{Symbol,Nothing}}(nothing)
88+
const DATADEPS_SCHEDULER = ScopedValue{Union{DataDepsScheduler,Nothing}}(nothing)
8989
const DATADEPS_LAUNCH_WAIT = ScopedValue{Union{Bool,Nothing}}(nothing)
9090

9191
function distribute_tasks!(queue::DataDepsTaskQueue)
@@ -107,7 +107,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
107107
if isempty(all_procs)
108108
throw(Sch.SchedulingException("No processors available, try widening scope"))
109109
end
110-
scope = UnionScope(map(ExactScope, all_procs))
110+
all_scope = UnionScope(map(ExactScope, all_procs))
111111
exec_spaces = unique(vcat(map(proc->collect(memory_spaces(proc)), all_procs)...))
112112
if !all(space->space isa CPURAMMemorySpace, exec_spaces) && !all(space->root_worker_id(space) == myid(), exec_spaces)
113113
@warn "Datadeps support for multi-GPU, multi-worker is currently broken\nPlease be prepared for incorrect results or errors" maxlog=1
@@ -116,23 +116,14 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
116116
# Round-robin assign tasks to processors
117117
upper_queue = get_options(:task_queue)
118118

119-
state = DataDepsState(queue.aliasing)
120-
sstate = DataDepsSchedulerState()
121-
for proc in all_procs
122-
space = only(memory_spaces(proc))
123-
get!(()->0, sstate.capacities, space)
124-
sstate.capacities[space] += 1
125-
end
126-
127119
# Start launching tasks and necessary copies
120+
state = DataDepsState()
128121
write_num = 1
129-
proc_idx = 1
130-
#pressures = Dict{Processor,Int}()
131122
proc_to_scope_lfu = BasicLFUCache{Processor,AbstractScope}(1024)
132123
for pair in queue.seen_tasks
133124
spec = pair.spec
134125
task = pair.task
135-
write_num, proc_idx = distribute_task!(queue, state, all_procs, scope, spec, task, spec.fargs, proc_to_scope_lfu, write_num, proc_idx)
126+
write_num = distribute_task!(queue, state, all_procs, all_scope, spec, task, spec.fargs, proc_to_scope_lfu, write_num)
136127
end
137128

138129
# Copy args from remote to local
@@ -180,7 +171,7 @@ struct TypedDataDepsTaskArgument{T,N}
180171
end
181172
map_or_ntuple(f, xs::Vector) = map(f, 1:length(xs))
182173
@inline map_or_ntuple(@specialize(f), xs::NTuple{N,T}) where {N,T} = ntuple(f, Val(N))
183-
function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_procs, scope, spec::DTaskSpec{typed}, task::DTask, fargs, proc_to_scope_lfu, write_num::Int, proc_idx::Int) where typed
174+
function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_procs, all_scope, spec::DTaskSpec{typed}, task::DTask, fargs, proc_to_scope_lfu, write_num::Int) where typed
184175
@specialize spec fargs
185176

186177
if typed
@@ -191,153 +182,17 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr
191182

192183
task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope())
193184
scheduler = queue.scheduler
194-
if scheduler == :naive
195-
raw_args = map(arg->tochunk(value(arg)), spec.fargs)
196-
our_proc = remotecall_fetch(1, all_procs, raw_args) do all_procs, raw_args
197-
Sch.init_eager()
198-
sch_state = Sch.EAGER_STATE[]
199-
200-
@lock sch_state.lock begin
201-
# Calculate costs per processor and select the most optimal
202-
# FIXME: This should consider any already-allocated slots,
203-
# whether they are up-to-date, and if not, the cost of moving
204-
# data to them
205-
procs, costs = Sch.estimate_task_costs(sch_state, all_procs, nothing, raw_args)
206-
return first(procs)
207-
end
208-
end
209-
elseif scheduler == :smart
210-
raw_args = map(filter(arg->haskey(state.data_locality, value(arg)), spec.fargs)) do arg
211-
arg_chunk = tochunk(value(arg))
212-
# Only the owned slot is valid
213-
# FIXME: Track up-to-date copies and pass all of those
214-
return arg_chunk => data_locality[arg]
215-
end
216-
f_chunk = tochunk(value(spec.fargs[1]))
217-
our_proc, task_pressure = remotecall_fetch(1, all_procs, pressures, f_chunk, raw_args) do all_procs, pressures, f, chunks_locality
218-
Sch.init_eager()
219-
sch_state = Sch.EAGER_STATE[]
220-
221-
@lock sch_state.lock begin
222-
tx_rate = sch_state.transfer_rate[]
223-
224-
costs = Dict{Processor,Float64}()
225-
for proc in all_procs
226-
# Filter out chunks that are already local
227-
chunks_filt = Iterators.filter(((chunk, space)=chunk_locality)->!(proc in processors(space)), chunks_locality)
228-
229-
# Estimate network transfer costs based on data size
230-
# N.B. `affinity(x)` really means "data size of `x`"
231-
# N.B. We treat same-worker transfers as having zero transfer cost
232-
tx_cost = Sch.impute_sum(affinity(chunk)[2] for chunk in chunks_filt)
233-
234-
# Estimate total cost to move data and get task running after currently-scheduled tasks
235-
est_time_util = get(pressures, proc, UInt64(0))
236-
costs[proc] = est_time_util + (tx_cost/tx_rate)
237-
end
238-
239-
# Look up estimated task cost
240-
sig = Sch.signature(sch_state, f, map(first, chunks_locality))
241-
task_pressure = get(sch_state.signature_time_cost, sig, 1000^3)
242-
243-
# Shuffle procs around, so equally-costly procs are equally considered
244-
P = randperm(length(all_procs))
245-
procs = getindex.(Ref(all_procs), P)
246-
247-
# Sort by lowest cost first
248-
sort!(procs, by=p->costs[p])
249-
250-
best_proc = first(procs)
251-
return best_proc, task_pressure
252-
end
253-
end
254-
# FIXME: Pressure should be decreased by pressure of syncdeps on same processor
255-
pressures[our_proc] = get(pressures, our_proc, UInt64(0)) + task_pressure
256-
elseif scheduler == :ultra
257-
args = Base.mapany(spec.fargs) do arg
258-
pos, data = arg
259-
data, _ = unwrap_inout(data)
260-
if data isa DTask
261-
data = fetch(data; move_value=false, unwrap=false)
262-
end
263-
return pos => tochunk(data)
264-
end
265-
f_chunk = tochunk(value(spec.fargs[1]))
266-
task_time = remotecall_fetch(1, f_chunk, args) do f, args
267-
Sch.init_eager()
268-
sch_state = Sch.EAGER_STATE[]
269-
return @lock sch_state.lock begin
270-
sig = Sch.signature(sch_state, f, args)
271-
return get(sch_state.signature_time_cost, sig, 1000^3)
272-
end
273-
end
274-
275-
# FIXME: Copy deps are computed eagerly
276-
deps = @something(spec.options.syncdeps, Set{Any}())
277-
278-
# Find latest time-to-completion of all syncdeps
279-
deps_completed = UInt64(0)
280-
for dep in deps
281-
haskey(sstate.task_completions, dep) || continue # copy deps aren't recorded
282-
deps_completed = max(deps_completed, sstate.task_completions[dep])
283-
end
284-
285-
# Find latest time-to-completion of each memory space
286-
# FIXME: Figure out space completions based on optimal packing
287-
spaces_completed = Dict{MemorySpace,UInt64}()
288-
for space in exec_spaces
289-
completed = UInt64(0)
290-
for (task, other_space) in sstate.assignments
291-
space == other_space || continue
292-
completed = max(completed, sstate.task_completions[task])
293-
end
294-
spaces_completed[space] = completed
295-
end
296-
297-
# Choose the earliest-available memory space and processor
298-
# FIXME: Consider move time
299-
move_time = UInt64(0)
300-
local our_space_completed
301-
while true
302-
our_space_completed, our_space = findmin(spaces_completed)
303-
our_space_procs = filter(proc->proc in all_procs, processors(our_space))
304-
if isempty(our_space_procs)
305-
delete!(spaces_completed, our_space)
306-
continue
307-
end
308-
our_proc = rand(our_space_procs)
309-
break
310-
end
311-
312-
sstate.task_to_spec[task] = spec
313-
sstate.assignments[task] = our_space
314-
sstate.task_completions[task] = our_space_completed + move_time + task_time
315-
elseif scheduler == :roundrobin
316-
our_proc = all_procs[proc_idx]
317-
if task_scope == scope
318-
# all_procs is already limited to scope
319-
else
320-
if isa(constrain(task_scope, scope), InvalidScope)
321-
throw(Sch.SchedulingException("Scopes are not compatible: $(scope), $(task_scope)"))
322-
end
323-
while !proc_in_scope(our_proc, task_scope)
324-
proc_idx = mod1(proc_idx + 1, length(all_procs))
325-
our_proc = all_procs[proc_idx]
326-
end
327-
end
328-
else
329-
error("Invalid scheduler: $sched")
330-
end
185+
our_proc = datadeps_schedule_task(scheduler, state, all_procs, all_scope, task_scope, spec, task)
331186
@assert our_proc in all_procs
332187
our_space = only(memory_spaces(our_proc))
333188

334189
# Find the scope for this task (and its copies)
335190
task_scope = @something(spec.options.compute_scope, spec.options.scope, DefaultScope())
336-
if task_scope == scope
191+
if task_scope == all_scope
337192
# Optimize for the common case, cache the proc=>scope mapping
338193
our_scope = get!(proc_to_scope_lfu, our_proc) do
339194
our_procs = filter(proc->proc in all_procs, collect(processors(our_space)))
340-
return constrain(UnionScope(map(ExactScope, our_procs)...), scope)
195+
return constrain(UnionScope(map(ExactScope, our_procs)...), all_scope)
341196
end
342197
else
343198
# Use the provided scope and constrain it to the available processors
@@ -495,7 +350,6 @@ function distribute_task!(queue::DataDepsTaskQueue, state::DataDepsState, all_pr
495350
end
496351

497352
write_num += 1
498-
proc_idx = mod1(proc_idx + 1, length(all_procs))
499353

500-
return write_num, proc_idx
354+
return write_num
501355
end

0 commit comments

Comments
 (0)