Skip to content

Commit

Permalink
Automated to_host transfers
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Dec 30, 2024
1 parent f48985d commit 1a33588
Show file tree
Hide file tree
Showing 17 changed files with 176 additions and 261 deletions.
2 changes: 1 addition & 1 deletion arrayjit/lib/backends.ml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
Tn.prepare_read
~is_done:(fun () -> Backend.is_done e)
~sync:(fun () -> Backend.sync e)
~transfer:(fun () -> assert (to_host ctx tn))
~transfer:(fun () -> assert (to_host ctx tn); Backend.await s)
tn);
(* To be on the safe side, record events for potentially cross-stream nodes. *)
match tn with
Expand Down
24 changes: 18 additions & 6 deletions arrayjit/lib/tnode.ml
Original file line number Diff line number Diff line change
Expand Up @@ -582,33 +582,45 @@ let find =

(** {2 Accessors} *)

let do_read tn =
Option.iter
~f:(fun p ->
p.sync ();
p.transfer ())
tn.prepare_read;
tn.prepare_read <- None

let do_write tn =
Option.iter ~f:(fun p -> p.sync ()) tn.prepare_write;
tn.prepare_write <- None

let points_1d ?from_axis ~xdim tn =
Option.iter ~f:(fun p -> p.sync ()) tn.prepare_read;
do_read tn;
Option.value_map ~default:[||] ~f:(fun arr -> Nd.retrieve_1d_points ?from_axis ~xdim arr)
@@ Lazy.force tn.array

let points_2d ?from_axis ~xdim ~ydim tn =
Option.iter ~f:(fun p -> p.sync ()) tn.prepare_read;
do_read tn;
Option.value_map ~default:[||] ~f:(fun arr -> Nd.retrieve_2d_points ?from_axis ~xdim ~ydim arr)
@@ Lazy.force tn.array

let set_value tn =
Option.iter ~f:(fun p -> p.sync ()) tn.prepare_write;
do_write tn;
Nd.set_from_float @@ Option.value_exn ~here:[%here] @@ Lazy.force tn.array

let get_value tn =
Option.iter ~f:(fun p -> p.sync ()) tn.prepare_read;
do_read tn;
Nd.get_as_float @@ Option.value_exn ~here:[%here] @@ Lazy.force tn.array

let set_values tn values =
Option.iter ~f:(fun p -> p.sync ()) tn.prepare_write;
do_write tn;
Nd.(
reset (Constant_fill { values; strict = false })
@@ Option.value_exn ~here:[%here]
@@ Lazy.force tn.array)

let get_values tn =
Option.iter ~f:(fun p -> p.sync ()) tn.prepare_read;
do_read tn;
Nd.(retrieve_flat_values @@ Option.value_exn ~here:[%here] @@ Lazy.force tn.array)

let print_accessible_headers () =
Expand Down
3 changes: 0 additions & 3 deletions bin/compilation_speed.ml
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,7 @@ let benchmark_overhead backend () =
Train.to_routine (module Backend) f_routine.context ~name:"assign_x" IDX.empty update_x
in
Train.run assign_x;
(* await device; *)
Train.run f_routine;
assert (Backend.to_host f_routine.context f.value);
Backend.await stream;
f.@[0])
in
let plot_box =
Expand Down
12 changes: 6 additions & 6 deletions bin/einsum_trivia.ml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ let _suspended () =
let%op ho2 = hey2 ++ "ab|cd->ef => cf|ae->db" in
Utils.capture_stdout_logs @@ fun () ->
Train.forward_and_forget backend ctx ho2;
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ ho2
Tensor.print ~with_code:false ~with_grad:false `Default @@ ho2

let () =
Utils.set_log_level 2;
Expand All @@ -67,16 +67,16 @@ let () =
let a = TDSL.range_of_shape ~batch_dims:[ 2 ] ~input_dims:[ 3 ] ~output_dims:[ 4 ] () in
let b = TDSL.range_of_shape ~batch_dims:[ 2 ] ~input_dims:[ 4 ] ~output_dims:[ 5 ] () in
let%op a2 = a *+ "b|i->o; b|i->o => b|i->o" a in
Tensor.print ~force:false ~with_code:false ~with_grad:false `Default @@ a;
Tensor.print ~spy:true ~with_code:false ~with_grad:false `Default @@ a;
let ctx = Utils.capture_stdout_logs (fun () -> Train.forward_and_ctx backend ctx a2) in
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ a;
Tensor.print ~with_code:false ~with_grad:false `Default @@ a;
let%op c = b *+ "b|h->o; b|i->h => b|i->o" a in
Utils.capture_stdout_logs (fun () -> Train.forward_and_forget backend ctx c);
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ a;
Tensor.print ~with_code:false ~with_grad:false `Default @@ a;
(* let%op d = a *+ "a|i->h; b|h->o => ab|i->o" b in Utils.capture_stdout_logs (fun () ->
Train.forward_and_forget backend ctx d); let%op e = a *+ "b|i->h; b|h->o => i->o" b in
Utils.capture_stdout_logs (fun () -> Train.forward_and_forget backend ctx e); let%op f = a *+
"a|i->h; b|h->o => i->o" b in Utils.capture_stdout_logs (fun () -> Train.forward_and_forget
backend ctx f); *)
(* Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ a2; *)
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ c
(* Tensor.print ~with_code:false ~with_grad:false `Default @@ a2; *)
Tensor.print ~with_code:false ~with_grad:false `Default @@ c
12 changes: 5 additions & 7 deletions bin/hello_world.ml
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ let hello3 () =
Tensor.print_tree ~with_grad:false ~depth:9 zero_to_twenty;
Stdlib.Format.print_newline ();
Train.run routine;
assert (Backend.to_host routine.context y.value);
Backend.await stream;
Tensor.print ~with_code:true ~with_grad:false `Default y;
Stdlib.Format.force_newline ();
Tensor.print_tree ~with_grad:false ~depth:9 y;
Expand Down Expand Up @@ -95,11 +93,11 @@ let hello4 () =
Train.set_hosted tk.value;
Train.forward_and_forget backend ctx positions;
Stdio.print_endline "positions:";
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ positions;
Tensor.print ~with_code:false ~with_grad:false `Default @@ positions;
Stdio.print_endline "tk:";
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ tk;
Tensor.print ~with_code:false ~with_grad:false `Default @@ tk;
Stdio.print_endline "ti:";
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ ti;
Tensor.print ~with_code:false ~with_grad:false `Default @@ ti;
Stdio.printf "\n%!"

let hello5 () =
Expand All @@ -120,8 +118,8 @@ let hello5 () =
let hey = TDSL.range_of_shape ~batch_dims:[ 2 ] ~input_dims:[ 3 ] ~output_dims:[ 4 ] () in
let%op ho = hey ++ "...|1->... => ...|..." in
Train.forward_and_forget backend ctx ho;
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ hey;
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ ho
Tensor.print ~with_code:false ~with_grad:false `Default @@ hey;
Tensor.print ~with_code:false ~with_grad:false `Default @@ ho

let hello6 () =
Utils.set_log_level 2;
Expand Down
26 changes: 13 additions & 13 deletions bin/hello_world_op.ml
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,11 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
Rand.init 0;
let%op hey = [ (1, 2, 3); (4, 5, 6) ] in
Train.forward_and_forget backend ctx hey;
Tensor.print ~force:true ~with_code:false ~with_grad:false `Inline @@ hey;
Tensor.print ~with_code:false ~with_grad:false `Inline @@ hey;
Tensor.print ~with_code:false ~with_grad:false `Default @@ hey;
let%op hoo = [| [ 1; 2; 3 ]; [ 4; 5; 6 ] |] in
Train.forward_and_forget backend ctx hoo;
Tensor.print ~force:true ~with_code:false ~with_grad:false `Inline @@ hoo;
Tensor.print ~with_code:false ~with_grad:false `Inline @@ hoo;
Tensor.print ~with_code:false ~with_grad:false `Default @@ hoo;
let%op hey2 =
[
Expand All @@ -87,7 +87,7 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
]
in
Train.forward_and_forget backend ctx hey2;
Tensor.print ~force:true ~with_code:false ~with_grad:false `Inline @@ hey2;
Tensor.print ~with_code:false ~with_grad:false `Inline @@ hey2;
Tensor.print ~with_code:false ~with_grad:false `Default @@ hey2;
let%op hoo2 =
[|
Expand All @@ -98,8 +98,8 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
|]
in
Train.forward_and_forget backend ctx hoo2;
Tensor.print ~force:true ~with_code:false ~with_grad:false `Inline @@ hoo2;
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ hoo2;
Tensor.print ~with_code:false ~with_grad:false `Inline @@ hoo2;
Tensor.print ~with_code:false ~with_grad:false `Default @@ hoo2;
let%op heyhoo =
[|
[| [ 1; 2; 3 ]; [ 4; 5; 6 ] |];
Expand All @@ -109,8 +109,8 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
|]
in
Train.forward_and_forget backend ctx heyhoo;
Tensor.print ~force:true ~with_code:false ~with_grad:false `Inline @@ heyhoo;
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ heyhoo;
Tensor.print ~with_code:false ~with_grad:false `Inline @@ heyhoo;
Tensor.print ~with_code:false ~with_grad:false `Default @@ heyhoo;
let%op heyhoo2 =
[|
[| [ [ 1; 31 ]; [ 2; 32 ]; [ 3; 33 ] ]; [ [ 4; 34 ]; [ 5; 35 ]; [ 6; 36 ] ] |];
Expand All @@ -120,8 +120,8 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
|]
in
Train.forward_and_forget backend ctx heyhoo2;
Tensor.print ~force:true ~with_code:false ~with_grad:false `Inline @@ heyhoo2;
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ heyhoo2;
Tensor.print ~with_code:false ~with_grad:false `Inline @@ heyhoo2;
Tensor.print ~with_code:false ~with_grad:false `Default @@ heyhoo2;
let%op heyhoo3 =
[|
[|
Expand All @@ -135,8 +135,8 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
|]
in
Train.forward_and_forget backend ctx heyhoo3;
Tensor.print ~force:true ~with_code:false ~with_grad:false `Inline @@ heyhoo3;
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ heyhoo3;
Tensor.print ~with_code:false ~with_grad:false `Inline @@ heyhoo3;
Tensor.print ~with_code:false ~with_grad:false `Default @@ heyhoo3;
let%op heyhoo4 =
[|
[
Expand All @@ -150,8 +150,8 @@ let%track2_sexp _Print_constant_tensor (() : unit) : unit =
|]
in
Train.forward_and_forget backend ctx heyhoo4;
Tensor.print ~force:true ~with_code:false ~with_grad:false `Inline @@ heyhoo4;
Tensor.print ~force:true ~with_code:false ~with_grad:false `Default @@ heyhoo4
Tensor.print ~with_code:false ~with_grad:false `Inline @@ heyhoo4;
Tensor.print ~with_code:false ~with_grad:false `Default @@ heyhoo4

let%track2_sexp _Matrix_multiplication_dims_2x3 (() : unit) : unit =
Tensor.unsafe_reinitialize ();
Expand Down
4 changes: 2 additions & 2 deletions bin/micrograd_basic.ml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ let%diagn_sexp () =
]; *)
let update = Train.grad_update d in
let routine = Train.to_routine (module Backend) ctx IDX.empty update.fwd_bprop in
Train.sync_run (module Backend) routine d;
Train.run routine;
Tensor.print_tree ~with_grad:true ~depth:9 d;
Stdio.print_endline "\n";
Tensor.print ~with_code:false ~with_grad:false `Default @@ d;
Expand All @@ -53,7 +53,7 @@ let%diagn_sexp _suspended () : unit =
(* Train.every_non_literal_on_host g; *)
let update = Train.grad_update g in
let routine = Train.to_routine (module Backend) ctx IDX.empty update.fwd_bprop in
Train.sync_run (module Backend) routine g;
Train.run routine;
(* Tensor.print_tree ~with_grad:true ~depth:9 g; *)
Tensor.print ~with_code:false ~with_grad:false `Default @@ g;
Tensor.print ~with_code:false ~with_grad:true `Default @@ a;
Expand Down
5 changes: 0 additions & 5 deletions bin/micrograd_demo.ml
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,6 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
for batch = 0 to n_batches - 1 do
batch_ref := batch;
Train.run routine;
assert (Backend.to_host routine.context learning_rate.value);
assert (Backend.to_host routine.context scalar_loss.value);
Backend.await stream;
(* Stdio.printf "Data batch=%d, step=%d, lr=%f, batch loss=%f\n%!" !batch_ref !step_ref
learning_rate.@[0] scalar_loss.@[0]; *)
learning_rates := learning_rate.@[0] :: !learning_rates;
Expand Down Expand Up @@ -141,8 +138,6 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
needed. *)
assert (Backend.from_host result_routine.context point.value);
Train.run result_routine;
assert (Backend.to_host result_routine.context mlp_result.value);
Backend.await stream;
Float.(mlp_result.@[0] >= 0.)
in
let%track3_sexp _plotting : unit =
Expand Down
5 changes: 0 additions & 5 deletions bin/moons_demo.ml
Original file line number Diff line number Diff line change
Expand Up @@ -91,9 +91,6 @@ let demo () =
batch_ref := batch;
Utils.capture_stdout_logs @@ fun () ->
Train.run routine;
assert (Backend.to_host routine.context learning_rate.value);
assert (Backend.to_host routine.context scalar_loss.value);
Backend.await stream;
epoch_loss := !epoch_loss +. scalar_loss.@[0];
Int.incr step_ref
done;
Expand All @@ -117,8 +114,6 @@ let demo () =
Utils.capture_stdout_logs @@ fun () ->
assert (Backend.from_host result_routine.context point.value);
Train.run result_routine;
assert (Backend.to_host result_routine.context mlp_result.value);
Backend.await stream;
Float.(mlp_result.@[0] >= 0.)
in

Expand Down
35 changes: 13 additions & 22 deletions bin/zero2hero_1of7.ml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ let _suspended () =
Train.every_non_literal_on_host v;
let code = Train.grad_update v in
let routine = Train.to_routine (module Backend) ctx IDX.empty code.fwd_bprop in
Train.sync_run (module Backend) routine v;
Train.run routine;
Stdio.printf "\n%!";
Tensor.print_tree ~with_id:true ~with_grad:true ~depth:9 v;
Stdlib.Format.printf "\nHigh-level code:\n%!";
Expand All @@ -47,7 +47,7 @@ let _suspended () =
Tensor.print_tree ~with_grad:false ~depth:9 f5;
Stdio.printf "\n%!"

let () =
let _suspended () =
(* FIXME: why is this toplevel example broken and the next one working? *)
Utils.settings.output_debug_files_in_build_directory <- true;
Rand.init 0;
Expand Down Expand Up @@ -75,14 +75,12 @@ let () =
let step_ref = IDX.find_exn routine.bindings step_sym in
let ys = Array.create ~len:size 0. and dys = Array.create ~len:size 0. in
let open Operation.At in
let looping () =
assert (Backend.to_host routine.context fx.value);
assert (Backend.to_host routine.context (Option.value_exn ~here:[%here] x.diff).grad);
Backend.await stream;
let f () =
Train.run routine;
ys.(!step_ref) <- fx.@[0];
dys.(!step_ref) <- x.@%[0]
in
Train.sync_run ~looping (module Backend) routine fx;
Train.sequential_loop routine.bindings ~f;
let plot_box =
let open PrintBox_utils in
plot ~size:(75, 35) ~x_label:"x" ~y_label:"f(x)"
Expand All @@ -101,13 +99,6 @@ let _suspended () =
(* Utils.settings.debug_log_from_routines <- true; *)
Rand.init 0;
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
let backend =
(module Backend : Backend
with type buffer_ptr = Backend.buffer_ptr
and type dev = Backend.dev
and type runner = Backend.runner
and type event = Backend.event)
in
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
let ctx = Backend.make_context stream in
let open Operation.At in
Expand Down Expand Up @@ -138,7 +129,7 @@ let _suspended () =
Array.unzip
@@ Array.mapi xs ~f:(fun i _ ->
step_ref := i;
Train.sync_run backend fx_routine fx;
Train.run fx_routine;
(fx.@[0], x.@%[0]))
in
(* It is fine to loop around the data: it's "next epoch". We redo the work though. *)
Expand All @@ -155,7 +146,7 @@ let _suspended () =
in
()

let _suspended () =
let () =
Rand.init 0;
Utils.set_log_level 2;
Utils.settings.output_debug_files_in_build_directory <- true;
Expand All @@ -172,8 +163,8 @@ let _suspended () =
in
Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.from_host routine.context a : bool));
Train.run routine;
Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
Backend.await stream;
(* Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
Backend.await stream; *)
Stdio.print_endline
{|
We did not update the params: all values and gradients will be at initial points,
Expand All @@ -195,8 +186,8 @@ let _suspended () =
List.iter [ a.value; b.value; c.value; f.value ] ~f:(fun a ->
assert (Backend.from_host routine.context a));
Train.run routine;
Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
Backend.await stream;
(* Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
Backend.await stream; *)
Stdio.print_endline
{|
Now we updated the params, but after the forward and backward passes:
Expand All @@ -206,8 +197,8 @@ let _suspended () =
let update = Train.grad_update l in
let routine = Train.to_routine (module Backend) routine.context IDX.empty update.fwd_bprop in
Train.run routine;
Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
Backend.await stream;
(* Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
Backend.await stream; *)
Stdio.print_endline
{|
Now again we did not update the params, they will remain as above, but both param
Expand Down
Loading

0 comments on commit 1a33588

Please sign in to comment.