Skip to content

Commit

Permalink
Proper syncing for from_host, to_host and device_to_device
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Nov 18, 2024
1 parent b116838 commit 164e9ef
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 41 deletions.
1 change: 1 addition & 0 deletions arrayjit/lib/backend_impl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ struct
stream_id;
allocated_buffer = None;
updating_for = Hashtbl.create (module Tnode);
updating_for_merge_buffer = None;
reader_streams = Hashtbl.create (module Tnode);
}

Expand Down
5 changes: 4 additions & 1 deletion arrayjit/lib/backend_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ type 'context routine = {
inputs : Set.M(Tnode).t;
(** The materialized read-only and read-before-write (within the routine) non-constant nodes.
They are inputs in a broad sense, as they could be recurrent nodes or parameters. *)
merge_buffer_input : bool; (** Similar to {!field-inputs}, for the merge buffer. *)
outputs : Set.M(Tnode).t; (** All the materialized nodes written-to by the routine. *)
}
[@@deriving sexp_of]
Expand Down Expand Up @@ -123,6 +124,8 @@ and ('buffer_ptr, 'dev, 'runner, 'event) stream = {
mutable allocated_buffer : 'buffer_ptr buffer option;
updating_for : 'event Hashtbl.M(Tnode).t;
(* The completion event for updating (writing to) a node via this stream, if any. *)
mutable updating_for_merge_buffer : (Tnode.t * 'event) option;
(** Like {!field-updating_for}, but for the merge buffer. *)
reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(** The streams, other than this stream, that most recently have been reading from a node in
this stream's context, and the associated use completion events. The completed events are
Expand Down Expand Up @@ -274,7 +277,7 @@ module type With_buffer_retrieval_and_syncing = sig
- If [into_merge_buffer=Streaming], remembers the buffer pointer of the source node to use for
streaming.
- If [into_merge_buffer=Copy], schedules copying from [src] to the merge buffer of [dst]'s
stream, and registers [dst.stream] with a reader event for the node.
stream, and updates the writer event for the merge buffer.
NOTE: If [into_merge_buffer=Streaming], after scheduling the work on [dst] using the merge
buffer but before scheduling work on [src] that modifies [tn], execute
Expand Down
61 changes: 38 additions & 23 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -30,32 +30,40 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
|> List.iter ~f:(fun (work_stream, e) ->
if not (equal_stream work_stream s) then Backend.will_wait_for ctx e)

let update_writer_event ?(from_host = false) s tn =
let e = Backend.all_work s in
if from_host then
Hashtbl.update s.device.host_writing_streams tn ~f:(fun l ->
(s, e) :: Option.value ~default:[] l);
(* To be on the safe side, record events for potentially cross-stream nodes. *)
if Tn.potentially_cross_stream tn then
Hashtbl.update s.device.shared_writer_streams tn ~f:(fun l ->
(s, e) :: Option.value ~default:[] l);
Hashtbl.update s.updating_for tn ~f:(fun _ -> e)

let add_reader s tn from =
let wait_for_ready ~dst ~src tn =
let s = src.stream in
let d = dst.stream in
Hashtbl.find s.updating_for tn
|> Option.iter ~f:(fun upd_e ->
if not (equal_stream s d || Backend.is_done upd_e) then Backend.will_wait_for dst upd_e)

let update_writer_event ?from s tn =
let e = Backend.all_work s in
let f l = (s, e) :: Option.value ~default:[] l in
match from with
| `Host -> Hashtbl.update s.device.host_reading_streams tn ~f
| `Src src -> Hashtbl.update src.reader_streams tn ~f
(match (from, tn) with
| None, _ -> ()
| Some `Host, Assignments.(Node tn | Merge_buffer tn) ->
Hashtbl.update s.device.host_reading_streams tn ~f
| Some (`Src src), (Node tn | Merge_buffer tn) -> Hashtbl.update src.reader_streams tn ~f);
(* To be on the safe side, record events for potentially cross-stream nodes. *)
match tn with
| Node tn ->
if Tn.potentially_cross_stream tn then
Hashtbl.update s.device.shared_writer_streams tn ~f:(fun l ->
(s, e) :: Option.value ~default:[] l);
Hashtbl.update s.updating_for tn ~f:(fun _ -> e)
| Merge_buffer tn ->
Option.iter s.updating_for_merge_buffer ~f:(fun (_old_tn, old_e) ->
assert (Backend.is_done old_e));
s.updating_for_merge_buffer <- Some (tn, e)

let%diagn2_l_sexp from_host (ctx : Backend.context) tn =
match (tn, Map.find ctx.ctx_arrays tn) with
| { Tn.array = (lazy (Some hosted)); _ }, Some dst ->
wait_for_all ctx ctx.stream.reader_streams tn;
[%log "copying", Tn.debug_name tn, "to", (dst : Backend.buffer_ptr), "from host"];
Backend.from_host ~dst_ptr:dst ~dst:ctx hosted;
update_writer_event ~from_host:true ctx.stream tn;
add_reader ctx.stream tn @@ `Host;
update_writer_event ~from:`Host ctx.stream @@ Node tn;
true
| _ -> false

Expand All @@ -66,6 +74,10 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
wait_for_all ctx ctx.stream.device.shared_writer_streams tn;
[%log "copying", Tn.debug_name tn, "at", (src : Backend.buffer_ptr), "to host"];
Backend.to_host ~src_ptr:src ~src:ctx hosted;
let s = ctx.stream in
let e = Backend.all_work s in
Hashtbl.update s.device.host_writing_streams tn ~f:(fun l ->
(s, e) :: Option.value ~default:[] l);
true
| _ -> false

Expand All @@ -80,6 +92,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
match Map.find src.ctx_arrays tn with
| None -> false
| Some s_arr -> (
wait_for_ready ~dst ~src tn;
match into_merge_buffer with
| No -> (
match Map.find dst.ctx_arrays tn with
Expand All @@ -88,8 +101,9 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
Backend.(
device_to_device tn ~into_merge_buffer ~dst_ptr:(Some d_arr) ~dst ~src_ptr:s_arr
~src);
update_writer_event ~from:(`Src src.stream) dst.stream @@ Node tn;
[%log
"copied",
"copying",
Tn.debug_name tn,
"from",
name_of src,
Expand All @@ -106,7 +120,8 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
| Copy | Streaming ->
Backend.(
device_to_device tn ~into_merge_buffer ~dst_ptr:None ~dst ~src_ptr:s_arr ~src);
[%log "copied into merge buffer", Tn.debug_name tn, "from", name_of src];
update_writer_event ~from:(`Src src.stream) dst.stream @@ Merge_buffer tn;
[%log "copying into merge buffer", Tn.debug_name tn, "from", name_of src];
true)
end

Expand Down Expand Up @@ -371,7 +386,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
let%debug3_sexp link context (code : code) =
verify_prior_context ~use_host_memory ~ctx_arrays:context.ctx_arrays
~from_prior_context:code.from_prior_context;
let inputs, outputs = Low_level.input_and_output_nodes code.lowered in
let (inputs, outputs), merge_buffer_input = Low_level.input_and_output_nodes code.lowered in
let ctx_arrays =
Hashtbl.fold code.lowered.traced_store ~init:context.ctx_arrays
~f:(alloc_if_needed context.stream)
Expand All @@ -382,7 +397,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
Task.prepend schedule ~work:(fun () ->
check_merge_buffer context.stream ~code_node:code.expected_merge_node)
in
{ context; schedule; bindings; name = code.name; inputs; outputs }
{ context; schedule; bindings; name = code.name; inputs; merge_buffer_input; outputs }

let%debug3_sexp link_batch context code_batch =
verify_prior_context ~use_host_memory ~ctx_arrays:context.ctx_arrays
Expand All @@ -401,14 +416,14 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
let ctx_arrays = Option.value_exn ctx_arrays.(i) in
let context = make_child ~ctx_arrays context in
let expected_merge_node = code_batch.expected_merge_nodes.(i) in
let inputs, outputs =
let (inputs, outputs), merge_buffer_input =
Low_level.input_and_output_nodes @@ Option.value_exn code_batch.lowereds.(i)
in
let schedule =
Task.prepend schedule ~work:(fun () ->
check_merge_buffer context.stream ~code_node:expected_merge_node)
in
(context, Some { context; schedule; bindings; name; inputs; outputs }))
(context, Some { context; schedule; bindings; name; inputs; merge_buffer_input; outputs }))
end

module Cuda_backend : Backend = Raise_backend ((Cuda_backend : Lowered_backend))
Expand Down
35 changes: 19 additions & 16 deletions arrayjit/lib/low_level.ml
Original file line number Diff line number Diff line change
Expand Up @@ -738,22 +738,25 @@ type optimized = { traced_store : traced_store; llc : t; merge_node : Tn.t optio
[@@deriving sexp_of]

let input_and_output_nodes optimized =
Hashtbl.fold optimized.traced_store
~init:(Set.empty (module Tn), Set.empty (module Tn))
~f:(fun ~key ~data (inputs, outputs) ->
let materialized = Tn.is_materialized_force key 50 in
let inputs =
if
materialized && (not (Tn.known_constant key)) && (data.read_only || data.read_before_write)
then Set.add inputs key
else inputs
in
let outputs =
if materialized && (data.zeroed_out || not (Hash_set.is_empty data.assignments)) then
Set.add outputs key
else outputs
in
(inputs, outputs))
( Hashtbl.fold optimized.traced_store
~init:(Set.empty (module Tn), Set.empty (module Tn))
~f:(fun ~key ~data (inputs, outputs) ->
let materialized = Tn.is_materialized_force key 50 in
let inputs =
if
materialized
&& (not (Tn.known_constant key))
&& (data.read_only || data.read_before_write)
then Set.add inputs key
else inputs
in
let outputs =
if materialized && (data.zeroed_out || not (Hash_set.is_empty data.assignments)) then
Set.add outputs key
else outputs
in
(inputs, outputs)),
Option.is_some optimized.merge_node )

let%diagn2_sexp optimize_proc static_indices llc =
let traced_store = Hashtbl.create (module Tnode) in
Expand Down
2 changes: 1 addition & 1 deletion arrayjit/lib/low_level.mli
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ val optimize :
t ->
optimized

val input_and_output_nodes : optimized -> Set.M(Tnode).t * Set.M(Tnode).t
val input_and_output_nodes : optimized -> (Set.M(Tnode).t * Set.M(Tnode).t) * bool
(** Inputs are the materialized read-only and read-before-write (within the code) non-constant
nodes. They are inputs in a broad sense, as they could be recurrent nodes or parameters.
Expand Down

0 comments on commit 164e9ef

Please sign in to comment.