diff --git a/CRUSH.md b/CRUSH.md index 102ad43ca5758beee6515ab9da4054ddc92b9a9f..cc1f61bb366ad31c9a7642e19f63cd27c0fe3df0 100644 --- a/CRUSH.md +++ b/CRUSH.md @@ -6,7 +6,7 @@ - **Test**: `task test` or `go test ./...` (run single test: `go test ./internal/llm/prompt -run TestGetContextFromPaths`) - **Update Golden Files**: `go test ./... -update` (regenerates .golden files when test output changes) - Update specific package: `go test ./internal/tui/components/core -update` (in this case, we're updating "core") -- **Lint**: `task lint-fix` +- **Lint**: `task lint:fix` - **Format**: `task fmt` (gofumpt -w .) - **Dev**: `task dev` (runs with profiling enabled) diff --git a/internal/agent/agent.go b/internal/agent/agent.go index a62254675221680a20aef4089485db803eb289de..3188d53097a5ac0089c19a94a2387bec1d0437f9 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -62,36 +62,41 @@ type Model struct { } type sessionAgent struct { - largeModel Model - smallModel Model - systemPrompt string - tools []ai.AgentTool - sessions session.Service - messages message.Service + largeModel Model + smallModel Model + systemPrompt string + tools []ai.AgentTool + sessions session.Service + messages message.Service + disableAutoSummarize bool messageQueue *csync.Map[string, []SessionAgentCall] activeRequests *csync.Map[string, context.CancelFunc] } -type SessionAgentOption func(*sessionAgent) +type SessionAgentOptions struct { + LargeModel Model + SmallModel Model + SystemPrompt string + DisableAutoSummarize bool + Sessions session.Service + Messages message.Service + Tools []ai.AgentTool +} func NewSessionAgent( - largeModel Model, - smallModel Model, - systemPrompt string, - sessions session.Service, - messages message.Service, - tools ...ai.AgentTool, + opts SessionAgentOptions, ) SessionAgent { return &sessionAgent{ - largeModel: largeModel, - smallModel: smallModel, - systemPrompt: systemPrompt, - sessions: sessions, - messages: messages, - tools: tools, - messageQueue: csync.NewMap[string, []SessionAgentCall](), - activeRequests: csync.NewMap[string, context.CancelFunc](), + largeModel: opts.LargeModel, + smallModel: opts.SmallModel, + systemPrompt: opts.SystemPrompt, + sessions: opts.Sessions, + messages: opts.Messages, + disableAutoSummarize: opts.DisableAutoSummarize, + tools: opts.Tools, + messageQueue: csync.NewMap[string, []SessionAgentCall](), + activeRequests: csync.NewMap[string, context.CancelFunc](), } } @@ -164,6 +169,7 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen history, files := a.preparePrompt(msgs, call.Attachments...) var currentAssistant *message.Message + var shouldSummarize bool result, err := agent.Stream(genCtx, ai.AgentStreamCall{ Prompt: call.Prompt, Files: files, @@ -292,16 +298,16 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen IsError: isError, Metadata: result.ClientMetadata, } - _, err := a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{ + _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{ Role: message.Tool, Parts: []message.ContentPart{ toolResult, }, }) - if err != nil { - return err + if createMsgErr != nil { + return createMsgErr } - return a.messages.Update(genCtx, *currentAssistant) + return nil }, OnStepFinish: func(stepResult ai.StepResult) error { finishReason := message.FinishReasonUnknown @@ -323,6 +329,18 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen } return a.messages.Update(genCtx, *currentAssistant) }, + StopWhen: []ai.StopCondition{ + func(_ []ai.StepResult) bool { + contextWindow := a.largeModel.CatwalkCfg.ContextWindow + tokens := currentSession.CompletionTokens + currentSession.PromptTokens + percentage := (float64(tokens) / float64(contextWindow)) * 100 + if (percentage > 80) && !a.disableAutoSummarize { + shouldSummarize = true + return true + } + return false + }, + }, }) if err != nil { isCancelErr := errors.Is(err, context.Canceled) @@ -358,28 +376,29 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen break } } - if !found { - content := "There was an error while executing the tool" - if isCancelErr { - content = "Tool execution canceled by user" - } else if isPermissionErr { - content = "Permission denied" - } - toolResult := message.ToolResult{ - ToolCallID: tc.ID, - Name: tc.Name, - Content: content, - IsError: true, - } - _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{ - Role: message.Tool, - Parts: []message.ContentPart{ - toolResult, - }, - }) - if createErr != nil { - return nil, createErr - } + if found { + continue + } + content := "There was an error while executing the tool" + if isCancelErr { + content = "Tool execution canceled by user" + } else if isPermissionErr { + content = "Permission denied" + } + toolResult := message.ToolResult{ + ToolCallID: tc.ID, + Name: tc.Name, + Content: content, + IsError: true, + } + _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{ + Role: message.Tool, + Parts: []message.ContentPart{ + toolResult, + }, + }) + if createErr != nil { + return nil, createErr } } if isCancelErr { @@ -398,6 +417,13 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.Agen } wg.Wait() + if shouldSummarize { + a.activeRequests.Del(call.SessionID) + if summarizeErr := a.Summarize(genCtx, call.SessionID); summarizeErr != nil { + return nil, summarizeErr + } + } + queuedMessages, ok := a.messageQueue.Get(call.SessionID) if !ok || len(queuedMessages) == 0 { return result, err @@ -437,20 +463,21 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string) error { ai.WithSystemPrompt(string(summaryPrompt)), ) summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{ - Role: message.Assistant, - Model: a.largeModel.Model.Model(), - Provider: a.largeModel.Model.Provider(), + Role: message.Assistant, + Model: a.largeModel.Model.Model(), + Provider: a.largeModel.Model.Provider(), + IsSummaryMessage: true, }) if err != nil { return err } - resp, err := agent.Stream(ctx, ai.AgentStreamCall{ + resp, err := agent.Stream(genCtx, ai.AgentStreamCall{ Prompt: "Provide a detailed summary of our conversation above.", Messages: aiMsgs, OnReasoningDelta: func(id string, text string) error { summaryMessage.AppendReasoningContent(text) - return a.messages.Update(ctx, summaryMessage) + return a.messages.Update(genCtx, summaryMessage) }, OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error { // handle anthropic signature @@ -460,14 +487,20 @@ func (a *sessionAgent) Summarize(ctx context.Context, sessionID string) error { } } summaryMessage.FinishThinking() - return a.messages.Update(ctx, summaryMessage) + return a.messages.Update(genCtx, summaryMessage) }, OnTextDelta: func(id, text string) error { summaryMessage.AppendContent(text) - return a.messages.Update(ctx, summaryMessage) + return a.messages.Update(genCtx, summaryMessage) }, }) if err != nil { + isCancelErr := errors.Is(err, context.Canceled) + if isCancelErr { + // User cancelled summarize we need to remove the summary message + deleteErr := a.messages.Delete(ctx, summaryMessage.ID) + return deleteErr + } return err } diff --git a/internal/agent/agent_tool.go b/internal/agent/agent_tool.go index 23defbffb1606b3fa0c1fe3020ca84f6216b8183..8601170f53ea623e4ce82118c9558aff4e5bf53c 100644 --- a/internal/agent/agent_tool.go +++ b/internal/agent/agent_tool.go @@ -82,7 +82,7 @@ func (c *coordinator) agentTool() (ai.AgentTool, error) { PresencePenalty: model.ModelCfg.PresencePenalty, }) if err != nil { - return ai.ToolResponse{}, fmt.Errorf("error generating agent: %s", err) + return ai.NewTextErrorResponse("error generating response"), nil } updatedSession, err := c.sessions.Get(ctx, session.ID) if err != nil { diff --git a/internal/agent/common_test.go b/internal/agent/common_test.go index c81939a17ee825f82afdf8a6e99f1416c9facb48..d5007129a14860141fc322780afec74800fcb27c 100644 --- a/internal/agent/common_test.go +++ b/internal/agent/common_test.go @@ -127,7 +127,7 @@ func testSessionAgent(env env, large, small ai.LanguageModel, systemPrompt strin // todo: add values }, } - agent := NewSessionAgent(largeModel, smallModel, systemPrompt, env.sessions, env.messages, tools...) + agent := NewSessionAgent(SessionAgentOptions{largeModel, smallModel, systemPrompt, false, env.sessions, env.messages, tools}) return agent } diff --git a/internal/agent/coordinator.go b/internal/agent/coordinator.go index 78b153bc341f293beba8cd33d83470f0a8322d9a..c9b886f0746a3cb6c25561bd31ac4e6923e6fc05 100644 --- a/internal/agent/coordinator.go +++ b/internal/agent/coordinator.go @@ -165,7 +165,7 @@ func (c *coordinator) buildAgent(prompt *prompt.Prompt, agent config.Agent) (Ses if err != nil { return nil, err } - return NewSessionAgent(large, small, systemPrompt, c.sessions, c.messages, tools...), nil + return NewSessionAgent(SessionAgentOptions{large, small, systemPrompt, c.cfg.Options.DisableAutoSummarize, c.sessions, c.messages, tools}), nil } func (c *coordinator) buildTools(agent config.Agent) ([]ai.AgentTool, error) { diff --git a/internal/agent/tools/grep.go b/internal/agent/tools/grep.go index e7181f87924d6a1b23bfad9cf4101997345cabf0..e299d7d71a3046f33fb27d14aa7870d20ef9670c 100644 --- a/internal/agent/tools/grep.go +++ b/internal/agent/tools/grep.go @@ -89,7 +89,10 @@ type GrepResponseMetadata struct { Truncated bool `json:"truncated"` } -const GrepToolName = "grep" +const ( + GrepToolName = "grep" + maxGrepContentWidth = 500 +) //go:embed grep.md var grepDescription []byte @@ -135,7 +138,11 @@ func NewGrepTool(workingDir string) ai.AgentTool { fmt.Fprintf(&output, "%s:\n", match.path) } if match.lineNum > 0 { - fmt.Fprintf(&output, " Line %d: %s\n", match.lineNum, match.lineText) + lineText := match.lineText + if len(lineText) > maxGrepContentWidth { + lineText = lineText[:maxGrepContentWidth] + "..." + } + fmt.Fprintf(&output, " Line %d: %s\n", match.lineNum, lineText) } else { fmt.Fprintf(&output, " %s\n", match.path) } diff --git a/internal/db/db.go b/internal/db/db.go index 62ebe0134c683f2a3f69d26ea3f826c9bbf02d14..6f57f2c2c6c7c2854e93fa6246cad6dbfcfa569c 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.29.0 +// sqlc v1.30.0 package db diff --git a/internal/db/files.sql.go b/internal/db/files.sql.go index a52516d20edb189e476ad41bbc7486b2ea8cc18b..ec8dfefc734f35d76ce488b19678872038721d71 100644 --- a/internal/db/files.sql.go +++ b/internal/db/files.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.29.0 +// sqlc v1.30.0 // source: files.sql package db diff --git a/internal/db/messages.sql.go b/internal/db/messages.sql.go index 81f322921db87dde7ade48ce64322aa01004d255..f10b9d5e2c47ec90aec9dc0f206d4a157fa7f6b0 100644 --- a/internal/db/messages.sql.go +++ b/internal/db/messages.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.29.0 +// sqlc v1.30.0 // source: messages.sql package db @@ -18,21 +18,23 @@ INSERT INTO messages ( parts, model, provider, + is_summary_message, created_at, updated_at ) VALUES ( - ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') + ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') ) -RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at, provider +RETURNING id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message ` type CreateMessageParams struct { - ID string `json:"id"` - SessionID string `json:"session_id"` - Role string `json:"role"` - Parts string `json:"parts"` - Model sql.NullString `json:"model"` - Provider sql.NullString `json:"provider"` + ID string `json:"id"` + SessionID string `json:"session_id"` + Role string `json:"role"` + Parts string `json:"parts"` + Model sql.NullString `json:"model"` + Provider sql.NullString `json:"provider"` + IsSummaryMessage int64 `json:"is_summary_message"` } func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (Message, error) { @@ -43,6 +45,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M arg.Parts, arg.Model, arg.Provider, + arg.IsSummaryMessage, ) var i Message err := row.Scan( @@ -55,6 +58,7 @@ func (q *Queries) CreateMessage(ctx context.Context, arg CreateMessageParams) (M &i.UpdatedAt, &i.FinishedAt, &i.Provider, + &i.IsSummaryMessage, ) return i, err } @@ -80,7 +84,7 @@ func (q *Queries) DeleteSessionMessages(ctx context.Context, sessionID string) e } const getMessage = `-- name: GetMessage :one -SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider +SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message FROM messages WHERE id = ? LIMIT 1 ` @@ -98,12 +102,13 @@ func (q *Queries) GetMessage(ctx context.Context, id string) (Message, error) { &i.UpdatedAt, &i.FinishedAt, &i.Provider, + &i.IsSummaryMessage, ) return i, err } const listMessagesBySession = `-- name: ListMessagesBySession :many -SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider +SELECT id, session_id, role, parts, model, created_at, updated_at, finished_at, provider, is_summary_message FROM messages WHERE session_id = ? ORDER BY created_at ASC @@ -128,6 +133,7 @@ func (q *Queries) ListMessagesBySession(ctx context.Context, sessionID string) ( &i.UpdatedAt, &i.FinishedAt, &i.Provider, + &i.IsSummaryMessage, ); err != nil { return nil, err } diff --git a/internal/db/migrations/20250810000000_add_is_summary_message.sql b/internal/db/migrations/20250810000000_add_is_summary_message.sql new file mode 100644 index 0000000000000000000000000000000000000000..0c400b5e574ed73f60562ee27ea4b4e09c9d8699 --- /dev/null +++ b/internal/db/migrations/20250810000000_add_is_summary_message.sql @@ -0,0 +1,5 @@ +-- +goose Up +ALTER TABLE messages ADD COLUMN is_summary_message INTEGER DEFAULT 0 NOT NULL; + +-- +goose Down +ALTER TABLE messages DROP COLUMN is_summary_message; diff --git a/internal/db/models.go b/internal/db/models.go index ec3e6e10ad990d0f1a3d03a7533c8b1aed184447..ddced85da6628097d981b219ef8c768f50474c85 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.29.0 +// sqlc v1.30.0 package db @@ -19,15 +19,16 @@ type File struct { } type Message struct { - ID string `json:"id"` - SessionID string `json:"session_id"` - Role string `json:"role"` - Parts string `json:"parts"` - Model sql.NullString `json:"model"` - CreatedAt int64 `json:"created_at"` - UpdatedAt int64 `json:"updated_at"` - FinishedAt sql.NullInt64 `json:"finished_at"` - Provider sql.NullString `json:"provider"` + ID string `json:"id"` + SessionID string `json:"session_id"` + Role string `json:"role"` + Parts string `json:"parts"` + Model sql.NullString `json:"model"` + CreatedAt int64 `json:"created_at"` + UpdatedAt int64 `json:"updated_at"` + FinishedAt sql.NullInt64 `json:"finished_at"` + Provider sql.NullString `json:"provider"` + IsSummaryMessage int64 `json:"is_summary_message"` } type Session struct { diff --git a/internal/db/querier.go b/internal/db/querier.go index 472137273387d85a83a27260037321adccc9230f..0978eb2c6e4c7b1aa80888530bb5169a1d2bcec3 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.29.0 +// sqlc v1.30.0 package db diff --git a/internal/db/sessions.sql.go b/internal/db/sessions.sql.go index 76ef6480b8e435cff66f29f7a1912aa5db5b9e9d..99d31fa26d771255c3cc0ae35097e322471ab394 100644 --- a/internal/db/sessions.sql.go +++ b/internal/db/sessions.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.29.0 +// sqlc v1.30.0 // source: sessions.sql package db diff --git a/internal/db/sql/messages.sql b/internal/db/sql/messages.sql index ea946177591d1e145a59475a1ca9272f3191d4d6..fc66b78c08b85c8fe1f7ec79985fb2edd4a03668 100644 --- a/internal/db/sql/messages.sql +++ b/internal/db/sql/messages.sql @@ -17,10 +17,11 @@ INSERT INTO messages ( parts, model, provider, + is_summary_message, created_at, updated_at ) VALUES ( - ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') + ?, ?, ?, ?, ?, ?, ?, strftime('%s', 'now'), strftime('%s', 'now') ) RETURNING *; diff --git a/internal/message/content.go b/internal/message/content.go index 55e036ba12948bfd025457fe471c6743073f0df9..e397946ab6a04439a55639d2934b6da761796e0d 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -120,14 +120,15 @@ type Finish struct { func (Finish) isPart() {} type Message struct { - ID string - Role MessageRole - SessionID string - Parts []ContentPart - Model string - Provider string - CreatedAt int64 - UpdatedAt int64 + ID string + Role MessageRole + SessionID string + Parts []ContentPart + Model string + Provider string + CreatedAt int64 + UpdatedAt int64 + IsSummaryMessage bool } func (m *Message) Content() TextContent { diff --git a/internal/message/message.go b/internal/message/message.go index 7cd823bc3129df5f807ec478d9d6c02364c6cfec..663a8a3ea3599c49ea1e82f343c564a62efebd84 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -13,10 +13,11 @@ import ( ) type CreateMessageParams struct { - Role MessageRole - Parts []ContentPart - Model string - Provider string + Role MessageRole + Parts []ContentPart + Model string + Provider string + IsSummaryMessage bool } type Service interface { @@ -64,13 +65,18 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes if err != nil { return Message{}, err } + isSummary := int64(0) + if params.IsSummaryMessage { + isSummary = 1 + } dbMessage, err := s.q.CreateMessage(ctx, db.CreateMessageParams{ - ID: uuid.New().String(), - SessionID: sessionID, - Role: string(params.Role), - Parts: string(partsJSON), - Model: sql.NullString{String: string(params.Model), Valid: true}, - Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""}, + ID: uuid.New().String(), + SessionID: sessionID, + Role: string(params.Role), + Parts: string(partsJSON), + Model: sql.NullString{String: string(params.Model), Valid: true}, + Provider: sql.NullString{String: params.Provider, Valid: params.Provider != ""}, + IsSummaryMessage: isSummary, }) if err != nil { return Message{}, err @@ -151,14 +157,15 @@ func (s *service) fromDBItem(item db.Message) (Message, error) { return Message{}, err } return Message{ - ID: item.ID, - SessionID: item.SessionID, - Role: MessageRole(item.Role), - Parts: parts, - Model: item.Model.String, - Provider: item.Provider.String, - CreatedAt: item.CreatedAt, - UpdatedAt: item.UpdatedAt, + ID: item.ID, + SessionID: item.SessionID, + Role: MessageRole(item.Role), + Parts: parts, + Model: item.Model.String, + Provider: item.Provider.String, + CreatedAt: item.CreatedAt, + UpdatedAt: item.UpdatedAt, + IsSummaryMessage: item.IsSummaryMessage != 0, }, nil } diff --git a/internal/tui/components/chat/chat.go b/internal/tui/components/chat/chat.go index 1643af9549a6ff41c0940dd067033d4e776a9eae..a9c75950d94b469cae62495a7251e145d8974fa0 100644 --- a/internal/tui/components/chat/chat.go +++ b/internal/tui/components/chat/chat.go @@ -334,6 +334,11 @@ func (m *messageListCmp) handleMessageEvent(event pubsub.Event[message.Message]) return nil } return m.handleNewMessage(event.Payload) + case pubsub.DeletedEvent: + if event.Payload.SessionID != m.session.ID { + return nil + } + return m.handleDeleteMessage(event.Payload) case pubsub.UpdatedEvent: if event.Payload.SessionID != m.session.ID { return m.handleChildSession(event) @@ -360,6 +365,18 @@ func (m *messageListCmp) messageExists(messageID string) bool { return false } +// handleDeleteMessage removes a message from the list. +func (m *messageListCmp) handleDeleteMessage(msg message.Message) tea.Cmd { + items := m.listCmp.Items() + for i := len(items) - 1; i >= 0; i-- { + if msgCmp, ok := items[i].(messages.MessageCmp); ok && msgCmp.GetMessage().ID == msg.ID { + m.listCmp.DeleteItem(items[i].ID()) + return nil + } + } + return nil +} + // handleNewMessage routes new messages to appropriate handlers based on role. func (m *messageListCmp) handleNewMessage(msg message.Message) tea.Cmd { switch msg.Role { diff --git a/internal/tui/components/chat/messages/messages.go b/internal/tui/components/chat/messages/messages.go index 296b02478a7d0738fef2f60ae6b2211d44424a2f..fb2c5e1087514b1be605f8742cf6560ea3919a18 100644 --- a/internal/tui/components/chat/messages/messages.go +++ b/internal/tui/components/chat/messages/messages.go @@ -121,6 +121,9 @@ func (m *messageCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { // Returns different views for spinning, user, and assistant messages. func (m *messageCmp) View() string { if m.spinning && m.message.ReasoningContent().Thinking == "" { + if m.message.IsSummaryMessage { + m.anim.SetLabel("Summarizing") + } return m.style().PaddingLeft(1).Render(m.anim.View()) } if m.message.ID != "" { diff --git a/internal/tui/components/dialogs/compact/compact.go b/internal/tui/components/dialogs/compact/compact.go deleted file mode 100644 index 6321bb8e53e183feb4280fad4bd451e1ae37d8ba..0000000000000000000000000000000000000000 --- a/internal/tui/components/dialogs/compact/compact.go +++ /dev/null @@ -1,168 +0,0 @@ -package compact - -import ( - "context" - - "github.com/charmbracelet/bubbles/v2/key" - tea "github.com/charmbracelet/bubbletea/v2" - "github.com/charmbracelet/lipgloss/v2" - - "github.com/charmbracelet/crush/internal/agent" - "github.com/charmbracelet/crush/internal/tui/components/core" - "github.com/charmbracelet/crush/internal/tui/components/dialogs" - "github.com/charmbracelet/crush/internal/tui/styles" - "github.com/charmbracelet/crush/internal/tui/util" -) - -const CompactDialogID dialogs.DialogID = "compact" - -// CompactDialog interface for the session compact dialog -type CompactDialog interface { - dialogs.DialogModel -} - -type compactDialogCmp struct { - wWidth, wHeight int - width, height int - selected int - keyMap KeyMap - sessionID string - progress string - agent agent.Coordinator - noAsk bool // If true, skip confirmation dialog -} - -// NewCompactDialogCmp creates a new session compact dialog -func NewCompactDialogCmp(agent agent.Coordinator, sessionID string, noAsk bool) CompactDialog { - return &compactDialogCmp{ - sessionID: sessionID, - keyMap: DefaultKeyMap(), - selected: 0, - agent: agent, - noAsk: noAsk, - } -} - -func (c *compactDialogCmp) Init() tea.Cmd { - if c.noAsk { - // If noAsk is true, skip confirmation and start compaction immediately - c.agent.Summarize(context.Background(), c.sessionID) - return util.CmdHandler(dialogs.CloseDialogMsg{}) - } - return nil -} - -func (c *compactDialogCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { - switch msg := msg.(type) { - case tea.WindowSizeMsg: - c.wWidth = msg.Width - c.wHeight = msg.Height - cmd := c.SetSize() - return c, cmd - case tea.KeyPressMsg: - switch { - case key.Matches(msg, c.keyMap.ChangeSelection): - c.selected = (c.selected + 1) % 2 - return c, nil - case key.Matches(msg, c.keyMap.Select): - if c.selected == 0 { - c.agent.Summarize(context.Background(), c.sessionID) - return c, util.CmdHandler(dialogs.CloseDialogMsg{}) - } else { - return c, util.CmdHandler(dialogs.CloseDialogMsg{}) - } - case key.Matches(msg, c.keyMap.Y): - c.agent.Summarize(context.Background(), c.sessionID) - return c, util.CmdHandler(dialogs.CloseDialogMsg{}) - case key.Matches(msg, c.keyMap.N): - return c, util.CmdHandler(dialogs.CloseDialogMsg{}) - case key.Matches(msg, c.keyMap.Close): - return c, util.CmdHandler(dialogs.CloseDialogMsg{}) - } - } - return c, nil -} - -func (c *compactDialogCmp) renderButtons() string { - t := styles.CurrentTheme() - baseStyle := t.S().Base - - buttons := []core.ButtonOpts{ - { - Text: "Yes", - UnderlineIndex: 0, // "Y" - Selected: c.selected == 0, - }, - { - Text: "No", - UnderlineIndex: 0, // "N" - Selected: c.selected == 1, - }, - } - - content := core.SelectableButtons(buttons, " ") - - return baseStyle.AlignHorizontal(lipgloss.Right).Width(c.width - 4).Render(content) -} - -func (c *compactDialogCmp) render() string { - t := styles.CurrentTheme() - baseStyle := t.S().Base - - title := "Compact Session" - titleView := core.Title(title, c.width-4) - explanation := t.S().Text. - Width(c.width - 4). - Render("This will summarize the current session and reset the context. The conversation history will be condensed into a summary to free up context space while preserving important information.") - - question := t.S().Text. - Width(c.width - 4). - Render("Do you want to continue?") - - content := baseStyle.Render(lipgloss.JoinVertical( - lipgloss.Left, - explanation, - "", - question, - )) - - buttons := c.renderButtons() - dialogContent := lipgloss.JoinVertical( - lipgloss.Top, - titleView, - "", - content, - "", - buttons, - "", - ) - - return baseStyle. - Padding(0, 1). - Border(lipgloss.RoundedBorder()). - BorderForeground(t.BorderFocus). - Width(c.width). - Render(dialogContent) -} - -func (c *compactDialogCmp) View() string { - return c.render() -} - -// SetSize sets the size of the component. -func (c *compactDialogCmp) SetSize() tea.Cmd { - c.width = min(90, c.wWidth) - c.height = min(15, c.wHeight) - return nil -} - -func (c *compactDialogCmp) Position() (int, int) { - row := (c.wHeight / 2) - (c.height / 2) - col := (c.wWidth / 2) - (c.width / 2) - return row, col -} - -// ID implements CompactDialog. -func (c *compactDialogCmp) ID() dialogs.DialogID { - return CompactDialogID -} diff --git a/internal/tui/components/dialogs/compact/keys.go b/internal/tui/components/dialogs/compact/keys.go deleted file mode 100644 index cec1486491e342c28f148a50d37f1129944c002e..0000000000000000000000000000000000000000 --- a/internal/tui/components/dialogs/compact/keys.go +++ /dev/null @@ -1,71 +0,0 @@ -package compact - -import ( - "github.com/charmbracelet/bubbles/v2/key" -) - -// KeyMap defines the key bindings for the compact dialog. -type KeyMap struct { - ChangeSelection key.Binding - Select key.Binding - Y key.Binding - N key.Binding - Close key.Binding -} - -// DefaultKeyMap returns the default key bindings for the compact dialog. -func DefaultKeyMap() KeyMap { - return KeyMap{ - ChangeSelection: key.NewBinding( - key.WithKeys("tab", "left", "right", "h", "l"), - key.WithHelp("tab/←/→", "toggle selection"), - ), - Select: key.NewBinding( - key.WithKeys("enter"), - key.WithHelp("enter", "confirm"), - ), - Y: key.NewBinding( - key.WithKeys("y"), - key.WithHelp("y", "yes"), - ), - N: key.NewBinding( - key.WithKeys("n"), - key.WithHelp("n", "no"), - ), - Close: key.NewBinding( - key.WithKeys("esc", "alt+esc"), - key.WithHelp("esc", "cancel"), - ), - } -} - -// KeyBindings implements layout.KeyMapProvider -func (k KeyMap) KeyBindings() []key.Binding { - return []key.Binding{ - k.ChangeSelection, - k.Select, - k.Y, - k.N, - k.Close, - } -} - -// FullHelp implements help.KeyMap. -func (k KeyMap) FullHelp() [][]key.Binding { - m := [][]key.Binding{} - slice := k.KeyBindings() - for i := 0; i < len(slice); i += 4 { - end := min(i+4, len(slice)) - m = append(m, slice[i:end]) - } - return m -} - -// ShortHelp implements help.KeyMap. -func (k KeyMap) ShortHelp() []key.Binding { - return []key.Binding{ - k.ChangeSelection, - k.Select, - k.Close, - } -} diff --git a/internal/tui/tui.go b/internal/tui/tui.go index dbe3467adeab2619e884b8f0078471538d1af902..7a24d21aacdac7b473fc4e9c454a8a95ed800ea2 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -21,7 +21,6 @@ import ( "github.com/charmbracelet/crush/internal/tui/components/core/status" "github.com/charmbracelet/crush/internal/tui/components/dialogs" "github.com/charmbracelet/crush/internal/tui/components/dialogs/commands" - "github.com/charmbracelet/crush/internal/tui/components/dialogs/compact" "github.com/charmbracelet/crush/internal/tui/components/dialogs/filepicker" "github.com/charmbracelet/crush/internal/tui/components/dialogs/models" "github.com/charmbracelet/crush/internal/tui/components/dialogs/permissions" @@ -178,9 +177,13 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { ) // Compact case commands.CompactMsg: - return a, util.CmdHandler(dialogs.OpenDialogMsg{ - Model: compact.NewCompactDialogCmp(a.app.AgentCoordinator, msg.SessionID, true), - }) + return a, func() tea.Msg { + err := a.app.AgentCoordinator.Summarize(context.Background(), msg.SessionID) + if err != nil { + return util.ReportError(err)() + } + return nil + } case commands.QuitMsg: return a, util.CmdHandler(dialogs.OpenDialogMsg{ Model: quit.NewQuitDialog(), @@ -251,38 +254,6 @@ func (a *appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { a.app.Permissions.Deny(msg.Permission) } return a, nil - // Agent Events - // TODO: HANDLE AUTO COMPACT - // case pubsub.Event[agent.AgentEvent]: - // payload := msg.Payload - // - // // Forward agent events to dialogs - // if a.dialog.HasDialogs() && a.dialog.ActiveDialogID() == compact.CompactDialogID { - // u, dialogCmd := a.dialog.Update(payload) - // if model, ok := u.(dialogs.DialogCmp); ok { - // a.dialog = model - // } - // - // cmds = append(cmds, dialogCmd) - // } - // - // // Handle auto-compact logic - // if payload.Done && payload.Type == agent.AgentEventTypeResponse && a.selectedSessionID != "" { - // // Get current session to check token usage - // session, err := a.app.Sessions.Get(context.Background(), a.selectedSessionID) - // if err == nil { - // model := a.app.AgentCoordinator.Model() - // contextWindow := model.CatwalkCfg.ContextWindow - // tokens := session.CompletionTokens + session.PromptTokens - // if (tokens >= int64(float64(contextWindow)*0.95)) && !config.Get().Options.DisableAutoSummarize { // Show compact confirmation dialog - // cmds = append(cmds, util.CmdHandler(dialogs.OpenDialogMsg{ - // Model: compact.NewCompactDialogCmp(a.app.AgentCoordinator, a.selectedSessionID, false), - // })) - // } - // } - // } - // - // return a, tea.Batch(cmds...) case splash.OnboardingCompleteMsg: item, ok := a.pages[a.currentPage] if !ok {