Skip to content

Commit

Permalink
Debugging tweaks
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Dec 8, 2024
1 parent 9c29e0c commit e85a09d
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 47 deletions.
26 changes: 13 additions & 13 deletions arrayjit/lib/assignments.ml
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ let is_total ~initialize_neutral ~projections =
(** Returns materialized nodes in the sense of {!Tnode.is_in_context}. NOTE: it should be called
after compilation and ideally after linking with the relevant contexts; otherwise, it is an
under-estimate. *)
let context_nodes ~use_host_memory asgns =
let%debug3_sexp context_nodes ~(use_host_memory : bool) (asgns : t) : Tn.t_set =
let open Utils.Set_O in
let empty = Set.empty (module Tn) in
let one tn =
Expand Down Expand Up @@ -117,12 +117,12 @@ let%diagn1_sexp to_low_level code =
if not (Array.length idcs = Array.length (Lazy.force tn.Tn.dims)) then
[%log
"get",
"a=",
(tn : Tn.t),
":",
Tn.label tn,
(idcs : Indexing.axis_index array),
(Lazy.force tn.dims : int array)];
"a=",
(tn : Tn.t),
":",
Tn.label tn,
(idcs : Indexing.axis_index array),
(Lazy.force tn.dims : int array)];
assert (Array.length idcs = Array.length (Lazy.force tn.Tn.dims));
match buffer with
| Node tn -> Low_level.Get (tn, idcs)
Expand All @@ -133,12 +133,12 @@ let%diagn1_sexp to_low_level code =
if not (Array.length idcs = Array.length (Lazy.force tn.Tn.dims)) then
[%log
"set",
"a=",
(tn : Tn.t),
":",
Tn.label tn,
(idcs : Indexing.axis_index array),
(Lazy.force tn.dims : int array)];
"a=",
(tn : Tn.t),
":",
Tn.label tn,
(idcs : Indexing.axis_index array),
(Lazy.force tn.dims : int array)];
assert (Array.length idcs = Array.length (Lazy.force tn.Tn.dims));
Low_level.Set { tn; idcs; llv; debug = "" }
in
Expand Down
46 changes: 26 additions & 20 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ let check_merge_buffer stream ~code_node =
^ ", expected by code: " ^ name code_node)

module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncing) = struct
let[@landmark] wait_for_all ctx streams tn =
let wait_for_all ctx streams tn =
let s = ctx.stream in
Hashtbl.update_and_return streams tn
~f:
Expand All @@ -31,15 +31,15 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
|> List.iter ~f:(fun (work_stream, e) ->
if not (equal_stream work_stream s) then Backend.will_wait_for ctx e)

let[@landmark] wait_for_ready ~dst ~src tn =
let wait_for_ready ~dst ~src tn =
let s = src.stream in
let d = dst.stream in
(* TODO: maybe it's worthwhile to clean up s.updating_for every now and then. *)
Hashtbl.find s.updating_for tn
|> Option.iter ~f:(fun upd_e ->
if not (equal_stream s d || Backend.is_done upd_e) then Backend.will_wait_for dst upd_e)

let[@landmark] update_writer_event ?e ?from s tn =
let update_writer_event ?e ?from s tn =
let e = Option.value_or_thunk e ~default:(fun () -> Backend.all_work s) in
let f l = (s, e) :: Option.value ~default:[] l in
(match (from, tn) with
Expand All @@ -59,22 +59,24 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
(* Note: the previous event does not need to be done! *)
s.updating_for_merge_buffer <- Some (tn, Some e)

let%track2_l_sexp[@landmark] from_host (ctx : Backend.context) tn =
let%track2_l_sexp from_host (ctx : Backend.context) tn =
match (tn, Map.find ctx.ctx_arrays tn) with
| { Tn.array = (lazy (Some hosted)); _ }, Some dst ->
wait_for_all ctx ctx.stream.reader_streams tn;
[%log "copying", Tn.debug_name tn, "to", (dst : Backend.buffer_ptr), "from host"];
(* Stdio.printf "copying: %s from_host\n" (Tn.debug_name tn); *)
Backend.from_host ~dst_ptr:dst ~dst:ctx hosted;
update_writer_event ~from:`Host ctx.stream @@ Node tn;
true
| _ -> false

let%track2_l_sexp[@landmark] to_host (ctx : Backend.context) (tn : Tn.t) =
let%track2_l_sexp to_host (ctx : Backend.context) (tn : Tn.t) =
match (tn, Map.find ctx.ctx_arrays tn) with
| { Tn.array = (lazy (Some hosted)); _ }, Some src ->
if Tn.potentially_cross_stream tn then
wait_for_all ctx ctx.stream.device.shared_writer_streams tn;
[%log "copying", Tn.debug_name tn, "at", (src : Backend.buffer_ptr), "to host"];
(* Stdio.printf "copying: %s to_host\n" (Tn.debug_name tn); *)
Backend.to_host ~src_ptr:src ~src:ctx hosted;
let s = ctx.stream in
let e = Backend.all_work s in
Expand All @@ -83,8 +85,8 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
true
| _ -> false

let%diagn2_l_sexp[@landmark] device_to_device (tn : Tn.t) ~into_merge_buffer
~(dst : Backend.context) ~(src : Backend.context) =
let%diagn2_l_sexp device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : Backend.context)
~(src : Backend.context) =
let ordinal_of ctx = ctx.stream.device.ordinal in
let name_of ctx = Backend.(get_name ctx.stream) in
let same_device = ordinal_of dst = ordinal_of src in
Expand Down Expand Up @@ -116,15 +118,17 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
Backend.(
device_to_device tn ~into_merge_buffer ~dst_ptr:None ~dst ~src_ptr:s_arr ~src);
dst.stream.updating_for_merge_buffer <- Some (tn, None);
let[@landmark] merge_task () = Task.run task in
let merge_task () = Task.run task in
merge_task ();
update_writer_event ~from:(`Src src.stream) dst.stream @@ Merge_buffer tn;
[%log "streaming into merge buffer", Tn.debug_name tn, "from", name_of src];
true)

let%track2_l_sexp sync_routine r =
type r = Backend.context routine [@@deriving sexp_of]

let%track2_l_sexp sync_routine (r : r) : r =
let s = r.context.stream in
let[@landmark] pre () =
let pre () =
Set.iter r.inputs ~f:(fun tn ->
if Tn.potentially_cross_stream tn then
Option.iter (Hashtbl.find s.device.shared_writer_streams tn) ~f:(fun data ->
Expand All @@ -135,13 +139,13 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
else Hashtbl.remove s.device.shared_writer_streams tn)
(* Since merge buffers are always per-stream, no need to check r.merge_buffer_input. *)
in
let[@landmark] post () =
let post () =
let e = Backend.all_work s in
Set.iter r.outputs ~f:(fun tn -> update_writer_event ~e s @@ Node tn)
in
{ r with schedule = Task.(prepend ~work:pre @@ append ~work:post r.schedule) }

let[@landmark] sync_device device =
let sync_device device =
Utils.weak_iter device.streams ~f:Backend.await;
Hashtbl.clear device.host_writing_streams;
Hashtbl.clear device.host_reading_streams;
Expand Down Expand Up @@ -180,15 +184,16 @@ let lower_batch_assignments ?names ?occupancy bindings asgns_l =
Some (Assignments.lower ~unoptim_ll_source ~ll_source ~cd_source ~name bound asgns) )
else (None, None))

let verify_prior_context ~use_host_memory ~ctx_arrays ~from_prior_context =
let%debug3_sexp verify_prior_context ~use_host_memory ~ctx_arrays ~from_prior_context : unit =
Set.iter from_prior_context ~f:(fun tn ->
if
(* Err on the safe side. *)
Option.value ~default:false (Tn.is_in_context ~use_host_memory tn)
&& not (Option.is_some @@ Map.find ctx_arrays tn)
then raise @@ Utils.User_error ("The linked context lacks node " ^ Tnode.debug_name tn))

let from_prior_context_batch ~use_host_memory comps =
let%debug3_sexp from_prior_context_batch ~use_host_memory (comps : Assignments.comp option array) :
Tn.t_set =
Array.filter_map comps ~f:(fun comp ->
Option.map comp ~f:(fun comp ->
Set.diff
Expand Down Expand Up @@ -279,20 +284,20 @@ module Add_device
in
(Option.value_exn ~here:[%here] bindings, schedules)

let[@landmark] from_host ~dst_ptr ~dst hosted =
let from_host ~dst_ptr ~dst hosted =
let work () = host_to_buffer hosted ~dst:dst_ptr in
(* TODO: pass description to from_host. *)
schedule_task dst.stream
(Task.Task
{ context_lifetime = dst; description = "from_host on " ^ get_name dst.stream; work })

let[@landmark] to_host ~src_ptr ~src hosted =
let to_host ~src_ptr ~src hosted =
let work () = buffer_to_host hosted ~src:src_ptr in
(* TODO: pass description to to_host. *)
schedule_task src.stream
(Task.Task { context_lifetime = src; description = "to_host on " ^ get_name src.stream; work })

let[@landmark] device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
let s = dst.stream in
let size_in_bytes = Tnode.size_in_bytes tn in
let work =
Expand Down Expand Up @@ -343,15 +348,16 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
}
[@@deriving sexp_of]

let compile ?shared ?name bindings comp : code =
let%debug3_sexp compile ?shared ?name bindings (comp : Assignments.comp) : code =
let name, lowered = lower_assignments ?name bindings comp.Assignments.asgns in
let code = compile ?shared ~name bindings lowered in
let from_prior_context =
Set.diff (Assignments.context_nodes ~use_host_memory comp.asgns) comp.embedded_nodes
in
{ from_prior_context; name; lowered; code; expected_merge_node = lowered.Low_level.merge_node }

let compile_batch ?shared ?names ?occupancy bindings comps =
let%debug3_sexp compile_batch ?shared ?names ?occupancy bindings (comps : Assignments.comp array) :
code_batch =
let names, lowereds =
lower_batch_assignments ?names ?occupancy bindings
@@ Array.map comps ~f:(fun c -> c.Assignments.asgns)
Expand Down Expand Up @@ -479,7 +485,7 @@ let reinitialize (module Backend : Backend) config =
Stdlib.Gc.full_major ();
Backend.initialize config)

let[@landmark] finalize (type buffer_ptr dev runner event)
let finalize (type buffer_ptr dev runner event)
(module Backend : Backend
with type buffer_ptr = buffer_ptr
and type dev = dev
Expand Down
2 changes: 1 addition & 1 deletion arrayjit/lib/c_syntax.ml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ struct

(* let compute_array_offset ~idcs ~dims = Array.fold2_exn idcs dims ~init:0 ~f:(fun offset idx dim
-> idx + (offset * dim)) *)
let%diagn_sexp compile_globals ppf =
let%debug3_sexp compile_globals ppf : Tn.t Hash_set.t =
let open Stdlib.Format in
let is_global = Hash_set.create (module Tn) in
fprintf ppf {|@[<v 0>%a@,/* Global declarations. */@,|} (pp_print_list pp_print_string)
Expand Down
12 changes: 8 additions & 4 deletions arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -172,24 +172,28 @@ let suggested_num_streams device =
| For_parallel_copying -> 1 + (cuda_properties device).async_engine_count
| Most_parallel_streams -> (cuda_properties device).multiprocessor_count

let[@landmark] await stream : unit =
let await stream : unit =
set_ctx stream.device.dev.primary_context;
Cu.Stream.synchronize stream.runner;
Option.iter !Utils.advance_captured_logs ~f:(fun callback -> callback ())

let is_idle stream = Cu.Stream.is_ready stream.runner

let[@landmark] from_host ~dst_ptr ~dst hosted =
let from_host ~dst_ptr ~dst hosted =
(* Stdio.printf "run: from_host on backend:0:%d\n" dst.stream.stream_id; *)
set_ctx @@ ctx_of dst;
let f src = Cu.Stream.memcpy_H_to_D ~dst:dst_ptr ~src dst.stream.runner in
Ndarray.map { f } hosted

let[@landmark] to_host ~src_ptr ~src hosted =
let to_host ~src_ptr ~src hosted =
(* Stdio.printf "run: to_host on backend:0:%d\n" src.stream.stream_id; *)
set_ctx @@ ctx_of src;
let f dst = Cu.Stream.memcpy_D_to_H ~dst ~src:src_ptr src.stream.runner in
Ndarray.map { f } hosted

let[@landmark] device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
let device_to_device tn ~into_merge_buffer ~dst_ptr ~dst ~src_ptr ~src =
(* Stdio.printf "run: device_to_device %s dst backend:0:%d src backend:0:%d\n" (Tn.debug_name tn)
dst.stream.stream_id src.stream.stream_id; *)
let dev = dst.stream.device in
let same_device = dev.ordinal = src.stream.device.ordinal in
let size_in_bytes = Tn.size_in_bytes tn in
Expand Down
2 changes: 1 addition & 1 deletion arrayjit/lib/task.ml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ type t =

let describe (Task task) = task.description

let%diagn_l_sexp run (Task task) =
let%debug3_l_sexp run (Task task) : unit =
[%log_result "run", task.description];
task.work ()

Expand Down
25 changes: 17 additions & 8 deletions arrayjit/lib/tnode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,13 @@ let log_debug_info ~from_log_level tn =
from_log_level (debug_name tn);
[%log
"id:",
(tn.id : int),
"label:",
(tn.label : string list),
"mem:",
debug_memory_mode tn.memory_mode,
"backends:",
(tn.backend_info : Sexp.t)];
(tn.id : int),
"label:",
(tn.label : string list),
"mem:",
debug_memory_mode tn.memory_mode,
"backends:",
(tn.backend_info : Sexp.t)];
if Lazy.is_val tn.array then
match tn.array with
| (lazy None) -> [%log "<not-on-host>"]
Expand Down Expand Up @@ -190,7 +190,7 @@ let is_materialized_force tn provenance =

(* Unlike the [known_] functions which can only change from [false] to [true], [is_in_context
~use_host_memory tn] is more precise. Generally, it can only change away from [None]. *)
let is_in_context ~use_host_memory tn =
let%debug3_sexp is_in_context ~(use_host_memory : bool) (tn : t) : bool option =
match tn.memory_mode with
| Some (Hosted (Changed_on_devices Per_stream), _) -> Some true
| Some ((Materialized | Hosted Nonconstant), _) when not use_host_memory -> Some true
Expand Down Expand Up @@ -404,6 +404,15 @@ let hash nd = Int.hash nd.id
let hash_fold_t acc nd = hash_fold_int acc nd.id
let hash_t = hash

module Comp = struct
type nonrec t = t
type nonrec comparator_witness = comparator_witness
end

type t_set = Set.M(Comp).t

let sexp_of_t_set s = [%sexp_of: t Sequence.t] @@ Set.to_sequence s

let get_exn a =
match a.array with
| (lazy (Some nd)) -> nd
Expand Down

0 comments on commit e85a09d

Please sign in to comment.