diff --git a/internal/server/proto.go b/internal/server/proto.go index b732f866dca6fe694257df060a349d9d1f5cf1a2..8848d34d03989a8bf09c44c24ae0c62c78e7988d 100644 --- a/internal/server/proto.go +++ b/internal/server/proto.go @@ -13,16 +13,22 @@ import ( "github.com/charmbracelet/crush/internal/app" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/db" + "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/proto" + "github.com/charmbracelet/crush/internal/session" ) -func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) { - jsonEncode(w, s.cfg) +type controllerV1 struct { + *Server } -func (s *Server) handleGetInstances(w http.ResponseWriter, r *http.Request) { +func (c *controllerV1) handleGetConfig(w http.ResponseWriter, r *http.Request) { + jsonEncode(w, c.cfg) +} + +func (c *controllerV1) handleGetInstances(w http.ResponseWriter, r *http.Request) { instances := []proto.Instance{} - for _, ins := range s.instances.Seq2() { + for _, ins := range c.instances.Seq2() { instances = append(instances, proto.Instance{ ID: ins.ID(), Path: ins.Path(), @@ -32,12 +38,297 @@ func (s *Server) handleGetInstances(w http.ResponseWriter, r *http.Request) { jsonEncode(w, instances) } -func (s *Server) handleGetInstanceEvents(w http.ResponseWriter, r *http.Request) { +func (c *controllerV1) handleGetInstanceLSPDiagnostics(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + var lsp *lsp.Client + lspName := r.PathValue("lsp") + for name, client := range ins.LSPClients.Seq2() { + if name == lspName { + lsp = client + break + } + } + + if lsp == nil { + c.logError(r, "LSP client not found", "id", id, "lsp", lspName) + jsonError(w, http.StatusNotFound, "LSP client not found") + return + } + + diagnostics := lsp.GetDiagnostics() + jsonEncode(w, diagnostics) +} + +func (c *controllerV1) handleGetInstanceLSPs(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + lspClients := ins.GetLSPStates() + jsonEncode(w, lspClients) +} + +func (c *controllerV1) handleGetInstanceAgentSessionPromptQueued(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + sid := r.PathValue("sid") + queued := ins.App.CoderAgent.QueuedPrompts(sid) + jsonEncode(w, queued) +} + +func (c *controllerV1) handlePostInstanceAgentSessionPromptClear(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + sid := r.PathValue("sid") + ins.App.CoderAgent.ClearQueue(sid) +} + +func (c *controllerV1) handleGetInstanceAgentSessionSummarize(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + sid := r.PathValue("sid") + if err := ins.App.CoderAgent.Summarize(r.Context(), sid); err != nil { + c.logError(r, "failed to summarize session", "error", err, "id", id, "sid", sid) + jsonError(w, http.StatusInternalServerError, "failed to summarize session") + return + } +} + +func (c *controllerV1) handlePostInstanceAgentSessionCancel(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + sid := r.PathValue("sid") + if ins.App.CoderAgent != nil { + ins.App.CoderAgent.Cancel(sid) + } +} + +func (c *controllerV1) handlePostInstanceAgent(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + var msg proto.AgentMessage + if err := json.NewDecoder(r.Body).Decode(&msg); err != nil { + c.logError(r, "failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + if ins.App.CoderAgent == nil { + c.logError(r, "coder agent not initialized", "id", id) + jsonError(w, http.StatusBadRequest, "coder agent not initialized") + return + } + + if _, err := ins.App.CoderAgent.Run(r.Context(), msg.SessionID, msg.Prompt, msg.Attachments...); err != nil { + c.logError(r, "failed to enqueue message", "error", err, "id", id, "sid", msg.SessionID) + jsonError(w, http.StatusInternalServerError, "failed to enqueue message") + return + } +} + +func (c *controllerV1) handleGetInstanceAgent(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + var agentInfo proto.AgentInfo + if ins.App.CoderAgent != nil { + agentInfo = proto.AgentInfo{ + Model: ins.App.CoderAgent.Model(), + IsBusy: ins.App.CoderAgent.IsBusy(), + } + } + jsonEncode(w, agentInfo) +} + +func (c *controllerV1) handlePostInstanceAgentUpdate(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + if err := ins.App.UpdateAgentModel(); err != nil { + c.logError(r, "failed to update agent model", "error", err) + jsonError(w, http.StatusInternalServerError, "failed to update agent model") + return + } +} + +func (c *controllerV1) handlePostInstanceAgentInit(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + if err := ins.App.InitCoderAgent(); err != nil { + c.logError(r, "failed to initialize coder agent", "error", err) + jsonError(w, http.StatusInternalServerError, "failed to initialize coder agent") + return + } +} + +func (c *controllerV1) handleGetInstanceSessionHistory(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + sid := r.PathValue("sid") + historyItems, err := ins.App.History.ListBySession(r.Context(), sid) + if err != nil { + c.logError(r, "failed to list history", "error", err, "id", id, "sid", sid) + jsonError(w, http.StatusInternalServerError, "failed to list history") + return + } + + jsonEncode(w, historyItems) +} + +func (c *controllerV1) handleGetInstanceSessionMessages(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + sid := r.PathValue("sid") + messages, err := ins.App.Messages.List(r.Context(), sid) + if err != nil { + c.logError(r, "failed to list messages", "error", err, "id", id, "sid", sid) + jsonError(w, http.StatusInternalServerError, "failed to list messages") + return + } + + jsonEncode(w, messages) +} + +func (c *controllerV1) handleGetInstanceSession(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + sid := r.PathValue("sid") + session, err := ins.App.Sessions.Get(r.Context(), sid) + if err != nil { + c.logError(r, "failed to get session", "error", err, "id", id, "sid", sid) + jsonError(w, http.StatusInternalServerError, "failed to get session") + return + } + + jsonEncode(w, session) +} + +func (c *controllerV1) handlePostInstanceSessions(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + var args session.Session + if err := json.NewDecoder(r.Body).Decode(&args); err != nil { + c.logError(r, "failed to decode request", "error", err) + jsonError(w, http.StatusBadRequest, "failed to decode request") + return + } + + sess, err := ins.App.Sessions.Create(r.Context(), args.Title) + if err != nil { + c.logError(r, "failed to create session", "error", err, "id", id) + jsonError(w, http.StatusInternalServerError, "failed to create session") + return + } + + jsonEncode(w, sess) +} + +func (c *controllerV1) handleGetInstanceSessions(w http.ResponseWriter, r *http.Request) { + id := r.PathValue("id") + ins, ok := c.instances.Get(id) + if !ok { + c.logError(r, "instance not found", "id", id) + jsonError(w, http.StatusNotFound, "instance not found") + return + } + + sessions, err := ins.App.Sessions.List(r.Context()) + if err != nil { + c.logError(r, "failed to list sessions", "error", err) + jsonError(w, http.StatusInternalServerError, "failed to list sessions") + return + } + + jsonEncode(w, sessions) +} + +func (c *controllerV1) handleGetInstanceEvents(w http.ResponseWriter, r *http.Request) { flusher := http.NewResponseController(w) id := r.PathValue("id") - ins, ok := s.instances.Get(id) + ins, ok := c.instances.Get(id) if !ok { - s.logError(r, "instance not found", "id", id) + c.logError(r, "instance not found", "id", id) jsonError(w, http.StatusNotFound, "instance not found") return } @@ -53,7 +344,7 @@ func (s *Server) handleGetInstanceEvents(w http.ResponseWriter, r *http.Request) case ev := <-ins.App.Events(): data, err := json.Marshal(ev) if err != nil { - s.logError(r, "failed to marshal event", "error", err) + c.logError(r, "failed to marshal event", "error", err) continue } @@ -63,7 +354,7 @@ func (s *Server) handleGetInstanceEvents(w http.ResponseWriter, r *http.Request) } } -func (s *Server) handleDeleteInstances(w http.ResponseWriter, r *http.Request) { +func (c *controllerV1) handleDeleteInstances(w http.ResponseWriter, r *http.Request) { var ids []string id := r.URL.Query().Get("id") if id != "" { @@ -73,7 +364,7 @@ func (s *Server) handleDeleteInstances(w http.ResponseWriter, r *http.Request) { // Get IDs from body var args []proto.Instance if err := json.NewDecoder(r.Body).Decode(&args); err != nil { - s.logError(r, "failed to decode request", "error", err) + c.logError(r, "failed to decode request", "error", err) jsonError(w, http.StatusBadRequest, "failed to decode request") return } @@ -86,14 +377,14 @@ func (s *Server) handleDeleteInstances(w http.ResponseWriter, r *http.Request) { }()...) for _, id := range ids { - s.instances.Del(id) + c.instances.Del(id) } } -func (s *Server) handlePostInstances(w http.ResponseWriter, r *http.Request) { +func (c *controllerV1) handlePostInstances(w http.ResponseWriter, r *http.Request) { var args proto.Instance if err := json.NewDecoder(r.Body).Decode(&args); err != nil { - s.logError(r, "failed to decode request", "error", err) + c.logError(r, "failed to decode request", "error", err) jsonError(w, http.StatusBadRequest, "failed to decode request") return } @@ -102,18 +393,21 @@ func (s *Server) handlePostInstances(w http.ResponseWriter, r *http.Request) { hasher := sha256.New() hasher.Write([]byte(filepath.Clean(args.Path))) id := hex.EncodeToString(hasher.Sum(nil)) - if existing, ok := s.instances.Get(id); ok { + if existing, ok := c.instances.Get(id); ok { jsonEncode(w, proto.Instance{ ID: existing.ID(), Path: existing.Path(), - YOLO: existing.cfg.Permissions != nil && existing.cfg.Permissions.SkipRequests, + // TODO: Investigate if this makes sense. + YOLO: existing.cfg.Permissions != nil && existing.cfg.Permissions.SkipRequests, + Debug: existing.cfg.Options.Debug, + DataDir: existing.cfg.Options.DataDirectory, }) return } - cfg, err := config.Init(args.Path, s.cfg.Options.DataDirectory, s.cfg.Options.Debug) + cfg, err := config.Init(args.Path, args.DataDir, args.Debug) if err != nil { - s.logError(r, "failed to initialize config", "error", err) + c.logError(r, "failed to initialize config", "error", err) jsonError(w, http.StatusBadRequest, fmt.Sprintf("failed to initialize config: %v", err)) return } @@ -123,16 +417,16 @@ func (s *Server) handlePostInstances(w http.ResponseWriter, r *http.Request) { } cfg.Permissions.SkipRequests = args.YOLO - if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil { - s.logError(r, "failed to create data directory", "error", err) + if err := createDotCrushDir(args.DataDir); err != nil { + c.logError(r, "failed to create data directory", "error", err) jsonError(w, http.StatusInternalServerError, "failed to create data directory") return } // Connect to DB; this will also run migrations. - conn, err := db.Connect(ctx, cfg.Options.DataDirectory) + conn, err := db.Connect(ctx, args.DataDir) if err != nil { - s.logError(r, "failed to connect to database", "error", err) + c.logError(r, "failed to connect to database", "error", err) jsonError(w, http.StatusInternalServerError, "failed to connect to database") return } @@ -152,7 +446,7 @@ func (s *Server) handlePostInstances(w http.ResponseWriter, r *http.Request) { cfg: cfg, } - s.instances.Set(id, ins) + c.instances.Set(id, ins) jsonEncode(w, proto.Instance{ ID: id, Path: args.Path, diff --git a/internal/server/server.go b/internal/server/server.go index 7a7aeb468abe2878aad64de54423c3bc88265223..d4188aaaa0f99f419f84e86f5b521697c508e1d4 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -110,12 +110,28 @@ func NewServer(cfg *config.Config, network, address string) *Server { var p http.Protocols p.SetHTTP1(true) p.SetUnencryptedHTTP2(true) + c := &controllerV1{Server: s} mux := http.NewServeMux() - mux.HandleFunc("GET /v1/config", s.handleGetConfig) - mux.HandleFunc("GET /v1/instances", s.handleGetInstances) - mux.HandleFunc("POST /v1/instances", s.handlePostInstances) - mux.HandleFunc("DELETE /v1/instances", s.handleDeleteInstances) - mux.HandleFunc("GET /v1/instances/{id}/events", s.handleGetInstanceEvents) + mux.HandleFunc("GET /v1/config", c.handleGetConfig) + mux.HandleFunc("GET /v1/instances", c.handleGetInstances) + mux.HandleFunc("POST /v1/instances", c.handlePostInstances) + mux.HandleFunc("DELETE /v1/instances", c.handleDeleteInstances) + mux.HandleFunc("GET /v1/instances/{id}/events", c.handleGetInstanceEvents) + mux.HandleFunc("GET /v1/instances/{id}/sessions", c.handleGetInstanceSessions) + mux.HandleFunc("POST /v1/instances/{id}/sessions", c.handlePostInstanceSessions) + mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}", c.handleGetInstanceSession) + mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}/history", c.handleGetInstanceSessionHistory) + mux.HandleFunc("GET /v1/instances/{id}/sessions/{sid}/messages", c.handleGetInstanceSessionMessages) + mux.HandleFunc("GET /v1/instances/{id}/lsps", c.handleGetInstanceLSPs) + mux.HandleFunc("GET /v1/instances/{id}/lsps/{lsp}/diagnostics", c.handleGetInstanceLSPDiagnostics) + mux.HandleFunc("GET /v1/instances/{id}/agent", c.handleGetInstanceAgent) + mux.HandleFunc("POST /v1/instances/{id}/agent", c.handlePostInstanceAgent) + mux.HandleFunc("POST /v1/instances/{id}/agent/init", c.handlePostInstanceAgentInit) + mux.HandleFunc("POST /v1/instances/{id}/agent/update", c.handlePostInstanceAgentUpdate) + mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/cancel", c.handlePostInstanceAgentSessionCancel) + mux.HandleFunc("GET /v1/instances/{id}/agent/sessions/{sid}/prompts/queued", c.handleGetInstanceAgentSessionPromptQueued) + mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/prompts/clear", c.handlePostInstanceAgentSessionPromptClear) + mux.HandleFunc("POST /v1/instances/{id}/agent/sessions/{sid}/summarize", c.handleGetInstanceAgentSessionSummarize) s.h = &http.Server{ Protocols: &p, Handler: s.loggingHandler(mux),