diff --git a/README.md b/README.md index 7fa5e445..f6152f77 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ # Polygon Go Client -![Coverage](https://img.shields.io/badge/Coverage-76.7%25-brightgreen) +![Coverage](https://img.shields.io/badge/Coverage-77.4%25-brightgreen) diff --git a/websocket/config.go b/websocket/config.go index 14c48b7c..a8faf499 100644 --- a/websocket/config.go +++ b/websocket/config.go @@ -30,6 +30,12 @@ type Config struct { // If this flag is `true`, it's up to the caller to handle all message types including auth and subscription responses. BypassRawDataRouting bool + // ReconnectCallback is a callback that is triggered on automatic reconnects by the websocket client. + // This can be useful for implementing additional logic around reconnect paths e.g. logging, metrics + // or managing the connection. The callback function takes as input an error type which will be non-nil + // if the reconnect attempt has failed and is being retried, and will be nil on reconnect success. + ReconnectCallback func(error) + // Log is an optional logger. Any logger implementation can be used as long as it // implements the basic Logger interface. Omitting this will disable client logging. Log Logger diff --git a/websocket/polygon.go b/websocket/polygon.go index 405c2276..56f92088 100644 --- a/websocket/polygon.go +++ b/websocket/polygon.go @@ -48,7 +48,8 @@ type Client struct { output chan any err chan error - log Logger + reconnectCallback func(error) + log Logger } // New creates a client for the Polygon WebSocket API. @@ -70,6 +71,7 @@ func New(config Config) (*Client, error) { output: make(chan any, 100000), err: make(chan error), log: config.Log, + reconnectCallback: config.ReconnectCallback, } uri, err := url.Parse(string(c.feed)) @@ -246,6 +248,9 @@ func (c *Client) reconnect() { notify := func(err error, _ time.Duration) { c.log.Errorf(err.Error()) + if c.reconnectCallback != nil { + c.reconnectCallback(err) + } } err := backoff.RetryNotify(c.connect(true), c.backoff, notify) if err != nil { @@ -253,6 +258,11 @@ func (c *Client) reconnect() { c.log.Errorf(err.Error()) c.close(false) c.err <- err + } else { + // Callback on success. + if c.reconnectCallback != nil { + c.reconnectCallback(nil) + } } } diff --git a/websocket/polygon_test.go b/websocket/polygon_test.go index fdfbd547..03ba1167 100644 --- a/websocket/polygon_test.go +++ b/websocket/polygon_test.go @@ -152,3 +152,32 @@ func TestConnectRetryFailure(t *testing.T) { assert.NotNil(t, err) c.Close() } + +func TestReconnectCallback(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(connect)) + defer s.Close() + + reconnectCallbackCount := 0 + log := logrus.New() + log.SetLevel(logrus.DebugLevel) + u := "ws" + strings.TrimPrefix(s.URL, "http") + var retries uint64 = 0 + c, err := New(Config{ + APIKey: "good", + Feed: Feed(u), + Market: Market(""), + Log: log, + MaxRetries: &retries, + ReconnectCallback: func(err error) { + assert.Nil(t, err) + reconnectCallbackCount++ + }, + }) + assert.NotNil(t, c) + assert.Nil(t, err) + err = c.Connect() + assert.Nil(t, err) + c.reconnect() + c.Close() + assert.Equal(t, 1, reconnectCallbackCount) +}