diff --git a/internal/server/proto.go b/internal/server/proto.go index 8848d34d03989a8bf09c44c24ae0c62c78e7988d..e12de149904db415b4d9fffb35d5b32ed4408af9 100644 --- a/internal/server/proto.go +++ b/internal/server/proto.go @@ -138,6 +138,34 @@ func (c *controllerV1) handlePostInstanceAgentSessionCancel(w http.ResponseWrite } } +func (c *controllerV1) handleGetInstanceAgentSession(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + sid := r.PathValue("sid") + se, err := ins.App.Sessions.Get(r.Context(), sid) + if err != nil { + c.logError(r, "failed to get session", "error", err, "id", id, "sid", sid) + jsonError(w, http.StatusInternalServerError, "failed to get session") + return + } + + var isSessionBusy bool + if ins.App.CoderAgent != nil { + isSessionBusy = ins.App.CoderAgent.IsSessionBusy(sid) + } + + jsonEncode(w, proto.AgentSession{ + Session: se, + IsBusy: isSessionBusy, + }) +} + func (c *controllerV1) handlePostInstanceAgent(w http.ResponseWriter, r *http.Request) { id := r.PathValue("id") ins, ok := c.instances.Get(id) @@ -147,6 +175,8 @@ func (c *controllerV1) handlePostInstanceAgent(w http.ResponseWriter, r *http.Re return } + w.Header().Set("Accept", "application/json") + var msg proto.AgentMessage if err := json.NewDecoder(r.Body).Decode(&msg); err != nil { c.logError(r, "failed to decode request", "error", err) @@ -160,7 +190,10 @@ func (c *controllerV1) handlePostInstanceAgent(w http.ResponseWriter, r *http.Re return } - if _, err := ins.App.CoderAgent.Run(r.Context(), msg.SessionID, msg.Prompt, msg.Attachments...); err != nil { + // NOTE: This needs to be on the server's context because the agent runs + // the request asynchronously. + // TODO: Look into this one more and make it work synchronously. + if _, err := ins.App.CoderAgent.Run(c.ctx, msg.SessionID, msg.Prompt, msg.Attachments...); err != nil { c.logError(r, "failed to enqueue message", "error", err, "id", id, "sid", msg.SessionID) jsonError(w, http.StatusInternalServerError, "failed to enqueue message") return @@ -270,7 +303,7 @@ func (c *controllerV1) handleGetInstanceSession(w http.ResponseWriter, r *http.R sid := r.PathValue("sid") session, err := ins.App.Sessions.Get(r.Context(), sid) if err != nil { - c.logError(r, "failed to get session", "error", err, "id", id, "sid", sid) + c.logError(r, "failedto get session", "error", err, "id", id, "sid", sid) jsonError(w, http.StatusInternalServerError, "failed to get session") return } @@ -323,6 +356,68 @@ func (c *controllerV1) handleGetInstanceSessions(w http.ResponseWriter, r *http. jsonEncode(w, sessions) } +func (c *controllerV1) handlePostInstancePermissionsGrant(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + var req proto.PermissionGrant + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.logError(r, "failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + switch req.Action { + case proto.PermissionAllow: + ins.App.Permissions.Grant(req.Permission) + case proto.PermissionAllowForSession: + ins.App.Permissions.GrantPersistent(req.Permission) + case proto.PermissionDeny: + ins.App.Permissions.Deny(req.Permission) + default: + c.logError(r, "invalid permission action", "action", req.Action) + jsonError(w, http.StatusBadRequest, "invalid permission action") + return + } +} + +func (c *controllerV1) handlePostInstancePermissionsSkip(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + var req proto.PermissionSkipRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + c.logError(r, "failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + ins.App.Permissions.SetSkipRequests(req.Skip) +} + +func (c *controllerV1) handleGetInstancePermissionsSkip(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + skip := ins.App.Permissions.SkipRequests() + jsonEncode(w, proto.PermissionSkipRequest{Skip: skip}) +} + func (c *controllerV1) handleGetInstanceEvents(w http.ResponseWriter, r *http.Request) { flusher := http.NewResponseController(w) id := r.PathValue("id") @@ -340,8 +435,10 @@ func (c *controllerV1) handleGetInstanceEvents(w http.ResponseWriter, r *http.Re for { select { case <-r.Context().Done(): + c.logDebug(r, "stopping event stream") return case ev := <-ins.App.Events(): + c.logDebug(r, "sending event", "event", fmt.Sprintf("%T %+v", ev, ev)) data, err := json.Marshal(ev) if err != nil { c.logError(r, "failed to marshal event", "error", err) @@ -354,6 +451,18 @@ func (c *controllerV1) handleGetInstanceEvents(w http.ResponseWriter, r *http.Re } } +func (c *controllerV1) handleGetInstanceConfig(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + jsonEncode(w, ins.cfg) +} + func (c *controllerV1) handleDeleteInstances(w http.ResponseWriter, r *http.Request) { var ids []string id := r.URL.Query().Get("id") @@ -377,6 +486,10 @@ func (c *controllerV1) handleDeleteInstances(w http.ResponseWriter, r *http.Requ }()...) for _, id := range ids { + ins, ok := c.instances.Get(id) + if ok { + ins.App.Shutdown() + } c.instances.Del(id) } } @@ -389,7 +502,6 @@ func (c *controllerV1) handlePostInstances(w http.ResponseWriter, r *http.Reques return } - ctx := r.Context() hasher := sha256.New() hasher.Write([]byte(filepath.Clean(args.Path))) id := hex.EncodeToString(hasher.Sum(nil)) @@ -417,21 +529,21 @@ func (c *controllerV1) handlePostInstances(w http.ResponseWriter, r *http.Reques } cfg.Permissions.SkipRequests = args.YOLO - if err := createDotCrushDir(args.DataDir); err != nil { + if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil { c.logError(r, "failed to create data directory", "error", err) jsonError(w, http.StatusInternalServerError, "failed to create data directory") return } // Connect to DB; this will also run migrations. - conn, err := db.Connect(ctx, args.DataDir) + conn, err := db.Connect(c.ctx, cfg.Options.DataDirectory) if err != nil { c.logError(r, "failed to connect to database", "error", err) jsonError(w, http.StatusInternalServerError, "failed to connect to database") return } - appInstance, err := app.New(ctx, conn, cfg) + appInstance, err := app.New(c.ctx, conn, cfg) if err != nil { slog.Error("failed to create app instance", "error", err) jsonError(w, http.StatusInternalServerError, "failed to create app instance") diff --git a/internal/server/server.go b/internal/server/server.go index d4188aaaa0f99f419f84e86f5b521697c508e1d4..fe497768ebe03247c5c55931f06f6225d0b9cdee 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -72,8 +72,9 @@ type Server struct { // Addr can be a TCP address, a Unix socket path, or a Windows named pipe. Addr string - h *http.Server - ln net.Listener + h *http.Server + ln net.Listener + ctx context.Context // instances is a map of running applications managed by the server. instances *csync.Map[string, *Instance] @@ -106,6 +107,7 @@ func NewServer(cfg *config.Config, network, address string) *Server { s.Addr = address s.cfg = cfg s.instances = csync.NewMap[string, *Instance]() + s.ctx = context.Background() var p http.Protocols p.SetHTTP1(true) @@ -116,6 +118,7 @@ func NewServer(cfg *config.Config, network, address string) *Server { mux.HandleFunc("GET /v1/instances", c.handleGetInstances) mux.HandleFunc("POST /v1/instances", c.handlePostInstances) mux.HandleFunc("DELETE /v1/instances", c.handleDeleteInstances) + mux.HandleFunc("GET /v1/instances/{id}/config", c.handleGetInstanceConfig) mux.HandleFunc("GET /v1/instances/{id}/events", c.handleGetInstanceEvents) mux.HandleFunc("GET /v1/instances/{id}/sessions", c.handleGetInstanceSessions) mux.HandleFunc("POST /v1/instances/{id}/sessions", c.handlePostInstanceSessions) @@ -124,10 +127,14 @@ func NewServer(cfg *config.Config, network, address string) *Server { mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}/messages", c.handleGetInstanceSessionMessages) mux.HandleFunc("GET /v1/instances/{id}/lsps", c.handleGetInstanceLSPs) mux.HandleFunc("GET /v1/instances/{id}/lsps/{lsp}/diagnostics", c.handleGetInstanceLSPDiagnostics) + mux.HandleFunc("GET /v1/instances/{id}/permissions/skip", c.handleGetInstancePermissionsSkip) + mux.HandleFunc("POST /v1/instances/{id}/permissions/skip", c.handlePostInstancePermissionsSkip) + mux.HandleFunc("POST /v1/instances/{id}/permissions/grant", c.handlePostInstancePermissionsGrant) mux.HandleFunc("GET /v1/instances/{id}/agent", c.handleGetInstanceAgent) mux.HandleFunc("POST /v1/instances/{id}/agent", c.handlePostInstanceAgent) mux.HandleFunc("POST /v1/instances/{id}/agent/init", c.handlePostInstanceAgentInit) mux.HandleFunc("POST /v1/instances/{id}/agent/update", c.handlePostInstanceAgentUpdate) + mux.HandleFunc("GET /v1/instances/{id}/agent/sessions/{sid}", c.handleGetInstanceAgentSession) mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/cancel", c.handlePostInstanceAgentSessionCancel) mux.HandleFunc("GET /v1/instances/{id}/agent/sessions/{sid}/prompts/queued", c.handleGetInstanceAgentSessionPromptQueued) mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/prompts/clear", c.handlePostInstanceAgentSessionPromptClear) @@ -177,6 +184,16 @@ func (s *Server) Shutdown(ctx context.Context) error { return s.h.Shutdown(ctx) } +func (s *Server) logDebug(r *http.Request, msg string, args ...any) { + if s.logger != nil { + s.logger.With( + slog.String("method", r.Method), + slog.String("url", r.URL.String()), + slog.String("remote_addr", r.RemoteAddr), + ).Debug(msg, args...) + } +} + func (s *Server) logError(r *http.Request, msg string, args ...any) { if s.logger != nil { s.logger.With(