Detailed changes
@@ -24,7 +24,6 @@ require (
github.com/charmbracelet/x/exp/golden v0.0.0-20250207160936-21c02780d27a
github.com/disintegration/imageorient v0.0.0-20180920195336-8147d86e83ec
github.com/google/uuid v1.6.0
- github.com/hashicorp/net-rpc-msgpackrpc/v2 v2.0.1
github.com/invopop/jsonschema v0.13.0
github.com/joho/godotenv v1.5.1
github.com/mark3labs/mcp-go v0.40.0
@@ -48,12 +47,6 @@ require (
mvdan.cc/sh/v3 v3.12.1-0.20250902163504-3cf4fd5717a5
)
-require (
- github.com/hashicorp/errwrap v1.0.0 // indirect
- github.com/hashicorp/go-msgpack/v2 v2.1.3 // indirect
- github.com/hashicorp/go-multierror v1.1.1 // indirect
-)
-
require (
cloud.google.com/go v0.116.0 // indirect
cloud.google.com/go/auth v0.13.0 // indirect
@@ -162,14 +162,6 @@ github.com/gorilla/css v1.0.1/go.mod h1:BvnYkspnSzMmwRK+b8/xgNPLiIuNZr6vbZBTPQ2A
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
-github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
-github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
-github.com/hashicorp/go-msgpack/v2 v2.1.3 h1:cB1w4Zrk0O3jQBTcFMKqYQWRFfsSQ/TYKNyUUVyCP2c=
-github.com/hashicorp/go-msgpack/v2 v2.1.3/go.mod h1:SjlwKKFnwBXvxD/I1bEcfJIBbEJ+MCUn39TxymNR5ZU=
-github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=
-github.com/hashicorp/go-multierror v1.1.1/go.mod h1:iw975J/qwKPdAO1clOe2L8331t/9/fmwbPZ6JB6eMoM=
-github.com/hashicorp/net-rpc-msgpackrpc/v2 v2.0.1 h1:Y1sd8ZCCUUlUetCk+3MCpOwdWd+WicHdk2zk2yUM0qw=
-github.com/hashicorp/net-rpc-msgpackrpc/v2 v2.0.1/go.mod h1:wASEfI5dofjm9S9Jp3JM4pfoBZy8Z07JUE2wHNi0zuc=
github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM=
github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg=
github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8=
@@ -94,6 +94,11 @@ func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
return app, nil
}
+// Events returns the application's event channel.
+func (app *App) Events() <-chan tea.Msg {
+ return app.events
+}
+
// Config returns the application configuration.
func (app *App) Config() *config.Config {
return app.config
@@ -18,13 +18,22 @@ const (
LSPEventDiagnosticsChanged LSPEventType = "diagnostics_changed"
)
+func (e LSPEventType) MarshalText() ([]byte, error) {
+ return []byte(e), nil
+}
+
+func (e *LSPEventType) UnmarshalText(data []byte) error {
+ *e = LSPEventType(data)
+ return nil
+}
+
// LSPEvent represents an event in the LSP system
type LSPEvent struct {
- Type LSPEventType
- Name string
- State lsp.ServerState
- Error error
- DiagnosticCount int
+ Type LSPEventType `json:"type"`
+ Name string `json:"name"`
+ State lsp.ServerState `json:"state"`
+ Error error `json:"error,omitempty"`
+ DiagnosticCount int `json:"diagnostic_count,omitempty"`
}
// LSPClientInfo holds information about an LSP client's state
@@ -1,17 +1,19 @@
package client
import (
- "net/rpc"
+ "context"
+ "encoding/json"
+ "net"
+ "net/http"
+ "time"
"github.com/charmbracelet/crush/internal/config"
- "github.com/charmbracelet/crush/internal/proto"
"github.com/charmbracelet/crush/internal/server"
- msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc/v2"
)
// Client represents an RPC client connected to a Crush server.
type Client struct {
- rpc *rpc.Client
+ h *http.Client
}
// DefaultClient creates a new [Client] connected to the default server address.
@@ -22,20 +24,34 @@ func DefaultClient() (*Client, error) {
// NewClient creates a new [Client] connected to the server at the given
// network and address.
func NewClient(network, address string) (*Client, error) {
- rpc, err := msgpackrpc.Dial(network, address)
- if err != nil {
- return nil, err
+ var p http.Protocols
+ p.SetHTTP1(true)
+ p.SetUnencryptedHTTP2(true)
+ tr := http.DefaultTransport.(*http.Transport).Clone()
+ tr.Protocols = &p
+ tr.DialContext = func(ctx context.Context, _, _ string) (net.Conn, error) {
+ d := net.Dialer{
+ Timeout: 30 * time.Second,
+ KeepAlive: 30 * time.Second,
+ }
+ return d.DialContext(ctx, network, address)
+ }
+ h := &http.Client{
+ Transport: tr,
}
- return &Client{rpc: rpc}, nil
+ return &Client{h: h}, nil
}
// GetConfig retrieves the server's configuration via RPC.
func (c *Client) GetConfig() (*config.Config, error) {
var cfg config.Config
- var args proto.Args
- err := c.rpc.Call("ServerProto.GetConfig", &args, &cfg)
+ rsp, err := c.h.Get("http://localhost/v1/config")
if err != nil {
return nil, err
}
+ defer rsp.Body.Close()
+ if err := json.NewDecoder(rsp.Body).Decode(&cfg); err != nil {
+ return nil, err
+ }
return &cfg, nil
}
@@ -16,13 +16,13 @@ const (
)
type File struct {
- ID string
- SessionID string
- Path string
- Content string
- Version int64
- CreatedAt int64
- UpdatedAt int64
+ ID string `json:"id"`
+ SessionID string `json:"session_id"`
+ Path string `json:"path"`
+ Content string `json:"content"`
+ Version int64 `json:"version"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
}
type Service interface {
@@ -39,15 +39,24 @@ const (
AgentEventTypeSummarize AgentEventType = "summarize"
)
+func (t AgentEventType) MarshalText() ([]byte, error) {
+ return []byte(t), nil
+}
+
+func (t *AgentEventType) UnmarshalText(text []byte) error {
+ *t = AgentEventType(text)
+ return nil
+}
+
type AgentEvent struct {
- Type AgentEventType
- Message message.Message
- Error error
+ Type AgentEventType `json:"type"`
+ Message message.Message `json:"message"`
+ Error error `json:"error,omitempty"`
// When summarizing
- SessionID string
- Progress string
- Done bool
+ SessionID string `json:"session_id,omitempty"`
+ Progress string `json:"progress,omitempty"`
+ Done bool `json:"done,omitempty"`
}
type Service interface {
@@ -34,6 +34,26 @@ const (
MCPStateError
)
+func (s MCPState) MarshalText() ([]byte, error) {
+ return []byte(s.String()), nil
+}
+
+func (s *MCPState) UnmarshalText(data []byte) error {
+ switch string(data) {
+ case "disabled":
+ *s = MCPStateDisabled
+ case "starting":
+ *s = MCPStateStarting
+ case "connected":
+ *s = MCPStateConnected
+ case "error":
+ *s = MCPStateError
+ default:
+ return fmt.Errorf("unknown mcp state: %s", data)
+ }
+ return nil
+}
+
func (s MCPState) String() string {
switch s {
case MCPStateDisabled:
@@ -56,13 +76,22 @@ const (
MCPEventStateChanged MCPEventType = "state_changed"
)
+func (t MCPEventType) MarshalText() ([]byte, error) {
+ return []byte(t), nil
+}
+
+func (t *MCPEventType) UnmarshalText(data []byte) error {
+ *t = MCPEventType(data)
+ return nil
+}
+
// MCPEvent represents an event in the MCP system
type MCPEvent struct {
- Type MCPEventType
- Name string
- State MCPState
- Error error
- ToolCount int
+ Type MCPEventType `json:"type"`
+ Name string `json:"name"`
+ State MCPState `json:"state"`
+ Error error `json:"error,omitempty"`
+ ToolCount int `json:"tool_count,omitempty"`
}
// MCPClientInfo holds information about an MCP client's state
@@ -157,6 +157,37 @@ const (
StateDisabled
)
+func (s ServerState) MarshalText() ([]byte, error) {
+ switch s {
+ case StateStarting:
+ return []byte("starting"), nil
+ case StateReady:
+ return []byte("ready"), nil
+ case StateError:
+ return []byte("error"), nil
+ case StateDisabled:
+ return []byte("disabled"), nil
+ default:
+ return nil, fmt.Errorf("unknown server state: %d", s)
+ }
+}
+
+func (s *ServerState) UnmarshalText(data []byte) error {
+ switch strings.ToLower(string(data)) {
+ case "starting":
+ *s = StateStarting
+ case "ready":
+ *s = StateReady
+ case "error":
+ *s = StateError
+ case "disabled":
+ *s = StateDisabled
+ default:
+ return fmt.Errorf("unknown server state: %s", data)
+ }
+ return nil
+}
+
// GetServerState returns the current state of the LSP server
func (c *Client) GetServerState() ServerState {
if val := c.serverState.Load(); val != nil {
@@ -17,6 +17,15 @@ const (
Tool MessageRole = "tool"
)
+func (r MessageRole) MarshalText() ([]byte, error) {
+ return []byte(r), nil
+}
+
+func (r *MessageRole) UnmarshalText(data []byte) error {
+ *r = MessageRole(data)
+ return nil
+}
+
type FinishReason string
const (
@@ -31,6 +40,15 @@ const (
FinishReasonUnknown FinishReason = "unknown"
)
+func (fr FinishReason) MarshalText() ([]byte, error) {
+ return []byte(fr), nil
+}
+
+func (fr *FinishReason) UnmarshalText(data []byte) error {
+ *fr = FinishReason(data)
+ return nil
+}
+
type ContentPart interface {
isPart()
}
@@ -114,14 +132,14 @@ type Finish struct {
func (Finish) isPart() {}
type Message struct {
- ID string
- Role MessageRole
- SessionID string
- Parts []ContentPart
- Model string
- Provider string
- CreatedAt int64
- UpdatedAt int64
+ ID string `json:"id"`
+ Role MessageRole `json:"role"`
+ SessionID string `json:"session_id"`
+ Parts []ContentPart `json:"parts"`
+ Model string `json:"model"`
+ Provider string `json:"provider"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
}
func (m *Message) Content() TextContent {
@@ -13,10 +13,10 @@ import (
)
type CreateMessageParams struct {
- Role MessageRole
- Parts []ContentPart
- Model string
- Provider string
+ Role MessageRole `json:"role"`
+ Parts []ContentPart `json:"parts"`
+ Model string `json:"model"`
+ Provider string `json:"provider,omitempty"`
}
type Service interface {
@@ -1,4 +1,14 @@
package proto
-// Args represents generic arguments that apply to all RPC routines.
-type Args struct{}
+// Instance represents a running app.App instance with its associated resources
+// and state.
+type Instance struct {
+ ID string `json:"id"`
+ Path string `json:"path"`
+ YOLO bool `json:"yolo"`
+}
+
+// Error represents an error response.
+type Error struct {
+ Message string `json:"message"`
+}
@@ -18,11 +18,20 @@ type (
// Event represents an event in the lifecycle of a resource
Event[T any] struct {
- Type EventType
- Payload T
+ Type EventType `json:"type"`
+ Payload T `json:"payload"`
}
Publisher[T any] interface {
Publish(EventType, T)
}
)
+
+func (t EventType) MarshalText() ([]byte, error) {
+ return []byte(t), nil
+}
+
+func (t *EventType) UnmarshalText(data []byte) error {
+ *t = EventType(data)
+ return nil
+}
@@ -1,42 +0,0 @@
-package server
-
-import (
- "log/slog"
- "net/rpc"
-
- msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc/v2"
-)
-
-// ServerCodec is a wrapper around msgpackrpc.ServerCodec that adds logging
-// functionality.
-type ServerCodec struct {
- *msgpackrpc.MsgpackCodec
- logger *slog.Logger
-}
-
-var _ rpc.ServerCodec = (*ServerCodec)(nil)
-
-// ReadRequestHeader reads the request header and logs it.
-func (c *ServerCodec) ReadRequestHeader(r *rpc.Request) error {
- err := c.MsgpackCodec.ReadRequestHeader(r)
- if c.logger != nil {
- c.logger.Debug("rpc request",
- slog.String("service_method", r.ServiceMethod),
- slog.Int("seq", int(r.Seq)),
- )
- }
- return err
-}
-
-// WriteResponse writes the response and logs it.
-func (c *ServerCodec) WriteResponse(r *rpc.Response, body any) error {
- err := c.MsgpackCodec.WriteResponse(r, body)
- if c.logger != nil {
- c.logger.Debug("rpc response",
- slog.String("service_method", r.ServiceMethod),
- slog.String("error", r.Error),
- slog.Int("seq", int(r.Seq)),
- )
- }
- return err
-}
@@ -0,0 +1,51 @@
+package server
+
+import (
+ "log/slog"
+ "net/http"
+ "time"
+)
+
+func (s *Server) loggingHandler(next http.Handler) http.Handler {
+ return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ if s.logger == nil {
+ next.ServeHTTP(w, r)
+ return
+ }
+
+ start := time.Now()
+ lrw := &loggingResponseWriter{ResponseWriter: w, statusCode: http.StatusOK}
+ s.logger.Debug("HTTP request",
+ slog.String("method", r.Method),
+ slog.String("path", r.URL.Path),
+ slog.String("remote_addr", r.RemoteAddr),
+ slog.String("user_agent", r.UserAgent()),
+ )
+
+ next.ServeHTTP(lrw, r)
+ duration := time.Since(start)
+
+ s.logger.Debug("HTTP response",
+ slog.String("method", r.Method),
+ slog.String("path", r.URL.Path),
+ slog.Int("status", lrw.statusCode),
+ slog.Duration("duration", duration),
+ slog.String("remote_addr", r.RemoteAddr),
+ slog.String("user_agent", r.UserAgent()),
+ )
+ })
+}
+
+type loggingResponseWriter struct {
+ http.ResponseWriter
+ statusCode int
+}
+
+func (lrw *loggingResponseWriter) WriteHeader(code int) {
+ lrw.statusCode = code
+ lrw.ResponseWriter.WriteHeader(code)
+}
+
+func (lrw *loggingResponseWriter) Unwrap() http.ResponseWriter {
+ return lrw.ResponseWriter
+}
@@ -1,17 +1,187 @@
package server
import (
+ "crypto/sha256"
+ "encoding/hex"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+ "net/http"
+ "os"
+ "path/filepath"
+
+ "github.com/charmbracelet/crush/internal/app"
"github.com/charmbracelet/crush/internal/config"
+ "github.com/charmbracelet/crush/internal/db"
"github.com/charmbracelet/crush/internal/proto"
)
-// ServerProto defines the RPC methods exposed by the Crush server.
-type ServerProto struct {
- *Server
+func (s *Server) handleGetConfig(w http.ResponseWriter, r *http.Request) {
+ jsonEncode(w, s.cfg)
+}
+
+func (s *Server) handleGetInstances(w http.ResponseWriter, r *http.Request) {
+ instances := []proto.Instance{}
+ for _, ins := range s.instances.Seq2() {
+ instances = append(instances, proto.Instance{
+ ID: ins.ID(),
+ Path: ins.Path(),
+ YOLO: ins.cfg.Permissions != nil && ins.cfg.Permissions.SkipRequests,
+ })
+ }
+ jsonEncode(w, instances)
+}
+
+func (s *Server) handleGetInstanceEvents(w http.ResponseWriter, r *http.Request) {
+ flusher := http.NewResponseController(w)
+ id := r.PathValue("id")
+ ins, ok := s.instances.Get(id)
+ if !ok {
+ s.logError(r, "instance not found", "id", id)
+ jsonError(w, http.StatusNotFound, "instance not found")
+ return
+ }
+
+ w.Header().Set("Content-Type", "text/event-stream")
+ w.Header().Set("Cache-Control", "no-cache")
+ w.Header().Set("Connection", "keep-alive")
+
+ for {
+ select {
+ case <-r.Context().Done():
+ return
+ case ev := <-ins.App.Events():
+ data, err := json.Marshal(ev)
+ if err != nil {
+ s.logError(r, "failed to marshal event", "error", err)
+ continue
+ }
+
+ fmt.Fprintf(w, "data: %s\n\n", data)
+ flusher.Flush()
+ }
+ }
+}
+
+func (s *Server) handleDeleteInstances(w http.ResponseWriter, r *http.Request) {
+ var ids []string
+ id := r.URL.Query().Get("id")
+ if id != "" {
+ ids = append(ids, id)
+ }
+
+ // Get IDs from body
+ var args []proto.Instance
+ if err := json.NewDecoder(r.Body).Decode(&args); err != nil {
+ s.logError(r, "failed to decode request", "error", err)
+ jsonError(w, http.StatusBadRequest, "failed to decode request")
+ return
+ }
+ ids = append(ids, func() []string {
+ out := make([]string, len(args))
+ for i, arg := range args {
+ out[i] = arg.ID
+ }
+ return out
+ }()...)
+
+ for _, id := range ids {
+ s.instances.Del(id)
+ }
+}
+
+func (s *Server) handlePostInstances(w http.ResponseWriter, r *http.Request) {
+ var args proto.Instance
+ if err := json.NewDecoder(r.Body).Decode(&args); err != nil {
+ s.logError(r, "failed to decode request", "error", err)
+ jsonError(w, http.StatusBadRequest, "failed to decode request")
+ return
+ }
+
+ ctx := r.Context()
+ hasher := sha256.New()
+ hasher.Write([]byte(filepath.Clean(args.Path)))
+ id := hex.EncodeToString(hasher.Sum(nil))
+ if existing, ok := s.instances.Get(id); ok {
+ jsonEncode(w, proto.Instance{
+ ID: existing.ID(),
+ Path: existing.Path(),
+ YOLO: existing.cfg.Permissions != nil && existing.cfg.Permissions.SkipRequests,
+ })
+ return
+ }
+
+ cfg, err := config.Init(args.Path, s.cfg.Options.DataDirectory, s.cfg.Options.Debug)
+ if err != nil {
+ s.logError(r, "failed to initialize config", "error", err)
+ jsonError(w, http.StatusBadRequest, fmt.Sprintf("failed to initialize config: %v", err))
+ return
+ }
+
+ if cfg.Permissions == nil {
+ cfg.Permissions = &config.Permissions{}
+ }
+ cfg.Permissions.SkipRequests = args.YOLO
+
+ if err := createDotCrushDir(cfg.Options.DataDirectory); err != nil {
+ s.logError(r, "failed to create data directory", "error", err)
+ jsonError(w, http.StatusInternalServerError, "failed to create data directory")
+ return
+ }
+
+ // Connect to DB; this will also run migrations.
+ conn, err := db.Connect(ctx, cfg.Options.DataDirectory)
+ if err != nil {
+ s.logError(r, "failed to connect to database", "error", err)
+ jsonError(w, http.StatusInternalServerError, "failed to connect to database")
+ return
+ }
+
+ appInstance, err := app.New(ctx, conn, cfg)
+ if err != nil {
+ slog.Error("failed to create app instance", "error", err)
+ jsonError(w, http.StatusInternalServerError, "failed to create app instance")
+ return
+ }
+
+ ins := &Instance{
+ App: appInstance,
+ State: InstanceStateCreated,
+ id: id,
+ path: args.Path,
+ cfg: cfg,
+ }
+
+ s.instances.Set(id, ins)
+ jsonEncode(w, proto.Instance{
+ ID: id,
+ Path: args.Path,
+ YOLO: cfg.Permissions.SkipRequests,
+ })
}
-// GetConfig is an RPC routine that returns the server's configuration.
-func (s *ServerProto) GetConfig(args *proto.Args, reply *config.Config) error {
- *reply = *s.cfg
+func createDotCrushDir(dir string) error {
+ if err := os.MkdirAll(dir, 0o700); err != nil {
+ return fmt.Errorf("failed to create data directory: %q %w", dir, err)
+ }
+
+ gitIgnorePath := filepath.Join(dir, ".gitignore")
+ if _, err := os.Stat(gitIgnorePath); os.IsNotExist(err) {
+ if err := os.WriteFile(gitIgnorePath, []byte("*\n"), 0o644); err != nil {
+ return fmt.Errorf("failed to create .gitignore file: %q %w", gitIgnorePath, err)
+ }
+ }
+
return nil
}
+
+func jsonEncode(w http.ResponseWriter, v any) {
+ w.Header().Set("Content-Type", "application/json")
+ _ = json.NewEncoder(w).Encode(v)
+}
+
+func jsonError(w http.ResponseWriter, status int, message string) {
+ w.Header().Set("Content-Type", "application/json")
+ w.WriteHeader(status)
+ _ = json.NewEncoder(w).Encode(proto.Error{Message: message})
+}
@@ -5,20 +5,16 @@ import (
"fmt"
"log/slog"
"net"
- "net/rpc"
+ "net/http"
"os"
"os/user"
"path/filepath"
"runtime"
"strings"
- "sync/atomic"
- "time"
"github.com/charmbracelet/crush/internal/app"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/csync"
-
- msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc/v2"
)
// ErrServerClosed is returned when the server is closed.
@@ -41,6 +37,8 @@ const (
type Instance struct {
*app.App
State InstanceState
+ ln net.Listener
+ cfg *config.Config
id string
path string
}
@@ -58,15 +56,15 @@ func (i *Instance) Path() string {
// DefaultAddr returns the default address path for the Crush server based on
// the operating system.
func DefaultAddr() string {
- sock := "crush.sock"
+ sockPath := "crush.sock"
user, err := user.Current()
if err == nil && user.Uid != "" {
- sock = fmt.Sprintf("crush-%s.sock", user.Uid)
+ sockPath = fmt.Sprintf("crush-%s.sock", user.Uid)
}
if runtime.GOOS == "windows" {
- return fmt.Sprintf(`\\.\pipe\%s`, sock)
+ return fmt.Sprintf(`\\.\pipe\%s`, sockPath)
}
- return filepath.Join(os.TempDir(), sock)
+ return filepath.Join(os.TempDir(), sockPath)
}
// Server represents a Crush server instance bound to a specific address.
@@ -74,14 +72,13 @@ type Server struct {
// Addr can be a TCP address, a Unix socket path, or a Windows named pipe.
Addr string
+ h *http.Server
+ ln net.Listener
+
// instances is a map of running applications managed by the server.
instances *csync.Map[string, *Instance]
- // listeners is the network listener for the server.
- listeners *csync.Map[*net.Listener, struct{}]
cfg *config.Config
logger *slog.Logger
-
- shutdown atomic.Bool
}
// SetLogger sets the logger for the server.
@@ -109,44 +106,33 @@ func NewServer(cfg *config.Config, network, address string) *Server {
s.Addr = address
s.cfg = cfg
s.instances = csync.NewMap[string, *Instance]()
- rpc.Register(&ServerProto{s})
+
+ var p http.Protocols
+ p.SetHTTP1(true)
+ p.SetUnencryptedHTTP2(true)
+ mux := http.NewServeMux()
+ mux.HandleFunc("GET /v1/config", s.handleGetConfig)
+ mux.HandleFunc("GET /v1/instances", s.handleGetInstances)
+ mux.HandleFunc("POST /v1/instances", s.handlePostInstances)
+ mux.HandleFunc("DELETE /v1/instances", s.handleDeleteInstances)
+ mux.HandleFunc("GET /v1/instances/{id}/events", s.handleGetInstanceEvents)
+ s.h = &http.Server{
+ Protocols: &p,
+ Handler: s.loggingHandler(mux),
+ }
return s
}
// Serve accepts incoming connections on the listener.
func (s *Server) Serve(ln net.Listener) error {
- if s.listeners == nil {
- s.listeners = csync.NewMap[*net.Listener, struct{}]()
- }
- s.listeners.Set(&ln, struct{}{})
-
- var tempDelay time.Duration // how long to sleep on accept failure
- for {
- conn, err := ln.Accept()
- if err != nil {
- if s.shuttingDown() {
- return ErrServerClosed
- }
- if ne, ok := err.(net.Error); ok && ne.Temporary() {
- if tempDelay == 0 {
- tempDelay = 5 * time.Millisecond
- } else {
- tempDelay *= 2
- }
- if max := 1 * time.Second; tempDelay > max {
- tempDelay = max
- }
- time.Sleep(tempDelay)
- continue
- }
- return fmt.Errorf("failed to accept connection: %w", err)
- }
- go s.handleConn(conn)
- }
+ return s.h.Serve(ln)
}
// ListenAndServe starts the server and begins accepting connections.
func (s *Server) ListenAndServe() error {
+ if s.ln != nil {
+ return fmt.Errorf("server already started")
+ }
ln, err := listen("unix", s.Addr)
if err != nil {
return fmt.Errorf("failed to listen on %s: %w", s.Addr, err)
@@ -154,63 +140,33 @@ func (s *Server) ListenAndServe() error {
return s.Serve(ln)
}
+func (s *Server) closeListener() {
+ if s.ln != nil {
+ s.ln.Close()
+ s.ln = nil
+ }
+}
+
// Close force close all listeners and connections.
func (s *Server) Close() error {
- s.shutdown.Store(true)
- var firstErr error
- for k := range s.listeners.Seq2() {
- if err := (*k).Close(); err != nil && firstErr == nil {
- firstErr = err
- }
- s.listeners.Del(k)
- }
- return firstErr
+ defer func() { s.closeListener() }()
+ return s.h.Close()
}
// Shutdown gracefully shuts down the server without interrupting active
// connections. It stops accepting new connections and waits for existing
// connections to finish.
func (s *Server) Shutdown(ctx context.Context) error {
- // TODO: implement graceful shutdown
- return s.Close()
-}
-
-func (s *Server) handleConn(conn net.Conn) {
- s.info("accepted connection", "remote_addr", conn.LocalAddr())
- codec := &ServerCodec{
- MsgpackCodec: msgpackrpc.NewCodec(true, true, conn),
- logger: s.logger.With(
- slog.String("remote_addr", conn.RemoteAddr().String()),
- slog.String("local_addr", conn.LocalAddr().String()),
- ),
- }
- rpc.ServeCodec(codec)
-}
-
-func (s *Server) shuttingDown() bool {
- return s.shutdown.Load()
-}
-
-func (s *Server) info(msg string, args ...any) {
- if s.logger != nil {
- s.logger.Info(msg, args...)
- }
-}
-
-func (s *Server) debug(msg string, args ...any) {
- if s.logger != nil {
- s.logger.Debug(msg, args...)
- }
-}
-
-func (s *Server) error(msg string, args ...any) {
- if s.logger != nil {
- s.logger.Error(msg, args...)
- }
+ defer func() { s.closeListener() }()
+ return s.h.Shutdown(ctx)
}
-func (s *Server) warn(msg string, args ...any) {
+func (s *Server) logError(r *http.Request, msg string, args ...any) {
if s.logger != nil {
- s.logger.Warn(msg, args...)
+ s.logger.With(
+ slog.String("method", r.Method),
+ slog.String("url", r.URL.String()),
+ slog.String("remote_addr", r.RemoteAddr),
+ ).Error(msg, args...)
}
}
@@ -10,16 +10,16 @@ import (
)
type Session struct {
- ID string
- ParentSessionID string
- Title string
- MessageCount int64
- PromptTokens int64
- CompletionTokens int64
- SummaryMessageID string
- Cost float64
- CreatedAt int64
- UpdatedAt int64
+ ID string `json:"id"`
+ ParentSessionID string `json:"parent_session_id"`
+ Title string `json:"title"`
+ MessageCount int64 `json:"message_count"`
+ PromptTokens int64 `json:"prompt_tokens"`
+ CompletionTokens int64 `json:"completion_tokens"`
+ SummaryMessageID string `json:"summary_message_id"`
+ Cost float64 `json:"cost"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
}
type Service interface {