-
Notifications
You must be signed in to change notification settings - Fork 2
/
zero2hero_1of7.ml
198 lines (190 loc) · 7.87 KB
/
zero2hero_1of7.ml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
open Base
open Ocannl
module Tn = Arrayjit.Tnode
module IDX = Train.IDX
module CDSL = Train.CDSL
module TDSL = Operation.TDSL
module NTDSL = Operation.NTDSL
module Utils = Arrayjit.Utils
module Rand = Arrayjit.Rand.Lib
module Debug_runtime = Utils.Debug_runtime
module type Backend = Arrayjit.Backend_intf.Backend
let _get_local_debug_runtime = Arrayjit.Utils._get_local_debug_runtime
[%%global_debug_log_level 9]
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
let _suspended () =
Rand.init 0;
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
let ctx = Backend.make_context stream in
let%op v = ("w" [ (-3, 1) ] * "x" [ 2; 0 ]) + "b" [ 6.7 ] in
Train.every_non_literal_on_host v;
let code = Train.grad_update v in
let routine = Train.to_routine (module Backend) ctx IDX.empty code.fwd_bprop in
Train.run routine;
Stdio.printf "\n%!";
Tensor.print_tree ~with_id:true ~with_grad:true ~depth:9 v;
Stdlib.Format.printf "\nHigh-level code:\n%!";
Stdlib.Format.printf "%a\n%!" (Arrayjit.Assignments.fprint_hum ()) code.fwd_bprop.asgns
let _suspended () =
Rand.init 0;
CDSL.enable_all_debugs ();
CDSL.virtualize_settings.enable_device_only <- false;
let%op f x = (3 *. (x **. 2)) - (4 *. x) + 5 in
let%op f5 = f 5 in
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
Train.every_non_literal_on_host f5;
Train.forward_and_forget
(module Backend)
Backend.(make_context @@ new_stream @@ get_device ~ordinal:0)
f5;
Stdio.printf "\n%!";
Tensor.print_tree ~with_grad:false ~depth:9 f5;
Stdio.printf "\n%!"
let _suspended () =
(* FIXME: why is this toplevel example broken and the next one working? *)
Utils.settings.output_debug_files_in_build_directory <- true;
Rand.init 0;
let%op f x = (3 *. (x **. 2)) - (4 *. x) + 5 in
let size = 100 in
let values = Array.init size ~f:Float.(fun i -> (of_int i / 10.) - 5.) in
(* Test that the batch axis dimensions will be inferred. *)
let x_flat =
Tensor.term ~grad_spec:Tensor.Require_grad
~label:[ "x_flat" ] (* ~input_dims:[] ~output_dims:[ 1 ] *)
~init_op:(Constant_fill { values; strict = true })
()
in
let step_sym, bindings = IDX.get_static_symbol ~static_range:size IDX.empty in
(* The [let x =] line is the same as this except [let%op x =] uses [~grad_spec:If_needed]. *)
let%op x = x_flat @| step_sym in
(* let x = Operation.slice ~label:[ "x" ] ~grad_spec:Require_grad step_sym x_flat in *)
Train.set_hosted (Option.value_exn ~here:[%here] x.diff).grad;
let%op fx = f x in
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
let ctx = Backend.make_context stream in
let update = Train.grad_update fx in
let routine = Train.to_routine (module Backend) ctx bindings update.fwd_bprop in
let step_ref = IDX.find_exn routine.bindings step_sym in
let ys = Array.create ~len:size 0. and dys = Array.create ~len:size 0. in
let open Operation.At in
let f () =
Train.run routine;
ys.(!step_ref) <- fx.@[0];
dys.(!step_ref) <- x.@%[0]
in
Train.sequential_loop routine.bindings ~f;
let plot_box =
PrintBox_utils.plot ~x_label:"x" ~y_label:"f(x)"
[
Scatterplot { points = Array.zip_exn values ys; content = PrintBox.line "#" };
Scatterplot { points = Array.zip_exn values dys; content = PrintBox.line "*" };
Line_plot { points = Array.create ~len:20 0.; content = PrintBox.line "-" };
]
in
PrintBox_text.output Stdio.stdout plot_box;
Stdio.print_endline ""
let _suspended () =
(* Utils.set_log_level 2; *)
Utils.settings.output_debug_files_in_build_directory <- true;
(* Utils.settings.debug_log_from_routines <- true; *)
Rand.init 0;
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
let ctx = Backend.make_context stream in
let open Operation.At in
CDSL.virtualize_settings.enable_device_only <- false;
let%op f x = (3 *. (x **. 2)) - (4 *. x) + 5 in
let%op f5 = f 5 in
Train.every_non_literal_on_host f5;
Train.forward_and_forget (module Backend) ctx f5;
Tensor.print_tree ~with_grad:false ~depth:9 f5;
let size = 100 in
let xs = Array.init size ~f:Float.(fun i -> (of_int i / 10.) - 5.) in
(* Yay, the whole shape gets inferred! *)
let x_flat =
Tensor.term ~grad_spec:Require_grad ~label:[ "x_flat" ]
~init_op:(Constant_fill { values = xs; strict = true })
()
in
let step_sym, bindings = IDX.get_static_symbol ~static_range:size IDX.empty in
let%op x = x_flat @| step_sym in
let%op fx = f x in
Train.set_hosted x.value;
Train.set_hosted (Option.value_exn ~here:[%here] x.diff).grad;
let update = Train.grad_update fx in
let fx_routine = Train.to_routine (module Backend) ctx bindings update.fwd_bprop in
let step_ref = IDX.find_exn fx_routine.bindings step_sym in
let%track_sexp () =
let ys, dys =
Array.unzip
@@ Array.mapi xs ~f:(fun i _ ->
step_ref := i;
Train.run fx_routine;
(fx.@[0], x.@%[0]))
in
(* It is fine to loop around the data: it's "next epoch". We redo the work though. *)
let plot_box =
PrintBox_utils.plot ~size:(75, 35) ~x_label:"x" ~y_label:"f(x)"
[
Scatterplot { points = Array.zip_exn xs ys; content = PrintBox.line "#" };
Scatterplot { points = Array.zip_exn xs dys; content = PrintBox.line "*" };
Line_plot { points = Array.create ~len:20 0.; content = PrintBox.line "-" };
]
in
PrintBox_text.output Stdio.stdout plot_box
in
()
let () =
Rand.init 0;
Utils.set_log_level 2;
Utils.settings.output_debug_files_in_build_directory <- true;
Utils.settings.debug_log_from_routines <- true;
let%op e = "a" [ 2 ] *. "b" [ -3 ] in
let%op d = e + "c" [ 10 ] in
let%op l = d *. "f" [ -2 ] in
Train.every_non_literal_on_host l;
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
let update = Train.grad_update l in
let routine =
Train.to_routine (module Backend) (Backend.make_context stream) IDX.empty update.fwd_bprop
in
Train.run routine;
(* Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
Backend.await stream; *)
Stdio.print_endline
{|
We did not update the params: all values and gradients will be at initial points,
which are specified in the tensor in the brackets.|};
Tensor.print_tree ~with_grad:true ~depth:9 l;
let%op learning_rate = 0.1 in
let routine =
Train.to_routine (module Backend) routine.context IDX.empty
@@ Train.sgd_update ~learning_rate update
in
(* learning_rate is virtual so this will not print anything. *)
Stdio.print_endline
{|
Due to how the gccjit backend works, since the parameters were constant in the grad_update
computation, they did not exist on the device before. Now they do. This would not be needed
on the cuda backend.|};
Train.run routine;
Stdio.print_endline
{|
Now we updated the params, but after the forward and backward passes:
only params values will change, compared to the above.|};
Tensor.print_tree ~with_grad:true ~depth:9 l;
(* We could reuse the jitted code if we did not use `jit_and_run`. *)
let update = Train.grad_update l in
let routine = Train.to_routine (module Backend) routine.context IDX.empty update.fwd_bprop in
Train.run routine;
(* Tensor.iter_embedded l ~f:(fun a -> ignore (Backend.to_host routine.context a : bool));
Backend.await stream; *)
Stdio.print_endline
{|
Now again we did not update the params, they will remain as above, but both param
gradients and the values and gradients of other nodes will change thanks to the forward and
backward passes.|};
Tensor.print_tree ~with_grad:true ~depth:9 l