Skip to content

Commit

Permalink
Untested: synchronization for routines
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Nov 18, 2024
1 parent 164e9ef commit aee62bc
Show file tree
Hide file tree
Showing 4 changed files with 72 additions and 25 deletions.
37 changes: 28 additions & 9 deletions arrayjit/lib/anatomy_of_a_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -141,19 +141,38 @@ When using the default stream, CUDA would predictably write to the standard outp

## Synchronization and data transfers

OCANNL expects backends to implement FIFO queue scheduling, and an event mechanism for synchronizing between streams (and ideally devices), matching the CUDA specification. On top of events, OCANNL implements per-tensor-node synchronization, using the fields `reader_streams` and `writer_streams` of the device record, and `updating_for` of the stream record.
OCANNL expects backends to implement FIFO queue scheduling, and an event mechanism for synchronizing between streams (and ideally devices), matching the CUDA specification. On top of events, OCANNL implements per-tensor-node synchronization. 1/3rd of the `device` fields have to do with synchronization:

```ocaml
mutable scheduled_shared_merge_node : (Tnode.t * 'event option) option;
(** The tensor node that was most recently scheduled to be in the cross-stream merge buffer,
and its readiness event. *)
shared_writer_streams :
(('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(** The streams that most recently have been scheduled to update (write to) a
cross-stream-shared node, and the associated update completion event. The completed events
are removed opportunistically. *)
host_reading_streams :
(('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(** The streams that most recently have been reading from a node's on-host array. The
completed events are removed opportunistically. *)
host_writing_streams :
(('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(** The streams that most recently have been writing to a node's on-host array. The completed
events are removed opportunistically. *)
```

and 1/3rd of the stream fields also:

```ocaml
...
writer_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(** The streams that most recently have been scheduled to update (write to) the node, and the
associated update completion event. The completed events are removed opportunistically. *)
reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(** The streams that most recently have been reading from the node, and the associated use
completion events. The completed events are removed opportunistically. *)
...
updating_for : 'event Hashtbl.M(Tnode).t;
(* The completion event for updating (writing to) a node via this stream, if any. *)
mutable updating_for_merge_buffer : (Tnode.t * 'event) option;
(** Like {!field-updating_for}, but for the merge buffer. *)
reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(** The streams, other than this stream, that most recently have been reading from a node in
this stream's context, and the associated use completion events. The completed events are
removed opportunistically. *)
```

Besides routines, calling `from_host`, `to_host`, `device_to_device` from a backend puts the corresponding tasks on the device's queue. Both invoking a routine and calling these copying functions will perform the necessary event creations and synchronizations to ensure that when scheduling writing into an array precedes scheduling reading from it, the actual writing also precedes the actual reading.
Expand Down
7 changes: 4 additions & 3 deletions arrayjit/lib/backend_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,9 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device = {
mutable shared_merge_buffer : 'buffer_ptr buffer option;
(** Depending on backend implementations, either the currently used cross-stream merge buffer,
or the one most recently scheduled. *)
mutable scheduled_shared_merge_node : Tnode.t option;
(** The tensor node that was most recently scheduled to be in the cross-stream merge buffer. *)
mutable scheduled_shared_merge_node : (Tnode.t * 'event option) option;
(** The tensor node that was most recently scheduled to be in the cross-stream merge buffer,
and its readiness event. *)
mutable latest_stream_id : int;
released : Utils.atomic_bool;
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
Expand Down Expand Up @@ -123,7 +124,7 @@ and ('buffer_ptr, 'dev, 'runner, 'event) stream = {
stream_id : int; (** An ID unique within the device. *)
mutable allocated_buffer : 'buffer_ptr buffer option;
updating_for : 'event Hashtbl.M(Tnode).t;
(* The completion event for updating (writing to) a node via this stream, if any. *)
(* The completion event for updating (writing to) a node via this stream, if any. *)
mutable updating_for_merge_buffer : (Tnode.t * 'event) option;
(** Like {!field-updating_for}, but for the merge buffer. *)
reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
Expand Down
43 changes: 30 additions & 13 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -33,12 +33,13 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
let wait_for_ready ~dst ~src tn =
let s = src.stream in
let d = dst.stream in
(* TODO: maybe it's worthwhile to clean up s.updating_for every now and then. *)
Hashtbl.find s.updating_for tn
|> Option.iter ~f:(fun upd_e ->
if not (equal_stream s d || Backend.is_done upd_e) then Backend.will_wait_for dst upd_e)

let update_writer_event ?from s tn =
let e = Backend.all_work s in
let update_writer_event ?e ?from s tn =
let e = Option.value_or_thunk e ~default:(fun () -> Backend.all_work s) in
let f l = (s, e) :: Option.value ~default:[] l in
(match (from, tn) with
| None, _ -> ()
Expand Down Expand Up @@ -102,15 +103,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
device_to_device tn ~into_merge_buffer ~dst_ptr:(Some d_arr) ~dst ~src_ptr:s_arr
~src);
update_writer_event ~from:(`Src src.stream) dst.stream @@ Node tn;
[%log
"copying",
Tn.debug_name tn,
"from",
name_of src,
"at",
(s_arr : Backend.buffer_ptr),
"to",
(d_arr : Backend.buffer_ptr)];
[%log "copying", Tn.debug_name tn, "from", name_of src, "to", name_of dst];
true)
| Streaming when same_device ->
Backend.(
Expand All @@ -123,6 +116,26 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
update_writer_event ~from:(`Src src.stream) dst.stream @@ Merge_buffer tn;
[%log "copying into merge buffer", Tn.debug_name tn, "from", name_of src];
true)

let%track3_l_sexp sync_routine r =
let s = r.context.stream in
let pre () =
Hashtbl.iteri s.device.shared_writer_streams ~f:(fun ~key ~data ->
if Set.mem r.inputs key then
List.iter data ~f:(fun (work_stream, e) ->
if not (equal_stream work_stream s) then Backend.will_wait_for r.context e));
if r.merge_buffer_input then
Option.iter s.device.scheduled_shared_merge_node ~f:(fun (shared_tn, e) ->
match (s.scheduled_merge_node, e) with
| Some merge_tn, Some e ->
if Tn.equal shared_tn merge_tn then Backend.will_wait_for r.context e
| _ -> ())
in
let post () =
let e = Backend.all_work s in
Set.iter r.outputs ~f:(fun tn -> update_writer_event ~e s @@ Node tn)
in
{ r with schedule = Task.(prepend ~work:pre @@ append ~work:post r.schedule) }
end

let lower_assignments ?name bindings asgns =
Expand Down Expand Up @@ -397,7 +410,8 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
Task.prepend schedule ~work:(fun () ->
check_merge_buffer context.stream ~code_node:code.expected_merge_node)
in
{ context; schedule; bindings; name = code.name; inputs; merge_buffer_input; outputs }
sync_routine
{ context; schedule; bindings; name = code.name; inputs; merge_buffer_input; outputs }

let%debug3_sexp link_batch context code_batch =
verify_prior_context ~use_host_memory ~ctx_arrays:context.ctx_arrays
Expand All @@ -423,7 +437,10 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
Task.prepend schedule ~work:(fun () ->
check_merge_buffer context.stream ~code_node:expected_merge_node)
in
(context, Some { context; schedule; bindings; name; inputs; merge_buffer_input; outputs }))
let r =
sync_routine { context; schedule; bindings; name; inputs; merge_buffer_input; outputs }
in
(context, Some r))
end

module Cuda_backend : Backend = Raise_backend ((Cuda_backend : Lowered_backend))
Expand Down
10 changes: 10 additions & 0 deletions arrayjit/lib/task.ml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,16 @@ let prepend ~work (Task task) =
task.work ());
}

let append ~work (Task task) =
Task
{
task with
work =
(fun () ->
task.work ();
work ());
}

let%track3_l_sexp enschedule ~schedule_task ~get_stream_name stream
(Task { description; _ } as task) =
[%log_result "enschedule", description, "on", get_stream_name stream];
Expand Down

0 comments on commit aee62bc

Please sign in to comment.