Detailed changes
@@ -52,6 +52,9 @@ type App struct {
// New initializes a new applcation instance.
func New(ctx context.Context, conn *sql.DB, cfg *config.Config) (*App, error) {
+ // Attach config to context for use in services.
+ ctx = config.WithContext(ctx, cfg)
+
q := db.New(conn)
sessions := session.NewService(q)
messages := message.NewService(q)
@@ -275,6 +278,7 @@ func (app *App) InitCoderAgent() error {
var err error
app.CoderAgent, err = agent.NewAgent(
app.globalCtx,
+ app.config,
coderAgentCfg,
app.Permissions,
app.Sessions,
@@ -22,12 +22,12 @@ func (app *App) initLSPClients(ctx context.Context) {
}
// createAndStartLSPClient creates a new LSP client, initializes it, and starts its workspace watcher
-func (app *App) createAndStartLSPClient(ctx context.Context, name string, config config.LSPConfig) {
- slog.Info("Creating LSP client", "name", name, "command", config.Command, "fileTypes", config.FileTypes, "args", config.Args)
+func (app *App) createAndStartLSPClient(ctx context.Context, name string, lspCfg config.LSPConfig) {
+ slog.Info("Creating LSP client", "name", name, "command", lspCfg.Command, "fileTypes", lspCfg.FileTypes, "args", lspCfg.Args)
// Check if any root markers exist in the working directory (config now has defaults)
- if !lsp.HasRootMarkers(app.config.WorkingDir(), config.RootMarkers) {
- slog.Info("Skipping LSP client - no root markers found", "name", name, "rootMarkers", config.RootMarkers)
+ if !lsp.HasRootMarkers(app.config.WorkingDir(), lspCfg.RootMarkers) {
+ slog.Info("Skipping LSP client - no root markers found", "name", name, "rootMarkers", lspCfg.RootMarkers)
app.updateLSPState(name, lsp.StateDisabled, nil, 0)
return
}
@@ -36,7 +36,7 @@ func (app *App) createAndStartLSPClient(ctx context.Context, name string, config
app.updateLSPState(name, lsp.StateStarting, nil, 0)
// Create LSP client.
- lspClient, err := lsp.New(ctx, name, config)
+ lspClient, err := lsp.New(ctx, app.config, name, lspCfg)
if err != nil {
slog.Error("Failed to create LSP client for", name, err)
app.updateLSPState(name, lsp.StateError, err, 0)
@@ -6,44 +6,20 @@ import (
"time"
"github.com/charmbracelet/crush/internal/lsp"
+ "github.com/charmbracelet/crush/internal/proto"
"github.com/charmbracelet/crush/internal/pubsub"
)
-// LSPEventType represents the type of LSP event
-type LSPEventType string
+type (
+ LSPClientInfo = proto.LSPClientInfo
+ LSPEvent = proto.LSPEvent
+)
const (
- LSPEventStateChanged LSPEventType = "state_changed"
- LSPEventDiagnosticsChanged LSPEventType = "diagnostics_changed"
+ LSPEventStateChanged = proto.LSPEventStateChanged
+ LSPEventDiagnosticsChanged = proto.LSPEventDiagnosticsChanged
)
-func (e LSPEventType) MarshalText() ([]byte, error) {
- return []byte(e), nil
-}
-
-func (e *LSPEventType) UnmarshalText(data []byte) error {
- *e = LSPEventType(data)
- return nil
-}
-
-// LSPEvent represents an event in the LSP system
-type LSPEvent struct {
- Type LSPEventType `json:"type"`
- Name string `json:"name"`
- State lsp.ServerState `json:"state"`
- Error error `json:"error,omitempty"`
- DiagnosticCount int `json:"diagnostic_count,omitempty"`
-}
-
-// LSPClientInfo holds information about an LSP client's state
-type LSPClientInfo struct {
- Name string `json:"name"`
- State lsp.ServerState `json:"state"`
- Error error `json:"error,omitempty"`
- DiagnosticCount int `json:"diagnostic_count,omitempty"`
- ConnectedAt time.Time `json:"connected_at"`
-}
-
// SubscribeLSPEvents returns a channel for LSP events
func (a *App) SubscribeLSPEvents(ctx context.Context) <-chan pubsub.Event[LSPEvent] {
return a.lspBroker.Subscribe(ctx)
@@ -7,6 +7,7 @@ import (
"strings"
"github.com/charmbracelet/crush/internal/db"
+ "github.com/charmbracelet/crush/internal/proto"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/google/uuid"
)
@@ -15,15 +16,7 @@ const (
InitialVersion = 0
)
-type File struct {
- ID string `json:"id"`
- SessionID string `json:"session_id"`
- Path string `json:"path"`
- Content string `json:"content"`
- Version int64 `json:"version"`
- CreatedAt int64 `json:"created_at"`
- UpdatedAt int64 `json:"updated_at"`
-}
+type File = proto.File
type Service interface {
pubsub.Suscriber[File]
@@ -20,6 +20,7 @@ import (
"github.com/charmbracelet/crush/internal/lsp"
"github.com/charmbracelet/crush/internal/message"
"github.com/charmbracelet/crush/internal/permission"
+ "github.com/charmbracelet/crush/internal/proto"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/charmbracelet/crush/internal/session"
"github.com/charmbracelet/crush/internal/shell"
@@ -31,34 +32,17 @@ var (
ErrSessionBusy = errors.New("session is currently processing another request")
)
-type AgentEventType string
+type (
+ AgentEventType = proto.AgentEventType
+ AgentEvent = proto.AgentEvent
+)
const (
- AgentEventTypeError AgentEventType = "error"
- AgentEventTypeResponse AgentEventType = "response"
- AgentEventTypeSummarize AgentEventType = "summarize"
+ AgentEventTypeError = proto.AgentEventTypeError
+ AgentEventTypeResponse = proto.AgentEventTypeResponse
+ AgentEventTypeSummarize = proto.AgentEventTypeSummarize
)
-func (t AgentEventType) MarshalText() ([]byte, error) {
- return []byte(t), nil
-}
-
-func (t *AgentEventType) UnmarshalText(text []byte) error {
- *t = AgentEventType(text)
- return nil
-}
-
-type AgentEvent struct {
- Type AgentEventType `json:"type"`
- Message message.Message `json:"message"`
- Error error `json:"error,omitempty"`
-
- // When summarizing
- SessionID string `json:"session_id,omitempty"`
- Progress string `json:"progress,omitempty"`
- Done bool `json:"done,omitempty"`
-}
-
type Service interface {
pubsub.Suscriber[AgentEvent]
Model() catwalk.Model
@@ -79,6 +63,7 @@ type agent struct {
sessions session.Service
messages message.Service
mcpTools []McpTool
+ cfg *config.Config
tools *csync.LazySlice[tools.BaseTool]
// We need this to be able to update it when model changes
@@ -102,6 +87,7 @@ var agentPromptMap = map[string]prompt.PromptID{
func NewAgent(
ctx context.Context,
+ cfg *config.Config,
agentCfg config.Agent,
// These services are needed in the tools
permissions permission.Service,
@@ -110,16 +96,14 @@ func NewAgent(
history history.Service,
lspClients *csync.Map[string, *lsp.Client],
) (Service, error) {
- cfg := config.Get()
-
var agentToolFn func() (tools.BaseTool, error)
if agentCfg.ID == "coder" {
agentToolFn = func() (tools.BaseTool, error) {
- taskAgentCfg := config.Get().Agents["task"]
+ taskAgentCfg := cfg.Agents["task"]
if taskAgentCfg.ID == "" {
return nil, fmt.Errorf("task agent not found in config")
}
- taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
+ taskAgent, err := NewAgent(ctx, cfg, taskAgentCfg, permissions, sessions, messages, history, lspClients)
if err != nil {
return nil, fmt.Errorf("failed to create task agent: %w", err)
}
@@ -127,11 +111,11 @@ func NewAgent(
}
}
- providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
+ providerCfg := cfg.GetProviderForModel(agentCfg.Model)
if providerCfg == nil {
return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
}
- model := config.Get().GetModelByType(agentCfg.Model)
+ model := cfg.GetModelByType(agentCfg.Model)
if model == nil {
return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
@@ -143,9 +127,9 @@ func NewAgent(
}
opts := []provider.ProviderClientOption{
provider.WithModel(agentCfg.Model),
- provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
+ provider.WithSystemMessage(prompt.GetPrompt(cfg, promptID, providerCfg.ID, cfg.Options.ContextPaths...)),
}
- agentProvider, err := provider.NewProvider(*providerCfg, opts...)
+ agentProvider, err := provider.NewProvider(cfg, *providerCfg, opts...)
if err != nil {
return nil, err
}
@@ -168,18 +152,18 @@ func NewAgent(
titleOpts := []provider.ProviderClientOption{
provider.WithModel(config.SelectedModelTypeSmall),
- provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
+ provider.WithSystemMessage(prompt.GetPrompt(cfg, prompt.PromptTitle, smallModelProviderCfg.ID)),
}
- titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
+ titleProvider, err := provider.NewProvider(cfg, *smallModelProviderCfg, titleOpts...)
if err != nil {
return nil, err
}
summarizeOpts := []provider.ProviderClientOption{
provider.WithModel(config.SelectedModelTypeLarge),
- provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
+ provider.WithSystemMessage(prompt.GetPrompt(cfg, prompt.PromptSummarizer, providerCfg.ID)),
}
- summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
+ summarizeProvider, err := provider.NewProvider(cfg, *providerCfg, summarizeOpts...)
if err != nil {
return nil, err
}
@@ -246,11 +230,12 @@ func NewAgent(
activeRequests: csync.NewMap[string, context.CancelFunc](),
tools: csync.NewLazySlice(toolFn),
promptQueue: csync.NewMap[string, []string](),
+ cfg: cfg,
}, nil
}
func (a *agent) Model() catwalk.Model {
- return *config.Get().GetModelByType(a.agentCfg.Model)
+ return *a.cfg.GetModelByType(a.agentCfg.Model)
}
func (a *agent) Cancel(sessionID string) {
@@ -400,7 +385,6 @@ func (a *agent) Run(ctx context.Context, sessionID string, content string, attac
}
func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
- cfg := config.Get()
// List existing messages; if none, start title generation asynchronously.
msgs, err := a.messages.List(ctx, sessionID)
if err != nil {
@@ -459,7 +443,7 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
}
return a.err(fmt.Errorf("failed to process events: %w", err))
}
- if cfg.Options.Debug {
+ if a.cfg.Options.Debug {
slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
}
if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
@@ -866,7 +850,7 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
a.Publish(pubsub.CreatedEvent, event)
return
}
- shell := shell.GetPersistentShell(config.Get().WorkingDir())
+ shell := shell.GetPersistentShell(a.cfg.WorkingDir())
summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
event = AgentEvent{
Type: AgentEventTypeSummarize,
@@ -968,10 +952,8 @@ func (a *agent) CancelAll() {
}
func (a *agent) UpdateModel() error {
- cfg := config.Get()
-
// Get current provider configuration
- currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
+ currentProviderCfg := a.cfg.GetProviderForModel(a.agentCfg.Model)
if currentProviderCfg == nil || currentProviderCfg.ID == "" {
return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
}
@@ -979,7 +961,7 @@ func (a *agent) UpdateModel() error {
// Check if provider has changed
if string(currentProviderCfg.ID) != a.providerID {
// Provider changed, need to recreate the main provider
- model := cfg.GetModelByType(a.agentCfg.Model)
+ model := a.cfg.GetModelByType(a.agentCfg.Model)
if model.ID == "" {
return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
}
@@ -991,10 +973,10 @@ func (a *agent) UpdateModel() error {
opts := []provider.ProviderClientOption{
provider.WithModel(a.agentCfg.Model),
- provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
+ provider.WithSystemMessage(prompt.GetPrompt(a.cfg, promptID, currentProviderCfg.ID, a.cfg.Options.ContextPaths...)),
}
- newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
+ newProvider, err := provider.NewProvider(a.cfg, *currentProviderCfg, opts...)
if err != nil {
return fmt.Errorf("failed to create new provider: %w", err)
}
@@ -1005,9 +987,9 @@ func (a *agent) UpdateModel() error {
}
// Check if providers have changed for title (small) and summarize (large)
- smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
+ smallModelCfg := a.cfg.Models[config.SelectedModelTypeSmall]
var smallModelProviderCfg config.ProviderConfig
- for p := range cfg.Providers.Seq() {
+ for p := range a.cfg.Providers.Seq() {
if p.ID == smallModelCfg.Provider {
smallModelProviderCfg = p
break
@@ -1017,9 +999,9 @@ func (a *agent) UpdateModel() error {
return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
}
- largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
+ largeModelCfg := a.cfg.Models[config.SelectedModelTypeLarge]
var largeModelProviderCfg config.ProviderConfig
- for p := range cfg.Providers.Seq() {
+ for p := range a.cfg.Providers.Seq() {
if p.ID == largeModelCfg.Provider {
largeModelProviderCfg = p
break
@@ -1038,10 +1020,10 @@ func (a *agent) UpdateModel() error {
// Recreate title provider
titleOpts := []provider.ProviderClientOption{
provider.WithModel(config.SelectedModelTypeSmall),
- provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
+ provider.WithSystemMessage(prompt.GetPrompt(a.cfg, prompt.PromptTitle, smallModelProviderCfg.ID)),
provider.WithMaxTokens(maxTitleTokens),
}
- newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
+ newTitleProvider, err := provider.NewProvider(a.cfg, smallModelProviderCfg, titleOpts...)
if err != nil {
return fmt.Errorf("failed to create new title provider: %w", err)
}
@@ -1049,15 +1031,15 @@ func (a *agent) UpdateModel() error {
// Recreate summarize provider if provider changed (now large model)
if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
- largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
+ largeModel := a.cfg.GetModelByType(config.SelectedModelTypeLarge)
if largeModel == nil {
return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
}
summarizeOpts := []provider.ProviderClientOption{
provider.WithModel(config.SelectedModelTypeLarge),
- provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
+ provider.WithSystemMessage(prompt.GetPrompt(a.cfg, prompt.PromptSummarizer, largeModelProviderCfg.ID)),
}
- newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
+ newSummarizeProvider, err := provider.NewProvider(a.cfg, largeModelProviderCfg, summarizeOpts...)
if err != nil {
return fmt.Errorf("failed to create new summarize provider: %w", err)
}
@@ -18,6 +18,7 @@ import (
"github.com/charmbracelet/crush/internal/home"
"github.com/charmbracelet/crush/internal/llm/tools"
"github.com/charmbracelet/crush/internal/permission"
+ "github.com/charmbracelet/crush/internal/proto"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/charmbracelet/crush/internal/version"
"github.com/mark3labs/mcp-go/client"
@@ -25,75 +26,20 @@ import (
"github.com/mark3labs/mcp-go/mcp"
)
-// MCPState represents the current state of an MCP client
-type MCPState int
-
-const (
- MCPStateDisabled MCPState = iota
- MCPStateStarting
- MCPStateConnected
- MCPStateError
+type (
+ MCPState = proto.MCPState
+ MCPEventType = proto.MCPEventType
+ MCPEvent = proto.MCPEvent
)
-func (s MCPState) MarshalText() ([]byte, error) {
- return []byte(s.String()), nil
-}
-
-func (s *MCPState) UnmarshalText(data []byte) error {
- switch string(data) {
- case "disabled":
- *s = MCPStateDisabled
- case "starting":
- *s = MCPStateStarting
- case "connected":
- *s = MCPStateConnected
- case "error":
- *s = MCPStateError
- default:
- return fmt.Errorf("unknown mcp state: %s", data)
- }
- return nil
-}
-
-func (s MCPState) String() string {
- switch s {
- case MCPStateDisabled:
- return "disabled"
- case MCPStateStarting:
- return "starting"
- case MCPStateConnected:
- return "connected"
- case MCPStateError:
- return "error"
- default:
- return "unknown"
- }
-}
-
-// MCPEventType represents the type of MCP event
-type MCPEventType string
-
const (
- MCPEventStateChanged MCPEventType = "state_changed"
-)
-
-func (t MCPEventType) MarshalText() ([]byte, error) {
- return []byte(t), nil
-}
-
-func (t *MCPEventType) UnmarshalText(data []byte) error {
- *t = MCPEventType(data)
- return nil
-}
+ MCPStateDisabled = proto.MCPStateDisabled
+ MCPStateStarting = proto.MCPStateStarting
+ MCPStateConnected = proto.MCPStateConnected
+ MCPStateError = proto.MCPStateError
-// MCPEvent represents an event in the MCP system
-type MCPEvent struct {
- Type MCPEventType `json:"type"`
- Name string `json:"name"`
- State MCPState `json:"state"`
- Error error `json:"error,omitempty"`
- ToolCount int `json:"tool_count,omitempty"`
-}
+ MCPEventStateChanged = proto.MCPEventStateChanged
+)
// MCPClientInfo holds information about an MCP client's state
type MCPClientInfo struct {
@@ -117,7 +63,7 @@ type McpTool struct {
mcpName string
tool mcp.Tool
permissions permission.Service
- workingDir string
+ cfg *config.Config
}
func (b *McpTool) Name() string {
@@ -141,13 +87,13 @@ func (b *McpTool) Info() tools.ToolInfo {
}
}
-func runTool(ctx context.Context, name, toolName string, input string) (tools.ToolResponse, error) {
+func runTool(ctx context.Context, cfg *config.Config, name, toolName string, input string) (tools.ToolResponse, error) {
var args map[string]any
if err := json.Unmarshal([]byte(input), &args); err != nil {
return tools.NewTextErrorResponse(fmt.Sprintf("error parsing parameters: %s", err)), nil
}
- c, err := getOrRenewClient(ctx, name)
+ c, err := getOrRenewClient(ctx, cfg, name)
if err != nil {
return tools.NewTextErrorResponse(err.Error()), nil
}
@@ -172,13 +118,13 @@ func runTool(ctx context.Context, name, toolName string, input string) (tools.To
return tools.NewTextResponse(strings.Join(output, "\n")), nil
}
-func getOrRenewClient(ctx context.Context, name string) (*client.Client, error) {
+func getOrRenewClient(ctx context.Context, cfg *config.Config, name string) (*client.Client, error) {
c, ok := mcpClients.Get(name)
if !ok {
return nil, fmt.Errorf("mcp '%s' not available", name)
}
- m := config.Get().MCP[name]
+ m := cfg.MCP[name]
state, _ := mcpStates.Get(name)
timeout := mcpTimeout(m)
@@ -210,7 +156,7 @@ func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
permission.CreatePermissionRequest{
SessionID: sessionID,
ToolCallID: params.ID,
- Path: b.workingDir,
+ Path: b.cfg.WorkingDir(),
ToolName: b.Info().Name,
Action: "execute",
Description: permissionDescription,
@@ -221,10 +167,10 @@ func (b *McpTool) Run(ctx context.Context, params tools.ToolCall) (tools.ToolRes
return tools.ToolResponse{}, permission.ErrorPermissionDenied
}
- return runTool(ctx, b.mcpName, b.tool.Name, params.Input)
+ return runTool(ctx, b.cfg, b.mcpName, b.tool.Name, params.Input)
}
-func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, workingDir string) []tools.BaseTool {
+func getTools(ctx context.Context, name string, permissions permission.Service, c *client.Client, cfg *config.Config) []tools.BaseTool {
result, err := c.ListTools(ctx, mcp.ListToolsRequest{})
if err != nil {
slog.Error("error listing tools", "error", err)
@@ -239,7 +185,7 @@ func getTools(ctx context.Context, name string, permissions permission.Service,
mcpName: name,
tool: tool,
permissions: permissions,
- workingDir: workingDir,
+ cfg: cfg,
})
}
return mcpTools
@@ -348,7 +294,7 @@ func doGetMCPTools(ctx context.Context, permissions permission.Service, cfg *con
}
mcpClients.Set(name, c)
- tools := getTools(ctx, name, permissions, c, cfg.WorkingDir())
+ tools := getTools(ctx, name, permissions, c, cfg)
updateMCPState(name, MCPStateConnected, nil, c, len(tools))
result.Append(tools...)
}(name, m)
@@ -14,7 +14,7 @@ import (
"github.com/charmbracelet/crush/internal/llm/tools"
)
-func CoderPrompt(p string, contextFiles ...string) string {
+func CoderPrompt(cfg *config.Config, p string, contextFiles ...string) string {
var basePrompt string
basePrompt = string(anthropicCoderPrompt)
@@ -28,11 +28,11 @@ func CoderPrompt(p string, contextFiles ...string) string {
if ok, _ := strconv.ParseBool(os.Getenv("CRUSH_CODER_V2")); ok {
basePrompt = string(coderV2Prompt)
}
- envInfo := getEnvironmentInfo()
+ envInfo := getEnvironmentInfo(cfg)
- basePrompt = fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation())
+ basePrompt = fmt.Sprintf("%s\n\n%s\n%s", basePrompt, envInfo, lspInformation(cfg))
- contextContent := getContextFromPaths(config.Get().WorkingDir(), contextFiles)
+ contextContent := getContextFromPaths(cfg.WorkingDir(), contextFiles)
if contextContent != "" {
return fmt.Sprintf("%s\n\n# Project-Specific Context\n Make sure to follow the instructions in the context below\n%s", basePrompt, contextContent)
}
@@ -48,8 +48,8 @@ var geminiCoderPrompt []byte
//go:embed v2.md
var coderV2Prompt []byte
-func getEnvironmentInfo() string {
- cwd := config.Get().WorkingDir()
+func getEnvironmentInfo(cfg *config.Config) string {
+ cwd := cfg.WorkingDir()
isGit := isGitRepo(cwd)
platform := runtime.GOOS
date := time.Now().Format("1/2/2006")
@@ -72,8 +72,7 @@ func isGitRepo(dir string) bool {
return err == nil
}
-func lspInformation() string {
- cfg := config.Get()
+func lspInformation(cfg *config.Config) string {
hasLSP := false
for _, v := range cfg.LSP {
if !v.Disabled {
@@ -22,15 +22,15 @@ const (
PromptDefault PromptID = "default"
)
-func GetPrompt(promptID PromptID, provider string, contextPaths ...string) string {
+func GetPrompt(cfg *config.Config, promptID PromptID, provider string, contextPaths ...string) string {
basePrompt := ""
switch promptID {
case PromptCoder:
- basePrompt = CoderPrompt(provider, contextPaths...)
+ basePrompt = CoderPrompt(cfg, provider, contextPaths...)
case PromptTitle:
basePrompt = TitlePrompt()
case PromptTask:
- basePrompt = TaskPrompt()
+ basePrompt = TaskPrompt(cfg)
case PromptSummarizer:
basePrompt = SummarizerPrompt()
default:
@@ -2,14 +2,16 @@ package prompt
import (
"fmt"
+
+ "github.com/charmbracelet/crush/internal/config"
)
-func TaskPrompt() string {
+func TaskPrompt(cfg *config.Config) string {
agentPrompt := `You are an agent for Crush. Given the user's prompt, you should use the tools available to you to answer the user's question.
Notes:
1. IMPORTANT: You should be concise, direct, and to the point, since your responses will be displayed on a command line interface. Answer the user's question directly, without elaboration, explanation, or details. One word answers are best. Avoid introductions, conclusions, and explanations. You MUST avoid text before/after your response, such as "The answer is <answer>.", "Here is the content of the file..." or "Based on the information provided, the answer is..." or "Here is what I will do next...".
2. When relevant, share file names and code snippets relevant to the query
3. Any file paths you return in your final response MUST be absolute. DO NOT use relative paths.`
- return fmt.Sprintf("%s\n%s\n", agentPrompt, getEnvironmentInfo())
+ return fmt.Sprintf("%s\n%s\n", agentPrompt, getEnvironmentInfo(cfg))
}
@@ -80,13 +80,13 @@ func createAnthropicClient(opts providerClientOptions, tp AnthropicClientType) a
}
if opts.baseURL != "" {
- resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
+ resolvedBaseURL, err := opts.cfg.Resolve(opts.baseURL)
if err == nil && resolvedBaseURL != "" {
anthropicClientOptions = append(anthropicClientOptions, option.WithBaseURL(resolvedBaseURL))
}
}
- if config.Get().Options.Debug {
+ if opts.cfg.Options.Debug {
httpClient := log.NewHTTPClient()
anthropicClientOptions = append(anthropicClientOptions, option.WithHTTPClient(httpClient))
}
@@ -223,7 +223,7 @@ func (a *anthropicClient) finishReason(reason string) message.FinishReason {
}
func (a *anthropicClient) isThinkingEnabled() bool {
- cfg := config.Get()
+ cfg := a.providerOptions.cfg
modelConfig := cfg.Models[config.SelectedModelTypeLarge]
if a.providerOptions.modelType == config.SelectedModelTypeSmall {
modelConfig = cfg.Models[config.SelectedModelTypeSmall]
@@ -234,7 +234,7 @@ func (a *anthropicClient) isThinkingEnabled() bool {
func (a *anthropicClient) preparedMessages(messages []anthropic.MessageParam, tools []anthropic.ToolUnionParam) anthropic.MessageNewParams {
model := a.providerOptions.model(a.providerOptions.modelType)
var thinkingParam anthropic.ThinkingConfigParamUnion
- cfg := config.Get()
+ cfg := a.providerOptions.cfg
modelConfig := cfg.Models[config.SelectedModelTypeLarge]
if a.providerOptions.modelType == config.SelectedModelTypeSmall {
modelConfig = cfg.Models[config.SelectedModelTypeSmall]
@@ -493,7 +493,7 @@ func (a *anthropicClient) shouldRetry(attempts int, err error) (bool, int64, err
}
if apiErr.StatusCode == 401 {
- a.providerOptions.apiKey, err = config.Get().Resolve(a.providerOptions.config.APIKey)
+ a.providerOptions.apiKey, err = a.providerOptions.cfg.Resolve(a.providerOptions.config.APIKey)
if err != nil {
return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
}
@@ -1,7 +1,6 @@
package provider
import (
- "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/log"
"github.com/openai/openai-go"
"github.com/openai/openai-go/azure"
@@ -24,7 +23,7 @@ func newAzureClient(opts providerClientOptions) AzureClient {
azure.WithEndpoint(opts.baseURL, apiVersion),
}
- if config.Get().Options.Debug {
+ if opts.cfg.Options.Debug {
httpClient := log.NewHTTPClient()
reqOpts = append(reqOpts, option.WithHTTPClient(httpClient))
}
@@ -33,7 +33,7 @@ func newBedrockClient(opts providerClientOptions) BedrockClient {
}
opts.model = func(modelType config.SelectedModelType) catwalk.Model {
- model := config.Get().GetModelByType(modelType)
+ model := opts.cfg.GetModelByType(modelType)
// Prefix the model name with region
regionPrefix := region[:2]
@@ -44,7 +44,7 @@ func createGeminiClient(opts providerClientOptions) (*genai.Client, error) {
APIKey: opts.apiKey,
Backend: genai.BackendGeminiAPI,
}
- if config.Get().Options.Debug {
+ if opts.cfg.Options.Debug {
cc.HTTPClient = log.NewHTTPClient()
}
client, err := genai.NewClient(context.Background(), cc)
@@ -178,7 +178,7 @@ func (g *geminiClient) send(ctx context.Context, messages []message.Message, too
// Convert messages
geminiMessages := g.convertMessages(messages)
model := g.providerOptions.model(g.providerOptions.modelType)
- cfg := config.Get()
+ cfg := g.providerOptions.cfg
modelConfig := cfg.Models[config.SelectedModelTypeLarge]
if g.providerOptions.modelType == config.SelectedModelTypeSmall {
@@ -274,7 +274,7 @@ func (g *geminiClient) stream(ctx context.Context, messages []message.Message, t
geminiMessages := g.convertMessages(messages)
model := g.providerOptions.model(g.providerOptions.modelType)
- cfg := config.Get()
+ cfg := g.providerOptions.cfg
modelConfig := cfg.Models[config.SelectedModelTypeLarge]
if g.providerOptions.modelType == config.SelectedModelTypeSmall {
@@ -433,7 +433,7 @@ func (g *geminiClient) shouldRetry(attempts int, err error) (bool, int64, error)
// Check for token expiration (401 Unauthorized)
if contains(errMsg, "unauthorized", "invalid api key", "api key expired") {
- g.providerOptions.apiKey, err = config.Get().Resolve(g.providerOptions.config.APIKey)
+ g.providerOptions.apiKey, err = g.providerOptions.cfg.Resolve(g.providerOptions.config.APIKey)
if err != nil {
return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
}
@@ -42,13 +42,13 @@ func createOpenAIClient(opts providerClientOptions) openai.Client {
openaiClientOptions = append(openaiClientOptions, option.WithAPIKey(opts.apiKey))
}
if opts.baseURL != "" {
- resolvedBaseURL, err := config.Get().Resolve(opts.baseURL)
+ resolvedBaseURL, err := opts.cfg.Resolve(opts.baseURL)
if err == nil && resolvedBaseURL != "" {
openaiClientOptions = append(openaiClientOptions, option.WithBaseURL(resolvedBaseURL))
}
}
- if config.Get().Options.Debug {
+ if opts.cfg.Options.Debug {
httpClient := log.NewHTTPClient()
openaiClientOptions = append(openaiClientOptions, option.WithHTTPClient(httpClient))
}
@@ -217,7 +217,7 @@ func (o *openaiClient) finishReason(reason string) message.FinishReason {
func (o *openaiClient) preparedParams(messages []openai.ChatCompletionMessageParamUnion, tools []openai.ChatCompletionToolParam) openai.ChatCompletionNewParams {
model := o.providerOptions.model(o.providerOptions.modelType)
- cfg := config.Get()
+ cfg := o.providerOptions.cfg
modelConfig := cfg.Models[config.SelectedModelTypeLarge]
if o.providerOptions.modelType == config.SelectedModelTypeSmall {
@@ -514,7 +514,7 @@ func (o *openaiClient) shouldRetry(attempts int, err error) (bool, int64, error)
if errors.As(err, &apiErr) {
// Check for token expiration (401 Unauthorized)
if apiErr.StatusCode == 401 {
- o.providerOptions.apiKey, err = config.Get().Resolve(o.providerOptions.config.APIKey)
+ o.providerOptions.apiKey, err = o.providerOptions.cfg.Resolve(o.providerOptions.config.APIKey)
if err != nil {
return false, 0, fmt.Errorf("failed to resolve API key: %w", err)
}
@@ -62,6 +62,8 @@ type Provider interface {
}
type providerClientOptions struct {
+ cfg *config.Config
+
baseURL string
config config.ProviderConfig
apiKey string
@@ -139,40 +141,41 @@ func WithMaxTokens(maxTokens int64) ProviderClientOption {
}
}
-func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
+func NewProvider(cfg *config.Config, pcfg config.ProviderConfig, opts ...ProviderClientOption) (Provider, error) {
restore := config.PushPopCrushEnv()
defer restore()
- resolvedAPIKey, err := config.Get().Resolve(cfg.APIKey)
+ resolvedAPIKey, err := cfg.Resolve(pcfg.APIKey)
if err != nil {
- return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", cfg.ID, err)
+ return nil, fmt.Errorf("failed to resolve API key for provider %s: %w", pcfg.ID, err)
}
// Resolve extra headers
resolvedExtraHeaders := make(map[string]string)
- for key, value := range cfg.ExtraHeaders {
- resolvedValue, err := config.Get().Resolve(value)
+ for key, value := range pcfg.ExtraHeaders {
+ resolvedValue, err := cfg.Resolve(value)
if err != nil {
- return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, cfg.ID, err)
+ return nil, fmt.Errorf("failed to resolve extra header %s for provider %s: %w", key, pcfg.ID, err)
}
resolvedExtraHeaders[key] = resolvedValue
}
clientOptions := providerClientOptions{
- baseURL: cfg.BaseURL,
- config: cfg,
+ cfg: cfg,
+ baseURL: pcfg.BaseURL,
+ config: pcfg,
apiKey: resolvedAPIKey,
extraHeaders: resolvedExtraHeaders,
- extraBody: cfg.ExtraBody,
- extraParams: cfg.ExtraParams,
- systemPromptPrefix: cfg.SystemPromptPrefix,
+ extraBody: pcfg.ExtraBody,
+ extraParams: pcfg.ExtraParams,
+ systemPromptPrefix: pcfg.SystemPromptPrefix,
model: func(tp config.SelectedModelType) catwalk.Model {
- return *config.Get().GetModelByType(tp)
+ return *cfg.GetModelByType(tp)
},
}
for _, o := range opts {
o(&clientOptions)
}
- switch cfg.Type {
+ switch pcfg.Type {
case catwalk.TypeAnthropic:
return &baseProvider[AnthropicClient]{
options: clientOptions,
@@ -204,5 +207,5 @@ func NewProvider(cfg config.ProviderConfig, opts ...ProviderClientOption) (Provi
client: newVertexAIClient(clientOptions),
}, nil
}
- return nil, fmt.Errorf("provider not supported: %s", cfg.Type)
+ return nil, fmt.Errorf("provider not supported: %s", pcfg.Type)
}
@@ -5,7 +5,6 @@ import (
"log/slog"
"strings"
- "github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/log"
"google.golang.org/genai"
)
@@ -20,7 +19,7 @@ func newVertexAIClient(opts providerClientOptions) VertexAIClient {
Location: location,
Backend: genai.BackendVertexAI,
}
- if config.Get().Options.Debug {
+ if opts.cfg.Options.Debug {
cc.HTTPClient = log.NewHTTPClient()
}
client, err := genai.NewClient(context.Background(), cc)
@@ -23,6 +23,7 @@ import (
type Client struct {
client *powernap.Client
+ cfg *config.Config
name string
// File types this LSP server handles (e.g., .go, .rs, .py)
@@ -45,7 +46,7 @@ type Client struct {
}
// New creates a new LSP client using the powernap implementation.
-func New(ctx context.Context, name string, config config.LSPConfig) (*Client, error) {
+func New(ctx context.Context, cfg *config.Config, name string, lspCfg config.LSPConfig) (*Client, error) {
// Convert working directory to file URI
workDir, err := os.Getwd()
if err != nil {
@@ -56,16 +57,16 @@ func New(ctx context.Context, name string, config config.LSPConfig) (*Client, er
// Create powernap client config
clientConfig := powernap.ClientConfig{
- Command: home.Long(config.Command),
- Args: config.Args,
+ Command: home.Long(lspCfg.Command),
+ Args: lspCfg.Args,
RootURI: rootURI,
Environment: func() map[string]string {
env := make(map[string]string)
- maps.Copy(env, config.Env)
+ maps.Copy(env, lspCfg.Env)
return env
}(),
- Settings: config.Options,
- InitOptions: config.InitOptions,
+ Settings: lspCfg.Options,
+ InitOptions: lspCfg.InitOptions,
WorkspaceFolders: []protocol.WorkspaceFolder{
{
URI: rootURI,
@@ -81,12 +82,13 @@ func New(ctx context.Context, name string, config config.LSPConfig) (*Client, er
}
client := &Client{
+ cfg: cfg,
client: powernapClient,
name: name,
- fileTypes: config.FileTypes,
+ fileTypes: lspCfg.FileTypes,
diagnostics: csync.NewVersionedMap[protocol.DocumentURI, []protocol.Diagnostic](),
openFiles: csync.NewMap[string, *OpenFileInfo](),
- config: config,
+ config: lspCfg,
}
// Initialize server state
@@ -214,8 +216,6 @@ func (c *Client) SetDiagnosticsCallback(callback func(name string, count int)) {
// WaitForServerReady waits for the server to be ready
func (c *Client) WaitForServerReady(ctx context.Context) error {
- cfg := config.Get()
-
// Set initial state
c.SetServerState(StateStarting)
@@ -227,7 +227,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
ticker := time.NewTicker(500 * time.Millisecond)
defer ticker.Stop()
- if cfg != nil && cfg.Options.DebugLSP {
+ if c.cfg != nil && c.cfg.Options.DebugLSP {
slog.Debug("Waiting for LSP server to be ready...")
}
@@ -241,7 +241,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
case <-ticker.C:
// Check if client is running
if !c.client.IsRunning() {
- if cfg != nil && cfg.Options.DebugLSP {
+ if c.cfg != nil && c.cfg.Options.DebugLSP {
slog.Debug("LSP server not ready yet", "server", c.name)
}
continue
@@ -249,7 +249,7 @@ func (c *Client) WaitForServerReady(ctx context.Context) error {
// Server is ready
c.SetServerState(StateReady)
- if cfg != nil && cfg.Options.DebugLSP {
+ if c.cfg != nil && c.cfg.Options.DebugLSP {
slog.Debug("LSP server is ready")
}
return nil
@@ -349,14 +349,13 @@ func (c *Client) NotifyChange(ctx context.Context, filepath string) error {
//
// NOTE: this is only ever called on LSP shutdown.
func (c *Client) CloseFile(ctx context.Context, filepath string) error {
- cfg := config.Get()
uri := string(protocol.URIFromPath(filepath))
if _, exists := c.openFiles.Get(uri); !exists {
return nil // Already closed
}
- if cfg.Options.DebugLSP {
+ if c.cfg.Options.DebugLSP {
slog.Debug("Closing file", "file", filepath)
}
@@ -378,7 +377,6 @@ func (c *Client) IsFileOpen(filepath string) bool {
// CloseAllFiles closes all currently open files.
func (c *Client) CloseAllFiles(ctx context.Context) {
- cfg := config.Get()
filesToClose := make([]string, 0, c.openFiles.Len())
// First collect all URIs that need to be closed
@@ -395,12 +393,12 @@ func (c *Client) CloseAllFiles(ctx context.Context) {
// Then close them all
for _, filePath := range filesToClose {
err := c.CloseFile(ctx, filePath)
- if err != nil && cfg != nil && cfg.Options.DebugLSP {
+ if err != nil && c.cfg != nil && c.cfg.Options.DebugLSP {
slog.Warn("Error closing file", "file", filePath, "error", err)
}
}
- if cfg != nil && cfg.Options.DebugLSP {
+ if c.cfg != nil && c.cfg.Options.DebugLSP {
slog.Debug("Closed all files", "files", filesToClose)
}
}
@@ -11,7 +11,8 @@ func TestPowernapClient(t *testing.T) {
ctx := context.Background()
// Create a simple config for testing
- cfg := config.LSPConfig{
+ var cfg config.Config
+ lspCfg := config.LSPConfig{
Command: "echo", // Use echo as a dummy command that won't fail
Args: []string{"hello"},
FileTypes: []string{"go"},
@@ -20,7 +21,7 @@ func TestPowernapClient(t *testing.T) {
// Test creating a powernap client - this will likely fail with echo
// but we can still test the basic structure
- client, err := New(ctx, "test", cfg)
+ client, err := New(ctx, &cfg, "test", lspCfg)
if err != nil {
// Expected to fail with echo command, skip the rest
t.Skipf("Powernap client creation failed as expected with dummy command: %v", err)
@@ -79,9 +79,9 @@ func notifyFileWatchRegistration(id string, watchers []protocol.FileSystemWatche
}
// HandleServerMessage handles server messages
-func HandleServerMessage(_ context.Context, method string, params json.RawMessage) {
- cfg := config.Get()
- if !cfg.Options.DebugLSP {
+func HandleServerMessage(ctx context.Context, method string, params json.RawMessage) {
+ cfg, ok := config.FromContext(ctx)
+ if !ok || !cfg.Options.DebugLSP {
return
}
@@ -1,47 +0,0 @@
-package message
-
-import (
- "encoding/base64"
- "encoding/json"
-)
-
-type Attachment struct {
- FilePath string `json:"file_path"`
- FileName string `json:"file_name"`
- MimeType string `json:"mime_type"`
- Content []byte `json:"content"`
-}
-
-// MarshalJSON implements the [json.Marshaler] interface.
-func (a Attachment) MarshalJSON() ([]byte, error) {
- // Encode the content as a base64 string
- type Alias Attachment
- return json.Marshal(&struct {
- Content string `json:"content"`
- *Alias
- }{
- Content: base64.StdEncoding.EncodeToString(a.Content),
- Alias: (*Alias)(&a),
- })
-}
-
-// UnmarshalJSON implements the [json.Unmarshaler] interface.
-func (a *Attachment) UnmarshalJSON(data []byte) error {
- // Decode the content from a base64 string
- type Alias Attachment
- aux := &struct {
- Content string `json:"content"`
- *Alias
- }{
- Alias: (*Alias)(a),
- }
- if err := json.Unmarshal(data, &aux); err != nil {
- return err
- }
- content, err := base64.StdEncoding.DecodeString(aux.Content)
- if err != nil {
- return err
- }
- a.Content = content
- return nil
-}
@@ -3,21 +3,42 @@ package message
import (
"context"
"database/sql"
- "encoding/json"
- "fmt"
"time"
"github.com/charmbracelet/crush/internal/db"
+ "github.com/charmbracelet/crush/internal/proto"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/google/uuid"
)
-type CreateMessageParams struct {
- Role MessageRole `json:"role"`
- Parts []ContentPart `json:"parts"`
- Model string `json:"model"`
- Provider string `json:"provider,omitempty"`
-}
+type (
+ CreateMessageParams = proto.CreateMessageParams
+ Message = proto.Message
+ Attachment = proto.Attachment
+ ToolCall = proto.ToolCall
+ ToolResult = proto.ToolResult
+ ContentPart = proto.ContentPart
+ TextContent = proto.TextContent
+ BinaryContent = proto.BinaryContent
+ FinishReason = proto.FinishReason
+ Finish = proto.Finish
+)
+
+const (
+ Assistant = proto.Assistant
+ User = proto.User
+ System = proto.System
+ Tool = proto.Tool
+
+ FinishReasonEndTurn = proto.FinishReasonEndTurn
+ FinishReasonMaxTokens = proto.FinishReasonMaxTokens
+ FinishReasonToolUse = proto.FinishReasonToolUse
+ FinishReasonCanceled = proto.FinishReasonCanceled
+ FinishReasonError = proto.FinishReasonError
+ FinishReasonPermissionDenied = proto.FinishReasonPermissionDenied
+
+ FinishReasonUnknown = proto.FinishReasonUnknown
+)
type Service interface {
pubsub.Suscriber[Message]
@@ -55,12 +76,12 @@ func (s *service) Delete(ctx context.Context, id string) error {
}
func (s *service) Create(ctx context.Context, sessionID string, params CreateMessageParams) (Message, error) {
- if params.Role != Assistant {
- params.Parts = append(params.Parts, Finish{
+ if params.Role != proto.Assistant {
+ params.Parts = append(params.Parts, proto.Finish{
Reason: "stop",
})
}
- partsJSON, err := marshallParts(params.Parts)
+ partsJSON, err := proto.MarshallParts(params.Parts)
if err != nil {
return Message{}, err
}
@@ -100,7 +121,7 @@ func (s *service) DeleteSessionMessages(ctx context.Context, sessionID string) e
}
func (s *service) Update(ctx context.Context, message Message) error {
- parts, err := marshallParts(message.Parts)
+ parts, err := proto.MarshallParts(message.Parts)
if err != nil {
return err
}
@@ -146,14 +167,14 @@ func (s *service) List(ctx context.Context, sessionID string) ([]Message, error)
}
func (s *service) fromDBItem(item db.Message) (Message, error) {
- parts, err := unmarshallParts([]byte(item.Parts))
+ parts, err := proto.UnmarshallParts([]byte(item.Parts))
if err != nil {
return Message{}, err
}
return Message{
ID: item.ID,
SessionID: item.SessionID,
- Role: MessageRole(item.Role),
+ Role: proto.MessageRole(item.Role),
Parts: parts,
Model: item.Model.String,
Provider: item.Provider.String,
@@ -161,122 +182,3 @@ func (s *service) fromDBItem(item db.Message) (Message, error) {
UpdatedAt: item.UpdatedAt,
}, nil
}
-
-type partType string
-
-const (
- reasoningType partType = "reasoning"
- textType partType = "text"
- imageURLType partType = "image_url"
- binaryType partType = "binary"
- toolCallType partType = "tool_call"
- toolResultType partType = "tool_result"
- finishType partType = "finish"
-)
-
-type partWrapper struct {
- Type partType `json:"type"`
- Data ContentPart `json:"data"`
-}
-
-func marshallParts(parts []ContentPart) ([]byte, error) {
- wrappedParts := make([]partWrapper, len(parts))
-
- for i, part := range parts {
- var typ partType
-
- switch part.(type) {
- case ReasoningContent:
- typ = reasoningType
- case TextContent:
- typ = textType
- case ImageURLContent:
- typ = imageURLType
- case BinaryContent:
- typ = binaryType
- case ToolCall:
- typ = toolCallType
- case ToolResult:
- typ = toolResultType
- case Finish:
- typ = finishType
- default:
- return nil, fmt.Errorf("unknown part type: %T", part)
- }
-
- wrappedParts[i] = partWrapper{
- Type: typ,
- Data: part,
- }
- }
- return json.Marshal(wrappedParts)
-}
-
-func unmarshallParts(data []byte) ([]ContentPart, error) {
- temp := []json.RawMessage{}
-
- if err := json.Unmarshal(data, &temp); err != nil {
- return nil, err
- }
-
- parts := make([]ContentPart, 0)
-
- for _, rawPart := range temp {
- var wrapper struct {
- Type partType `json:"type"`
- Data json.RawMessage `json:"data"`
- }
-
- if err := json.Unmarshal(rawPart, &wrapper); err != nil {
- return nil, err
- }
-
- switch wrapper.Type {
- case reasoningType:
- part := ReasoningContent{}
- if err := json.Unmarshal(wrapper.Data, &part); err != nil {
- return nil, err
- }
- parts = append(parts, part)
- case textType:
- part := TextContent{}
- if err := json.Unmarshal(wrapper.Data, &part); err != nil {
- return nil, err
- }
- parts = append(parts, part)
- case imageURLType:
- part := ImageURLContent{}
- if err := json.Unmarshal(wrapper.Data, &part); err != nil {
- return nil, err
- }
- case binaryType:
- part := BinaryContent{}
- if err := json.Unmarshal(wrapper.Data, &part); err != nil {
- return nil, err
- }
- parts = append(parts, part)
- case toolCallType:
- part := ToolCall{}
- if err := json.Unmarshal(wrapper.Data, &part); err != nil {
- return nil, err
- }
- parts = append(parts, part)
- case toolResultType:
- part := ToolResult{}
- if err := json.Unmarshal(wrapper.Data, &part); err != nil {
- return nil, err
- }
- parts = append(parts, part)
- case finishType:
- part := Finish{}
- if err := json.Unmarshal(wrapper.Data, &part); err != nil {
- return nil, err
- }
- parts = append(parts, part)
- default:
- return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
- }
- }
-
- return parts, nil
-}
@@ -9,38 +9,18 @@ import (
"sync"
"github.com/charmbracelet/crush/internal/csync"
+ "github.com/charmbracelet/crush/internal/proto"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/google/uuid"
)
var ErrorPermissionDenied = errors.New("permission denied")
-type CreatePermissionRequest struct {
- SessionID string `json:"session_id"`
- ToolCallID string `json:"tool_call_id"`
- ToolName string `json:"tool_name"`
- Description string `json:"description"`
- Action string `json:"action"`
- Params any `json:"params"`
- Path string `json:"path"`
-}
-
-type PermissionNotification struct {
- ToolCallID string `json:"tool_call_id"`
- Granted bool `json:"granted"`
- Denied bool `json:"denied"`
-}
-
-type PermissionRequest struct {
- ID string `json:"id"`
- SessionID string `json:"session_id"`
- ToolCallID string `json:"tool_call_id"`
- ToolName string `json:"tool_name"`
- Description string `json:"description"`
- Action string `json:"action"`
- Params any `json:"params"`
- Path string `json:"path"`
-}
+type (
+ PermissionRequest = proto.PermissionRequest
+ PermissionNotification = proto.PermissionNotification
+ CreatePermissionRequest = proto.CreatePermissionRequest
+)
type Service interface {
pubsub.Suscriber[PermissionRequest]
@@ -0,0 +1,70 @@
+package proto
+
+import (
+ "encoding/json"
+ "errors"
+)
+
+type AgentEventType string
+
+const (
+ AgentEventTypeError AgentEventType = "error"
+ AgentEventTypeResponse AgentEventType = "response"
+ AgentEventTypeSummarize AgentEventType = "summarize"
+)
+
+func (t AgentEventType) MarshalText() ([]byte, error) {
+ return []byte(t), nil
+}
+
+func (t *AgentEventType) UnmarshalText(text []byte) error {
+ *t = AgentEventType(text)
+ return nil
+}
+
+type AgentEvent struct {
+ Type AgentEventType `json:"type"`
+ Message Message `json:"message"`
+ Error error `json:"error,omitempty"`
+
+ // When summarizing
+ SessionID string `json:"session_id,omitempty"`
+ Progress string `json:"progress,omitempty"`
+ Done bool `json:"done,omitempty"`
+}
+
+// MarshalJSON implements the [json.Marshaler] interface.
+func (e AgentEvent) MarshalJSON() ([]byte, error) {
+ type Alias AgentEvent
+ return json.Marshal(&struct {
+ Error string `json:"error,omitempty"`
+ Alias
+ }{
+ Error: func() string {
+ if e.Error != nil {
+ return e.Error.Error()
+ }
+ return ""
+ }(),
+ Alias: (Alias)(e),
+ })
+}
+
+// UnmarshalJSON implements the [json.Unmarshaler] interface.
+func (e *AgentEvent) UnmarshalJSON(data []byte) error {
+ type Alias AgentEvent
+ aux := &struct {
+ Error string `json:"error,omitempty"`
+ Alias
+ }{
+ Alias: (Alias)(*e),
+ }
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+ *e = AgentEvent(aux.Alias)
+ if aux.Error != "" {
+ e.Error = errors.New(aux.Error)
+ }
+ return nil
+}
@@ -0,0 +1,11 @@
+package proto
+
+type File struct {
+ ID string `json:"id"`
+ SessionID string `json:"session_id"`
+ Path string `json:"path"`
+ Content string `json:"content"`
+ Version int64 `json:"version"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
+}
@@ -0,0 +1,73 @@
+package proto
+
+import "fmt"
+
+// MCPState represents the current state of an MCP client
+type MCPState int
+
+const (
+ MCPStateDisabled MCPState = iota
+ MCPStateStarting
+ MCPStateConnected
+ MCPStateError
+)
+
+func (s MCPState) MarshalText() ([]byte, error) {
+ return []byte(s.String()), nil
+}
+
+func (s *MCPState) UnmarshalText(data []byte) error {
+ switch string(data) {
+ case "disabled":
+ *s = MCPStateDisabled
+ case "starting":
+ *s = MCPStateStarting
+ case "connected":
+ *s = MCPStateConnected
+ case "error":
+ *s = MCPStateError
+ default:
+ return fmt.Errorf("unknown mcp state: %s", data)
+ }
+ return nil
+}
+
+func (s MCPState) String() string {
+ switch s {
+ case MCPStateDisabled:
+ return "disabled"
+ case MCPStateStarting:
+ return "starting"
+ case MCPStateConnected:
+ return "connected"
+ case MCPStateError:
+ return "error"
+ default:
+ return "unknown"
+ }
+}
+
+// MCPEventType represents the type of MCP event
+type MCPEventType string
+
+const (
+ MCPEventStateChanged MCPEventType = "state_changed"
+)
+
+func (t MCPEventType) MarshalText() ([]byte, error) {
+ return []byte(t), nil
+}
+
+func (t *MCPEventType) UnmarshalText(data []byte) error {
+ *t = MCPEventType(data)
+ return nil
+}
+
+// MCPEvent represents an event in the MCP system
+type MCPEvent struct {
+ Type MCPEventType `json:"type"`
+ Name string `json:"name"`
+ State MCPState `json:"state"`
+ Error error `json:"error,omitempty"`
+ ToolCount int `json:"tool_count,omitempty"`
+}
@@ -1,14 +1,32 @@
-package message
+package proto
import (
"encoding/base64"
"encoding/json"
+ "fmt"
"slices"
"time"
"github.com/charmbracelet/catwalk/pkg/catwalk"
)
+type CreateMessageParams struct {
+ Role MessageRole `json:"role"`
+ Parts []ContentPart `json:"parts"`
+ Model string `json:"model"`
+ Provider string `json:"provider,omitempty"`
+}
+
+type Message struct {
+ ID string `json:"id"`
+ Role MessageRole `json:"role"`
+ SessionID string `json:"session_id"`
+ Parts []ContentPart `json:"parts"`
+ Model string `json:"model"`
+ Provider string `json:"provider"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
+}
type MessageRole string
const (
@@ -132,22 +150,11 @@ type Finish struct {
func (Finish) isPart() {}
-type Message struct {
- ID string `json:"id"`
- Role MessageRole `json:"role"`
- SessionID string `json:"session_id"`
- Parts []ContentPart `json:"parts"`
- Model string `json:"model"`
- Provider string `json:"provider"`
- CreatedAt int64 `json:"created_at"`
- UpdatedAt int64 `json:"updated_at"`
-}
-
// MarshalJSON implements the [json.Marshaler] interface.
func (m Message) MarshalJSON() ([]byte, error) {
// We need to handle the Parts specially since they're ContentPart interfaces
// which can't be directly marshaled by the standard JSON package.
- parts, err := marshallParts(m.Parts)
+ parts, err := MarshallParts(m.Parts)
if err != nil {
return nil, err
}
@@ -179,7 +186,7 @@ func (m *Message) UnmarshalJSON(data []byte) error {
}
// Unmarshal the parts using our custom function
- parts, err := unmarshallParts([]byte(aux.Parts))
+ parts, err := UnmarshallParts([]byte(aux.Parts))
if err != nil {
return err
}
@@ -448,3 +455,163 @@ func (m *Message) AddImageURL(url, detail string) {
func (m *Message) AddBinary(mimeType string, data []byte) {
m.Parts = append(m.Parts, BinaryContent{MIMEType: mimeType, Data: data})
}
+
+type partType string
+
+const (
+ reasoningType partType = "reasoning"
+ textType partType = "text"
+ imageURLType partType = "image_url"
+ binaryType partType = "binary"
+ toolCallType partType = "tool_call"
+ toolResultType partType = "tool_result"
+ finishType partType = "finish"
+)
+
+type partWrapper struct {
+ Type partType `json:"type"`
+ Data ContentPart `json:"data"`
+}
+
+func MarshallParts(parts []ContentPart) ([]byte, error) {
+ wrappedParts := make([]partWrapper, len(parts))
+
+ for i, part := range parts {
+ var typ partType
+
+ switch part.(type) {
+ case ReasoningContent:
+ typ = reasoningType
+ case TextContent:
+ typ = textType
+ case ImageURLContent:
+ typ = imageURLType
+ case BinaryContent:
+ typ = binaryType
+ case ToolCall:
+ typ = toolCallType
+ case ToolResult:
+ typ = toolResultType
+ case Finish:
+ typ = finishType
+ default:
+ return nil, fmt.Errorf("unknown part type: %T", part)
+ }
+
+ wrappedParts[i] = partWrapper{
+ Type: typ,
+ Data: part,
+ }
+ }
+ return json.Marshal(wrappedParts)
+}
+
+func UnmarshallParts(data []byte) ([]ContentPart, error) {
+ temp := []json.RawMessage{}
+
+ if err := json.Unmarshal(data, &temp); err != nil {
+ return nil, err
+ }
+
+ parts := make([]ContentPart, 0)
+
+ for _, rawPart := range temp {
+ var wrapper struct {
+ Type partType `json:"type"`
+ Data json.RawMessage `json:"data"`
+ }
+
+ if err := json.Unmarshal(rawPart, &wrapper); err != nil {
+ return nil, err
+ }
+
+ switch wrapper.Type {
+ case reasoningType:
+ part := ReasoningContent{}
+ if err := json.Unmarshal(wrapper.Data, &part); err != nil {
+ return nil, err
+ }
+ parts = append(parts, part)
+ case textType:
+ part := TextContent{}
+ if err := json.Unmarshal(wrapper.Data, &part); err != nil {
+ return nil, err
+ }
+ parts = append(parts, part)
+ case imageURLType:
+ part := ImageURLContent{}
+ if err := json.Unmarshal(wrapper.Data, &part); err != nil {
+ return nil, err
+ }
+ case binaryType:
+ part := BinaryContent{}
+ if err := json.Unmarshal(wrapper.Data, &part); err != nil {
+ return nil, err
+ }
+ parts = append(parts, part)
+ case toolCallType:
+ part := ToolCall{}
+ if err := json.Unmarshal(wrapper.Data, &part); err != nil {
+ return nil, err
+ }
+ parts = append(parts, part)
+ case toolResultType:
+ part := ToolResult{}
+ if err := json.Unmarshal(wrapper.Data, &part); err != nil {
+ return nil, err
+ }
+ parts = append(parts, part)
+ case finishType:
+ part := Finish{}
+ if err := json.Unmarshal(wrapper.Data, &part); err != nil {
+ return nil, err
+ }
+ parts = append(parts, part)
+ default:
+ return nil, fmt.Errorf("unknown part type: %s", wrapper.Type)
+ }
+ }
+
+ return parts, nil
+}
+
+type Attachment struct {
+ FilePath string `json:"file_path"`
+ FileName string `json:"file_name"`
+ MimeType string `json:"mime_type"`
+ Content []byte `json:"content"`
+}
+
+// MarshalJSON implements the [json.Marshaler] interface.
+func (a Attachment) MarshalJSON() ([]byte, error) {
+ // Encode the content as a base64 string
+ type Alias Attachment
+ return json.Marshal(&struct {
+ Content string `json:"content"`
+ *Alias
+ }{
+ Content: base64.StdEncoding.EncodeToString(a.Content),
+ Alias: (*Alias)(&a),
+ })
+}
+
+// UnmarshalJSON implements the [json.Unmarshaler] interface.
+func (a *Attachment) UnmarshalJSON(data []byte) error {
+ // Decode the content from a base64 string
+ type Alias Attachment
+ aux := &struct {
+ Content string `json:"content"`
+ *Alias
+ }{
+ Alias: (*Alias)(a),
+ }
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+ content, err := base64.StdEncoding.DecodeString(aux.Content)
+ if err != nil {
+ return err
+ }
+ a.Content = content
+ return nil
+}
@@ -0,0 +1,28 @@
+package proto
+
+type CreatePermissionRequest struct {
+ SessionID string `json:"session_id"`
+ ToolCallID string `json:"tool_call_id"`
+ ToolName string `json:"tool_name"`
+ Description string `json:"description"`
+ Action string `json:"action"`
+ Params any `json:"params"`
+ Path string `json:"path"`
+}
+
+type PermissionNotification struct {
+ ToolCallID string `json:"tool_call_id"`
+ Granted bool `json:"granted"`
+ Denied bool `json:"denied"`
+}
+
+type PermissionRequest struct {
+ ID string `json:"id"`
+ SessionID string `json:"session_id"`
+ ToolCallID string `json:"tool_call_id"`
+ ToolName string `json:"tool_name"`
+ Description string `json:"description"`
+ Action string `json:"action"`
+ Params any `json:"params"`
+ Path string `json:"path"`
+}
@@ -1,8 +1,10 @@
package proto
import (
+ "time"
+
"github.com/charmbracelet/catwalk/pkg/catwalk"
- "github.com/charmbracelet/crush/internal/message"
+ "github.com/charmbracelet/crush/internal/lsp"
)
// Instance represents a running app.App instance with its associated resources
@@ -33,7 +35,84 @@ func (a AgentInfo) IsZero() bool {
// AgentMessage represents a message sent to the agent.
type AgentMessage struct {
- SessionID string `json:"session_id"`
- Prompt string `json:"prompt"`
- Attachments []message.Attachment `json:"attachments,omitempty"`
+ SessionID string `json:"session_id"`
+ Prompt string `json:"prompt"`
+ Attachments []Attachment `json:"attachments,omitempty"`
+}
+
+// AgentSession represents a session with its busy status.
+type AgentSession struct {
+ Session
+ IsBusy bool `json:"is_busy"`
+}
+
+// IsZero checks if the AgentSession is zero-valued.
+func (a AgentSession) IsZero() bool {
+ return a == AgentSession{}
+}
+
+type PermissionAction string
+
+// Permission responses
+const (
+ PermissionAllow PermissionAction = "allow"
+ PermissionAllowForSession PermissionAction = "allow_session"
+ PermissionDeny PermissionAction = "deny"
+)
+
+// MarshalText implements the [encoding.TextMarshaler] interface.
+func (p PermissionAction) MarshalText() ([]byte, error) {
+ return []byte(p), nil
+}
+
+// UnmarshalText implements the [encoding.TextUnmarshaler] interface.
+func (p *PermissionAction) UnmarshalText(text []byte) error {
+ *p = PermissionAction(text)
+ return nil
+}
+
+// PermissionGrant represents a permission grant request.
+type PermissionGrant struct {
+ Permission PermissionRequest `json:"permission"`
+ Action PermissionAction `json:"action"`
+}
+
+// PermissionSkipRequest represents a request to skip permission prompts.
+type PermissionSkipRequest struct {
+ Skip bool `json:"skip"`
+}
+
+// LSPEventType represents the type of LSP event
+type LSPEventType string
+
+const (
+ LSPEventStateChanged LSPEventType = "state_changed"
+ LSPEventDiagnosticsChanged LSPEventType = "diagnostics_changed"
+)
+
+func (e LSPEventType) MarshalText() ([]byte, error) {
+ return []byte(e), nil
+}
+
+func (e *LSPEventType) UnmarshalText(data []byte) error {
+ *e = LSPEventType(data)
+ return nil
+}
+
+// LSPEvent represents an event in the LSP system
+type LSPEvent struct {
+ Type LSPEventType `json:"type"`
+ Name string `json:"name"`
+ State lsp.ServerState `json:"state"`
+ Error error `json:"error,omitempty"`
+ DiagnosticCount int `json:"diagnostic_count,omitempty"`
+}
+
+// LSPClientInfo holds information about an LSP client's state
+type LSPClientInfo struct {
+ Name string `json:"name"`
+ State lsp.ServerState `json:"state"`
+ Error error `json:"error,omitempty"`
+ DiagnosticCount int `json:"diagnostic_count,omitempty"`
+ ConnectedAt time.Time `json:"connected_at"`
}
@@ -0,0 +1,14 @@
+package proto
+
+type Session struct {
+ ID string `json:"id"`
+ ParentSessionID string `json:"parent_session_id"`
+ Title string `json:"title"`
+ MessageCount int64 `json:"message_count"`
+ PromptTokens int64 `json:"prompt_tokens"`
+ CompletionTokens int64 `json:"completion_tokens"`
+ SummaryMessageID string `json:"summary_message_id"`
+ Cost float64 `json:"cost"`
+ CreatedAt int64 `json:"created_at"`
+ UpdatedAt int64 `json:"updated_at"`
+}
@@ -2,6 +2,7 @@ package pubsub
import (
"context"
+ "log/slog"
"sync"
)
@@ -113,6 +114,7 @@ func (b *Broker[T]) Publish(t EventType, payload T) {
default:
// Channel is full, subscriber is slow - skip this event
// This prevents blocking the publisher
+ slog.Warn("Skipping event for slow subscriber", "event_type", t)
}
}
}
@@ -1,6 +1,13 @@
package pubsub
-import "context"
+import (
+ "context"
+ "encoding/json"
+ "fmt"
+ "log/slog"
+
+ "github.com/charmbracelet/crush/internal/proto"
+)
const (
CreatedEvent EventType = "created"
@@ -13,6 +20,13 @@ type Suscriber[T any] interface {
}
type (
+ PayloadType = string
+
+ Payload struct {
+ Type PayloadType `json:"type"`
+ Payload json.RawMessage `json:"payload"`
+ }
+
// EventType identifies the type of event
EventType string
@@ -27,6 +41,17 @@ type (
}
)
+const (
+ PayloadTypeLSPEvent PayloadType = "lsp_event"
+ PayloadTypeMCPEvent PayloadType = "mcp_event"
+ PayloadTypePermissionRequest PayloadType = "permission_request"
+ PayloadTypePermissionNotification PayloadType = "permission_notification"
+ PayloadTypeMessage PayloadType = "message"
+ PayloadTypeSession PayloadType = "session"
+ PayloadTypeFile PayloadType = "file"
+ PayloadTypeAgentEvent PayloadType = "agent_event"
+)
+
func (t EventType) MarshalText() ([]byte, error) {
return []byte(t), nil
}
@@ -35,3 +60,148 @@ func (t *EventType) UnmarshalText(data []byte) error {
*t = EventType(data)
return nil
}
+
+func (e Event[T]) MarshalJSON() ([]byte, error) {
+ type Alias Event[T]
+
+ var (
+ typ string
+ bts []byte
+ err error
+ )
+ switch any(e.Payload).(type) {
+ case proto.LSPEvent:
+ typ = "lsp_event"
+ bts, err = json.Marshal(e.Payload)
+ case proto.MCPEvent:
+ typ = "mcp_event"
+ bts, err = json.Marshal(e.Payload)
+ case proto.PermissionRequest:
+ typ = "permission_request"
+ bts, err = json.Marshal(e.Payload)
+ case proto.PermissionNotification:
+ typ = "permission_notification"
+ bts, err = json.Marshal(e.Payload)
+ case proto.Message:
+ typ = "message"
+ bts, err = json.Marshal(e.Payload)
+ case proto.Session:
+ typ = "session"
+ bts, err = json.Marshal(e.Payload)
+ case proto.File:
+ typ = "file"
+ bts, err = json.Marshal(e.Payload)
+ case proto.AgentEvent:
+ typ = "agent_event"
+ bts, err = json.Marshal(e.Payload)
+ default:
+ panic(fmt.Sprintf("unknown payload type: %T", e.Payload))
+ }
+
+ if err != nil {
+ return nil, err
+ }
+
+ p, err := json.Marshal(&Payload{
+ Type: typ,
+ Payload: bts,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ b, err := json.Marshal(&struct {
+ Payload json.RawMessage `json:"payload"`
+ *Alias
+ }{
+ Payload: json.RawMessage(p),
+ Alias: (*Alias)(&e),
+ })
+
+ // slog.Info("marshalled event", "event", fmt.Sprintf("%q", string(b)))
+
+ return b, err
+}
+
+func (e *Event[T]) UnmarshalJSON(data []byte) error {
+ // slog.Info("unmarshalling event", "data", fmt.Sprintf("%q", string(data)))
+
+ type Alias Event[T]
+ aux := &struct {
+ Payload json.RawMessage `json:"payload"`
+ *Alias
+ }{
+ Alias: (*Alias)(e),
+ }
+
+ if err := json.Unmarshal(data, &aux); err != nil {
+ return err
+ }
+
+ e.Type = aux.Type
+
+ slog.Info("unmarshalled event payload", "aux", fmt.Sprintf("%q", aux.Payload))
+
+ var wp Payload
+ if err := json.Unmarshal(aux.Payload, &wp); err != nil {
+ return err
+ }
+
+ var pl any
+ switch wp.Type {
+ case "lsp_event":
+ var p proto.LSPEvent
+ if err := json.Unmarshal(wp.Payload, &p); err != nil {
+ return err
+ }
+ pl = p
+ case "mcp_event":
+ var p proto.MCPEvent
+ if err := json.Unmarshal(wp.Payload, &p); err != nil {
+ return err
+ }
+ pl = p
+ case "permission_request":
+ var p proto.PermissionRequest
+ if err := json.Unmarshal(wp.Payload, &p); err != nil {
+ return err
+ }
+ pl = p
+ case "permission_notification":
+ var p proto.PermissionNotification
+ if err := json.Unmarshal(wp.Payload, &p); err != nil {
+ return err
+ }
+ pl = p
+ case "message":
+ var p proto.Message
+ if err := json.Unmarshal(wp.Payload, &p); err != nil {
+ return err
+ }
+ pl = p
+ case "session":
+ var p proto.Session
+ if err := json.Unmarshal(wp.Payload, &p); err != nil {
+ return err
+ }
+ pl = p
+ case "file":
+ var p proto.File
+ if err := json.Unmarshal(wp.Payload, &p); err != nil {
+ return err
+ }
+ pl = p
+ case "agent_event":
+ var p proto.AgentEvent
+ if err := json.Unmarshal(wp.Payload, &p); err != nil {
+ return err
+ }
+ pl = p
+ default:
+ panic(fmt.Sprintf("unknown payload type: %q", wp.Type))
+ }
+
+ e.Payload = T(pl.(T))
+
+ return nil
+}
@@ -5,22 +5,12 @@ import (
"database/sql"
"github.com/charmbracelet/crush/internal/db"
+ "github.com/charmbracelet/crush/internal/proto"
"github.com/charmbracelet/crush/internal/pubsub"
"github.com/google/uuid"
)
-type Session struct {
- ID string `json:"id"`
- ParentSessionID string `json:"parent_session_id"`
- Title string `json:"title"`
- MessageCount int64 `json:"message_count"`
- PromptTokens int64 `json:"prompt_tokens"`
- CompletionTokens int64 `json:"completion_tokens"`
- SummaryMessageID string `json:"summary_message_id"`
- Cost float64 `json:"cost"`
- CreatedAt int64 `json:"created_at"`
- UpdatedAt int64 `json:"updated_at"`
-}
+type Session = proto.Session
type Service interface {
pubsub.Suscriber[Session]