diff --git a/mixer/websocket/client.go b/mixer/websocket/client.go index 44b50c34..1a0037f0 100644 --- a/mixer/websocket/client.go +++ b/mixer/websocket/client.go @@ -35,7 +35,7 @@ type Client struct { // NewClient creates a websocket client. func NewClient(rootPath string, cfg tp.PeerConfig, globalLeftPlugin ...tp.Plugin) *Client { - globalLeftPlugin = append(globalLeftPlugin, NewDialPlugin(rootPath)) + globalLeftPlugin = append([]tp.Plugin{NewDialPlugin(rootPath)}, globalLeftPlugin...) peer := tp.NewPeer(cfg, globalLeftPlugin...) return &Client{ Peer: peer, diff --git a/mixer/websocket/server.go b/mixer/websocket/server.go index f19e2440..b12399ac 100644 --- a/mixer/websocket/server.go +++ b/mixer/websocket/server.go @@ -183,6 +183,7 @@ func (w *serverHandler) handler(conn *ws.Conn) { sess, err := w.peer.ServeConn(conn, w.protoFunc) if err != nil { tp.Errorf("serverHandler: %v", err) + return } <-sess.CloseNotify() } diff --git a/mixer/websocket/websocket_test.go b/mixer/websocket/websocket_test.go index 7fa7832c..205fe12f 100644 --- a/mixer/websocket/websocket_test.go +++ b/mixer/websocket/websocket_test.go @@ -9,6 +9,7 @@ import ( tp "github.com/henrylee2cn/teleport" ws "github.com/henrylee2cn/teleport/mixer/websocket" "github.com/henrylee2cn/teleport/mixer/websocket/jsonSubProto" + "github.com/henrylee2cn/teleport/plugin/auth" ) type Arg struct { @@ -98,3 +99,67 @@ func TestCustomizedWebsocket(t *testing.T) { t.Logf("10/2=%d", result) time.Sleep(time.Second) } + +func TestJSONWebsocketAuth(t *testing.T) { + srv := ws.NewServer( + "/", + tp.PeerConfig{ListenPort: 9090}, + authChecker, + ) + srv.RouteCall(new(P)) + go srv.ListenAndServe() + + time.Sleep(time.Second * 1) + + cli := ws.NewClient( + "/", + tp.PeerConfig{}, + authBearer, + ) + sess, err := cli.Dial(":9090") + if err != nil { + t.Fatal(err) + } + var result int + rerr := sess.Call("/p/divide", &Arg{ + A: 10, + B: 2, + }, &result, + ).Rerror() + if rerr != nil { + t.Fatal(rerr) + } + t.Logf("10/2=%d", result) + time.Sleep(time.Second) +} + +const clientAuthInfo = "client-auth-info-12345" + +var authBearer = auth.NewBearerPlugin( + func(sess auth.Session, fn auth.SendOnce) (rerr *tp.Rerror) { + var ret string + rerr = fn(clientAuthInfo, &ret) + if rerr.HasError() { + return + } + tp.Infof("auth info: %s, result: %s", clientAuthInfo, ret) + return + }, + tp.WithBodyCodec('s'), +) + +var authChecker = auth.NewCheckerPlugin( + func(sess auth.Session, fn auth.RecvOnce) (ret interface{}, rerr *tp.Rerror) { + var authInfo string + rerr = fn(&authInfo) + if rerr.HasError() { + return + } + tp.Infof("auth info: %v", authInfo) + if clientAuthInfo != authInfo { + return nil, tp.NewRerror(403, "auth fail", "auth fail detail") + } + return "pass", nil + }, + tp.WithBodyCodec('s'), +)