Skip to content

Commit

Permalink
Upgrade to printbox 0.12 and migrate plotting to printbox-ext-plot
Browse files Browse the repository at this point in the history
  • Loading branch information
lukstafi committed Jan 3, 2025
1 parent 8689dec commit 53951bb
Show file tree
Hide file tree
Showing 14 changed files with 412 additions and 496 deletions.
5 changes: 3 additions & 2 deletions arrayjit/lib/ndarray.ml
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ let render_array ?(brief = false) ?(prefix = "") ?(entries_per_axis = 4) ?(label
let nlines = if brief then size1 else size1 + 1 in
let ncols = if brief then size2 else size2 + 1 in
let outer_grid v =
(if brief then Fn.id else B.frame)
(if brief then Fn.id else B.frame ~stretch:false)
@@ B.init_grid ~bars:true ~line:nlines ~col:ncols (fun ~line ~col ->
if (not brief) && line = 0 && col = 0 then
B.lines @@ List.filter ~f:(Fn.non String.is_empty) @@ [ tag ~pos:v label0 ind0 ]
Expand All @@ -586,7 +586,8 @@ let render_array ?(brief = false) ?(prefix = "") ?(entries_per_axis = 4) ?(label
B.init_grid ~bars:true ~line:size0 ~col:1 (fun ~line ~col:_ ->
if elide_for line ~ind:ind0 then B.hpad 1 @@ B.line "..." else outer_grid line)
in
(if brief then Fn.id else B.frame) @@ B.vlist ~bars:false [ B.text header; screens ]
(if brief then Fn.id else B.frame ~stretch:false)
@@ B.vlist ~bars:false [ B.text header; screens ]

let pp_array fmt ?prefix ?entries_per_axis ?labels ~indices arr =
PrintBox_text.pp fmt @@ render_array ?prefix ?entries_per_axis ?labels ~indices arr
Expand Down
5 changes: 2 additions & 3 deletions bin/compilation_speed.ml
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,8 @@ let benchmark_overhead backend () =
f.@[0])
in
let plot_box =
let open PrintBox_utils in
plot ~size:(40, 25) ~x_label:"x" ~y_label:"f(x)"
[ Scatterplot { points = Array.zip_exn xs ys; pixel = "#" } ]
PrintBox_utils.plot ~small:true ~x_label:"x" ~y_label:"f(x)"
[ Scatterplot { points = Array.zip_exn xs ys; content = PrintBox.line "#" } ]
in
let final_time = Time_now.nanoseconds_since_unix_epoch () in
let time_in_sec = Int63.(to_float @@ (final_time - init_time)) /. 1000_000_000. in
Expand Down
25 changes: 11 additions & 14 deletions bin/micrograd_demo.ml
Original file line number Diff line number Diff line change
Expand Up @@ -139,38 +139,35 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
in
let%track3_sexp _plotting : unit =
let plot_moons =
let open PrintBox_utils in
plot ~size:(120, 40) ~x_label:"ixes" ~y_label:"ygreks"
PrintBox_utils.plot ~as_canvas:true
[
Scatterplot { points = points1; pixel = "#" };
Scatterplot { points = points2; pixel = "%" };
Boundary_map { pixel_false = "."; pixel_true = "*"; callback };
Scatterplot { points = points1; content = PrintBox.line "#" };
Scatterplot { points = points2; content = PrintBox.line "%" };
Boundary_map
{ content_false = PrintBox.line "."; content_true = PrintBox.line "*"; callback };
]
in
Stdio.printf "Half-moons scatterplot and decision boundary:\n%!";
PrintBox_text.output Stdio.stdout plot_moons
in
Stdio.printf "Loss:\n%!";
let plot_loss =
let open PrintBox_utils in
plot ~size:(120, 30) ~x_label:"step" ~y_label:"loss"
[ Line_plot { points = Array.of_list_rev !losses; pixel = "-" } ]
PrintBox_utils.plot ~x_label:"step" ~y_label:"loss"
[ Line_plot { points = Array.of_list_rev !losses; content = PrintBox.line "-" } ]
in
PrintBox_text.output Stdio.stdout plot_loss;

Stdio.printf "Log-loss, for better visibility:\n%!";
let plot_loss =
let open PrintBox_utils in
plot ~size:(120, 30) ~x_label:"step" ~y_label:"log loss"
[ Line_plot { points = Array.of_list_rev !log_losses; pixel = "-" } ]
PrintBox_utils.plot ~x_label:"step" ~y_label:"log loss"
[ Line_plot { points = Array.of_list_rev !log_losses; content = PrintBox.line "-" } ]
in
PrintBox_text.output Stdio.stdout plot_loss;

Stdio.printf "\nLearning rate:\n%!";
let plot_lr =
let open PrintBox_utils in
plot ~size:(120, 30) ~x_label:"step" ~y_label:"learning rate"
[ Line_plot { points = Array.of_list_rev !learning_rates; pixel = "-" } ]
PrintBox_utils.plot ~x_label:"step" ~y_label:"learning rate"
[ Line_plot { points = Array.of_list_rev !learning_rates; content = PrintBox.line "-" } ]
in
PrintBox_text.output Stdio.stdout plot_lr

Expand Down
38 changes: 20 additions & 18 deletions bin/moons_benchmark.ml
Original file line number Diff line number Diff line change
Expand Up @@ -126,47 +126,49 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_streams ~batch_size ~b
let callback (x, y) = Float.((infer_callback [| x; y |]).(0) >= 0.) in
let%track3_sexp plot_moons () =
(* [%log_level 0; *)
let open PrintBox_utils in
plot
~size:(120, 40)
(* TINY for debugging: *)
(* ~size:(20, 10) *)
~x_label:"ixes" ~y_label:"ygreks"
PrintBox_utils.plot
(* TINY for debugging: *)
(* ~small:true *)
~as_canvas:true
[
Scatterplot { points = points1; pixel = "#" };
Scatterplot { points = points2; pixel = "%" };
Boundary_map { pixel_false = "."; pixel_true = "*"; callback };
Scatterplot { points = points1; content = PrintBox.line "#" };
Scatterplot { points = points2; content = PrintBox.line "%" };
Boundary_map
{ content_false = PrintBox.line "."; content_true = PrintBox.line "*"; callback };
]
(* ] *)
in
Stdio.printf "\nHalf-moons scatterplot and decision boundary:\n%!";
PrintBox_text.output Stdio.stdout @@ plot_moons ();
Stdio.printf "\nBatch Log-loss:\n%!";
let plot_loss =
let open PrintBox_utils in
plot ~size:(120, 30) ~x_label:"step" ~y_label:"batch log loss"
PrintBox_utils.plot ~x_label:"step" ~y_label:"batch log loss"
[
Line_plot
{
points =
Array.of_list_rev_map rev_batch_losses ~f:Float.(fun x -> max (log 0.00003) (log x));
pixel = "-";
content = PrintBox.line "-";
};
]
in
PrintBox_text.output Stdio.stdout plot_loss;
Stdio.printf "\nEpoch Log-loss:\n%!";
let plot_loss =
let open PrintBox_utils in
plot ~size:(120, 30) ~x_label:"step" ~y_label:"epoch log loss"
[ Line_plot { points = Array.of_list_rev_map rev_epoch_losses ~f:Float.log; pixel = "-" } ]
PrintBox_utils.plot ~x_label:"step" ~y_label:"epoch log loss"
[
Line_plot
{
points = Array.of_list_rev_map rev_epoch_losses ~f:Float.log;
content = PrintBox.line "-";
};
]
in
PrintBox_text.output Stdio.stdout plot_loss;
Stdio.printf "\nLearning rate:\n%!";
let plot_lr =
let open PrintBox_utils in
plot ~size:(120, 30) ~x_label:"step" ~y_label:"learning rate"
[ Line_plot { points = Array.of_list_rev learning_rates; pixel = "-" } ]
PrintBox_utils.plot ~x_label:"step" ~y_label:"learning rate"
[ Line_plot { points = Array.of_list_rev learning_rates; content = PrintBox.line "-" } ]
in
PrintBox_text.output Stdio.stdout plot_lr;
let final_time = Time_now.nanoseconds_since_unix_epoch () in
Expand Down
16 changes: 8 additions & 8 deletions bin/moons_demo.ml
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,10 @@ let demo () =
let classes = Tn.points_1d ~xdim:0 moons_classes.value in
let points1, points2 = Array.partitioni_tf points ~f:Float.(fun i _ -> classes.(i) > 0.) in
let plot_moons =
let open PrintBox_utils in
plot ~size:(120, 40) ~x_label:"ixes" ~y_label:"ygreks"
PrintBox_utils.plot ~as_canvas:true
[
Scatterplot { points = points1; pixel = "#" }; Scatterplot { points = points2; pixel = "+" };
Scatterplot { points = points1; content = PrintBox.line "#" };
Scatterplot { points = points2; content = PrintBox.line "+" };
]
in
Stdio.printf "\nHalf-moons scatterplot:\n%!";
Expand Down Expand Up @@ -116,12 +116,12 @@ let demo () =

let%track_sexp _plotting : unit =
let plot_moons =
let open PrintBox_utils in
plot ~size:(120, 40) ~x_label:"ixes" ~y_label:"ygreks"
PrintBox_utils.plot ~as_canvas:true
[
Scatterplot { points = points1; pixel = "#" };
Scatterplot { points = points2; pixel = "+" };
Boundary_map { pixel_false = "."; pixel_true = "*"; callback };
Scatterplot { points = points1; content = PrintBox.line "#" };
Scatterplot { points = points2; content = PrintBox.line "+" };
Boundary_map
{ content_false = PrintBox.line "."; content_true = PrintBox.line "*"; callback };
]
in
Stdio.printf "\nHalf-moons scatterplot and decision boundary:\n%!";
Expand Down
41 changes: 21 additions & 20 deletions bin/moons_demo_parallel.ml
Original file line number Diff line number Diff line change
Expand Up @@ -78,56 +78,57 @@ let experiment ~seed ~backend_name ~config () =
Stdio.printf "\n********\nUsed memory: %d\n%!" used_memory;
let callback (x, y) = Float.((infer_callback [| x; y |]).(0) >= 0.) in
let plot_moons =
let open PrintBox_utils in
plot ~size:(120, 40) ~x_label:"ixes" ~y_label:"ygreks"
PrintBox_utils.plot ~as_canvas:true
[
Scatterplot { points = points1; pixel = "#" };
Scatterplot { points = points2; pixel = "%" };
Boundary_map { pixel_false = "."; pixel_true = "*"; callback };
Scatterplot { points = points1; content = PrintBox.line "#" };
Scatterplot { points = points2; content = PrintBox.line "%" };
Boundary_map
{ content_false = PrintBox.line "."; content_true = PrintBox.line "*"; callback };
]
in
Stdio.printf "\nHalf-moons scatterplot and decision boundary:\n%!";
PrintBox_text.output Stdio.stdout plot_moons;
Stdio.printf "\nBatch Loss:\n%!";
let plot_loss =
let open PrintBox_utils in
plot ~size:(120, 30) ~x_label:"step" ~y_label:"batch loss"
[ Line_plot { points = Array.of_list_rev rev_batch_losses; pixel = "-" } ]
PrintBox_utils.plot ~x_label:"step" ~y_label:"batch loss"
[ Line_plot { points = Array.of_list_rev rev_batch_losses; content = PrintBox.line "-" } ]
in
PrintBox_text.output Stdio.stdout plot_loss;
Stdio.printf "\nEpoch Loss:\n%!";
let plot_loss =
let open PrintBox_utils in
plot ~size:(120, 30) ~x_label:"step" ~y_label:"epoch loss"
[ Line_plot { points = Array.of_list_rev rev_epoch_losses; pixel = "-" } ]
PrintBox_utils.plot ~x_label:"step" ~y_label:"epoch loss"
[ Line_plot { points = Array.of_list_rev rev_epoch_losses; content = PrintBox.line "-" } ]
in
PrintBox_text.output Stdio.stdout plot_loss;
Stdio.printf "\nBatch Log-loss:\n%!";
let plot_loss =
let open PrintBox_utils in
plot ~size:(120, 30) ~x_label:"step" ~y_label:"batch log loss"
PrintBox_utils.plot ~x_label:"step" ~y_label:"batch log loss"
[
Line_plot
{
points =
Array.of_list_rev_map rev_batch_losses ~f:Float.(fun x -> max (log 0.00003) (log x));
pixel = "-";
content = PrintBox.line "-";
};
]
in
PrintBox_text.output Stdio.stdout plot_loss;
Stdio.printf "\nEpoch Log-loss:\n%!";
let plot_loss =
let open PrintBox_utils in
plot ~size:(120, 30) ~x_label:"step" ~y_label:"epoch log loss"
[ Line_plot { points = Array.of_list_rev_map rev_epoch_losses ~f:Float.log; pixel = "-" } ]
PrintBox_utils.plot ~x_label:"step" ~y_label:"epoch log loss"
[
Line_plot
{
points = Array.of_list_rev_map rev_epoch_losses ~f:Float.log;
content = PrintBox.line "-";
};
]
in
PrintBox_text.output Stdio.stdout plot_loss;
Stdio.printf "\nLearning rate:\n%!";
let plot_lr =
let open PrintBox_utils in
plot ~size:(120, 30) ~x_label:"step" ~y_label:"learning rate"
[ Line_plot { points = Array.of_list_rev learning_rates; pixel = "-" } ]
PrintBox_utils.plot ~x_label:"step" ~y_label:"learning rate"
[ Line_plot { points = Array.of_list_rev learning_rates; content = PrintBox.line "-" } ]
in
PrintBox_text.output Stdio.stdout plot_lr

Expand Down
18 changes: 8 additions & 10 deletions bin/zero2hero_1of7.ml
Original file line number Diff line number Diff line change
Expand Up @@ -82,12 +82,11 @@ let _suspended () =
in
Train.sequential_loop routine.bindings ~f;
let plot_box =
let open PrintBox_utils in
plot ~size:(75, 35) ~x_label:"x" ~y_label:"f(x)"
PrintBox_utils.plot ~x_label:"x" ~y_label:"f(x)"
[
Scatterplot { points = Array.zip_exn values ys; pixel = "#" };
Scatterplot { points = Array.zip_exn values dys; pixel = "*" };
Line_plot { points = Array.create ~len:20 0.; pixel = "-" };
Scatterplot { points = Array.zip_exn values ys; content = PrintBox.line "#" };
Scatterplot { points = Array.zip_exn values dys; content = PrintBox.line "*" };
Line_plot { points = Array.create ~len:20 0.; content = PrintBox.line "-" };
]
in
PrintBox_text.output Stdio.stdout plot_box;
Expand Down Expand Up @@ -134,12 +133,11 @@ let _suspended () =
in
(* It is fine to loop around the data: it's "next epoch". We redo the work though. *)
let plot_box =
let open PrintBox_utils in
plot ~size:(75, 35) ~x_label:"x" ~y_label:"f(x)"
PrintBox_utils.plot ~size:(75, 35) ~x_label:"x" ~y_label:"f(x)"
[
Scatterplot { points = Array.zip_exn xs ys; pixel = "#" };
Scatterplot { points = Array.zip_exn xs dys; pixel = "*" };
Line_plot { points = Array.create ~len:20 0.; pixel = "-" };
Scatterplot { points = Array.zip_exn xs ys; content = PrintBox.line "#" };
Scatterplot { points = Array.zip_exn xs dys; content = PrintBox.line "*" };
Line_plot { points = Array.create ~len:20 0.; content = PrintBox.line "-" };
]
in
PrintBox_text.output Stdio.stdout plot_box
Expand Down
Loading

0 comments on commit 53951bb

Please sign in to comment.