-
Notifications
You must be signed in to change notification settings - Fork 2
/
tensor.mli
281 lines (244 loc) · 10.2 KB
/
tensor.mli
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
(** {1 Construction of runtime-compiled code supporting backpropagation.} *)
open Base
type tn = Arrayjit.Tnode.t
type tn_set = Set.M(Arrayjit.Tnode).t
type asgns = Arrayjit.Assignments.t
type comp = Arrayjit.Assignments.comp
type init_op = Arrayjit.Ops.init_op
type fetch_op = Arrayjit.Assignments.fetch_op
type projections = Arrayjit.Indexing.projections
type diff = {
grad : tn;
zero_grads : asgns;
(** Prepares for backpropagation. Always compile as: [Seq (zero_grads, backprop)]. *)
backprop : comp;
(** Backpropagates for the tensor and its descendants; which typically means adding partial
gradients to the gradient tensor of the subtensors, then for sub-subtensors etc. *)
}
type t = {
forward : comp;
diff : diff option;
id : int; (** Same as [value.id]. *)
value : tn;
shape : Shape.t;
(** The eventual shape of [t.value] and [t.diff.grad], incorporating the current state of
shape inference. *)
children : subtensor list;
}
[@@deriving sexp_of]
(** Information needed for compositional code generation. *)
and subtensor = {
subtensor : t;
embedded : bool;
(** A tensor can be an [embedded] child at most once -- that's where its [forward] computation
ends up when used as part of a bigger computation. *)
}
type comparator_witness
val comparator : (t, comparator_witness) Base.Comparator.t
val is_fwd_root : t -> bool
val remove_fwd_root : t -> unit
val is_bprop_root : t -> bool
val remove_bprop_root : t -> unit
val with_unchanged_roots : f:(unit -> 'a) -> 'a
val default_value_prec : Arrayjit.Ops.prec ref
(** The default precision for the value node of terminal (i.e. non-composite) tensors.
Note: the precision of a node can be set arbitrarily via {!Arrayjit.Tnode.update_prec}. The
default precision for value nodes of composite tensors is the maximum of precisions of the value
nodes of sub-tensors. *)
val default_grad_prec : Arrayjit.Ops.prec ref
(** The default precision for the gradient node of terminal (i.e. non-composite) tensors.
Note: the precision of a node can be set arbitrarily via {!Arrayjit.Tnode.update_prec}. The
default precision for gradient nodes of composite tensors is the maximum of precisions of the
gradient nodes of sub-tensors. *)
exception Session_error of string * t option
val max_sublabel_length : int ref
val raw_binop :
initialize_neutral:bool ->
accum:Arrayjit.Ops.binop ->
t:t ->
lhs_is_grad:bool ->
op:Arrayjit.Ops.binop ->
t1:t ->
rhs1_is_grad:bool ->
rhs1_is_merge:bool ->
t2:t ->
rhs2_is_grad:bool ->
rhs2_is_merge:bool ->
logic:Shape.compose_type ->
asgns
val raw_unop :
initialize_neutral:bool ->
accum:Arrayjit.Ops.binop ->
t:t ->
lhs_is_grad:bool ->
op:Arrayjit.Ops.unop ->
t1:t ->
rhs_is_grad:bool ->
rhs_is_merge:bool ->
logic:Shape.transpose_type ->
asgns
type grad_spec = Require_grad | Prohibit_grad | If_needed
val is_prohibit_grad : grad_spec -> bool
val op :
label:string list ->
?compose_op:Shape.compose_type ->
?transpose_op:Shape.transpose_type ->
?init_op:init_op ->
op_asn:(v:tn -> projections:projections Lazy.t -> comp) ->
grad_asn:(v:tn -> g:tn -> projections:projections Lazy.t -> comp) ->
?grad_spec:grad_spec ->
(debug_name:string -> id:int -> Shape.t) ->
t list ->
t
val binop :
label:string list ->
?compose_op:Shape.compose_type ->
op_asn:(v:tn -> t1:t -> t2:t -> projections:projections Lazy.t -> comp) ->
grad_asn:(v:tn -> g:tn -> t1:t -> t2:t -> projections:projections Lazy.t -> comp) ->
?grad_spec:grad_spec ->
t ->
t ->
t
val unop :
label:string list ->
?transpose_op:Shape.transpose_type ->
op_asn:(v:tn -> t1:t -> projections:projections Lazy.t -> comp) ->
grad_asn:(v:tn -> g:tn -> t1:t -> projections:projections Lazy.t -> comp) ->
?grad_spec:grad_spec ->
t ->
t
val term :
label:string list ->
grad_spec:grad_spec ->
?batch_dims:int list ->
?input_dims:int list ->
?output_dims:int list ->
?batch_axes:(string * int) list ->
?input_axes:(string * int) list ->
?output_axes:(string * int) list ->
?deduced:Shape.deduce_within_shape ->
?init_op:init_op ->
?fetch_op:(v:tn -> fetch_op) ->
unit ->
t
(** A terminal: a constant, a parameter, an input of the model. The semantics of shape specification
is the same as in {!Shape.make}, and by default the shape will be inferred. *)
val number : ?label:string list -> ?axis_label:string -> ?grad_spec:grad_spec -> float -> t
(** A number: a tensor with a single axis of one dimension, initialized to the given value.
[grad_spec] is by default [Prohibit_grad]. *)
val ndarray :
?label:string list ->
?grad_spec:grad_spec ->
?batch_dims:int list ->
?input_dims:int list ->
?output_dims:int list ->
?batch_axes:(string * int) list ->
?input_axes:(string * int) list ->
?output_axes:(string * int) list ->
?strict:bool ->
float array ->
t
(** A tensor with an explicit shape, initialized to the given values. Omitted shape rows default to
no axes. [grad_spec] is by default [Prohibit_grad]. If [strict] is [true] (the default), the
given values must fill the tensor's [value] node precisely; otherwise, the values will be looped
over to populate the [value] node. *)
val param :
?more_label:string list ->
?input_dims:int list ->
?output_dims:int list ->
?input_axes:(string * int) list ->
?output_axes:(string * int) list ->
?deduced:Shape.deduce_within_shape ->
?strict:bool ->
?values:float array ->
string ->
t
(* A tensor with no batch axes; input and output axes are by default inferred. [grad_spec] is set to
[Require_grad]. The resulting tensor's label is the passed string, appended by [more_label] if
any. *)
val consume_forward_code : t -> comp
(** A forward root is a tensor that is not (currently) used to compute another tensor.
[consume_forward_code t] ensures [t] is a forward root, removes it from forward roots, and
checks that there are no other forward roots for tensors with children. *)
val consume_backprop_code : t -> asgns * comp
(** A backprop root is a tensor with a gradient that is not (currently) receiving gradients from
another tensor. I.e. it is not currently used to compute a tensor with a gradient.
[consume_backprop_code t] ensures [t] is a backprop root, removes it from backprop roots, and
checks that there are no other backprop roots for tensors with children. *)
val iter_embedded : f:(tn -> unit) -> t -> unit
(** [iter_embedded t] iterates over all descendant nodes that are embedded, i.e. are members of
[t.forward.embedded_nodes] or '[t.diff.backprop.embedded_nodes]' (if any). Note: [iter_embedded]
should only be called after shape inference finishes. *)
val unsafe_reinitialize : unit -> unit
(** Bring global state to its initialization values. This invalidates any previously defined tensors
and tensor nodes. Also reinitializes the modules: {!Shape}, {!Arrayjit.Tnode},
{!Arrayjit.Rand.Random_for_tests}. *)
(** {2 Printing.} *)
val header : t -> string
(** Converts ID, label and the dimensions of a node to a string. *)
val log_debug_info : from_log_level:int -> t -> unit
(** Logs debug information about the tensor on the default ppx_minidebug runtime. *)
type array_print_style =
[ `Default
(** The inner rectangles comprise both an input and an output axis, if available. Similarly, the
outer rectangle comprises a second-from-end input axis and a second-from-end output axis, if
available. At least one batch axis is output, when available. The axes that couldn't be
output are printed at position/dimension [0]. *)
| `N5_layout of string
(** The string should provide exclusively non-negative integer pseudo-labels. The numbers
[0]-[4] represent the priorities of the axes to be printed out, where the priorities
correspond to, from highest: horizontal, vertical direction of the inner rectangle,
horizontal, vertical direction of the outer rectangle, repetition (see also
[Node.pp_print]). The numbers [n >= 5] stand for the actual positions [n - 5] within the
corresponding axes. *)
| `Label_layout of (string * int) list
(** The association from axis labels to integers. The negative numbers [-5] to [-1] represent
the priorities of the axes to be printed out, where the priorities correspond to, from
highest: horizontal, vertical direction of the inner rectangle, horizontal, vertical
direction of the outer rectangle, repetition (as above). The numbers [n >= 0] stand for the
actual positions within the corresponding axes. Unspecified axes are printed at position
[0]. *)
| `Inline
(** The tensors are printed linearly, in a bracketed manner, optionally prefixed with the labels
specification. Note that the syntax causes ambiguity for 1-dimensional input axes
(underscores are used for axes without explicit labels); when there is a 1-dimensional input
axis, we output the labels specification even if there are no axis labels as a way to
display the number of axes. The axis nesting is right-to-left (rightmost is innermost). The
input axes are innermost and the batch axes outermost. The input axes use [,] as a separator
and [()] as axis delimiters, but the delimiter for the outermost (i.e. leftmost) axis is
omitted. The output axes use [;] as a separator and [[]] as axis delimiters (obligatory).
The batch axes use [;] as a separator and [[||]] as axis delimiters (obligatory). *) ]
(** We print out up to 5 axes when printing a tensor, as a grid (outer rectangle) of (inner)
rectangles, possibly repeated (screens). *)
val to_printbox :
?single_node:bool ->
?entries_per_axis:int ->
?with_id:bool ->
?spy:bool ->
?with_shape:bool ->
?with_value:bool ->
with_grad:bool ->
depth:int ->
t ->
PrintBox.t
val print :
?spy:bool ->
with_grad:bool ->
with_code:bool ->
?with_low_level:bool ->
array_print_style ->
t ->
unit
val print_forward_roots : with_grad:bool -> with_code:bool -> array_print_style -> unit
val print_tree :
?entries_per_axis:int ->
?with_backend_info:bool ->
?with_id:bool ->
?spy:bool ->
?with_shape:bool ->
?with_value:bool ->
with_grad:bool ->
depth:int ->
t ->
unit
val debug_name : t -> string