-
Notifications
You must be signed in to change notification settings - Fork 42
/
server.go
303 lines (253 loc) · 8.88 KB
/
server.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
// Copyright 2017-2023 Block, Inc.
package rce
import (
"crypto/tls"
"errors"
"log"
"net"
"github.com/square/rce-agent/cmd"
pb "github.com/square/rce-agent/pb"
"golang.org/x/net/context"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials"
)
var (
// ErrInvalidServerConfigAllowAnyCommand is returned by Server.StartServer() when
// ServerConfig.AllowAnyCommand is true but ServerConfig.AllowedCommands is non-nil.
ErrInvalidServerConfigAllowAnyCommand = errors.New("invalid ServerConfig: AllowAnyCommand is true but AllowedCommands is non-nil")
// ErrInvalidServerConfigDisableSecurity is returned by Server.StartServer()
// when ServerConfig.AllowAnyCommand is true and ServerConfig.TLS is nil but
// ServerConfig.DisableSecurity is false.
ErrInvalidServerConfigDisableSecurity = errors.New("invalid ServerConfig: AllowAnyCommand enabled but TLS is nil")
// ErrCommandNotAllowed is safeguard error returned by the internal gRPC server when
// ServerConfig.AllowedCommands is nil and ServerConfig.AllowAnyCommand is false.
// This should not happen because these values are validated in Server.StartServer()
// before starting the internal gRPC server. If this error occurs, there is a bug
// in ServerConfig validation code.
ErrCommandNotAllowed = errors.New("command not allowed")
)
// A Server executes a whitelist of commands when called by clients.
type Server interface {
// Start the gRPC server, non-blocking.
StartServer() error
// Stop the gRPC server gracefully.
StopServer() error
pb.RCEAgentServer
}
// ServerConfig configures a Server.
type ServerConfig struct {
// ----------------------------------------------------------------------
// Required values
// Addr is the required host:post listen address.
Addr string
// AllowedCommands is the list of commands the server is allowed to run.
// By default, no commands are allowed; commands must be explicitly allowed.
AllowedCommands cmd.Runnable
// ----------------------------------------------------------------------
// Optional values
// AllowAnyCommand allows any commands if AllowedCommands is nil.
// This is not recommended. If true, TLS must be specified (non-nil);
// or, to enable AllowAnyCommand without TLS, DisableSecurity must be true.
AllowAnyCommand bool
// DisableSecurity allows AllowAnyCommand without TLS: an insecure server that
// can execute any command from any client.
//
// This option should not be used.
DisableSecurity bool
// TLS specifies the TLS configuration for secure and verified communication.
// Use TLSFiles.TLSConfig() to load TLS files and configure for server and
// client verification.
TLS *tls.Config
}
func NewServerWithConfig(cfg ServerConfig) Server {
// Set log flags here so other pkgs can't override in their init().
log.SetFlags(log.Ldate | log.Lmicroseconds | log.Lshortfile | log.LUTC)
s := &server{
cfg: cfg,
// --
repo: cmd.NewRepo(),
}
// Create a gRPC server and register this agent a implementing the
// RCEAgentServer interface and protocol
var grpcServer *grpc.Server
if cfg.TLS != nil {
opt := grpc.Creds(credentials.NewTLS(cfg.TLS))
grpcServer = grpc.NewServer(opt)
} else {
grpcServer = grpc.NewServer()
}
s.grpcServer = grpcServer
return s
}
// Internal implementation of pb.RCEAgentServer interface.
type server struct {
cfg ServerConfig
// --
repo cmd.Repo // running commands
grpcServer *grpc.Server // gRPC server instance of this agent
}
// NewServer makes a new Server that listens on laddr and runs the whitelist
// of commands. If tlsConfig is nil, the sever is insecure.
func NewServer(laddr string, tlsConfig *tls.Config, whitelist cmd.Runnable) Server {
return NewServerWithConfig(ServerConfig{
Addr: laddr,
AllowedCommands: whitelist,
TLS: tlsConfig,
})
}
func (s *server) StartServer() error {
// Validate the combination of server configs that disable security
if s.cfg.AllowAnyCommand {
// To allow any command, there can't be any allow list
if s.cfg.AllowedCommands != nil {
return ErrInvalidServerConfigAllowAnyCommand
} else {
log.Printf("WARNING: all commands are allowed!\n")
}
// And to to allow any command without TLS, the user must explicitly
// disble all security
if s.cfg.TLS == nil && !s.cfg.DisableSecurity {
return ErrInvalidServerConfigDisableSecurity
} else {
log.Printf("WARNING: all security is disabled!\n")
}
}
// Register the RCEAgent service with the gRPC server.
pb.RegisterRCEAgentServer(s.grpcServer, s)
lis, err := net.Listen("tcp", s.cfg.Addr)
if err != nil {
return err
}
go s.grpcServer.Serve(lis)
if s.cfg.TLS != nil {
log.Printf("secure server listening on %s", s.cfg.Addr)
} else {
log.Printf("insecure server listening on %s", s.cfg.Addr)
}
return nil
}
func (s *server) StopServer() error {
s.grpcServer.GracefulStop()
log.Printf("server stopped on %s", s.cfg.Addr)
return nil
}
// //////////////////////////////////////////////////////////////////////////
// pb.RCEAgentServer interface methods
// //////////////////////////////////////////////////////////////////////////
func (s *server) Start(ctx context.Context, c *pb.Command) (*pb.ID, error) {
id := &pb.ID{} // @todo we return this on error, but should be "return nil, <err>"
var rceCmd *cmd.Cmd // from AllowedCommands or an arbitrary if AllowAnyCommand
var path string // for logging below
if s.cfg.AllowedCommands != nil {
spec, err := s.cfg.AllowedCommands.FindByName(c.Name)
if err != nil {
log.Printf("unknown command: %s", c.Name)
return id, grpc.Errorf(codes.InvalidArgument, "unknown command: %s", c.Name)
}
// Append cmd request args to cmd spec args
rceCmd = cmd.NewCmd(spec, append(spec.Args(), c.Arguments...))
path = spec.Path()
} else if s.cfg.AllowAnyCommand {
// Make a spec for this arbitrary command
spec := cmd.Spec{
Name: c.Name, // any command, like "/usr/local/bin/gofmt"
Exec: append([]string{c.Name}, c.Arguments...),
}
rceCmd = cmd.NewCmd(spec, c.Arguments)
path = c.Name
} else {
return id, ErrCommandNotAllowed
}
if err := s.repo.Add(rceCmd); err != nil {
// This should never happen
log.Printf("duplicate command: %+v", rceCmd)
return id, grpc.Errorf(codes.AlreadyExists, "duplicate command: %s", rceCmd.Id)
}
log.Printf("cmd=%s: start: %s path: %s args: %v", rceCmd.Id, c.Name, path, rceCmd.Args)
rceCmd.Cmd.Start()
id.ID = rceCmd.Id
return id, nil
}
func (s *server) Wait(ctx context.Context, id *pb.ID) (*pb.Status, error) {
log.Printf("cmd=%s: wait", id.ID)
defer log.Printf("cmd=%s: wait return", id.ID)
cmd := s.repo.Get(id.ID)
if cmd == nil {
return nil, notFound(id)
}
// Reap the command
defer s.repo.Remove(id.ID)
// Wait for command or ctx to finish
select {
case <-cmd.Cmd.Done():
case <-ctx.Done():
}
// Get final status of command and convert to pb.Status. If ctx was canceled
// and command still running, its status will indicate this and ctx.Err()
// below will return an error, else it will return nil.
return mapStatus(cmd), ctx.Err()
}
func (s *server) GetStatus(ctx context.Context, id *pb.ID) (*pb.Status, error) {
log.Printf("cmd=%s: status", id.ID)
cmd := s.repo.Get(id.ID)
if cmd == nil {
return nil, notFound(id)
}
return mapStatus(cmd), nil
}
func (s *server) Stop(ctx context.Context, id *pb.ID) (*pb.Empty, error) {
log.Printf("cmd=%s: stop", id.ID)
cmd := s.repo.Get(id.ID)
if cmd == nil {
return nil, notFound(id)
}
cmd.Cmd.Stop()
return &pb.Empty{}, nil
}
func (s *server) Running(empty *pb.Empty, stream pb.RCEAgent_RunningServer) error {
log.Println("list running")
for _, id := range s.repo.All() {
if err := stream.Send(&pb.ID{ID: id}); err != nil {
return err
}
}
return nil
}
func notFound(id *pb.ID) error {
return grpc.Errorf(codes.NotFound, "command ID %s not found", id.ID)
}
func mapStatus(cmd *cmd.Cmd) *pb.Status {
cmdStatus := cmd.Cmd.Status()
var errMsg string
if cmdStatus.Error != nil {
errMsg = cmdStatus.Error.Error()
}
// Make a pb.Status struct by adding and mapping some fields
pbStatus := &pb.Status{
ID: cmd.Id, // add
Name: cmd.Name, // add
ExitCode: int64(cmdStatus.Exit), // map
Error: errMsg, // map
PID: int64(cmdStatus.PID), // map
StartTime: cmdStatus.StartTs, // map
StopTime: cmdStatus.StopTs, // map
Args: cmd.Args, // map
Stdout: cmdStatus.Stdout, // same
Stderr: cmdStatus.Stderr, // same
}
// Map go-cmd status to pb state
switch {
case cmdStatus.StartTs == 0 && cmdStatus.StopTs == 0:
pbStatus.State = pb.STATE_PENDING
case cmdStatus.StartTs > 0 && cmdStatus.StopTs == 0:
pbStatus.State = pb.STATE_RUNNING
case cmdStatus.StopTs > 0 && cmdStatus.Exit == 0:
pbStatus.State = pb.STATE_COMPLETE
case cmdStatus.StopTs > 0 && cmdStatus.Exit != 0:
pbStatus.State = pb.STATE_FAIL
default:
pbStatus.State = pb.STATE_UNKNOWN
}
return pbStatus
}