Skip to content

Commit

Permalink
Get rid of Postponed
Browse files Browse the repository at this point in the history
In the future, device-config-specific compilation will be handled by laziness and caching.
  • Loading branch information
lukstafi committed Dec 8, 2024
1 parent 93b427d commit b5d6104
Show file tree
Hide file tree
Showing 9 changed files with 81 additions and 111 deletions.
8 changes: 4 additions & 4 deletions arrayjit/lib/anatomy_of_a_backend.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,11 @@ The modules and files of `arrayjit` can loosely be divided into three parts.
- `reinitialize` a backend,
- `finalize` a context (freeing all of its arrays that don't come from its parent context).

### Shared (relocatable) compilation, batch compilation
### Batch compilation; in the future: lazy and cached compilation artifacts

Shared (relocatable) compilation, with `~shared:true`, improves compilation efficiency, because code can be compiled once for use on multiple devices (in multiple contexts). It also improves debugging convenience, by generating fewer debugging artifacts. A potential downside is slightly less efficient computations.
Batched compilation produces fewer debugging artifacts. The compilation might also be slightly more efficient since the compiler needs to be invoked fewer times. Batched compilation and linking process _many routines for one device/stream_ at once.

Batched compilation has similar benefits, especially in producing fewer debugging artifacts. The compilation might also be slightly more efficient since the compiler needs to be invoked fewer times. While `~shared:true` compiles _one routine for many devices_, batched compilation and linking process _many routines for one device_ at once.
In the future, when we introduce program search, `compile` functions will return compilation artifact objects. They will manage compilation lazily, caching compilation keyed by (a configuration of) device.

## Tensor nodes, arrays, memory properties

Expand Down Expand Up @@ -112,7 +112,7 @@ Contexts track (or store) the on-device arrays corresponding to tensor nodes. Co

## Typical details of a backend implementation

During the compilation process, the old context cannot be available if the backend supports shared compilation. A backend may for simplicity not suport shared compilation, i.e. ignore `~shared:true` and postpone compilation to the linking phase. Currently, the CUDA backend ignores `~shared:false` and always generates context-and-device-independent kernels, that refer to context (i.e. global) arrays via parameters.
During the compilation process, the old context cannot be available when `compile` is handled. Currently, all backends generate context-and-device-independent kernels, that refer to context arrays via parameters.

We use keys of the `Low_level.traced_store` containers assuming that they are precisely the tensor nodes used in the compiled code -- and the `Virtual` nodes are the ones optimized-away. The context can contain nodes from the parent context corresponding to tensors only needed by parent or ancestor context's computations. The `get_ident` function (e.g. provided by `C_syntax`) returns a human-readable identifier that's un-ambiguous in the context of the compiled code (shared within `compile_batch`).

Expand Down
3 changes: 1 addition & 2 deletions arrayjit/lib/backend_impl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -233,10 +233,9 @@ module type Lowered_backend = sig
type code [@@deriving sexp_of]
type code_batch [@@deriving sexp_of]

val compile : ?shared:bool -> name:string -> Indexing.unit_bindings -> Low_level.optimized -> code
val compile : name:string -> Indexing.unit_bindings -> Low_level.optimized -> code

val compile_batch :
?shared:bool ->
names:string option array ->
Indexing.unit_bindings ->
Low_level.optimized option array ->
Expand Down
17 changes: 8 additions & 9 deletions arrayjit/lib/backend_intf.ml
Original file line number Diff line number Diff line change
Expand Up @@ -220,22 +220,21 @@ module type Backend_common = sig
type code [@@deriving sexp_of]
type code_batch [@@deriving sexp_of]

val compile : ?shared:bool -> ?name:string -> Indexing.unit_bindings -> Assignments.comp -> code
(** If [~shared:true] (default [false]), the backend should prefer to do more compile work in a
device-and-stream-agnostic way. If [~shared:false], the backend can opt to postpone compiling
altogether until [link] is called, to benefit from more optimizations. *)
val compile : ?name:string -> Indexing.unit_bindings -> Assignments.comp -> code
(** [name] is used to derive names for compilation artifacts. If omitted, it's derived via
{!Assignments.get_name_exn}. *)

val compile_batch :
?shared:bool ->
?names:string array ->
?occupancy:(name:string -> src_n:int -> bool) ->
Indexing.unit_bindings ->
Assignments.comp array ->
code_batch
(** Unlike the [~shared] parameter, [compile_batch] vs. [compile] is mostly about improving the
compile time and debugging convenience by generating fewer files -- ideally does not affect
execution, but there can be backend-specific differences. Only array entries for which
[occupancy] returns true are included. *)
(** [compile_batch] vs. [compile] is mostly about improving the compile time and debugging
convenience by generating fewer files -- ideally does not affect execution, but there can be
backend-specific differences. Only array entries for which [occupancy] returns true are
included. [names] are used to derive names for compilation artifacts. If omitted, they're
derived via {!Assignments.get_name_exn}. *)
end

(** Parts shared by both assignments-level and lowered-level backend interfaces providing streams
Expand Down
65 changes: 18 additions & 47 deletions arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -209,51 +209,28 @@ module Add_device
(Backend : Lowered_no_device_backend) : Lowered_backend = struct
include Backend

type code =
| Postponed of {
lowered : Low_level.optimized;
bindings : Indexing.unit_bindings;
name : string;
}
| Compiled of { lowered : Low_level.optimized; proc : Backend.procedure }
[@@deriving sexp_of]
type code = { lowered : Low_level.optimized; proc : Backend.procedure } [@@deriving sexp_of]

type code_batch =
| Postponed of {
lowereds : Low_level.optimized option array;
bindings : Indexing.unit_bindings;
names : string option array;
}
| Compiled of {
lowereds : Low_level.optimized option array;
procs : Backend.procedure option array;
}
type code_batch = {
lowereds : Low_level.optimized option array;
procs : Backend.procedure option array;
}
[@@deriving sexp_of]

let compile ?(shared = false) ~name bindings lowered : code =
if shared then
let proc = compile ~name ~opt_ctx_arrays:None bindings lowered in
Compiled { lowered; proc }
else Postponed { lowered; bindings; name }
let compile ~name bindings lowered : code =
let proc = compile ~name ~opt_ctx_arrays:None bindings lowered in
{ lowered; proc }

let compile_batch ?(shared = false) ~names bindings lowereds : code_batch =
if shared then
let procs = compile_batch ~names ~opt_ctx_arrays:None bindings lowereds in
Compiled { lowereds; procs }
else Postponed { lowereds; bindings; names }
let compile_batch ~names bindings lowereds : code_batch =
let procs = compile_batch ~names ~opt_ctx_arrays:None bindings lowereds in
{ lowereds; procs }

include Add_scheduler (Backend)

let link context (code : code) ctx_arrays =
let runner_label = get_name context.stream in
let merge_buffer = context.stream.merge_buffer in
let bindings, to_schedule =
match code with
| Postponed { lowered; bindings; name } ->
let proc = Backend.compile ~name ~opt_ctx_arrays:(Some ctx_arrays) bindings lowered in
link_compiled ~merge_buffer ~runner_label ctx_arrays proc
| Compiled { proc; _ } -> link_compiled ~merge_buffer ~runner_label ctx_arrays proc
in
let bindings, to_schedule = link_compiled ~merge_buffer ~runner_label ctx_arrays code.proc in
let schedule =
Task.enschedule ~schedule_task ~get_stream_name:get_name context.stream to_schedule
in
Expand All @@ -262,14 +239,8 @@ module Add_device
let link_batch context (code_batch : code_batch) ctx_arrays =
let runner_label = get_name context.stream in
let merge_buffer = context.stream.merge_buffer in
let procs =
match code_batch with
| Postponed { lowereds; bindings; names } ->
Backend.compile_batch ~names ~opt_ctx_arrays:(Some ctx_arrays) bindings lowereds
| Compiled { procs; _ } -> procs
in
let bindings, schedules =
Array.fold_mapi procs ~init:None ~f:(fun i bindings -> function
Array.fold_mapi code_batch.procs ~init:None ~f:(fun i bindings -> function
| Some proc ->
let ctx_arrays = Option.value_exn ctx_arrays.(i) in
let bindings', to_schedule =
Expand Down Expand Up @@ -348,21 +319,21 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
}
[@@deriving sexp_of]

let%debug3_sexp compile ?shared ?name bindings (comp : Assignments.comp) : code =
let%debug3_sexp compile ?name bindings (comp : Assignments.comp) : code =
let name, lowered = lower_assignments ?name bindings comp.Assignments.asgns in
let code = compile ?shared ~name bindings lowered in
let code = compile ~name bindings lowered in
let from_prior_context =
Set.diff (Assignments.context_nodes ~use_host_memory comp.asgns) comp.embedded_nodes
in
{ from_prior_context; name; lowered; code; expected_merge_node = lowered.Low_level.merge_node }

let%debug3_sexp compile_batch ?shared ?names ?occupancy bindings (comps : Assignments.comp array)
: code_batch =
let%debug3_sexp compile_batch ?names ?occupancy bindings (comps : Assignments.comp array) :
code_batch =
let names, lowereds =
lower_batch_assignments ?names ?occupancy bindings
@@ Array.map comps ~f:(fun c -> c.Assignments.asgns)
in
let code_batch = compile_batch ?shared ~names bindings lowereds in
let code_batch = compile_batch ~names bindings lowereds in
let from_prior_context =
from_prior_context_batch ~use_host_memory
@@ Array.mapi lowereds ~f:(fun i -> Option.map ~f:(fun _ -> comps.(i)))
Expand Down
4 changes: 2 additions & 2 deletions arrayjit/lib/cuda_backend.cudajit.ml
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ struct
| _ -> ("(" ^ typ_of_prec to_ ^ ")(", ")")
end

let compile ?shared:_ ~name bindings ({ Low_level.traced_store; _ } as lowered) =
let compile ~name bindings ({ Low_level.traced_store; _ } as lowered) =
(* TODO: The following link seems to claim it's better to expand into loops than use memset.
https://stackoverflow.com/questions/23712558/how-do-i-best-initialize-a-local-memory-array-to-0 *)
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
Expand All @@ -353,7 +353,7 @@ let compile ?shared:_ ~name bindings ({ Low_level.traced_store; _ } as lowered)
let ptx = cuda_to_ptx ~name @@ Buffer.contents b in
{ traced_store; ptx; params; bindings; name }

let compile_batch ?shared:_ ~names bindings lowereds =
let compile_batch ~names bindings lowereds =
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
let procs = Array.filter_map lowereds ~f:(Option.map ~f:(fun lowereds -> (lowereds, None)))
end)) in
Expand Down
4 changes: 2 additions & 2 deletions arrayjit/lib/lowered_backend_missing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,10 @@ type code_batch

let sexp_of_code_batch _code_batch = failwith "Backend missing -- install the corresponding library"

let compile ?shared:_ ~name:_ _unit_bindings _optimized =
let compile ~name:_ _unit_bindings _optimized =
failwith "Backend missing -- install the corresponding library"

let compile_batch ?shared:_ ~names:_ _unit_bindings _optimizeds =
let compile_batch ~names:_ _unit_bindings _optimizeds =
failwith "Backend missing -- install the corresponding library"

let link _context _code = failwith "Backend missing -- install the corresponding library"
Expand Down
63 changes: 42 additions & 21 deletions bin/moons_benchmark.ml
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,25 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
Tensor.default_grad_prec := grad_prec;
Utils.settings.output_debug_files_in_build_directory <- true;
(* This will only log from routines if log-level is high enough. *)
Utils.settings.debug_log_from_routines <- true;
(* Utils.settings.debug_log_from_routines <- true; *)
Rand.init (* seed *) 0;
let hid_dim_1 = 16 in
let hid_dim_2 = 8 in
let hid_dim_3 = 4 in
(* TINY for debugging: *)
(* let hid_dim = 2 in *)
(* let hid_dim = 4 in *)
let data_len = 3 * 5 * 1024 in
(* TINY for debugging: *)
(* let data_len = 3 * 4 in *)
(* let data_len = 3 * 16 in *)
let flat_len = data_len / 2 in
(* Note: [minibatch_size = batch_size / num_streams] is the actual per-device batch used. *)
let epochs = 200 in
(* let epochs = 400 in *)
(* let epochs = 100 in *)
(* let epochs = 50 in *)
(* TINY for debugging: *)
let epochs = 3 in
(* let epochs = 2 in *)
(* let epochs = 1 in *)
(* let init_lr = 0.1 in *)
Expand Down Expand Up @@ -84,12 +87,15 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
Stdlib.Format.printf "Initial backend global debug info: %a\n%!" Sexp.pp_hum
@@ Backend.get_global_debug_info ();
let per_batch_callback ~at_batch:_ ~at_step:_ ~learning_rate:_ ~batch_loss:_ ~epoch_loss:_ =
(* Stdio.printf "Batch=%d, step=%d, lr=%f, batch loss=%f, epoch loss=%f\n%!" at_batch at_step
learning_rate batch_loss epoch_loss; *)
if Option.is_none !start_time then start_time := Some (Time_now.nanoseconds_since_unix_epoch ())
in
(* Tn.print_accessible_headers (); *)
let per_epoch_callback ~at_step ~at_epoch ~learning_rate ~epoch_loss =
Stdio.printf "Epoch=%d, step=%d, lr=%f, epoch loss=%f\n%!" at_epoch at_step learning_rate
epoch_loss
if at_epoch % 10 = 9 then
Stdio.printf "Epoch=%d, step=%d, lr=%f, epoch loss=%f\n%!" at_epoch at_step learning_rate
epoch_loss
in

Backend.initialize Train.BT.Most_parallel_streams;
Expand All @@ -115,22 +121,25 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
Stdio.print_endline "\n******** mlp_result **********";
Tensor.print_tree ~with_id:true ~with_grad:false ~depth:9 model_result;
Stdio.printf "\n********\n%!";
Arrayjit.Tnode.print_accessible_headers ();
let callback (x, y) = Float.((infer_callback [| x; y |]).(0) >= 0.) in
let plot_moons =
let open PrintBox_utils in
plot
~size:(120, 40)
(* TINY for debugging: *)
(* ~size:(20, 10) *)
~x_label:"ixes" ~y_label:"ygreks"
[
Scatterplot { points = points1; pixel = "#" };
Scatterplot { points = points2; pixel = "%" };
Boundary_map { pixel_false = "."; pixel_true = "*"; callback };
]
let%track3_sexp plot_moons () =
[%log_level
0;
let open PrintBox_utils in
plot
~size:(120, 40)
(* TINY for debugging: *)
(* ~size:(20, 10) *)
~x_label:"ixes" ~y_label:"ygreks"
[
Scatterplot { points = points1; pixel = "#" };
Scatterplot { points = points2; pixel = "%" };
Boundary_map { pixel_false = "."; pixel_true = "*"; callback };
]]
in
Stdio.printf "\nHalf-moons scatterplot and decision boundary:\n%!";
PrintBox_text.output Stdio.stdout plot_moons;
PrintBox_text.output Stdio.stdout @@ plot_moons ();
Stdio.printf "\nBatch Log-loss:\n%!";
let plot_loss =
let open PrintBox_utils in
Expand Down Expand Up @@ -181,6 +190,7 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
}
in
Stdio.printf "\n\n%!";
Arrayjit.Tnode.print_accessible_headers ();
Stdlib.Format.printf "Final backend global debug info: %a\n%!" Sexp.pp_hum
@@ Backend.get_global_debug_info ();
result
Expand Down Expand Up @@ -211,13 +221,24 @@ let _cuda_benchmarks =
]))))))

let _cuda_parallel_benchmarks =
List.concat_map [ (* 1; 2; *) (* 3; 4; 5; 6; 8; 10; 12; 16; *) 20 (* 32; 64 *) ] ~f:(fun num_streams ->
List.concat_map
[
(* 1; 2; *)
3;
(* 4; 5; 6; 8; 10; 12; 16; 20 *)
(* 32; 64 *)
] ~f:(fun num_streams ->
List.concat_map
[ 3 * 5 * 16 (* ; 3 * 5 * 32 *) ]
[
(* TINY for debugging: *)
(* 3 * 4 *)
3 * 5 * 16 (* ; 3 * 5 * 32 *);
]
~f:(fun batch_size ->
List.concat_map [ (* 1; *) (* 2; *) 3 ] ~f:(fun inlining_cutoff ->
List.concat_map [ (* 1; 2; *) 3 ] ~f:(fun inlining_cutoff ->
List.concat_map [ (* 1; 3; *) 7 (* *) ] ~f:(fun seed ->
List.concat_map [ (* "gccjit" ; "cc"; *) "cuda" ] ~f:(fun backend_name ->
List.concat_map [ (* "gccjit"; "cuda" ;"cc"; *) "sync_cc" ]
~f:(fun backend_name ->
List.concat_map [ (* CDSL.double; *) CDSL.single (* ; CDSL.half *) ]
~f:(fun value_prec ->
[
Expand Down
20 changes: 0 additions & 20 deletions lib/attic.mld
Original file line number Diff line number Diff line change
Expand Up @@ -476,26 +476,6 @@ Old post-launch code in Cuda_backend.link_proc:
data.tracking <- Some (Cu.Delimited_event.record context.stream.cu_stream));
]}

Obsoleted part of interfaces in backend_impl.ml:
{[
let expected_merge_node : code -> _ = function
| Postponed { lowered = Low_level.{ merge_node; _ }; _ }
| Compiled { lowered = Low_level.{ merge_node; _ }; _ } ->
merge_node

let expected_merge_nodes : code_batch -> _ = function
| Postponed { lowereds; _ } | Compiled { lowereds; _ } ->
Array.map lowereds ~f:(fun lowered ->
Option.(join @@ map lowered ~f:(fun optim -> optim.merge_node)))

let get_lowered : code -> _ = function
| Postponed { lowered; _ } | Compiled { lowered; _ } -> lowered

let get_lowereds : code_batch -> _ = function
| Postponed { lowereds; _ } -> lowereds
| Compiled { lowereds; _ } -> lowereds
]}

Old context finalizer from the cuda backend:
{[
let%track3_sexp finalize (ctx : context) : unit =
Expand Down
8 changes: 4 additions & 4 deletions lib/train.ml
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,7 @@ let%track3_sexp parallel_update (type buffer_ptr dev runner event)
Array.mapi ctxs ~f:(fun dst_n ctx ->
if occupancy_dst ~dst_n then
snd
@@ Backend.(link_batch ctx @@ compile_batch ~shared:true ~occupancy Idx.Empty grad_merges)
@@ Backend.(link_batch ctx @@ compile_batch ~occupancy Idx.Empty grad_merges)
else [||])
in
(* We can cache scheduling, because merging and copying does not depend on static indexing. *)
Expand Down Expand Up @@ -441,8 +441,8 @@ let to_routine (type buffer_ptr dev runner event)
with type buffer_ptr = buffer_ptr
and type dev = dev
and type runner = runner
and type event = event) (context : Backend.context) ?shared ?name bindings comp =
Backend.link context @@ Backend.compile ?shared ?name bindings comp
and type event = event) (context : Backend.context) ?name bindings comp =
Backend.link context @@ Backend.compile ?name bindings comp

type example_train_result = {
inputs : Tensor.t;
Expand Down Expand Up @@ -500,7 +500,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
Utils.settings.check_half_prec_constants_cutoff, no need to upcast learning_rate.value. *)
set_hosted learning_rate.value;
let sgd = sgd_update ~learning_rate ~weight_decay update in
let grad_update = Backend.compile ~shared:true bindings update.fwd_bprop in
let grad_update = Backend.compile bindings update.fwd_bprop in
let grad_updates = Array.map contexts ~f:(fun ctx -> Backend.link ctx grad_update) in
let sgd_update = to_routine (module Backend) grad_updates.(0).context bindings sgd in
Tensor.log_debug_info ~from_log_level:2 inputs;
Expand Down

0 comments on commit b5d6104

Please sign in to comment.