Skip to content

Commit

Permalink
Automatically synchronize potential overwriting of an array that is s…
Browse files Browse the repository at this point in the history
…treamed into a merge buffer
  • Loading branch information
lukstafi committed Nov 30, 2024
1 parent d16b54f commit c2c4bd4
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 47 deletions.
8 changes: 5 additions & 3 deletions arrayjit/lib/anatomy_of_a_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,10 @@ and some stream fields also:
```ocaml
updating_for : 'event Hashtbl.M(Tnode).t;
(* The completion event for updating (writing to) a node via this stream, if any. *)
mutable updating_for_merge_buffer : (Tnode.t * 'event) option;
(** Like {!field-updating_for}, but for the merge buffer. *)
mutable updating_for_merge_buffer : (Tnode.t * 'event option) option;
(** The tensor node that was most recently scheduled to be in the [stream]'s merge buffer. The
event finishes after the [task] from a [Streaming_for task]. See also
{!field-updating_for}. *)
reader_streams : (('buffer_ptr, 'dev, 'runner, 'event) stream * 'event) list Hashtbl.M(Tnode).t;
(** The streams, other than this stream, that most recently have been reading from a node in
this stream's context, and the associated use completion events. The completed events are
Expand All @@ -182,6 +184,6 @@ OCANNL supports asynchronous data transfers by embedding them in the scheduling

OCANNL provides explicit _merge buffers_ for performing those tensor node updates, where different versions of a tensor node from two streams feature in the same computation. The `%cd` syntax for using merge buffers is via the `.merge` pseudo-field. For example, the code for merging gradients might be: `[%cd p.grad =+ p.grad.merge]`. In the current design, there's at most one merge buffer per stream, and the memory is reused for merging different nodes. We keep track of the specific tensor node that was scheduled to occupy this buffer in the stream, and the merge node expected by the linked code, so that we can detect mismatches at scheduling time.

The interface exposes two modes of utilizing merge buffers. The `Streaming` mode relies in some way on the array from the source context. Currently, this simply means using the source array (buffer) pointer, and the CUDA backend falls back to using `~into_merge_buffer:Copy` when the source and destination contexts live on different devices. The `Copy` mode uses physical arrays to back merge buffers. The merge buffer array (one per stream) is resized (grown) if needed to fit a node's array.
The interface exposes two modes of utilizing merge buffers. The `Streaming_for` mode relies in some way on the array from the source context. Currently, this simply means using the source array (buffer) pointer, and the CUDA backend falls back to using `~into_merge_buffer:Copy` when the source and destination contexts live on different devices. The `Copy` mode uses physical arrays to back merge buffers. The merge buffer array (one per stream) is resized (grown) if needed to fit a node's array. To block the source stream from overwriting the array, `Streaming_for` is parameterized by the task (actually, routine) intended to make use of the merge buffer.

Currently, OCANNL does not support merge buffers for `from_host` transfers. But it might in the future. Currently, combining `to_host` and `from_host` is the only way to make different backends cooperate, and that requires `from_host ~into_merge_buffer` to adapt single-backend design patterns.
31 changes: 16 additions & 15 deletions arrayjit/lib/backend_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ end
type config = Only_devices_parallel | For_parallel_copying | Most_parallel_streams
[@@deriving equal, sexp, variants]

type merge_buffer_use = No | Streaming | Copy [@@deriving equal, sexp]
type merge_buffer_use = No | Streaming_for of Task.t | Copy [@@deriving sexp_of]

type param_source =
| Log_file_name
Expand Down Expand Up @@ -101,7 +101,7 @@ and ('buffer_ptr, 'dev, 'runner, 'event) stream_ref = {
stream_id : int;
mutable allocated_buffer : 'buffer_ptr buffer option;
updating_for : 'event Hashtbl.M(Tnode).t;
mutable updating_for_merge_buffer : (Tnode.t * 'event) option;
mutable updating_for_merge_buffer : (Tnode.t * 'event option) option;
reader_streams :
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
}
Expand Down Expand Up @@ -151,9 +151,10 @@ type ('buffer_ptr, 'dev, 'runner, 'event) stream =
mutable allocated_buffer : 'buffer_ptr buffer option;
updating_for : 'event Hashtbl.M(Tnode).t;
(* The completion event for the most recent updating (writing to) a node via this stream. *)
mutable updating_for_merge_buffer : (Tnode.t * 'event) option;
(** The tensor node that was most recently scheduled to be in the [stream]'s merge buffer and
its updating completion event. See also {!field-updating_for}. *)
mutable updating_for_merge_buffer : (Tnode.t * 'event option) option;
(** The tensor node that was most recently scheduled to be in the [stream]'s merge buffer. The
event finishes after the [task] from a [Streaming_for task]. See also
{!field-updating_for}. *)
reader_streams :
(('buffer_ptr, 'dev, 'runner, 'event) stream_ref * 'event) list Hashtbl.M(Tnode).t;
(** The streams, other than this stream, that most recently have been reading from a node in
Expand Down Expand Up @@ -247,16 +248,19 @@ module type Backend_device_common = sig
include Backend_any_common with type buffer_ptr := buffer_ptr

val sync : event -> unit
(** Blocks till the event completes, if it's not done already. *)
(** Blocks till the event completes, if it's not done already.
FIXME: it should rarely be needed to call [sync] explicitly, because it should always be
called internally when necessary, in particular before extracting values from host. *)

val is_done : event -> bool
(** Whether the event completed. *)

val will_wait_for : context -> event -> unit
(** Schedules waiting for the given event on the context's stream.
NOTE: it should rarely be needed to call [will_wait_for] explicitly, because it is typically
called internally when necessary. *)
NOTE: it should rarely be needed to call [will_wait_for] explicitly, because it should always
be called internally when necessary. *)

val get_used_memory : device -> int
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)
Expand Down Expand Up @@ -309,14 +313,11 @@ module type With_buffer_retrieval_and_syncing = sig
updates the writer event for the node.
- If [into_merge_buffer] is different from [No]: sets on [dst] the merge buffer source to the
given node.
- If [into_merge_buffer=Streaming], remembers the buffer pointer of the source node to use for
streaming.
- If [into_merge_buffer=Streaming_for task], remembers the buffer pointer of the source node
to use for streaming, runs [task] -- intended to be the routine making use of the merge
buffer, and initializes the merge buffer's streaming event.
- If [into_merge_buffer=Copy], schedules copying from [src] to the merge buffer of [dst]'s
stream, and updates the writer event for the merge buffer.
NOTE: If [into_merge_buffer=Streaming], after scheduling the work on [dst] using the merge
buffer but before scheduling work on [src] that modifies [tn], execute
[will_wait_for src (all_work (get_ctx_stream dst))]. *)
stream, and updates the writer event for the merge buffer. *)
end

module type Backend = sig
Expand Down
38 changes: 20 additions & 18 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
Hashtbl.update s.updating_for tn ~f:(fun _ -> e)
| Merge_buffer tn ->
(* Note: the previous event does not need to be done! *)
s.updating_for_merge_buffer <- Some (tn, e)
s.updating_for_merge_buffer <- Some (tn, Some e)

let%diagn2_l_sexp from_host (ctx : Backend.context) tn =
match (tn, Map.find ctx.ctx_arrays tn) with
Expand Down Expand Up @@ -105,17 +105,19 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
update_writer_event ~from:(`Src src.stream) dst.stream @@ Node tn;
[%log "copying", Tn.debug_name tn, "from", name_of src, "to", name_of dst];
true)
| Copy | Streaming ->
| Copy ->
Backend.(
device_to_device tn ~into_merge_buffer ~dst_ptr:None ~dst ~src_ptr:s_arr ~src);
update_writer_event ~from:(`Src src.stream) dst.stream @@ Merge_buffer tn;
let use =
match into_merge_buffer with
| Copy -> "copying"
| Streaming -> "streaming"
| No -> assert false
in
[%log use, "into merge buffer", Tn.debug_name tn, "from", name_of src];
[%log "copy into merge buffer", Tn.debug_name tn, "from", name_of src];
true
| Streaming_for task ->
Backend.(
device_to_device tn ~into_merge_buffer ~dst_ptr:None ~dst ~src_ptr:s_arr ~src);
dst.stream.updating_for_merge_buffer <- Some (tn, None);
Task.run task;
update_writer_event ~from:(`Src src.stream) dst.stream @@ Merge_buffer tn;
[%log "streaming into merge buffer", Tn.debug_name tn, "from", name_of src];
true)

let%track3_l_sexp sync_routine r =
Expand Down Expand Up @@ -280,32 +282,32 @@ module Add_device
(Task.Task { context_lifetime = src; description = "to_host on " ^ get_name src.stream; work })

let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
let dev = dst.stream in
let s = dst.stream in
let size_in_bytes = Tnode.size_in_bytes tn in
let work =
(* TODO: log the operation if [Utils.settings.with_log_level > 1]. *)
match (into_merge_buffer, dst_ptr) with
| No, None -> invalid_arg "Multicore_scheduler.device_to_device: missing dst_ptr"
| No, Some dst_ptr -> fun () -> buffer_to_buffer ~dst:dst_ptr ~src:src_ptr ~size_in_bytes
| Streaming, _ -> fun () -> dev.merge_buffer := Some { ptr = src_ptr; size_in_bytes }
| Streaming_for _, _ -> fun () -> s.merge_buffer := Some { ptr = src_ptr; size_in_bytes }
| Copy, _ ->
fun () ->
let size_in_bytes = Tnode.size_in_bytes tn in
let allocated_capacity =
match dev.allocated_buffer with None -> 0 | Some buf -> buf.size_in_bytes
match s.allocated_buffer with None -> 0 | Some buf -> buf.size_in_bytes
in
if allocated_capacity < size_in_bytes then
dev.allocated_buffer <-
Some (alloc_buffer ?old_buffer:dev.allocated_buffer ~size_in_bytes dst.stream);
let merge_ptr = (Option.value_exn dev.allocated_buffer).ptr in
dev.merge_buffer := dev.allocated_buffer;
s.allocated_buffer <-
Some (alloc_buffer ?old_buffer:s.allocated_buffer ~size_in_bytes dst.stream);
let merge_ptr = (Option.value_exn s.allocated_buffer).ptr in
s.merge_buffer := s.allocated_buffer;
buffer_to_buffer ~dst:merge_ptr ~src:src_ptr ~size_in_bytes
in
let description =
"device_to_device " ^ Tnode.debug_name tn ^ " dst " ^ get_name dev ^ " src "
"device_to_device " ^ Tnode.debug_name tn ^ " dst " ^ get_name s ^ " src "
^ get_name src.stream
in
schedule_task dev (Task.Task { context_lifetime = (src, dst); description; work })
schedule_task s (Task.Task { context_lifetime = (src, dst); description; work })
end

module Raise_backend (Device : Lowered_backend) : Backend = struct
Expand Down
5 changes: 4 additions & 1 deletion arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ let () =
message;
if not @@ Cu.is_success status then [%log (status : Cu.result)]]])

let _suspended () =
Cu.cuda_call_hook := Some (fun ~message ~status:_ -> Stdlib.Printf.printf "CUDA %s\n" message)

module Backend_buffer = struct
type buffer_ptr = Cu.Deviceptr.t

Expand Down Expand Up @@ -203,7 +206,7 @@ let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
| No, Some dst_ptr ->
set_ctx @@ ctx_of dst;
memcpy ~dst_ptr
| Streaming, _ ->
| Streaming_for _, _ ->
assert same_device;
dst.stream.merge_buffer := Some { ptr = src_ptr; size_in_bytes }
| Copy, _ ->
Expand Down
31 changes: 24 additions & 7 deletions bin/moons_benchmark.ml
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,9 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
(* let data_len = 3 * 4 in *)
let flat_len = data_len / 2 in
(* Note: [minibatch_size = batch_size / num_streams] is the actual per-device batch used. *)
(* let epochs = 200 in *)
let epochs = 100 in
let epochs = 200 in
(* let epochs = 100 in *)
(* let epochs = 50 in *)
(* TINY for debugging: *)
(* let epochs = 2 in *)
(* let epochs = 1 in *)
Expand Down Expand Up @@ -88,9 +89,9 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
(* Tn.print_accessible_headers (); *)
let per_epoch_callback ~at_step ~at_epoch ~learning_rate ~epoch_loss =
Stdio.printf "Epoch=%d, step=%d, lr=%f, epoch loss=%f\n%!" at_epoch at_step learning_rate
epoch_loss;

epoch_loss
in

Backend.initialize Train.BT.Most_parallel_streams;
let {
Train.inputs;
Expand All @@ -104,7 +105,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
} =
Train.example_train_loop ~seed ~batch_size ~init_lr ~max_num_streams:num_streams ~data_len
~epochs ~inputs:moons_flat ~outputs:moons_classes ~model:mlp ~loss_fn ~weight_decay
~per_batch_callback ~per_epoch_callback ~per_epoch_debug_streams:true
~per_batch_callback ~per_epoch_callback ~per_epoch_debug_streams:false
(module Backend)
()
in
Expand Down Expand Up @@ -195,7 +196,8 @@ let _cuda_benchmarks =
[
(* TINY for debugging: *)
(* 3 * 2 *)
3 * 5 * 16 (* ; 3 * 5 * 32; 3 * 5 * 64 *);
3 * 5 * 16;
3 * 5 * 32 (*; 3 * 5 * 64 *);
]
~f:(fun batch_size ->
List.concat_map [ 0; (* 1; 2; *) 3 ] ~f:(fun inlining_cutoff ->
Expand All @@ -208,6 +210,21 @@ let _cuda_benchmarks =
~batch_size ~backend_name ~value_prec ~grad_prec:value_prec;
]))))))

let _cuda_parallel_benchmarks =
List.concat_map [ (* 1; 2; *) (* 3; 4; 5; 6; 8; 10; 12; 16; *) 20 (* 32; 64 *) ] ~f:(fun num_streams ->
List.concat_map
[ 3 * 5 * 16 (* ; 3 * 5 * 32 *) ]
~f:(fun batch_size ->
List.concat_map [ (* 1; *) (* 2; *) 3 ] ~f:(fun inlining_cutoff ->
List.concat_map [ (* 1; 3; *) 7 (* *) ] ~f:(fun seed ->
List.concat_map [ (* "gccjit" ; "cc"; *) "cuda" ] ~f:(fun backend_name ->
List.concat_map [ (* CDSL.double; *) CDSL.single (* ; CDSL.half *) ]
~f:(fun value_prec ->
[
classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_streams
~batch_size ~backend_name ~value_prec ~grad_prec:value_prec;
]))))))

let _mem_benchmarks =
List.concat_map [ 1; 3; 6; 12; 16 (* ; 20; 32; 64 *) ] ~f:(fun num_streams ->
List.concat_map
Expand Down Expand Up @@ -253,4 +270,4 @@ let benchmark benchmarks =
List.map benchmarks ~f:(fun bench -> bench ())
|> PrintBox_utils.table |> PrintBox_text.output Stdio.stdout

let () = benchmark _cuda_benchmarks
let () = benchmark _cuda_parallel_benchmarks
8 changes: 5 additions & 3 deletions lib/train.ml
Original file line number Diff line number Diff line change
Expand Up @@ -375,22 +375,24 @@ let%track3_sexp parallel_update (type buffer_ptr dev runner event)
~~("merging" updaten.loss;
updaten.loss.value =+ updaten.loss.value.merge)])
in
let into_merge_buffer = if copy_to_merge then BT.Copy else BT.Streaming in
let mbuf_use sched = if copy_to_merge then (BT.Copy, false) else (BT.Streaming_for sched, true) in
(* Since each device has its own queue, we can iterate over devices in the outer loop. *)
let merge_grads ~(from : int) ~(to_ : int) : unit =
Array.iteri all_params ~f:(fun i p ->
let grad_merge =
Option.value_exn ~here:[%here] ~message:(Tn.debug_name p.value) grad_merges_to.(to_).(i)
in
let into_merge_buffer, streaming = mbuf_use grad_merge.schedule in
assert (
Backend.device_to_device (Option.value_exn ~here:[%here] p.diff).grad ~into_merge_buffer
~dst:ctxs.(to_) ~src:ctxs.(from));
(Task.run grad_merge.schedule : unit))
if not streaming then Task.run grad_merge.schedule)
in
let merge_loss ~src =
let into_merge_buffer, streaming = mbuf_use loss_merge.schedule in
assert (
Backend.device_to_device updaten.loss.value ~into_merge_buffer ~dst:sgd_update.context ~src);
Task.run loss_merge.schedule
if not streaming then Task.run loss_merge.schedule
in
(* FIXME: missing backcopy. *)
let needed_on_host = ref @@ Set.empty (module Tn) in
Expand Down

0 comments on commit c2c4bd4

Please sign in to comment.