Skip to content

Commit

Permalink
Untested: cross-stream CPU events
Browse files Browse the repository at this point in the history
The test suite hangs.
I have the impression this was happening already before this commit.
  • Loading branch information
lukstafi committed Nov 20, 2024
1 parent aee62bc commit 72bf7ec
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 17 deletions.
4 changes: 2 additions & 2 deletions arrayjit/lib/backend_impl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ module type No_buffer_retrieval_or_syncing = sig
[Invalid_argument] if [into_merge_buffer = No] and [dst_ptr = None]. *)
end

(** A compilation-agnostic backend API -- {!Lowered_backend} instantates it, but
{!Lowered_no_device_backend} backends are also converted to its instantations. *)
(** An intermediate stage for converting {!Lowered_no_device_backend} backends into
{!Lowered_backend}. *)
module type With_scheduler = sig
include Backend_device_common

Expand Down
1 change: 0 additions & 1 deletion arrayjit/lib/backend_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ module type With_buffer_retrieval_and_syncing = sig
NOTE: If [into_merge_buffer=Streaming], after scheduling the work on [dst] using the merge
buffer but before scheduling work on [src] that modifies [tn], execute
[will_wait_for src (all_work (get_ctx_stream dst))]. *)
(* FIXME: udpate the syncing comment. *)
end

module type Backend = sig
Expand Down
59 changes: 45 additions & 14 deletions arrayjit/lib/schedulers.ml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,11 @@ module Multicore (Backend : For_add_scheduler) :
mut : (Mut.t[@sexp.opaque]);
host_wait_for_idle : (Stdlib.Condition.t[@sexp.opaque]);
dev_wait_for_work : (Stdlib.Condition.t[@sexp.opaque]);
clock_tick : (Stdlib.Condition.t[@sexp.opaque]);
mutable is_ready : bool;
mutable schedule_clock : int;
run_clock : Utils.atomic_int;
stream_id : int;
}
[@@deriving sexp_of]

Expand All @@ -54,7 +58,7 @@ module Multicore (Backend : For_add_scheduler) :
let sexp_of_domain (d : domain) = Sexp.Atom ("domain-" ^ Int.to_string (Domain.get_id d :> int))

type runner = { state : stream_state; domain : domain } [@@deriving sexp_of]
type event = Not_implemented_yet [@@deriving sexp_of]
type event = { stream_state : stream_state; target_clock : int } [@@deriving sexp_of]

let name = "multicore_" ^ Backend.name
end
Expand All @@ -63,15 +67,14 @@ module Multicore (Backend : For_add_scheduler) :
include Device (Device_types) (Alloc_buffer_ignore_stream (Device_types) (Backend))
open Device_config

(** TODO: Blocks till the event completes, if it's not done already. *)
let sync Not_implemented_yet = ()

(** TODO: Whether the event completed. *)
let is_done Not_implemented_yet = true

(** TODO: Schedules waiting for the given event on the context's stream. *)
let will_wait_for _ctx Not_implemented_yet = ()
let sync { stream_state; target_clock } =
Mut.lock stream_state.mut;
while Atomic.get stream_state.run_clock < target_clock do
Condition.wait stream_state.clock_tick stream_state.mut
done;
Mut.unlock stream_state.mut

let is_done { stream_state; target_clock } = Atomic.get stream_state.run_clock >= target_clock
let get_used_memory _device = get_used_memory ()
let is_dev_queue_empty state = Queue.size state.queue = 0
let is_idle stream = is_dev_queue_empty stream.runner.state && stream.runner.state.is_ready
Expand All @@ -89,10 +92,6 @@ module Multicore (Backend : For_add_scheduler) :
Mut.unlock d.mut;
Option.iter d.stream_error ~f:(fun e -> Exn.reraise e @@ get_name stream))

(** TODO: Returns the event indicating if any currently running or scheduled computations on the
stream have completed. *)
let all_work _stream = Not_implemented_yet

let%track3_l_sexp schedule_task stream task =
assert (Domain.is_main_domain ());
[%log_result "schedule_task", Task.describe task, get_name stream];
Expand All @@ -110,6 +109,34 @@ module Multicore (Backend : For_add_scheduler) :
let global_run_no = ref 0
let device : device = make_device CPU ~ordinal:0

let will_wait_for ctx ({ stream_state; target_clock } as event) =
let work () = sync event in
let task =
Task.Task
{
context_lifetime = ();
description =
[%string
"wait on %{get_name ctx.stream} till clock %{target_clock#Int} on \
0:%{stream_state.stream_id#Int}"];
work;
}
in
schedule_task ctx.stream task

let all_work stream =
assert (Domain.is_main_domain ());
let stream_state = stream.runner.state in
stream_state.schedule_clock <- stream_state.schedule_clock + 1;
let work () =
Atomic.incr stream_state.run_clock;
Mut.lock stream_state.mut;
Stdlib.Condition.broadcast stream_state.clock_tick;
Mut.unlock stream_state.mut
in
schedule_task stream @@ Task { context_lifetime = (); description = "clock tick"; work };
{ stream_state; target_clock = stream_state.schedule_clock }

let%track3_l_sexp spinup_stream ~stream_id : stream =
Int.incr global_run_no;
let state =
Expand All @@ -121,6 +148,10 @@ module Multicore (Backend : For_add_scheduler) :
is_ready = false;
host_wait_for_idle = Stdlib.Condition.create ();
dev_wait_for_work = Stdlib.Condition.create ();
clock_tick = Stdlib.Condition.create ();
schedule_clock = 0;
run_clock = Atomic.make 0;
stream_id;
}
in
let%track3_l_sexp worker (() : unit) : unit =
Expand Down Expand Up @@ -154,7 +185,7 @@ module Multicore (Backend : For_add_scheduler) :
let num_devices () = 1
let suggested_num_streams _device = Domain.recommended_domain_count () - 1

let cleanup_stream stream =
let cleanup_stream (stream : stream) =
assert (Domain.is_main_domain ());
await stream;
let r = stream.runner in
Expand Down

0 comments on commit 72bf7ec

Please sign in to comment.