feat(server): add endpoints for managing permissions and fetching config

Ayman Bagabas created

Change summary

internal/server/proto.go  | 124 +++++++++++++++++++++++++++++++++++++++-
internal/server/server.go |  21 ++++++
2 files changed, 137 insertions(+), 8 deletions(-)

Detailed changes

internal/server/proto.go 🔗

@@ -138,6 +138,34 @@ func (c *controllerV1) handlePostInstanceAgentSessionCancel(w http.ResponseWrite
 	}
 }
 
+func (c *controllerV1) handleGetInstanceAgentSession(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	sid := r.PathValue("sid")
+	se, err := ins.App.Sessions.Get(r.Context(), sid)
+	if err != nil {
+		c.logError(r, "failed to get session", "error", err, "id", id, "sid", sid)
+		jsonError(w, http.StatusInternalServerError, "failed to get session")
+		return
+	}
+
+	var isSessionBusy bool
+	if ins.App.CoderAgent != nil {
+		isSessionBusy = ins.App.CoderAgent.IsSessionBusy(sid)
+	}
+
+	jsonEncode(w, proto.AgentSession{
+		Session: se,
+		IsBusy:  isSessionBusy,
+	})
+}
+
 func (c *controllerV1) handlePostInstanceAgent(w http.ResponseWriter, r *http.Request) {
 	id := r.PathValue("id")
 	ins, ok := c.instances.Get(id)
@@ -147,6 +175,8 @@ func (c *controllerV1) handlePostInstanceAgent(w http.ResponseWriter, r *http.Re
 		return
 	}
 
+	w.Header().Set("Accept", "application/json")
+
 	var msg proto.AgentMessage
 	if err := json.NewDecoder(r.Body).Decode(&msg); err != nil {
 		c.logError(r, "failed to decode request", "error", err)
@@ -160,7 +190,10 @@ func (c *controllerV1) handlePostInstanceAgent(w http.ResponseWriter, r *http.Re
 		return
 	}
 
-	if _, err := ins.App.CoderAgent.Run(r.Context(), msg.SessionID, msg.Prompt, msg.Attachments...); err != nil {
+	// NOTE: This needs to be on the server's context because the agent runs
+	// the request asynchronously.
+	// TODO: Look into this one more and make it work synchronously.
+	if _, err := ins.App.CoderAgent.Run(c.ctx, msg.SessionID, msg.Prompt, msg.Attachments...); err != nil {
 		c.logError(r, "failed to enqueue message", "error", err, "id", id, "sid", msg.SessionID)
 		jsonError(w, http.StatusInternalServerError, "failed to enqueue message")
 		return
@@ -270,7 +303,7 @@ func (c *controllerV1) handleGetInstanceSession(w http.ResponseWriter, r *http.R
 	sid := r.PathValue("sid")
 	session, err := ins.App.Sessions.Get(r.Context(), sid)
 	if err != nil {
-		c.logError(r, "failed to get session", "error", err, "id", id, "sid", sid)
+		c.logError(r, "failedto get session", "error", err, "id", id, "sid", sid)
 		jsonError(w, http.StatusInternalServerError, "failed to get session")
 		return
 	}
@@ -323,6 +356,68 @@ func (c *controllerV1) handleGetInstanceSessions(w http.ResponseWriter, r *http.
 	jsonEncode(w, sessions)
 }
 
+func (c *controllerV1) handlePostInstancePermissionsGrant(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	var req proto.PermissionGrant
+	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+		c.logError(r, "failed to decode request", "error", err)
+		jsonError(w, http.StatusBadRequest, "failed to decode request")
+		return
+	}
+
+	switch req.Action {
+	case proto.PermissionAllow:
+		ins.App.Permissions.Grant(req.Permission)
+	case proto.PermissionAllowForSession:
+		ins.App.Permissions.GrantPersistent(req.Permission)
+	case proto.PermissionDeny:
+		ins.App.Permissions.Deny(req.Permission)
+	default:
+		c.logError(r, "invalid permission action", "action", req.Action)
+		jsonError(w, http.StatusBadRequest, "invalid permission action")
+		return
+	}
+}
+
+func (c *controllerV1) handlePostInstancePermissionsSkip(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	var req proto.PermissionSkipRequest
+	if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
+		c.logError(r, "failed to decode request", "error", err)
+		jsonError(w, http.StatusBadRequest, "failed to decode request")
+		return
+	}
+
+	ins.App.Permissions.SetSkipRequests(req.Skip)
+}
+
+func (c *controllerV1) handleGetInstancePermissionsSkip(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	skip := ins.App.Permissions.SkipRequests()
+	jsonEncode(w, proto.PermissionSkipRequest{Skip: skip})
+}
+
 func (c *controllerV1) handleGetInstanceEvents(w http.ResponseWriter, r *http.Request) {
 	flusher := http.NewResponseController(w)
 	id := r.PathValue("id")
@@ -340,8 +435,10 @@ func (c *controllerV1) handleGetInstanceEvents(w http.ResponseWriter, r *http.Re
 	for {
 		select {
 		case <-r.Context().Done():
+			c.logDebug(r, "stopping event stream")
 			return
 		case ev := <-ins.App.Events():
+			c.logDebug(r, "sending event", "event", fmt.Sprintf("%T %+v", ev, ev))
 			data, err := json.Marshal(ev)
 			if err != nil {
 				c.logError(r, "failed to marshal event", "error", err)
@@ -354,6 +451,18 @@ func (c *controllerV1) handleGetInstanceEvents(w http.ResponseWriter, r *http.Re
 	}
 }
 
+func (c *controllerV1) handleGetInstanceConfig(w http.ResponseWriter, r *http.Request) {
+	id := r.PathValue("id")
+	ins, ok := c.instances.Get(id)
+	if !ok {
+		c.logError(r, "instance not found", "id", id)
+		jsonError(w, http.StatusNotFound, "instance not found")
+		return
+	}
+
+	jsonEncode(w, ins.cfg)
+}
+
 func (c *controllerV1) handleDeleteInstances(w http.ResponseWriter, r *http.Request) {
 	var ids []string
 	id := r.URL.Query().Get("id")
@@ -377,6 +486,10 @@ func (c *controllerV1) handleDeleteInstances(w http.ResponseWriter, r *http.Requ
 	}()...)
 
 	for _, id := range ids {
+		ins, ok := c.instances.Get(id)
+		if ok {
+			ins.App.Shutdown()
+		}
 		c.instances.Del(id)
 	}
 }
@@ -389,7 +502,6 @@ func (c *controllerV1) handlePostInstances(w http.ResponseWriter, r *http.Reques
 		return
 	}
 
-	ctx := r.Context()
 	hasher := sha256.New()
 	hasher.Write([]byte(filepath.Clean(args.Path)))
 	id := hex.EncodeToString(hasher.Sum(nil))
@@ -417,21 +529,21 @@ func (c *controllerV1) handlePostInstances(w http.ResponseWriter, r *http.Reques
 	}
 	cfg.Permissions.SkipRequests = args.YOLO
 
-	if err := createDotCrushDir(args.DataDir); err != nil {
+	if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil {
 		c.logError(r, "failed to create data directory", "error", err)
 		jsonError(w, http.StatusInternalServerError, "failed to create data directory")
 		return
 	}
 
 	// Connect to DB; this will also run migrations.
-	conn, err := db.Connect(ctx, args.DataDir)
+	conn, err := db.Connect(c.ctx, cfg.Options.DataDirectory)
 	if err != nil {
 		c.logError(r, "failed to connect to database", "error", err)
 		jsonError(w, http.StatusInternalServerError, "failed to connect to database")
 		return
 	}
 
-	appInstance, err := app.New(ctx, conn, cfg)
+	appInstance, err := app.New(c.ctx, conn, cfg)
 	if err != nil {
 		slog.Error("failed to create app instance", "error", err)
 		jsonError(w, http.StatusInternalServerError, "failed to create app instance")

internal/server/server.go 🔗

@@ -72,8 +72,9 @@ type Server struct {
 	// Addr can be a TCP address, a Unix socket path, or a Windows named pipe.
 	Addr string
 
-	h  *http.Server
-	ln net.Listener
+	h   *http.Server
+	ln  net.Listener
+	ctx context.Context
 
 	// instances is a map of running applications managed by the server.
 	instances *csync.Map[string, *Instance]
@@ -106,6 +107,7 @@ func NewServer(cfg *config.Config, network, address string) *Server {
 	s.Addr = address
 	s.cfg = cfg
 	s.instances = csync.NewMap[string, *Instance]()
+	s.ctx = context.Background()
 
 	var p http.Protocols
 	p.SetHTTP1(true)
@@ -116,6 +118,7 @@ func NewServer(cfg *config.Config, network, address string) *Server {
 	mux.HandleFunc("GET /v1/instances", c.handleGetInstances)
 	mux.HandleFunc("POST /v1/instances", c.handlePostInstances)
 	mux.HandleFunc("DELETE /v1/instances", c.handleDeleteInstances)
+	mux.HandleFunc("GET /v1/instances/{id}/config", c.handleGetInstanceConfig)
 	mux.HandleFunc("GET /v1/instances/{id}/events", c.handleGetInstanceEvents)
 	mux.HandleFunc("GET /v1/instances/{id}/sessions", c.handleGetInstanceSessions)
 	mux.HandleFunc("POST /v1/instances/{id}/sessions", c.handlePostInstanceSessions)
@@ -124,10 +127,14 @@ func NewServer(cfg *config.Config, network, address string) *Server {
 	mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}/messages", c.handleGetInstanceSessionMessages)
 	mux.HandleFunc("GET /v1/instances/{id}/lsps", c.handleGetInstanceLSPs)
 	mux.HandleFunc("GET /v1/instances/{id}/lsps/{lsp}/diagnostics", c.handleGetInstanceLSPDiagnostics)
+	mux.HandleFunc("GET /v1/instances/{id}/permissions/skip", c.handleGetInstancePermissionsSkip)
+	mux.HandleFunc("POST /v1/instances/{id}/permissions/skip", c.handlePostInstancePermissionsSkip)
+	mux.HandleFunc("POST /v1/instances/{id}/permissions/grant", c.handlePostInstancePermissionsGrant)
 	mux.HandleFunc("GET /v1/instances/{id}/agent", c.handleGetInstanceAgent)
 	mux.HandleFunc("POST /v1/instances/{id}/agent", c.handlePostInstanceAgent)
 	mux.HandleFunc("POST /v1/instances/{id}/agent/init", c.handlePostInstanceAgentInit)
 	mux.HandleFunc("POST /v1/instances/{id}/agent/update", c.handlePostInstanceAgentUpdate)
+	mux.HandleFunc("GET /v1/instances/{id}/agent/sessions/{sid}", c.handleGetInstanceAgentSession)
 	mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/cancel", c.handlePostInstanceAgentSessionCancel)
 	mux.HandleFunc("GET /v1/instances/{id}/agent/sessions/{sid}/prompts/queued", c.handleGetInstanceAgentSessionPromptQueued)
 	mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/prompts/clear", c.handlePostInstanceAgentSessionPromptClear)
@@ -177,6 +184,16 @@ func (s *Server) Shutdown(ctx context.Context) error {
 	return s.h.Shutdown(ctx)
 }
 
+func (s *Server) logDebug(r *http.Request, msg string, args ...any) {
+	if s.logger != nil {
+		s.logger.With(
+			slog.String("method", r.Method),
+			slog.String("url", r.URL.String()),
+			slog.String("remote_addr", r.RemoteAddr),
+		).Debug(msg, args...)
+	}
+}
+
 func (s *Server) logError(r *http.Request, msg string, args ...any) {
 	if s.logger != nil {
 		s.logger.With(