forked from karpathy/llm.c
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request karpathy#637 from karpathy/feature/outliers
add outlier detector, test for it, and start tracking z score of loss
- Loading branch information
Showing
5 changed files
with
191 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
/* | ||
Tests our OutlierDetector | ||
compile and run as (from dev/test directory) | ||
gcc -O3 -I../../llmc -o test_outlier_detector test_outlier_detector.c -lm && ./test_outlier_detector | ||
*/ | ||
|
||
#include <stdlib.h> | ||
#include "../../llmc/outlier_detector.h" | ||
|
||
int main(void) { | ||
OutlierDetector detector; | ||
init_detector(&detector); | ||
|
||
srand(1337); // init rng | ||
|
||
// generate OUTLIER_DETECTOR_WINDOW_SIZE * 2 random numbers between -1 and 1 | ||
for (int i = 0; i < OUTLIER_DETECTOR_WINDOW_SIZE * 2; i++) { | ||
double val = (double)rand() / RAND_MAX * 2 - 1; // Random number between -1 and 1 | ||
double zscore = update_detector(&detector, val); | ||
|
||
printf("Step %d: Value = %.4f, zscore = %.4f\n", i, val, zscore); | ||
|
||
// check that the first OUTLIER_DETECTOR_WINDOW_SIZE values return nan | ||
if (i < OUTLIER_DETECTOR_WINDOW_SIZE) { | ||
if (!isnan(zscore)) { | ||
printf("Error: Expected nan, got %.4f\n", zscore); | ||
return EXIT_FAILURE; | ||
} | ||
} else { | ||
// check that the zscore is within reasonable bounds | ||
if (zscore < -3.0 || zscore > 3.0) { | ||
printf("Error: Z-score %.4f is outside of expected range\n", zscore); | ||
return EXIT_FAILURE; | ||
} | ||
} | ||
} | ||
|
||
// simulate an outlier | ||
double outlier = 10.0; // <--- loss spike | ||
double zscore = update_detector(&detector, outlier); | ||
printf("Outlier Step: Value = %.4f, zscore = %.4f\n", outlier, zscore); | ||
|
||
// check that the z-score here is large | ||
if (zscore < 5.0) { | ||
printf("Error: Z-score %.4f is not large enough for an outlier\n", zscore); | ||
return EXIT_FAILURE; | ||
} | ||
|
||
printf("OK\n"); | ||
return EXIT_SUCCESS; | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
/* | ||
Simple OutlierDetector that we can use to monitor the loss and grad norm | ||
Internally, it keeps track of a window of measurements and each time we | ||
add a measurement, it returns the z-score of the new value with respect to | ||
the window of measurements. This can be used to detect outliers in the data. | ||
We use double so that the detector doesn't drift too much, because we | ||
update the mean and variance with += on each step for efficiency. We could | ||
reconsider this choice in the future, as the compute cost here is minimal. | ||
*/ | ||
|
||
#include <stdio.h> | ||
#include <math.h> | ||
|
||
// use compile-time constant for window size to avoid dynamic memory allocations | ||
#define OUTLIER_DETECTOR_WINDOW_SIZE 128 | ||
|
||
typedef struct { | ||
double buffer[OUTLIER_DETECTOR_WINDOW_SIZE]; | ||
int count; | ||
int index; | ||
double sum; | ||
double sum_sq; | ||
} OutlierDetector; | ||
|
||
void init_detector(OutlierDetector *detector) { | ||
for (int i = 0; i < OUTLIER_DETECTOR_WINDOW_SIZE; i++) { | ||
detector->buffer[i] = 0.0; | ||
} | ||
detector->count = 0; | ||
detector->index = 0; | ||
detector->sum = 0.0; | ||
detector->sum_sq = 0.0; | ||
} | ||
|
||
double update_detector(OutlierDetector *detector, double new_value) { | ||
|
||
if (detector->count < OUTLIER_DETECTOR_WINDOW_SIZE) { | ||
// here we are still building up a window of observations | ||
detector->buffer[detector->count] = new_value; | ||
detector->sum += new_value; | ||
detector->sum_sq += new_value * new_value; | ||
detector->count++; | ||
return nan(""); // not enough data yet | ||
|
||
} else { | ||
// we've filled the window, so now we can start detecting outliers | ||
|
||
// pop the oldest value from the window | ||
double old_value = detector->buffer[detector->index]; | ||
detector->sum -= old_value; | ||
detector->sum_sq -= old_value * old_value; | ||
// push the new value into the window | ||
detector->buffer[detector->index] = new_value; | ||
detector->sum += new_value; | ||
detector->sum_sq += new_value * new_value; | ||
// move the index to the next position | ||
detector->index = (detector->index + 1) % OUTLIER_DETECTOR_WINDOW_SIZE; | ||
// calculate the z-score of the new value | ||
double mean = detector->sum / OUTLIER_DETECTOR_WINDOW_SIZE; | ||
double variance = (detector->sum_sq / OUTLIER_DETECTOR_WINDOW_SIZE) - (mean * mean); | ||
double std_dev = sqrt(variance); | ||
if (std_dev == 0.0) { | ||
return 0.0; | ||
} | ||
double z = (new_value - mean) / std_dev; | ||
|
||
return z; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters