Skip to content

Commit

Permalink
(1) Get rid of the option to share merge buffers, (2) refactor tracki…
Browse files Browse the repository at this point in the history
…ng merge buffer events

-- formerly `~into_merge_buffer:Streaming` would not generate an event,
but it should to prevent overriding the source.
(2) will be continued: prohibiting overriding till the routine using the streamed merge buffer finishes.
  • Loading branch information
lukstafi committed Nov 30, 2024
1 parent b9987fa commit d16b54f
Show file tree
Hide file tree
Showing 8 changed files with 65 additions and 73 deletions.
7 changes: 3 additions & 4 deletions arrayjit/lib/anatomy_of_a_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,6 @@ When using the default stream, CUDA would predictably write to the standard outp
OCANNL expects backends to implement FIFO queue scheduling, and an event mechanism for synchronizing between streams (and ideally devices), matching the CUDA specification. On top of events, OCANNL implements per-tensor-node synchronization. 1/3rd of the `device` fields have to do with synchronization:

```ocaml
mutable scheduled_shared_merge_node : (Tnode.t * 'event option) option;
(** The tensor node that was most recently scheduled to be in the cross-stream merge buffer,
and its readiness event. *)
shared_writer_streams :
(('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(** The streams that most recently have been scheduled to update (write to) a
Expand All @@ -162,7 +159,7 @@ OCANNL expects backends to implement FIFO queue scheduling, and an event mechani
events are removed opportunistically. *)
```

and 1/3rd of the stream fields also:
and some stream fields also:

```ocaml
updating_for : 'event Hashtbl.M(Tnode).t;
Expand All @@ -175,6 +172,8 @@ and 1/3rd of the stream fields also:
removed opportunistically. *)
```

While we never share merge buffers across streams, there is always an event associated with an occupied merge buffer. Its primary use is for tracking the merge buffer's stream as a reader on the source stream.

Besides routines, calling `from_host`, `to_host`, `device_to_device` from a backend puts the corresponding tasks on the device's queue. Both invoking a routine and calling these copying functions will perform the necessary event creations and synchronizations to ensure that when scheduling writing into an array precedes scheduling reading from it, the actual writing also precedes the actual reading.

### Data transfers
Expand Down
9 changes: 3 additions & 6 deletions arrayjit/lib/backend_impl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,6 @@ struct
{
dev;
ordinal;
shared_merge_buffer = None;
scheduled_shared_merge_node = None;
latest_stream_id = -1;
released = Atomic.make false;
cross_stream_candidates = Hashtbl.create (module Tnode);
Expand All @@ -117,7 +115,6 @@ struct
device;
runner;
merge_buffer = ref None;
scheduled_merge_node = None;
stream_id;
allocated_buffer = None;
updating_for = Hashtbl.create (module Tnode);
Expand Down Expand Up @@ -202,7 +199,7 @@ module type No_buffer_retrieval_or_syncing = sig
(** Like {!Backend.from_host}, but without synchronization and buffer retrieval. *)

val to_host : src_ptr:buffer_ptr -> src:context -> Ndarray.t -> unit
(** Like {!Backend.to_host}, but without synchronization and buffer retrieval. *)
(** Like {!Backend.to_host}, but without synchronization events and buffer retrieval. *)

val device_to_device :
Tnode.t ->
Expand All @@ -212,8 +209,8 @@ module type No_buffer_retrieval_or_syncing = sig
src_ptr:buffer_ptr ->
src:context ->
unit
(** Like {!Backend.device_to_device}, but without synchronization and buffer retrieval. Raises
[Invalid_argument] if [into_merge_buffer = No] and [dst_ptr = None]. *)
(** Like {!Backend.device_to_device}, but without synchronization events and buffer retrieval.
Raises [Invalid_argument] if [into_merge_buffer = No] and [dst_ptr = None]. *)
end

(** An intermediate stage for converting {!Lowered_no_device_backend} backends into
Expand Down
20 changes: 6 additions & 14 deletions arrayjit/lib/backend_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ type 'context routine = {
inputs : Set.M(Tnode).t;
(** The materialized read-only and read-before-write (within the routine) non-constant nodes.
They are inputs in a broad sense, as they could be recurrent nodes or parameters. *)
merge_buffer_input : bool; (** Similar to {!field-inputs}, for the merge buffer. *)
merge_buffer_input : Tnode.t option; (** Similar to {!field-inputs}, for the merge buffer. *)
outputs : Set.M(Tnode).t; (** All the materialized nodes written-to by the routine. *)
}
[@@deriving sexp_of]
Expand All @@ -82,8 +82,6 @@ end
type ('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
dev : 'dev;
ordinal : int;
mutable shared_merge_buffer : 'buffer_ptr buffer option;
mutable scheduled_shared_merge_node : (Tnode.t * 'event option) option;
mutable latest_stream_id : int;
released : Utils.atomic_bool;
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
Expand All @@ -100,7 +98,6 @@ and ('buffer_ptr, 'dev, 'runner, 'event) stream_ref = {
device : ('buffer_ptr, 'dev, 'runner, 'event) device_ref;
runner : 'runner;
merge_buffer : 'buffer_ptr buffer option ref;
mutable scheduled_merge_node : Tnode.t option;
stream_id : int;
mutable allocated_buffer : 'buffer_ptr buffer option;
updating_for : 'event Hashtbl.M(Tnode).t;
Expand All @@ -117,11 +114,6 @@ type ('buffer_ptr, 'dev, 'runner, 'event) device =
('buffer_ptr, 'dev, 'runner, 'event) device_ref = {
dev : 'dev;
ordinal : int;
mutable shared_merge_buffer : 'buffer_ptr buffer option;
(** Depending on backend implementations, either the currently used cross-stream merge buffer,
or the one most recently scheduled. *)
mutable scheduled_shared_merge_node : (Tnode.t * 'event option) option;
(** The tensor node that was most recently scheduled to be in the cross-stream merge buffer. *)
mutable latest_stream_id : int;
released : Utils.atomic_bool;
cross_stream_candidates : 'buffer_ptr Hashtbl.M(Tnode).t;
Expand Down Expand Up @@ -153,15 +145,15 @@ type ('buffer_ptr, 'dev, 'runner, 'event) stream =
runner : 'runner;
merge_buffer : 'buffer_ptr buffer option ref;
(** Depending on backend implementations, either the currently used merge buffer, or the one
most recently scheduled. *)
mutable scheduled_merge_node : Tnode.t option;
(** The tensor node that was most recently scheduled to be in the [stream]'s merge buffer. *)
most recently scheduled. Note that the pointer can be reused for nodes that fit in an
already allocated buffer. *)
stream_id : int; (** An ID unique within the device. *)
mutable allocated_buffer : 'buffer_ptr buffer option;
updating_for : 'event Hashtbl.M(Tnode).t;
(* The completion event for updating (writing to) a node via this stream, if any. *)
(* The completion event for the most recent updating (writing to) a node via this stream. *)
mutable updating_for_merge_buffer : (Tnode.t * 'event) option;
(** Like {!field-updating_for}, but for the merge buffer. *)
(** The tensor node that was most recently scheduled to be in the [stream]'s merge buffer and
its updating completion event. See also {!field-updating_for}. *)
reader_streams :
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
(** The streams, other than this stream, that most recently have been reading from a node in
Expand Down
47 changes: 23 additions & 24 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,15 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime

let check_merge_buffer stream ~code_node =
let name = function Some tn -> Tnode.debug_name tn | None -> "none" in
match (stream.scheduled_merge_node, code_node) with
match (stream.updating_for_merge_buffer, code_node) with
| _, None -> ()
| Some actual, Some expected when Tnode.equal actual expected -> ()
| Some (actual, _), Some expected when Tnode.equal actual expected -> ()
| _ ->
raise
@@ Utils.User_error
("Merge buffer mismatch, on stream: " ^ name stream.scheduled_merge_node
^ ", expected by code: " ^ name code_node)
("Merge buffer mismatch, on stream: "
^ name (Option.map ~f:fst stream.updating_for_merge_buffer)
^ ", expected by code: " ^ name code_node)

module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncing) = struct
let wait_for_all ctx streams tn =
Expand Down Expand Up @@ -54,8 +55,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
(s, e) :: Option.value ~default:[] l);
Hashtbl.update s.updating_for tn ~f:(fun _ -> e)
| Merge_buffer tn ->
Option.iter s.updating_for_merge_buffer ~f:(fun (_old_tn, old_e) ->
assert (Backend.is_done old_e));
(* Note: the previous event does not need to be done! *)
s.updating_for_merge_buffer <- Some (tn, e)

let%diagn2_l_sexp from_host (ctx : Backend.context) tn =
Expand Down Expand Up @@ -105,31 +105,32 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
update_writer_event ~from:(`Src src.stream) dst.stream @@ Node tn;
[%log "copying", Tn.debug_name tn, "from", name_of src, "to", name_of dst];
true)
| Streaming when same_device ->
Backend.(
device_to_device tn ~into_merge_buffer ~dst_ptr:None ~dst ~src_ptr:s_arr ~src);
[%log "using merge buffer for", Tn.debug_name tn, "from", name_of src];
true
| Copy | Streaming ->
Backend.(
device_to_device tn ~into_merge_buffer ~dst_ptr:None ~dst ~src_ptr:s_arr ~src);
update_writer_event ~from:(`Src src.stream) dst.stream @@ Merge_buffer tn;
[%log "copying into merge buffer", Tn.debug_name tn, "from", name_of src];
let use =
match into_merge_buffer with
| Copy -> "copying"
| Streaming -> "streaming"
| No -> assert false
in
[%log use, "into merge buffer", Tn.debug_name tn, "from", name_of src];
true)

let%track3_l_sexp sync_routine r =
let s = r.context.stream in
let pre () =
Hashtbl.iteri s.device.shared_writer_streams ~f:(fun ~key ~data ->
if Set.mem r.inputs key then
List.iter data ~f:(fun (work_stream, e) ->
if not (equal_stream work_stream s) then Backend.will_wait_for r.context e));
if r.merge_buffer_input then
Option.iter s.device.scheduled_shared_merge_node ~f:(fun (shared_tn, e) ->
match (s.scheduled_merge_node, e) with
| Some merge_tn, Some e ->
if Tn.equal shared_tn merge_tn then Backend.will_wait_for r.context e
| _ -> ())
Hashtbl.filter_mapi_inplace s.device.shared_writer_streams ~f:(fun ~key ~data ->
if Tn.potentially_cross_stream key then
if Set.mem r.inputs key then (
let data = List.filter data ~f:(fun (_, e) -> Backend.is_done e) in
List.iter data ~f:(fun (work_stream, e) ->
if not (equal_stream work_stream s) then Backend.will_wait_for r.context e);
Some data)
else Some data
else None)
(* Since merge buffers are always per-stream, no need to check r.merge_buffer_input. *)
in
let post () =
let e = Backend.all_work s in
Expand Down Expand Up @@ -281,7 +282,6 @@ module Add_device
let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
let dev = dst.stream in
let size_in_bytes = Tnode.size_in_bytes tn in
(* FIXME(#290): handle shared_merge_node. *)
let work =
(* TODO: log the operation if [Utils.settings.with_log_level > 1]. *)
match (into_merge_buffer, dst_ptr) with
Expand All @@ -305,7 +305,6 @@ module Add_device
"device_to_device " ^ Tnode.debug_name tn ^ " dst " ^ get_name dev ^ " src "
^ get_name src.stream
in
(match into_merge_buffer with No -> () | _ -> dev.scheduled_merge_node <- Some tn);
schedule_task dev (Task.Task { context_lifetime = (src, dst); description; work })
end

Expand Down
23 changes: 8 additions & 15 deletions arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -101,22 +101,20 @@ let get_used_memory (device : device) =
let free, total = Cudajit.Device.get_free_and_total_mem () in
total - free

let opt_alloc_merge_buffer ~size_in_bytes device : unit =
let opt_alloc_merge_buffer ~size_in_bytes dev stream : unit =
if
Option.value_map ~default:true device.shared_merge_buffer ~f:(fun buffer ->
Option.value_map ~default:true !(stream.merge_buffer) ~f:(fun buffer ->
buffer.size_in_bytes < size_in_bytes)
then (
set_ctx device.dev.primary_context;
Option.iter device.shared_merge_buffer ~f:(fun buffer -> Cu.Deviceptr.mem_free buffer.ptr);
device.shared_merge_buffer <-
Some { ptr = Cu.Deviceptr.mem_alloc ~size_in_bytes; size_in_bytes })
set_ctx dev.primary_context;
Option.iter !(stream.merge_buffer) ~f:(fun buffer -> Cu.Deviceptr.mem_free buffer.ptr);
stream.merge_buffer := Some { ptr = Cu.Deviceptr.mem_alloc ~size_in_bytes; size_in_bytes })

let%track3_sexp cleanup_device (device : device) =
Cu.Context.set_current device.dev.primary_context;
Cu.Context.synchronize ();
Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ());
(* Note: this is not necessary as releasing the primary context by GC will reset the context. *)
Option.iter ~f:(fun buffer -> Cu.Deviceptr.mem_free buffer.ptr) device.shared_merge_buffer;
Hashtbl.iter device.cross_stream_candidates ~f:(fun buffer_ptr ->
Cu.Deviceptr.mem_free buffer_ptr)

Expand All @@ -138,8 +136,6 @@ let%track3_sexp get_device ~(ordinal : int) : device =
if Utils.debug_log_from_routines () && not (Hash_set.mem initialized_devices ordinal) then
Option.iter Utils.settings.cuda_printf_fifo_size ~f:Cu.Context.(set_limit PRINTF_FIFO_SIZE);
Hash_set.add initialized_devices ordinal;
(* let size_in_bytes = 8 in let shared_merge_buffer = { ptr = Cu.Deviceptr.mem_alloc
~size_in_bytes; size_in_bytes } in *)
let result = make_device dev ~ordinal in
Stdlib.Gc.finalise finalize_device result;
Stdlib.Weak.set !devices ordinal (Some result);
Expand Down Expand Up @@ -209,16 +205,13 @@ let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
memcpy ~dst_ptr
| Streaming, _ ->
assert same_device;
dst.stream.scheduled_merge_node <- Some tn;
dst.stream.merge_buffer := Some { ptr = src_ptr; size_in_bytes }
| Copy, _ ->
set_ctx @@ ctx_of dst;
let size_in_bytes = Tn.size_in_bytes tn in
opt_alloc_merge_buffer ~size_in_bytes dev;
(* FIXME(#290): why use the shared buffer? This should depend on the memory mode! *)
Option.iter dev.shared_merge_buffer ~f:(fun buffer -> memcpy ~dst_ptr:buffer.ptr);
dst.stream.scheduled_merge_node <- Some tn;
dst.stream.merge_buffer := dev.shared_merge_buffer
opt_alloc_merge_buffer ~size_in_bytes dev.dev dst.stream;
let buffer = Option.value_exn ~here:[%here] !(dst.stream.merge_buffer) in
memcpy ~dst_ptr:buffer.ptr

type code = {
traced_store : Low_level.traced_store;
Expand Down
12 changes: 10 additions & 2 deletions arrayjit/lib/low_level.ml
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,15 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
~f:(visit ~is_assigned:(traced.zeroed_out || Hash_set.mem traced.assignments at_pos))
| Local_scope { body; _ } -> loop_proc env body
| Get_local _ -> ()
| Get_global (Ops.Merge_buffer { source_node_id }, _) -> merge_node_id := Some source_node_id
| Get_global (Ops.Merge_buffer { source_node_id }, _) ->
Option.iter !merge_node_id ~f:(fun merge_node_id ->
if merge_node_id <> source_node_id then
raise
@@ Utils.User_error
[%string
"Low_evel.optimize_proc: currently only one merge buffer per routine is \
allowed, found node ids %{source_node_id#Int} and %{merge_node_id#Int}"]);
merge_node_id := Some source_node_id
| Get_global _ -> ()
| Embed_index _ -> ()
| Binop (Arg1, llv1, _llv2) -> loop llv1
Expand Down Expand Up @@ -752,7 +760,7 @@ let input_and_output_nodes optimized =
else outputs
in
(inputs, outputs)),
Option.is_some optimized.merge_node )
optimized.merge_node )

let%diagn2_sexp optimize_proc static_indices llc =
let traced_store = Hashtbl.create (module Tnode) in
Expand Down
8 changes: 4 additions & 4 deletions arrayjit/lib/low_level.mli
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,11 @@ val optimize :
t ->
optimized

val input_and_output_nodes : optimized -> (Set.M(Tnode).t * Set.M(Tnode).t) * bool
val input_and_output_nodes : optimized -> (Set.M(Tnode).t * Set.M(Tnode).t) * Tnode.t option
(** Inputs are the materialized read-only and read-before-write (within the code) non-constant
nodes. They are inputs in a broad sense, as they could be recurrent nodes or parameters.
Outputs are all the materialized nodes written-to by the code. *)
non-merge nodes. They are inputs in a broad sense, as they could be recurrent nodes or
parameters. Outputs are all the materialized nodes written-to by the code. The last returned
component is the input merge node, if used in the code. *)

(** {2 Printing} *)

Expand Down
12 changes: 8 additions & 4 deletions lib/train.ml
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,8 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
assert (Backend.to_host sgd_update.context learning_rate.value);
(* scalar_loss is not in the sgd_update context. *)
assert (Backend.to_host grad_updates.(0).context scalar_loss.value);
(* TODO: syncing callbacks should be integrated into Tensor. *)
Backend.await grad_updates.(0).context.stream;
let batch_loss = scalar_loss.@[0] in
epoch_loss := !epoch_loss +. batch_loss;
batch_losses := batch_loss :: !batch_losses;
Expand All @@ -530,14 +532,16 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
Option.iter per_epoch_callback ~f:(fun f ->
f ~at_step:!step_ref ~at_epoch:epoch ~learning_rate:learning_rate.@[0]
~epoch_loss:!epoch_loss);
let debug_at pos =
let _debug_at pos =
Array.iter streams ~f:(fun s ->
Stdlib.Format.printf "Stream %d debug %s:@ %a\n%!" s.stream_id pos Sexp.pp_hum
@@ Backend.get_debug_info s)
in
if per_epoch_debug_streams then debug_at "before sync";
Array.iter streams ~f:Backend.await;
if per_epoch_debug_streams then debug_at "after sync"
if per_epoch_debug_streams then _debug_at "before sync";
(* TODO: there should be nothing pending left to sync. *)
Array.iter streams ~f:Backend.await
(* This is now cleaned up by await. *)
(* if per_epoch_debug_streams then _debug_at "after sync" *)
done;
let%op model_result = model "infer" in
let infer_fwd =
Expand Down

0 comments on commit d16b54f

Please sign in to comment.