From d3fe51ba3aa74ff0a19535121c4a449f45870922 Mon Sep 17 00:00:00 2001 From: Ayman Bagabas Date: Tue, 23 Sep 2025 16:27:41 -0400 Subject: [PATCH] refactor: migrate to proto package and remove global config access --- internal/app/app.go | 4 + internal/app/lsp.go | 10 +- internal/app/lsp_events.go | 38 +--- internal/history/file.go | 11 +- internal/llm/agent/agent.go | 92 ++++----- internal/llm/agent/mcp-tools.go | 96 ++------- internal/llm/prompt/coder.go | 15 +- internal/llm/prompt/prompt.go | 6 +- internal/llm/prompt/task.go | 6 +- internal/llm/provider/anthropic.go | 10 +- internal/llm/provider/azure.go | 3 +- internal/llm/provider/bedrock.go | 2 +- internal/llm/provider/gemini.go | 8 +- internal/llm/provider/openai.go | 8 +- internal/llm/provider/provider.go | 31 +-- internal/llm/provider/vertexai.go | 3 +- internal/lsp/client.go | 34 ++- internal/lsp/client_test.go | 5 +- internal/lsp/handlers.go | 6 +- internal/message/attachment.go | 47 ----- internal/message/message.go | 168 ++++----------- internal/permission/permission.go | 32 +-- internal/proto/agent.go | 70 +++++++ internal/proto/history.go | 11 + internal/proto/mcp.go | 73 +++++++ .../{message/content.go => proto/message.go} | 195 ++++++++++++++++-- internal/proto/permission.go | 28 +++ internal/proto/proto.go | 87 +++++++- internal/proto/session.go | 14 ++ internal/pubsub/broker.go | 2 + internal/pubsub/events.go | 172 ++++++++++++++- internal/session/session.go | 14 +- 32 files changed, 821 insertions(+), 480 deletions(-) delete mode 100644 internal/message/attachment.go create mode 100644 internal/proto/agent.go create mode 100644 internal/proto/history.go create mode 100644 internal/proto/mcp.go rename internal/{message/content.go => proto/message.go} (71%) create mode 100644 internal/proto/permission.go create mode 100644 internal/proto/session.go diff --git a/internal/app/app.go b/internal/app/app.go index dc75f19f30ee0a4b8dee37f5d58102fc5e1a7c11..60f7c67bfe1ae23b32dfcb8a200120f870f76368 100644 --- a/internal/app/app.go +++ b/internal/app/app.go @@ -52,6 +52,9 @@ type App struct { // New initializes a new applcation instance. func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) { + // Attach config to context for use in services. + ctx = config.WithContext(ctx, cfg) + q := db.New(conn) sessions := session.NewService(q) messages := message.NewService(q) @@ -275,6 +278,7 @@ func (app *App) InitCoderAgent() error { var err error app.CoderAgent, err = agent.NewAgent( app.globalCtx, + app.config, coderAgentCfg, app.Permissions, app.Sessions, diff --git a/internal/app/lsp.go b/internal/app/lsp.go index 07db7f3420a5e30172da0698d95cabc307998ee2..4c4885cd597a3b39cac9326765758028dd7ffb3a 100644 --- a/internal/app/lsp.go +++ b/internal/app/lsp.go @@ -22,12 +22,12 @@ func (app *App) initLSPClients(ctx context.Context) { } // createAndStartLSPClient creates a new LSP client, initializes it, and starts its workspace watcher -func (app *App) createAndStartLSPClient(ctx context.Context, name string, config config.LSPConfig) { - slog.Info("Creating LSP client", "name", name, "command", config.Command, "fileTypes", config.FileTypes, "args", config.Args) +func (app *App) createAndStartLSPClient(ctx context.Context, name string, lspCfg config.LSPConfig) { + slog.Info("Creating LSP client", "name", name, "command", lspCfg.Command, "fileTypes", lspCfg.FileTypes, "args", lspCfg.Args) // Check if any root markers exist in the working directory (config now has defaults) - if !lsp.HasRootMarkers(app.config.WorkingDir(), config.RootMarkers) { - slog.Info("Skipping LSP client - no root markers found", "name", name, "rootMarkers", config.RootMarkers) + if !lsp.HasRootMarkers(app.config.WorkingDir(), lspCfg.RootMarkers) { + slog.Info("Skipping LSP client - no root markers found", "name", name, "rootMarkers", lspCfg.RootMarkers) app.updateLSPState(name, lsp.StateDisabled, nil, 0) return } @@ -36,7 +36,7 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, config app.updateLSPState(name, lsp.StateStarting, nil, 0) // Create LSP client. - lspClient, err := lsp.New(ctx, name, config) + lspClient, err := lsp.New(ctx, app.config, name, lspCfg) if err != nil { slog.Error("Failed to create LSP client for", name, err) app.updateLSPState(name, lsp.StateError, err, 0) diff --git a/internal/app/lsp_events.go b/internal/app/lsp_events.go index 9338357e8facd1ff14fdedaf16315aa7e99dd82a..9270ddc635c7d9419c940fda035b8e30f7628ac9 100644 --- a/internal/app/lsp_events.go +++ b/internal/app/lsp_events.go @@ -6,44 +6,20 @@ import ( "time" "github.com/charmbracelet/crush/internal/lsp" + "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/pubsub" ) -// LSPEventType represents the type of LSP event -type LSPEventType string +type ( + LSPClientInfo = proto.LSPClientInfo + LSPEvent = proto.LSPEvent +) const ( - LSPEventStateChanged LSPEventType = "state_changed" - LSPEventDiagnosticsChanged LSPEventType = "diagnostics_changed" + LSPEventStateChanged = proto.LSPEventStateChanged + LSPEventDiagnosticsChanged = proto.LSPEventDiagnosticsChanged ) -func (e LSPEventType) MarshalText() ([]byte, error) { - return []byte(e), nil -} - -func (e *LSPEventType) UnmarshalText(data []byte) error { - *e = LSPEventType(data) - return nil -} - -// LSPEvent represents an event in the LSP system -type LSPEvent struct { - Type LSPEventType `json:"type"` - Name string `json:"name"` - State lsp.ServerState `json:"state"` - Error error `json:"error,omitempty"` - DiagnosticCount int `json:"diagnostic_count,omitempty"` -} - -// LSPClientInfo holds information about an LSP client's state -type LSPClientInfo struct { - Name string `json:"name"` - State lsp.ServerState `json:"state"` - Error error `json:"error,omitempty"` - DiagnosticCount int `json:"diagnostic_count,omitempty"` - ConnectedAt time.Time `json:"connected_at"` -} - // SubscribeLSPEvents returns a channel for LSP events func (a *App) SubscribeLSPEvents(ctx context.Context) <-chan pubsub.Event[LSPEvent] { return a.lspBroker.Subscribe(ctx) diff --git a/internal/history/file.go b/internal/history/file.go index f7c5a04b715785ac992dd0283d230daf3e5114cc..ff96b34a03e99572f211cad898e51af9625dd113 100644 --- a/internal/history/file.go +++ b/internal/history/file.go @@ -7,6 +7,7 @@ import ( "strings" "github.com/charmbracelet/crush/internal/db" + "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/pubsub" "github.com/google/uuid" ) @@ -15,15 +16,7 @@ const ( InitialVersion = 0 ) -type File struct { - ID string `json:"id"` - SessionID string `json:"session_id"` - Path string `json:"path"` - Content string `json:"content"` - Version int64 `json:"version"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` -} +type File = proto.File type Service interface { pubsub.Suscriber[File] diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index e9419d99f63c5ba7b22096ea2a5e4992dd3ad5d1..3e9d97980ef5eee5ab66bad1e8f5dbcd3be5caea 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -20,6 +20,7 @@ import ( "github.com/charmbracelet/crush/internal/lsp" "github.com/charmbracelet/crush/internal/message" "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/crush/internal/session" "github.com/charmbracelet/crush/internal/shell" @@ -31,34 +32,17 @@ var ( ErrSessionBusy = errors.New("session is currently processing another request") ) -type AgentEventType string +type ( + AgentEventType = proto.AgentEventType + AgentEvent = proto.AgentEvent +) const ( - AgentEventTypeError AgentEventType = "error" - AgentEventTypeResponse AgentEventType = "response" - AgentEventTypeSummarize AgentEventType = "summarize" + AgentEventTypeError = proto.AgentEventTypeError + AgentEventTypeResponse = proto.AgentEventTypeResponse + AgentEventTypeSummarize = proto.AgentEventTypeSummarize ) -func (t AgentEventType) MarshalText() ([]byte, error) { - return []byte(t), nil -} - -func (t *AgentEventType) UnmarshalText(text []byte) error { - *t = AgentEventType(text) - return nil -} - -type AgentEvent struct { - Type AgentEventType `json:"type"` - Message message.Message `json:"message"` - Error error `json:"error,omitempty"` - - // When summarizing - SessionID string `json:"session_id,omitempty"` - Progress string `json:"progress,omitempty"` - Done bool `json:"done,omitempty"` -} - type Service interface { pubsub.Suscriber[AgentEvent] Model() catwalk.Model @@ -79,6 +63,7 @@ type agent struct { sessions session.Service messages message.Service mcpTools []McpTool + cfg *config.Config tools *csync.LazySlice[tools.BaseTool] // We need this to be able to update it when model changes @@ -102,6 +87,7 @@ var agentPromptMap = map[string]prompt.PromptID{ func NewAgent( ctx context.Context, + cfg *config.Config, agentCfg config.Agent, // These services are needed in the tools permissions permission.Service, @@ -110,16 +96,14 @@ func NewAgent( history history.Service, lspClients *csync.Map[string, *lsp.Client], ) (Service, error) { - cfg := config.Get() - var agentToolFn func() (tools.BaseTool, error) if agentCfg.ID == "coder" { agentToolFn = func() (tools.BaseTool, error) { - taskAgentCfg := config.Get().Agents["task"] + taskAgentCfg := cfg.Agents["task"] if taskAgentCfg.ID == "" { return nil, fmt.Errorf("task agent not found in config") } - taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients) + taskAgent, err := NewAgent(ctx, cfg, taskAgentCfg, permissions, sessions, messages, history, lspClients) if err != nil { return nil, fmt.Errorf("failed to create task agent: %w", err) } @@ -127,11 +111,11 @@ func NewAgent( } } - providerCfg := config.Get().GetProviderForModel(agentCfg.Model) + providerCfg := cfg.GetProviderForModel(agentCfg.Model) if providerCfg == nil { return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name) } - model := config.Get().GetModelByType(agentCfg.Model) + model := cfg.GetModelByType(agentCfg.Model) if model == nil { return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name) @@ -143,9 +127,9 @@ func NewAgent( } opts := []provider.ProviderClientOption{ provider.WithModel(agentCfg.Model), - provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)), + provider.WithSystemMessage(prompt.GetPrompt(cfg, promptID, providerCfg.ID, cfg.Options.ContextPaths...)), } - agentProvider, err := provider.NewProvider(*providerCfg, opts...) + agentProvider, err := provider.NewProvider(cfg, *providerCfg, opts...) if err != nil { return nil, err } @@ -168,18 +152,18 @@ func NewAgent( titleOpts := []provider.ProviderClientOption{ provider.WithModel(config.SelectedModelTypeSmall), - provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), + provider.WithSystemMessage(prompt.GetPrompt(cfg, prompt.PromptTitle, smallModelProviderCfg.ID)), } - titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...) + titleProvider, err := provider.NewProvider(cfg, *smallModelProviderCfg, titleOpts...) if err != nil { return nil, err } summarizeOpts := []provider.ProviderClientOption{ provider.WithModel(config.SelectedModelTypeLarge), - provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)), + provider.WithSystemMessage(prompt.GetPrompt(cfg, prompt.PromptSummarizer, providerCfg.ID)), } - summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...) + summarizeProvider, err := provider.NewProvider(cfg, *providerCfg, summarizeOpts...) if err != nil { return nil, err } @@ -246,11 +230,12 @@ func NewAgent( activeRequests: csync.NewMap[string, context.CancelFunc](), tools: csync.NewLazySlice(toolFn), promptQueue: csync.NewMap[string, []string](), + cfg: cfg, }, nil } func (a *agent) Model() catwalk.Model { - return *config.Get().GetModelByType(a.agentCfg.Model) + return *a.cfg.GetModelByType(a.agentCfg.Model) } func (a *agent) Cancel(sessionID string) { @@ -400,7 +385,6 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac } func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent { - cfg := config.Get() // List existing messages; if none, start title generation asynchronously. msgs, err := a.messages.List(ctx, sessionID) if err != nil { @@ -459,7 +443,7 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string } return a.err(fmt.Errorf("failed to process events: %w", err)) } - if cfg.Options.Debug { + if a.cfg.Options.Debug { slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults) } if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil { @@ -866,7 +850,7 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error { a.Publish(pubsub.CreatedEvent, event) return } - shell := shell.GetPersistentShell(config.Get().WorkingDir()) + shell := shell.GetPersistentShell(a.cfg.WorkingDir()) summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir() event = AgentEvent{ Type: AgentEventTypeSummarize, @@ -968,10 +952,8 @@ func (a *agent) CancelAll() { } func (a *agent) UpdateModel() error { - cfg := config.Get() - // Get current provider configuration - currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model) + currentProviderCfg := a.cfg.GetProviderForModel(a.agentCfg.Model) if currentProviderCfg == nil || currentProviderCfg.ID == "" { return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name) } @@ -979,7 +961,7 @@ func (a *agent) UpdateModel() error { // Check if provider has changed if string(currentProviderCfg.ID) != a.providerID { // Provider changed, need to recreate the main provider - model := cfg.GetModelByType(a.agentCfg.Model) + model := a.cfg.GetModelByType(a.agentCfg.Model) if model.ID == "" { return fmt.Errorf("model not found for agent %s", a.agentCfg.Name) } @@ -991,10 +973,10 @@ func (a *agent) UpdateModel() error { opts := []provider.ProviderClientOption{ provider.WithModel(a.agentCfg.Model), - provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)), + provider.WithSystemMessage(prompt.GetPrompt(a.cfg, promptID, currentProviderCfg.ID, a.cfg.Options.ContextPaths...)), } - newProvider, err := provider.NewProvider(*currentProviderCfg, opts...) + newProvider, err := provider.NewProvider(a.cfg, *currentProviderCfg, opts...) if err != nil { return fmt.Errorf("failed to create new provider: %w", err) } @@ -1005,9 +987,9 @@ func (a *agent) UpdateModel() error { } // Check if providers have changed for title (small) and summarize (large) - smallModelCfg := cfg.Models[config.SelectedModelTypeSmall] + smallModelCfg := a.cfg.Models[config.SelectedModelTypeSmall] var smallModelProviderCfg config.ProviderConfig - for p := range cfg.Providers.Seq() { + for p := range a.cfg.Providers.Seq() { if p.ID == smallModelCfg.Provider { smallModelProviderCfg = p break @@ -1017,9 +999,9 @@ func (a *agent) UpdateModel() error { return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider) } - largeModelCfg := cfg.Models[config.SelectedModelTypeLarge] + largeModelCfg := a.cfg.Models[config.SelectedModelTypeLarge] var largeModelProviderCfg config.ProviderConfig - for p := range cfg.Providers.Seq() { + for p := range a.cfg.Providers.Seq() { if p.ID == largeModelCfg.Provider { largeModelProviderCfg = p break @@ -1038,10 +1020,10 @@ func (a *agent) UpdateModel() error { // Recreate title provider titleOpts := []provider.ProviderClientOption{ provider.WithModel(config.SelectedModelTypeSmall), - provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)), + provider.WithSystemMessage(prompt.GetPrompt(a.cfg, prompt.PromptTitle, smallModelProviderCfg.ID)), provider.WithMaxTokens(maxTitleTokens), } - newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...) + newTitleProvider, err := provider.NewProvider(a.cfg, smallModelProviderCfg, titleOpts...) if err != nil { return fmt.Errorf("failed to create new title provider: %w", err) } @@ -1049,15 +1031,15 @@ func (a *agent) UpdateModel() error { // Recreate summarize provider if provider changed (now large model) if string(largeModelProviderCfg.ID) != a.summarizeProviderID { - largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge) + largeModel := a.cfg.GetModelByType(config.SelectedModelTypeLarge) if largeModel == nil { return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID) } summarizeOpts := []provider.ProviderClientOption{ provider.WithModel(config.SelectedModelTypeLarge), - provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)), + provider.WithSystemMessage(prompt.GetPrompt(a.cfg, prompt.PromptSummarizer, largeModelProviderCfg.ID)), } - newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...) + newSummarizeProvider, err := provider.NewProvider(a.cfg, largeModelProviderCfg, summarizeOpts...) if err != nil { return fmt.Errorf("failed to create new summarize provider: %w", err) } diff --git a/internal/llm/agent/mcp-tools.go b/internal/llm/agent/mcp-tools.go index faf67591c692f22540ee368ed75e7a4f8c56d00d..def891ce75a047a8f4b75e02a7fed0829165af8d 100644 --- a/internal/llm/agent/mcp-tools.go +++ b/internal/llm/agent/mcp-tools.go @@ -18,6 +18,7 @@ import ( "github.com/charmbracelet/crush/internal/home" "github.com/charmbracelet/crush/internal/llm/tools" "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/pubsub" "github.com/charmbracelet/crush/internal/version" "github.com/mark3labs/mcp-go/client" @@ -25,75 +26,20 @@ import ( "github.com/mark3labs/mcp-go/mcp" ) -// MCPState represents the current state of an MCP client -type MCPState int - -const ( - MCPStateDisabled MCPState = iota - MCPStateStarting - MCPStateConnected - MCPStateError +type ( + MCPState = proto.MCPState + MCPEventType = proto.MCPEventType + MCPEvent = proto.MCPEvent ) -func (s MCPState) MarshalText() ([]byte, error) { - return []byte(s.String()), nil -} - -func (s *MCPState) UnmarshalText(data []byte) error { - switch string(data) { - case "disabled": - *s = MCPStateDisabled - case "starting": - *s = MCPStateStarting - case "connected": - *s = MCPStateConnected - case "error": - *s = MCPStateError - default: - return fmt.Errorf("unknown mcp state: %s", data) - } - return nil -} - -func (s MCPState) String() string { - switch s { - case MCPStateDisabled: - return "disabled" - case MCPStateStarting: - return "starting" - case MCPStateConnected: - return "connected" - case MCPStateError: - return "error" - default: - return "unknown" - } -} - -// MCPEventType represents the type of MCP event -type MCPEventType string - const ( - MCPEventStateChanged MCPEventType = "state_changed" -) - -func (t MCPEventType) MarshalText() ([]byte, error) { - return []byte(t), nil -} - -func (t *MCPEventType) UnmarshalText(data []byte) error { - *t = MCPEventType(data) - return nil -} + MCPStateDisabled = proto.MCPStateDisabled + MCPStateStarting = proto.MCPStateStarting + MCPStateConnected = proto.MCPStateConnected + MCPStateError = proto.MCPStateError -// MCPEvent represents an event in the MCP system -type MCPEvent struct { - Type MCPEventType `json:"type"` - Name string `json:"name"` - State MCPState `json:"state"` - Error error `json:"error,omitempty"` - ToolCount int `json:"tool_count,omitempty"` -} + MCPEventStateChanged = proto.MCPEventStateChanged +) // MCPClientInfo holds information about an MCP client's state type MCPClientInfo struct { @@ -117,7 +63,7 @@ type McpTool struct { mcpName string tool mcp.Tool permissions permission.Service - workingDir string + cfg *config.Config } func (b *McpTool) Name() string { @@ -141,13 +87,13 @@ func (b *McpTool) Info() tools.ToolInfo { } } -func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) { +func runTool(ctx context.Context, cfg *config.Config, name, toolName string, input string) (tools.ToolResponse, error) { var args map[string]any if err := json.Unmarshal([]byte(input), &args); err != nil { return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil } - c, err := getOrRenewClient(ctx, name) + c, err := getOrRenewClient(ctx, cfg, name) if err != nil { return tools.NewTextErrorResponse(err.Error()), nil } @@ -172,13 +118,13 @@ func runTool(ctx context.Context, name, toolName string, input string) (tools.To return tools.NewTextResponse(strings.Join(output, "\n")), nil } -func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) { +func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*client.Client, error) { c, ok := mcpClients.Get(name) if !ok { return nil, fmt.Errorf("mcp '%s' not available", name) } - m := config.Get().MCP[name] + m := cfg.MCP[name] state, _ := mcpStates.Get(name) timeout := mcpTimeout(m) @@ -210,7 +156,7 @@ func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes permission.CreatePermissionRequest{ SessionID: sessionID, ToolCallID: params.ID, - Path: b.workingDir, + Path: b.cfg.WorkingDir(), ToolName: b.Info().Name, Action: "execute", Description: permissionDescription, @@ -221,10 +167,10 @@ func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes return tools.ToolResponse{}, permission.ErrorPermissionDenied } - return runTool(ctx, b.mcpName, b.tool.Name, params.Input) + return runTool(ctx, b.cfg, b.mcpName, b.tool.Name, params.Input) } -func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool { +func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, cfg *config.Config) []tools.BaseTool { result, err := c.ListTools(ctx, mcp.ListToolsRequest{}) if err != nil { slog.Error("error listing tools", "error", err) @@ -239,7 +185,7 @@ func getTools(ctx context.Context, name string, permissions permission.Service, mcpName: name, tool: tool, permissions: permissions, - workingDir: workingDir, + cfg: cfg, }) } return mcpTools @@ -348,7 +294,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con } mcpClients.Set(name, c) - tools := getTools(ctx, name, permissions, c, cfg.WorkingDir()) + tools := getTools(ctx, name, permissions, c, cfg) updateMCPState(name, MCPStateConnected, nil, c, len(tools)) result.Append(tools...) }(name, m) diff --git a/internal/llm/prompt/coder.go b/internal/llm/prompt/coder.go index 90e5a17191f346a5df53622e1826bc04214ddbfc..2f6b4064176ba1c03753703f492cce08a0c6ab31 100644 --- a/internal/llm/prompt/coder.go +++ b/internal/llm/prompt/coder.go @@ -14,7 +14,7 @@ import ( "github.com/charmbracelet/crush/internal/llm/tools" ) -func CoderPrompt(p string, contextFiles ...string) string { +func CoderPrompt(cfg *config.Config, p string, contextFiles ...string) string { var basePrompt string basePrompt = string(anthropicCoderPrompt) @@ -28,11 +28,11 @@ func CoderPrompt(p string, contextFiles ...string) string { if ok, _ := strconv.ParseBool(os.Getenv("CRUSH_CODER_V2")); ok { basePrompt = string(coderV2Prompt) } - envInfo := getEnvironmentInfo() + envInfo := getEnvironmentInfo(cfg) - basePrompt = fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation()) + basePrompt = fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation(cfg)) - contextContent := getContextFromPaths(config.Get().WorkingDir(), contextFiles) + contextContent := getContextFromPaths(cfg.WorkingDir(), contextFiles) if contextContent != "" { return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent) } @@ -48,8 +48,8 @@ var geminiCoderPrompt []byte //go:embed v2.md var coderV2Prompt []byte -func getEnvironmentInfo() string { - cwd := config.Get().WorkingDir() +func getEnvironmentInfo(cfg *config.Config) string { + cwd := cfg.WorkingDir() isGit := isGitRepo(cwd) platform := runtime.GOOS date := time.Now().Format("1/2/2006") @@ -72,8 +72,7 @@ func isGitRepo(dir string) bool { return err == nil } -func lspInformation() string { - cfg := config.Get() +func lspInformation(cfg *config.Config) string { hasLSP := false for _, v := range cfg.LSP { if !v.Disabled { diff --git a/internal/llm/prompt/prompt.go b/internal/llm/prompt/prompt.go index 919686a7d248d6ac2f02ae21ff4a323b26fc536f..54a8b446fb0ff8411228722bc55c4bbb627723b8 100644 --- a/internal/llm/prompt/prompt.go +++ b/internal/llm/prompt/prompt.go @@ -22,15 +22,15 @@ const ( PromptDefault PromptID = "default" ) -func GetPrompt(promptID PromptID, provider string, contextPaths ...string) string { +func GetPrompt(cfg *config.Config, promptID PromptID, provider string, contextPaths ...string) string { basePrompt := "" switch promptID { case PromptCoder: - basePrompt = CoderPrompt(provider, contextPaths...) + basePrompt = CoderPrompt(cfg, provider, contextPaths...) case PromptTitle: basePrompt = TitlePrompt() case PromptTask: - basePrompt = TaskPrompt() + basePrompt = TaskPrompt(cfg) case PromptSummarizer: basePrompt = SummarizerPrompt() default: diff --git a/internal/llm/prompt/task.go b/internal/llm/prompt/task.go index e4f021d4ab7ef9f49873bc6893a231d72f2f3994..e5c6472da9a28ed22d1837d5b8a11eec6dccc576 100644 --- a/internal/llm/prompt/task.go +++ b/internal/llm/prompt/task.go @@ -2,14 +2,16 @@ package prompt import ( "fmt" + + "github.com/charmbracelet/crush/internal/config" ) -func TaskPrompt() string { +func TaskPrompt(cfg *config.Config) string { agentPrompt := `You are an agent for Crush. Given the user's prompt, you should use the tools available to you to answer the user's question. Notes: 1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is .", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...". 2. When relevant, share file names and code snippets relevant to the query 3. Any file paths you return in your final response MUST be absolute. DO NOT use relative paths.` - return fmt.Sprintf("%s\n%s\n", agentPrompt, getEnvironmentInfo()) + return fmt.Sprintf("%s\n%s\n", agentPrompt, getEnvironmentInfo(cfg)) } diff --git a/internal/llm/provider/anthropic.go b/internal/llm/provider/anthropic.go index a5355b09e235d791d178a445ba98095974acbef4..52cb173232477ae118321c0343eaefdf70bccd58 100644 --- a/internal/llm/provider/anthropic.go +++ b/internal/llm/provider/anthropic.go @@ -80,13 +80,13 @@ func createAnthropicClient(opts providerClientOptions, tp AnthropicClientType) a } if opts.baseURL != "" { - resolvedBaseURL, err := config.Get().Resolve(opts.baseURL) + resolvedBaseURL, err := opts.cfg.Resolve(opts.baseURL) if err == nil && resolvedBaseURL != "" { anthropicClientOptions = append(anthropicClientOptions, option.WithBaseURL(resolvedBaseURL)) } } - if config.Get().Options.Debug { + if opts.cfg.Options.Debug { httpClient := log.NewHTTPClient() anthropicClientOptions = append(anthropicClientOptions, option.WithHTTPClient(httpClient)) } @@ -223,7 +223,7 @@ func (a *anthropicClient) finishReason(reason string) message.FinishReason { } func (a *anthropicClient) isThinkingEnabled() bool { - cfg := config.Get() + cfg := a.providerOptions.cfg modelConfig := cfg.Models[config.SelectedModelTypeLarge] if a.providerOptions.modelType == config.SelectedModelTypeSmall { modelConfig = cfg.Models[config.SelectedModelTypeSmall] @@ -234,7 +234,7 @@ func (a *anthropicClient) isThinkingEnabled() bool { func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams { model := a.providerOptions.model(a.providerOptions.modelType) var thinkingParam anthropic.ThinkingConfigParamUnion - cfg := config.Get() + cfg := a.providerOptions.cfg modelConfig := cfg.Models[config.SelectedModelTypeLarge] if a.providerOptions.modelType == config.SelectedModelTypeSmall { modelConfig = cfg.Models[config.SelectedModelTypeSmall] @@ -493,7 +493,7 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err } if apiErr.StatusCode == 401 { - a.providerOptions.apiKey, err = config.Get().Resolve(a.providerOptions.config.APIKey) + a.providerOptions.apiKey, err = a.providerOptions.cfg.Resolve(a.providerOptions.config.APIKey) if err != nil { return false, 0, fmt.Errorf("failed to resolve API key: %w", err) } diff --git a/internal/llm/provider/azure.go b/internal/llm/provider/azure.go index 9042d66876c6f22bd9c06a5f52f6b4502e32c0f2..ed3c47e0eecd7e600db97e6997267bd058cf40df 100644 --- a/internal/llm/provider/azure.go +++ b/internal/llm/provider/azure.go @@ -1,7 +1,6 @@ package provider import ( - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/log" "github.com/openai/openai-go" "github.com/openai/openai-go/azure" @@ -24,7 +23,7 @@ func newAzureClient(opts providerClientOptions) AzureClient { azure.WithEndpoint(opts.baseURL, apiVersion), } - if config.Get().Options.Debug { + if opts.cfg.Options.Debug { httpClient := log.NewHTTPClient() reqOpts = append(reqOpts, option.WithHTTPClient(httpClient)) } diff --git a/internal/llm/provider/bedrock.go b/internal/llm/provider/bedrock.go index 526d11b5597859853be9314ed618748e3ae40f38..7dc0221bf47af907700f37f25db167269dd1a7fe 100644 --- a/internal/llm/provider/bedrock.go +++ b/internal/llm/provider/bedrock.go @@ -33,7 +33,7 @@ func newBedrockClient(opts providerClientOptions) BedrockClient { } opts.model = func(modelType config.SelectedModelType) catwalk.Model { - model := config.Get().GetModelByType(modelType) + model := opts.cfg.GetModelByType(modelType) // Prefix the model name with region regionPrefix := region[:2] diff --git a/internal/llm/provider/gemini.go b/internal/llm/provider/gemini.go index 256e21bf7d59216a41be4603c1475dc9e24bdeea..12dd9c7736d0db2eb57497561cb33144852124e3 100644 --- a/internal/llm/provider/gemini.go +++ b/internal/llm/provider/gemini.go @@ -44,7 +44,7 @@ func createGeminiClient(opts providerClientOptions) (*genai.Client, error) { APIKey: opts.apiKey, Backend: genai.BackendGeminiAPI, } - if config.Get().Options.Debug { + if opts.cfg.Options.Debug { cc.HTTPClient = log.NewHTTPClient() } client, err := genai.NewClient(context.Background(), cc) @@ -178,7 +178,7 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too // Convert messages geminiMessages := g.convertMessages(messages) model := g.providerOptions.model(g.providerOptions.modelType) - cfg := config.Get() + cfg := g.providerOptions.cfg modelConfig := cfg.Models[config.SelectedModelTypeLarge] if g.providerOptions.modelType == config.SelectedModelTypeSmall { @@ -274,7 +274,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t geminiMessages := g.convertMessages(messages) model := g.providerOptions.model(g.providerOptions.modelType) - cfg := config.Get() + cfg := g.providerOptions.cfg modelConfig := cfg.Models[config.SelectedModelTypeLarge] if g.providerOptions.modelType == config.SelectedModelTypeSmall { @@ -433,7 +433,7 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error) // Check for token expiration (401 Unauthorized) if contains(errMsg, "unauthorized", "invalid api key", "api key expired") { - g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey) + g.providerOptions.apiKey, err = g.providerOptions.cfg.Resolve(g.providerOptions.config.APIKey) if err != nil { return false, 0, fmt.Errorf("failed to resolve API key: %w", err) } diff --git a/internal/llm/provider/openai.go b/internal/llm/provider/openai.go index 8df3989abbacbb7e46c59a0c750df8a7879789c1..d190f1eb921174e196baf4b4da38944e753f185f 100644 --- a/internal/llm/provider/openai.go +++ b/internal/llm/provider/openai.go @@ -42,13 +42,13 @@ func createOpenAIClient(opts providerClientOptions) openai.Client { openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey)) } if opts.baseURL != "" { - resolvedBaseURL, err := config.Get().Resolve(opts.baseURL) + resolvedBaseURL, err := opts.cfg.Resolve(opts.baseURL) if err == nil && resolvedBaseURL != "" { openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL)) } } - if config.Get().Options.Debug { + if opts.cfg.Options.Debug { httpClient := log.NewHTTPClient() openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(httpClient)) } @@ -217,7 +217,7 @@ func (o *openaiClient) finishReason(reason string) message.FinishReason { func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams { model := o.providerOptions.model(o.providerOptions.modelType) - cfg := config.Get() + cfg := o.providerOptions.cfg modelConfig := cfg.Models[config.SelectedModelTypeLarge] if o.providerOptions.modelType == config.SelectedModelTypeSmall { @@ -514,7 +514,7 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error) if errors.As(err, &apiErr) { // Check for token expiration (401 Unauthorized) if apiErr.StatusCode == 401 { - o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey) + o.providerOptions.apiKey, err = o.providerOptions.cfg.Resolve(o.providerOptions.config.APIKey) if err != nil { return false, 0, fmt.Errorf("failed to resolve API key: %w", err) } diff --git a/internal/llm/provider/provider.go b/internal/llm/provider/provider.go index 3705645517cd10803ede285f8d2935f43575b746..3c5dade89bf3d1046d22ff4fb932a9bdd4cc938d 100644 --- a/internal/llm/provider/provider.go +++ b/internal/llm/provider/provider.go @@ -62,6 +62,8 @@ type Provider interface { } type providerClientOptions struct { + cfg *config.Config + baseURL string config config.ProviderConfig apiKey string @@ -139,40 +141,41 @@ func WithMaxTokens(maxTokens int64) ProviderClientOption { } } -func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) { +func NewProvider(cfg *config.Config, pcfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) { restore := config.PushPopCrushEnv() defer restore() - resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey) + resolvedAPIKey, err := cfg.Resolve(pcfg.APIKey) if err != nil { - return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err) + return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", pcfg.ID, err) } // Resolve extra headers resolvedExtraHeaders := make(map[string]string) - for key, value := range cfg.ExtraHeaders { - resolvedValue, err := config.Get().Resolve(value) + for key, value := range pcfg.ExtraHeaders { + resolvedValue, err := cfg.Resolve(value) if err != nil { - return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, cfg.ID, err) + return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, pcfg.ID, err) } resolvedExtraHeaders[key] = resolvedValue } clientOptions := providerClientOptions{ - baseURL: cfg.BaseURL, - config: cfg, + cfg: cfg, + baseURL: pcfg.BaseURL, + config: pcfg, apiKey: resolvedAPIKey, extraHeaders: resolvedExtraHeaders, - extraBody: cfg.ExtraBody, - extraParams: cfg.ExtraParams, - systemPromptPrefix: cfg.SystemPromptPrefix, + extraBody: pcfg.ExtraBody, + extraParams: pcfg.ExtraParams, + systemPromptPrefix: pcfg.SystemPromptPrefix, model: func(tp config.SelectedModelType) catwalk.Model { - return *config.Get().GetModelByType(tp) + return *cfg.GetModelByType(tp) }, } for _, o := range opts { o(&clientOptions) } - switch cfg.Type { + switch pcfg.Type { case catwalk.TypeAnthropic: return &baseProvider[AnthropicClient]{ options: clientOptions, @@ -204,5 +207,5 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi client: newVertexAIClient(clientOptions), }, nil } - return nil, fmt.Errorf("provider not supported: %s", cfg.Type) + return nil, fmt.Errorf("provider not supported: %s", pcfg.Type) } diff --git a/internal/llm/provider/vertexai.go b/internal/llm/provider/vertexai.go index 871ff092b058af70833ba615260efcdbc09f2514..759c0364dd81002008cdb556b916858757c9e69d 100644 --- a/internal/llm/provider/vertexai.go +++ b/internal/llm/provider/vertexai.go @@ -5,7 +5,6 @@ import ( "log/slog" "strings" - "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/log" "google.golang.org/genai" ) @@ -20,7 +19,7 @@ func newVertexAIClient(opts providerClientOptions) VertexAIClient { Location: location, Backend: genai.BackendVertexAI, } - if config.Get().Options.Debug { + if opts.cfg.Options.Debug { cc.HTTPClient = log.NewHTTPClient() } client, err := genai.NewClient(context.Background(), cc) diff --git a/internal/lsp/client.go b/internal/lsp/client.go index 16eeebb97989472b8743d1e777eb9cba89b04527..50eb11d1b9d9700c76218dc88672a3b8076df53f 100644 --- a/internal/lsp/client.go +++ b/internal/lsp/client.go @@ -23,6 +23,7 @@ import ( type Client struct { client *powernap.Client + cfg *config.Config name string // File types this LSP server handles (e.g., .go, .rs, .py) @@ -45,7 +46,7 @@ type Client struct { } // New creates a new LSP client using the powernap implementation. -func New(ctx context.Context, name string, config config.LSPConfig) (*Client, error) { +func New(ctx context.Context, cfg *config.Config, name string, lspCfg config.LSPConfig) (*Client, error) { // Convert working directory to file URI workDir, err := os.Getwd() if err != nil { @@ -56,16 +57,16 @@ func New(ctx context.Context, name string, config config.LSPConfig) (*Client, er // Create powernap client config clientConfig := powernap.ClientConfig{ - Command: home.Long(config.Command), - Args: config.Args, + Command: home.Long(lspCfg.Command), + Args: lspCfg.Args, RootURI: rootURI, Environment: func() map[string]string { env := make(map[string]string) - maps.Copy(env, config.Env) + maps.Copy(env, lspCfg.Env) return env }(), - Settings: config.Options, - InitOptions: config.InitOptions, + Settings: lspCfg.Options, + InitOptions: lspCfg.InitOptions, WorkspaceFolders: []protocol.WorkspaceFolder{ { URI: rootURI, @@ -81,12 +82,13 @@ func New(ctx context.Context, name string, config config.LSPConfig) (*Client, er } client := &Client{ + cfg: cfg, client: powernapClient, name: name, - fileTypes: config.FileTypes, + fileTypes: lspCfg.FileTypes, diagnostics: csync.NewVersionedMap[protocol.DocumentURI, []protocol.Diagnostic](), openFiles: csync.NewMap[string, *OpenFileInfo](), - config: config, + config: lspCfg, } // Initialize server state @@ -214,8 +216,6 @@ func (c *Client) SetDiagnosticsCallback(callback func(name string, count int)) { // WaitForServerReady waits for the server to be ready func (c *Client) WaitForServerReady(ctx context.Context) error { - cfg := config.Get() - // Set initial state c.SetServerState(StateStarting) @@ -227,7 +227,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error { ticker := time.NewTicker(500 * time.Millisecond) defer ticker.Stop() - if cfg != nil && cfg.Options.DebugLSP { + if c.cfg != nil && c.cfg.Options.DebugLSP { slog.Debug("Waiting for LSP server to be ready...") } @@ -241,7 +241,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error { case <-ticker.C: // Check if client is running if !c.client.IsRunning() { - if cfg != nil && cfg.Options.DebugLSP { + if c.cfg != nil && c.cfg.Options.DebugLSP { slog.Debug("LSP server not ready yet", "server", c.name) } continue @@ -249,7 +249,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error { // Server is ready c.SetServerState(StateReady) - if cfg != nil && cfg.Options.DebugLSP { + if c.cfg != nil && c.cfg.Options.DebugLSP { slog.Debug("LSP server is ready") } return nil @@ -349,14 +349,13 @@ func (c *Client) NotifyChange(ctx context.Context, filepath string) error { // // NOTE: this is only ever called on LSP shutdown. func (c *Client) CloseFile(ctx context.Context, filepath string) error { - cfg := config.Get() uri := string(protocol.URIFromPath(filepath)) if _, exists := c.openFiles.Get(uri); !exists { return nil // Already closed } - if cfg.Options.DebugLSP { + if c.cfg.Options.DebugLSP { slog.Debug("Closing file", "file", filepath) } @@ -378,7 +377,6 @@ func (c *Client) IsFileOpen(filepath string) bool { // CloseAllFiles closes all currently open files. func (c *Client) CloseAllFiles(ctx context.Context) { - cfg := config.Get() filesToClose := make([]string, 0, c.openFiles.Len()) // First collect all URIs that need to be closed @@ -395,12 +393,12 @@ func (c *Client) CloseAllFiles(ctx context.Context) { // Then close them all for _, filePath := range filesToClose { err := c.CloseFile(ctx, filePath) - if err != nil && cfg != nil && cfg.Options.DebugLSP { + if err != nil && c.cfg != nil && c.cfg.Options.DebugLSP { slog.Warn("Error closing file", "file", filePath, "error", err) } } - if cfg != nil && cfg.Options.DebugLSP { + if c.cfg != nil && c.cfg.Options.DebugLSP { slog.Debug("Closed all files", "files", filesToClose) } } diff --git a/internal/lsp/client_test.go b/internal/lsp/client_test.go index 99ef0ca3143e5b8689ba3b63fd5c172456a46c24..e05d6cec25a4b60979b11fab4b422bd83bf5ccff 100644 --- a/internal/lsp/client_test.go +++ b/internal/lsp/client_test.go @@ -11,7 +11,8 @@ func TestPowernapClient(t *testing.T) { ctx := context.Background() // Create a simple config for testing - cfg := config.LSPConfig{ + var cfg config.Config + lspCfg := config.LSPConfig{ Command: "echo", // Use echo as a dummy command that won't fail Args: []string{"hello"}, FileTypes: []string{"go"}, @@ -20,7 +21,7 @@ func TestPowernapClient(t *testing.T) { // Test creating a powernap client - this will likely fail with echo // but we can still test the basic structure - client, err := New(ctx, "test", cfg) + client, err := New(ctx, &cfg, "test", lspCfg) if err != nil { // Expected to fail with echo command, skip the rest t.Skipf("Powernap client creation failed as expected with dummy command: %v", err) diff --git a/internal/lsp/handlers.go b/internal/lsp/handlers.go index b386e0780f6f6db6db13be380496c60a6e3c457e..dd4228937e3eb9f0e67217fdb186fb179a392940 100644 --- a/internal/lsp/handlers.go +++ b/internal/lsp/handlers.go @@ -79,9 +79,9 @@ func notifyFileWatchRegistration(id string, watchers []protocol.FileSystemWatche } // HandleServerMessage handles server messages -func HandleServerMessage(_ context.Context, method string, params json.RawMessage) { - cfg := config.Get() - if !cfg.Options.DebugLSP { +func HandleServerMessage(ctx context.Context, method string, params json.RawMessage) { + cfg, ok := config.FromContext(ctx) + if !ok || !cfg.Options.DebugLSP { return } diff --git a/internal/message/attachment.go b/internal/message/attachment.go deleted file mode 100644 index 9f6e64172c6a8912694d50e4eb029e1f8d54a3a9..0000000000000000000000000000000000000000 --- a/internal/message/attachment.go +++ /dev/null @@ -1,47 +0,0 @@ -package message - -import ( - "encoding/base64" - "encoding/json" -) - -type Attachment struct { - FilePath string `json:"file_path"` - FileName string `json:"file_name"` - MimeType string `json:"mime_type"` - Content []byte `json:"content"` -} - -// MarshalJSON implements the [json.Marshaler] interface. -func (a Attachment) MarshalJSON() ([]byte, error) { - // Encode the content as a base64 string - type Alias Attachment - return json.Marshal(&struct { - Content string `json:"content"` - *Alias - }{ - Content: base64.StdEncoding.EncodeToString(a.Content), - Alias: (*Alias)(&a), - }) -} - -// UnmarshalJSON implements the [json.Unmarshaler] interface. -func (a *Attachment) UnmarshalJSON(data []byte) error { - // Decode the content from a base64 string - type Alias Attachment - aux := &struct { - Content string `json:"content"` - *Alias - }{ - Alias: (*Alias)(a), - } - if err := json.Unmarshal(data, &aux); err != nil { - return err - } - content, err := base64.StdEncoding.DecodeString(aux.Content) - if err != nil { - return err - } - a.Content = content - return nil -} diff --git a/internal/message/message.go b/internal/message/message.go index 106aa8846cee88e9bb17804de72d3d7c6743e873..00336f39b202e3cc9263a26616902846faf4d6fb 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -3,21 +3,42 @@ package message import ( "context" "database/sql" - "encoding/json" - "fmt" "time" "github.com/charmbracelet/crush/internal/db" + "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/pubsub" "github.com/google/uuid" ) -type CreateMessageParams struct { - Role MessageRole `json:"role"` - Parts []ContentPart `json:"parts"` - Model string `json:"model"` - Provider string `json:"provider,omitempty"` -} +type ( + CreateMessageParams = proto.CreateMessageParams + Message = proto.Message + Attachment = proto.Attachment + ToolCall = proto.ToolCall + ToolResult = proto.ToolResult + ContentPart = proto.ContentPart + TextContent = proto.TextContent + BinaryContent = proto.BinaryContent + FinishReason = proto.FinishReason + Finish = proto.Finish +) + +const ( + Assistant = proto.Assistant + User = proto.User + System = proto.System + Tool = proto.Tool + + FinishReasonEndTurn = proto.FinishReasonEndTurn + FinishReasonMaxTokens = proto.FinishReasonMaxTokens + FinishReasonToolUse = proto.FinishReasonToolUse + FinishReasonCanceled = proto.FinishReasonCanceled + FinishReasonError = proto.FinishReasonError + FinishReasonPermissionDenied = proto.FinishReasonPermissionDenied + + FinishReasonUnknown = proto.FinishReasonUnknown +) type Service interface { pubsub.Suscriber[Message] @@ -55,12 +76,12 @@ func (s *service) Delete(ctx context.Context, id string) error { } func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) { - if params.Role != Assistant { - params.Parts = append(params.Parts, Finish{ + if params.Role != proto.Assistant { + params.Parts = append(params.Parts, proto.Finish{ Reason: "stop", }) } - partsJSON, err := marshallParts(params.Parts) + partsJSON, err := proto.MarshallParts(params.Parts) if err != nil { return Message{}, err } @@ -100,7 +121,7 @@ func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) e } func (s *service) Update(ctx context.Context, message Message) error { - parts, err := marshallParts(message.Parts) + parts, err := proto.MarshallParts(message.Parts) if err != nil { return err } @@ -146,14 +167,14 @@ func (s *service) List(ctx context.Context, sessionID string) ([]Message, error) } func (s *service) fromDBItem(item db.Message) (Message, error) { - parts, err := unmarshallParts([]byte(item.Parts)) + parts, err := proto.UnmarshallParts([]byte(item.Parts)) if err != nil { return Message{}, err } return Message{ ID: item.ID, SessionID: item.SessionID, - Role: MessageRole(item.Role), + Role: proto.MessageRole(item.Role), Parts: parts, Model: item.Model.String, Provider: item.Provider.String, @@ -161,122 +182,3 @@ func (s *service) fromDBItem(item db.Message) (Message, error) { UpdatedAt: item.UpdatedAt, }, nil } - -type partType string - -const ( - reasoningType partType = "reasoning" - textType partType = "text" - imageURLType partType = "image_url" - binaryType partType = "binary" - toolCallType partType = "tool_call" - toolResultType partType = "tool_result" - finishType partType = "finish" -) - -type partWrapper struct { - Type partType `json:"type"` - Data ContentPart `json:"data"` -} - -func marshallParts(parts []ContentPart) ([]byte, error) { - wrappedParts := make([]partWrapper, len(parts)) - - for i, part := range parts { - var typ partType - - switch part.(type) { - case ReasoningContent: - typ = reasoningType - case TextContent: - typ = textType - case ImageURLContent: - typ = imageURLType - case BinaryContent: - typ = binaryType - case ToolCall: - typ = toolCallType - case ToolResult: - typ = toolResultType - case Finish: - typ = finishType - default: - return nil, fmt.Errorf("unknown part type: %T", part) - } - - wrappedParts[i] = partWrapper{ - Type: typ, - Data: part, - } - } - return json.Marshal(wrappedParts) -} - -func unmarshallParts(data []byte) ([]ContentPart, error) { - temp := []json.RawMessage{} - - if err := json.Unmarshal(data, &temp); err != nil { - return nil, err - } - - parts := make([]ContentPart, 0) - - for _, rawPart := range temp { - var wrapper struct { - Type partType `json:"type"` - Data json.RawMessage `json:"data"` - } - - if err := json.Unmarshal(rawPart, &wrapper); err != nil { - return nil, err - } - - switch wrapper.Type { - case reasoningType: - part := ReasoningContent{} - if err := json.Unmarshal(wrapper.Data, &part); err != nil { - return nil, err - } - parts = append(parts, part) - case textType: - part := TextContent{} - if err := json.Unmarshal(wrapper.Data, &part); err != nil { - return nil, err - } - parts = append(parts, part) - case imageURLType: - part := ImageURLContent{} - if err := json.Unmarshal(wrapper.Data, &part); err != nil { - return nil, err - } - case binaryType: - part := BinaryContent{} - if err := json.Unmarshal(wrapper.Data, &part); err != nil { - return nil, err - } - parts = append(parts, part) - case toolCallType: - part := ToolCall{} - if err := json.Unmarshal(wrapper.Data, &part); err != nil { - return nil, err - } - parts = append(parts, part) - case toolResultType: - part := ToolResult{} - if err := json.Unmarshal(wrapper.Data, &part); err != nil { - return nil, err - } - parts = append(parts, part) - case finishType: - part := Finish{} - if err := json.Unmarshal(wrapper.Data, &part); err != nil { - return nil, err - } - parts = append(parts, part) - default: - return nil, fmt.Errorf("unknown part type: %s", wrapper.Type) - } - } - - return parts, nil -} diff --git a/internal/permission/permission.go b/internal/permission/permission.go index 77b2526a592d0d194f75fb71af05477ae75df80b..bf91436f25c6d387e46d0efffa8149bc52936d83 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -9,38 +9,18 @@ import ( "sync" "github.com/charmbracelet/crush/internal/csync" + "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/pubsub" "github.com/google/uuid" ) var ErrorPermissionDenied = errors.New("permission denied") -type CreatePermissionRequest struct { - SessionID string `json:"session_id"` - ToolCallID string `json:"tool_call_id"` - ToolName string `json:"tool_name"` - Description string `json:"description"` - Action string `json:"action"` - Params any `json:"params"` - Path string `json:"path"` -} - -type PermissionNotification struct { - ToolCallID string `json:"tool_call_id"` - Granted bool `json:"granted"` - Denied bool `json:"denied"` -} - -type PermissionRequest struct { - ID string `json:"id"` - SessionID string `json:"session_id"` - ToolCallID string `json:"tool_call_id"` - ToolName string `json:"tool_name"` - Description string `json:"description"` - Action string `json:"action"` - Params any `json:"params"` - Path string `json:"path"` -} +type ( + PermissionRequest = proto.PermissionRequest + PermissionNotification = proto.PermissionNotification + CreatePermissionRequest = proto.CreatePermissionRequest +) type Service interface { pubsub.Suscriber[PermissionRequest] diff --git a/internal/proto/agent.go b/internal/proto/agent.go new file mode 100644 index 0000000000000000000000000000000000000000..b9e29ba4f32a382c64c0c58d10c8cf8f8813a7b0 --- /dev/null +++ b/internal/proto/agent.go @@ -0,0 +1,70 @@ +package proto + +import ( + "encoding/json" + "errors" +) + +type AgentEventType string + +const ( + AgentEventTypeError AgentEventType = "error" + AgentEventTypeResponse AgentEventType = "response" + AgentEventTypeSummarize AgentEventType = "summarize" +) + +func (t AgentEventType) MarshalText() ([]byte, error) { + return []byte(t), nil +} + +func (t *AgentEventType) UnmarshalText(text []byte) error { + *t = AgentEventType(text) + return nil +} + +type AgentEvent struct { + Type AgentEventType `json:"type"` + Message Message `json:"message"` + Error error `json:"error,omitempty"` + + // When summarizing + SessionID string `json:"session_id,omitempty"` + Progress string `json:"progress,omitempty"` + Done bool `json:"done,omitempty"` +} + +// MarshalJSON implements the [json.Marshaler] interface. +func (e AgentEvent) MarshalJSON() ([]byte, error) { + type Alias AgentEvent + return json.Marshal(&struct { + Error string `json:"error,omitempty"` + Alias + }{ + Error: func() string { + if e.Error != nil { + return e.Error.Error() + } + return "" + }(), + Alias: (Alias)(e), + }) +} + +// UnmarshalJSON implements the [json.Unmarshaler] interface. +func (e *AgentEvent) UnmarshalJSON(data []byte) error { + type Alias AgentEvent + aux := &struct { + Error string `json:"error,omitempty"` + Alias + }{ + Alias: (Alias)(*e), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + *e = AgentEvent(aux.Alias) + if aux.Error != "" { + e.Error = errors.New(aux.Error) + } + return nil +} diff --git a/internal/proto/history.go b/internal/proto/history.go new file mode 100644 index 0000000000000000000000000000000000000000..6eb2872aed6f3f59e014ea55d0c7705bac899783 --- /dev/null +++ b/internal/proto/history.go @@ -0,0 +1,11 @@ +package proto + +type File struct { + ID string `json:"id"` + SessionID string `json:"session_id"` + Path string `json:"path"` + Content string `json:"content"` + Version int64 `json:"version"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} diff --git a/internal/proto/mcp.go b/internal/proto/mcp.go new file mode 100644 index 0000000000000000000000000000000000000000..5d0df3de366616e9984c81f44022aeed7ddb54c2 --- /dev/null +++ b/internal/proto/mcp.go @@ -0,0 +1,73 @@ +package proto + +import "fmt" + +// MCPState represents the current state of an MCP client +type MCPState int + +const ( + MCPStateDisabled MCPState = iota + MCPStateStarting + MCPStateConnected + MCPStateError +) + +func (s MCPState) MarshalText() ([]byte, error) { + return []byte(s.String()), nil +} + +func (s *MCPState) UnmarshalText(data []byte) error { + switch string(data) { + case "disabled": + *s = MCPStateDisabled + case "starting": + *s = MCPStateStarting + case "connected": + *s = MCPStateConnected + case "error": + *s = MCPStateError + default: + return fmt.Errorf("unknown mcp state: %s", data) + } + return nil +} + +func (s MCPState) String() string { + switch s { + case MCPStateDisabled: + return "disabled" + case MCPStateStarting: + return "starting" + case MCPStateConnected: + return "connected" + case MCPStateError: + return "error" + default: + return "unknown" + } +} + +// MCPEventType represents the type of MCP event +type MCPEventType string + +const ( + MCPEventStateChanged MCPEventType = "state_changed" +) + +func (t MCPEventType) MarshalText() ([]byte, error) { + return []byte(t), nil +} + +func (t *MCPEventType) UnmarshalText(data []byte) error { + *t = MCPEventType(data) + return nil +} + +// MCPEvent represents an event in the MCP system +type MCPEvent struct { + Type MCPEventType `json:"type"` + Name string `json:"name"` + State MCPState `json:"state"` + Error error `json:"error,omitempty"` + ToolCount int `json:"tool_count,omitempty"` +} diff --git a/internal/message/content.go b/internal/proto/message.go similarity index 71% rename from internal/message/content.go rename to internal/proto/message.go index e226034f574b9561a3b0d2ea4adaf3b0267608f3..6b4decf291c71cc2fdbeefa44d6c8a03ce505c32 100644 --- a/internal/message/content.go +++ b/internal/proto/message.go @@ -1,14 +1,32 @@ -package message +package proto import ( "encoding/base64" "encoding/json" + "fmt" "slices" "time" "github.com/charmbracelet/catwalk/pkg/catwalk" ) +type CreateMessageParams struct { + Role MessageRole `json:"role"` + Parts []ContentPart `json:"parts"` + Model string `json:"model"` + Provider string `json:"provider,omitempty"` +} + +type Message struct { + ID string `json:"id"` + Role MessageRole `json:"role"` + SessionID string `json:"session_id"` + Parts []ContentPart `json:"parts"` + Model string `json:"model"` + Provider string `json:"provider"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} type MessageRole string const ( @@ -132,22 +150,11 @@ type Finish struct { func (Finish) isPart() {} -type Message struct { - ID string `json:"id"` - Role MessageRole `json:"role"` - SessionID string `json:"session_id"` - Parts []ContentPart `json:"parts"` - Model string `json:"model"` - Provider string `json:"provider"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` -} - // MarshalJSON implements the [json.Marshaler] interface. func (m Message) MarshalJSON() ([]byte, error) { // We need to handle the Parts specially since they're ContentPart interfaces // which can't be directly marshaled by the standard JSON package. - parts, err := marshallParts(m.Parts) + parts, err := MarshallParts(m.Parts) if err != nil { return nil, err } @@ -179,7 +186,7 @@ func (m *Message) UnmarshalJSON(data []byte) error { } // Unmarshal the parts using our custom function - parts, err := unmarshallParts([]byte(aux.Parts)) + parts, err := UnmarshallParts([]byte(aux.Parts)) if err != nil { return err } @@ -448,3 +455,163 @@ func (m *Message) AddImageURL(url, detail string) { func (m *Message) AddBinary(mimeType string, data []byte) { m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data}) } + +type partType string + +const ( + reasoningType partType = "reasoning" + textType partType = "text" + imageURLType partType = "image_url" + binaryType partType = "binary" + toolCallType partType = "tool_call" + toolResultType partType = "tool_result" + finishType partType = "finish" +) + +type partWrapper struct { + Type partType `json:"type"` + Data ContentPart `json:"data"` +} + +func MarshallParts(parts []ContentPart) ([]byte, error) { + wrappedParts := make([]partWrapper, len(parts)) + + for i, part := range parts { + var typ partType + + switch part.(type) { + case ReasoningContent: + typ = reasoningType + case TextContent: + typ = textType + case ImageURLContent: + typ = imageURLType + case BinaryContent: + typ = binaryType + case ToolCall: + typ = toolCallType + case ToolResult: + typ = toolResultType + case Finish: + typ = finishType + default: + return nil, fmt.Errorf("unknown part type: %T", part) + } + + wrappedParts[i] = partWrapper{ + Type: typ, + Data: part, + } + } + return json.Marshal(wrappedParts) +} + +func UnmarshallParts(data []byte) ([]ContentPart, error) { + temp := []json.RawMessage{} + + if err := json.Unmarshal(data, &temp); err != nil { + return nil, err + } + + parts := make([]ContentPart, 0) + + for _, rawPart := range temp { + var wrapper struct { + Type partType `json:"type"` + Data json.RawMessage `json:"data"` + } + + if err := json.Unmarshal(rawPart, &wrapper); err != nil { + return nil, err + } + + switch wrapper.Type { + case reasoningType: + part := ReasoningContent{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + parts = append(parts, part) + case textType: + part := TextContent{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + parts = append(parts, part) + case imageURLType: + part := ImageURLContent{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + case binaryType: + part := BinaryContent{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + parts = append(parts, part) + case toolCallType: + part := ToolCall{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + parts = append(parts, part) + case toolResultType: + part := ToolResult{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + parts = append(parts, part) + case finishType: + part := Finish{} + if err := json.Unmarshal(wrapper.Data, &part); err != nil { + return nil, err + } + parts = append(parts, part) + default: + return nil, fmt.Errorf("unknown part type: %s", wrapper.Type) + } + } + + return parts, nil +} + +type Attachment struct { + FilePath string `json:"file_path"` + FileName string `json:"file_name"` + MimeType string `json:"mime_type"` + Content []byte `json:"content"` +} + +// MarshalJSON implements the [json.Marshaler] interface. +func (a Attachment) MarshalJSON() ([]byte, error) { + // Encode the content as a base64 string + type Alias Attachment + return json.Marshal(&struct { + Content string `json:"content"` + *Alias + }{ + Content: base64.StdEncoding.EncodeToString(a.Content), + Alias: (*Alias)(&a), + }) +} + +// UnmarshalJSON implements the [json.Unmarshaler] interface. +func (a *Attachment) UnmarshalJSON(data []byte) error { + // Decode the content from a base64 string + type Alias Attachment + aux := &struct { + Content string `json:"content"` + *Alias + }{ + Alias: (*Alias)(a), + } + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + content, err := base64.StdEncoding.DecodeString(aux.Content) + if err != nil { + return err + } + a.Content = content + return nil +} diff --git a/internal/proto/permission.go b/internal/proto/permission.go new file mode 100644 index 0000000000000000000000000000000000000000..5f03fabe2617103119982eb6a70cc3ef75a75f76 --- /dev/null +++ b/internal/proto/permission.go @@ -0,0 +1,28 @@ +package proto + +type CreatePermissionRequest struct { + SessionID string `json:"session_id"` + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + Description string `json:"description"` + Action string `json:"action"` + Params any `json:"params"` + Path string `json:"path"` +} + +type PermissionNotification struct { + ToolCallID string `json:"tool_call_id"` + Granted bool `json:"granted"` + Denied bool `json:"denied"` +} + +type PermissionRequest struct { + ID string `json:"id"` + SessionID string `json:"session_id"` + ToolCallID string `json:"tool_call_id"` + ToolName string `json:"tool_name"` + Description string `json:"description"` + Action string `json:"action"` + Params any `json:"params"` + Path string `json:"path"` +} diff --git a/internal/proto/proto.go b/internal/proto/proto.go index 08c3b3699a129f0e06dc01964656b4fbd615ae49..1852cbaf0b71862a73f6e6d5fbf7d10a24a3a09c 100644 --- a/internal/proto/proto.go +++ b/internal/proto/proto.go @@ -1,8 +1,10 @@ package proto import ( + "time" + "github.com/charmbracelet/catwalk/pkg/catwalk" - "github.com/charmbracelet/crush/internal/message" + "github.com/charmbracelet/crush/internal/lsp" ) // Instance represents a running app.App instance with its associated resources @@ -33,7 +35,84 @@ func (a AgentInfo) IsZero() bool { // AgentMessage represents a message sent to the agent. type AgentMessage struct { - SessionID string `json:"session_id"` - Prompt string `json:"prompt"` - Attachments []message.Attachment `json:"attachments,omitempty"` + SessionID string `json:"session_id"` + Prompt string `json:"prompt"` + Attachments []Attachment `json:"attachments,omitempty"` +} + +// AgentSession represents a session with its busy status. +type AgentSession struct { + Session + IsBusy bool `json:"is_busy"` +} + +// IsZero checks if the AgentSession is zero-valued. +func (a AgentSession) IsZero() bool { + return a == AgentSession{} +} + +type PermissionAction string + +// Permission responses +const ( + PermissionAllow PermissionAction = "allow" + PermissionAllowForSession PermissionAction = "allow_session" + PermissionDeny PermissionAction = "deny" +) + +// MarshalText implements the [encoding.TextMarshaler] interface. +func (p PermissionAction) MarshalText() ([]byte, error) { + return []byte(p), nil +} + +// UnmarshalText implements the [encoding.TextUnmarshaler] interface. +func (p *PermissionAction) UnmarshalText(text []byte) error { + *p = PermissionAction(text) + return nil +} + +// PermissionGrant represents a permission grant request. +type PermissionGrant struct { + Permission PermissionRequest `json:"permission"` + Action PermissionAction `json:"action"` +} + +// PermissionSkipRequest represents a request to skip permission prompts. +type PermissionSkipRequest struct { + Skip bool `json:"skip"` +} + +// LSPEventType represents the type of LSP event +type LSPEventType string + +const ( + LSPEventStateChanged LSPEventType = "state_changed" + LSPEventDiagnosticsChanged LSPEventType = "diagnostics_changed" +) + +func (e LSPEventType) MarshalText() ([]byte, error) { + return []byte(e), nil +} + +func (e *LSPEventType) UnmarshalText(data []byte) error { + *e = LSPEventType(data) + return nil +} + +// LSPEvent represents an event in the LSP system +type LSPEvent struct { + Type LSPEventType `json:"type"` + Name string `json:"name"` + State lsp.ServerState `json:"state"` + Error error `json:"error,omitempty"` + DiagnosticCount int `json:"diagnostic_count,omitempty"` +} + +// LSPClientInfo holds information about an LSP client's state +type LSPClientInfo struct { + Name string `json:"name"` + State lsp.ServerState `json:"state"` + Error error `json:"error,omitempty"` + DiagnosticCount int `json:"diagnostic_count,omitempty"` + ConnectedAt time.Time `json:"connected_at"` } diff --git a/internal/proto/session.go b/internal/proto/session.go new file mode 100644 index 0000000000000000000000000000000000000000..21e7a050c8d1b22dc93ebff357939f8cd82672d5 --- /dev/null +++ b/internal/proto/session.go @@ -0,0 +1,14 @@ +package proto + +type Session struct { + ID string `json:"id"` + ParentSessionID string `json:"parent_session_id"` + Title string `json:"title"` + MessageCount int64 `json:"message_count"` + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + SummaryMessageID string `json:"summary_message_id"` + Cost float64 `json:"cost"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` +} diff --git a/internal/pubsub/broker.go b/internal/pubsub/broker.go index 80948d3d515a4fb5dad0d4dc36adbbff4e502993..8cb3c4cecc372b8053e299622665da2469edfffa 100644 --- a/internal/pubsub/broker.go +++ b/internal/pubsub/broker.go @@ -2,6 +2,7 @@ package pubsub import ( "context" + "log/slog" "sync" ) @@ -113,6 +114,7 @@ func (b *Broker[T]) Publish(t EventType, payload T) { default: // Channel is full, subscriber is slow - skip this event // This prevents blocking the publisher + slog.Warn("Skipping event for slow subscriber", "event_type", t) } } } diff --git a/internal/pubsub/events.go b/internal/pubsub/events.go index af3df38bdc8cde1f7255c26ef887934412ba537b..85351e045f5dfb03a87fd33ea288befd659095cb 100644 --- a/internal/pubsub/events.go +++ b/internal/pubsub/events.go @@ -1,6 +1,13 @@ package pubsub -import "context" +import ( + "context" + "encoding/json" + "fmt" + "log/slog" + + "github.com/charmbracelet/crush/internal/proto" +) const ( CreatedEvent EventType = "created" @@ -13,6 +20,13 @@ type Suscriber[T any] interface { } type ( + PayloadType = string + + Payload struct { + Type PayloadType `json:"type"` + Payload json.RawMessage `json:"payload"` + } + // EventType identifies the type of event EventType string @@ -27,6 +41,17 @@ type ( } ) +const ( + PayloadTypeLSPEvent PayloadType = "lsp_event" + PayloadTypeMCPEvent PayloadType = "mcp_event" + PayloadTypePermissionRequest PayloadType = "permission_request" + PayloadTypePermissionNotification PayloadType = "permission_notification" + PayloadTypeMessage PayloadType = "message" + PayloadTypeSession PayloadType = "session" + PayloadTypeFile PayloadType = "file" + PayloadTypeAgentEvent PayloadType = "agent_event" +) + func (t EventType) MarshalText() ([]byte, error) { return []byte(t), nil } @@ -35,3 +60,148 @@ func (t *EventType) UnmarshalText(data []byte) error { *t = EventType(data) return nil } + +func (e Event[T]) MarshalJSON() ([]byte, error) { + type Alias Event[T] + + var ( + typ string + bts []byte + err error + ) + switch any(e.Payload).(type) { + case proto.LSPEvent: + typ = "lsp_event" + bts, err = json.Marshal(e.Payload) + case proto.MCPEvent: + typ = "mcp_event" + bts, err = json.Marshal(e.Payload) + case proto.PermissionRequest: + typ = "permission_request" + bts, err = json.Marshal(e.Payload) + case proto.PermissionNotification: + typ = "permission_notification" + bts, err = json.Marshal(e.Payload) + case proto.Message: + typ = "message" + bts, err = json.Marshal(e.Payload) + case proto.Session: + typ = "session" + bts, err = json.Marshal(e.Payload) + case proto.File: + typ = "file" + bts, err = json.Marshal(e.Payload) + case proto.AgentEvent: + typ = "agent_event" + bts, err = json.Marshal(e.Payload) + default: + panic(fmt.Sprintf("unknown payload type: %T", e.Payload)) + } + + if err != nil { + return nil, err + } + + p, err := json.Marshal(&Payload{ + Type: typ, + Payload: bts, + }) + if err != nil { + return nil, err + } + + b, err := json.Marshal(&struct { + Payload json.RawMessage `json:"payload"` + *Alias + }{ + Payload: json.RawMessage(p), + Alias: (*Alias)(&e), + }) + + // slog.Info("marshalled event", "event", fmt.Sprintf("%q", string(b))) + + return b, err +} + +func (e *Event[T]) UnmarshalJSON(data []byte) error { + // slog.Info("unmarshalling event", "data", fmt.Sprintf("%q", string(data))) + + type Alias Event[T] + aux := &struct { + Payload json.RawMessage `json:"payload"` + *Alias + }{ + Alias: (*Alias)(e), + } + + if err := json.Unmarshal(data, &aux); err != nil { + return err + } + + e.Type = aux.Type + + slog.Info("unmarshalled event payload", "aux", fmt.Sprintf("%q", aux.Payload)) + + var wp Payload + if err := json.Unmarshal(aux.Payload, &wp); err != nil { + return err + } + + var pl any + switch wp.Type { + case "lsp_event": + var p proto.LSPEvent + if err := json.Unmarshal(wp.Payload, &p); err != nil { + return err + } + pl = p + case "mcp_event": + var p proto.MCPEvent + if err := json.Unmarshal(wp.Payload, &p); err != nil { + return err + } + pl = p + case "permission_request": + var p proto.PermissionRequest + if err := json.Unmarshal(wp.Payload, &p); err != nil { + return err + } + pl = p + case "permission_notification": + var p proto.PermissionNotification + if err := json.Unmarshal(wp.Payload, &p); err != nil { + return err + } + pl = p + case "message": + var p proto.Message + if err := json.Unmarshal(wp.Payload, &p); err != nil { + return err + } + pl = p + case "session": + var p proto.Session + if err := json.Unmarshal(wp.Payload, &p); err != nil { + return err + } + pl = p + case "file": + var p proto.File + if err := json.Unmarshal(wp.Payload, &p); err != nil { + return err + } + pl = p + case "agent_event": + var p proto.AgentEvent + if err := json.Unmarshal(wp.Payload, &p); err != nil { + return err + } + pl = p + default: + panic(fmt.Sprintf("unknown payload type: %q", wp.Type)) + } + + e.Payload = T(pl.(T)) + + return nil +} diff --git a/internal/session/session.go b/internal/session/session.go index 7b57a37c3bed304ef11211f511e6d993bc497ef4..31eaa7709dc1a73a97647eb4df1516591a174de6 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -5,22 +5,12 @@ import ( "database/sql" "github.com/charmbracelet/crush/internal/db" + "github.com/charmbracelet/crush/internal/proto" "github.com/charmbracelet/crush/internal/pubsub" "github.com/google/uuid" ) -type Session struct { - ID string `json:"id"` - ParentSessionID string `json:"parent_session_id"` - Title string `json:"title"` - MessageCount int64 `json:"message_count"` - PromptTokens int64 `json:"prompt_tokens"` - CompletionTokens int64 `json:"completion_tokens"` - SummaryMessageID string `json:"summary_message_id"` - Cost float64 `json:"cost"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` -} +type Session = proto.Session type Service interface { pubsub.Suscriber[Session]