Skip to content

Commit

Permalink
Rename base unit as Variable
Browse files Browse the repository at this point in the history
  • Loading branch information
rounakdatta committed Apr 10, 2024
1 parent 57a5787 commit 5fa5e2f
Show file tree
Hide file tree
Showing 4 changed files with 32 additions and 29 deletions.
2 changes: 1 addition & 1 deletion lib/dune
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
(library
(name smolgrad)
(public_name smolgrad)
(modules neuron))
(modules variable))
2 changes: 1 addition & 1 deletion lib/neuron.ml → lib/variable.ml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
module Neuron = struct
module Variable = struct
type t = {
data : float;
mutable grad : float;
Expand Down
9 changes: 6 additions & 3 deletions lib/neuron.mli → lib/variable.mli
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
module Neuron : sig
(* this is essentially the core variable which is differentiable at heart *)
(* these form the base data structure of weights and biases of a neuron *)
module Variable : sig
type t

(* getter for the data *)
val data : t -> float

(* getter for the gradient or weight *)
(* getter for the gradient *)
val grad : t -> float

(* getter for the dependencies of the node *)
val dependencies : t -> t list

(* Constructor; constructs a unit neuron of a value and an operator. *)
(* Constructor; constructs a unit variable with a value, *)
(* and optionally the creation operation and the dependencies. *)
val create : ?op:string -> ?deps:t list -> float -> t

(* Handles the gradient flows in addition operation. *)
Expand Down
48 changes: 24 additions & 24 deletions test/smolgrad_tests.ml
Original file line number Diff line number Diff line change
@@ -1,68 +1,68 @@
open Smolgrad.Neuron
open Smolgrad.Variable

let test_simple_operation () =
let a = Neuron.create 4.0 in
let b = Neuron.create 2.0 in
let a = Variable.create 4.0 in
let b = Variable.create 2.0 in

let abba = Neuron.(a + b) in
let abba = Variable.(a + b) in
Alcotest.(check (float 0.0))
"Nodes add up correctly"
6.0
(Neuron.data abba);
(Variable.data abba);
;;

(* here we open the Neuron module wide open locally, thereby allowing the clean custom `+` operator usage *)
(* here we open the Variable module wide open locally, thereby allowing the clean custom `+` operator usage *)
(* we'll avoid this pattern elsewhere in the tests *)
let test_custom_operator () =
let open Neuron in
let open Variable in
let a = create 4.0 in
let b = create 2.0 in

let abba = a + b in
Alcotest.(check (float 0.0))
"Nodes add up correctly with custom operator"
6.0
(Neuron.data abba);
(Variable.data abba);
;;

let test_graph_construction () =
let a = Neuron.create 4.0 in
let b = Neuron.create 2.0 in
let a = Variable.create 4.0 in
let b = Variable.create 2.0 in

let c = Neuron.(a * b + b ** 3.0) in
(* c essentially depends on (a * b) as the first node and b ** 3.0 as the second node *)
let c = Variable.(a * b + b ** 3.0) in
(* c essentially depends on (a * b) as the first node and (b ** 3.0) as the second node *)
(* interesting how BODMAS rules automagically apply here for the operators *)

let d = Neuron.(c + a) in
(* d essentially depends on c and a; simple *)
let d = Variable.(c + a) in
(* d essentially depends on c and a - simple and obvious *)

Alcotest.(check (list (float 0.0)))
"Dependency graph is constructed correctly for c"
(List.map (fun x -> Neuron.data x) [Neuron.(a * b); Neuron.(b ** 3.0)])
(List.map (fun x -> Neuron.data x) (Neuron.dependencies c));
(List.map (fun x -> Variable.data x) [Variable.(a * b); Variable.(b ** 3.0)])
(List.map (fun x -> Variable.data x) (Variable.dependencies c));

Alcotest.(check (list (float 0.0)))
"Dependency graph is constructed correctly for d"
(List.map (fun x -> Neuron.data x) [c; a])
(List.map (fun x -> Neuron.data x) (Neuron.dependencies d));
(List.map (fun x -> Variable.data x) [c; a])
(List.map (fun x -> Variable.data x) (Variable.dependencies d));
;;

let test_backpropagation () =
let a = Neuron.create (-4.0) in
let b = Neuron.create 2.0 in
let c = Neuron.(a * b + b ** 3.0) in
let a = Variable.create (-4.0) in
let b = Variable.create 2.0 in
let c = Variable.(a * b + b ** 3.0) in

Neuron.backpropagate c;
Variable.backpropagate c;

Alcotest.(check (float 0.0))
"Backpropagation yields correct gradient for a for a complex graph"
2.0
(Neuron.grad a);
(Variable.grad a);

Alcotest.(check (float 0.0))
"Backpropagation yields correct gradient for b a complex graph"
8.0
(Neuron.grad b);
(Variable.grad b);
;;

let () =
Expand Down

0 comments on commit 5fa5e2f

Please sign in to comment.