From 3e424754b48862fdd941f5d6434abda989caaa21 Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Thu, 15 May 2025 15:59:18 +0200 Subject: [PATCH] Improve summary to keep context (#159) * improve summary to keep context * improve loop * remove debug msg --- internal/db/db.go | 2 +- internal/db/files.sql.go | 2 +- internal/db/messages.sql.go | 2 +- .../20250515105448_add_summary_message_id.sql | 9 +++ internal/db/models.go | 3 +- internal/db/querier.go | 2 +- internal/db/sessions.sql.go | 29 +++++---- internal/db/sql/sessions.sql | 3 + internal/llm/agent/agent.go | 60 ++++++++++++++----- internal/session/session.go | 8 ++- internal/tui/components/chat/list.go | 11 ++++ internal/tui/components/chat/message.go | 4 ++ internal/tui/tui.go | 24 -------- 13 files changed, 105 insertions(+), 54 deletions(-) create mode 100644 internal/db/migrations/20250515105448_add_summary_message_id.sql diff --git a/internal/db/db.go b/internal/db/db.go index 16e66380405615910bd1ebf1449cb40e0fca5756..5badad3a280eb9e11ae0b6a9d068f8f9efb937b6 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.27.0 +// sqlc v1.29.0 package db diff --git a/internal/db/files.sql.go b/internal/db/files.sql.go index 39def271f104addc2fa0057de503c17c2cdfecf7..28abaa55d736b6eeefb721b69f4bcc7fceb4af37 100644 --- a/internal/db/files.sql.go +++ b/internal/db/files.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.27.0 +// sqlc v1.29.0 // source: files.sql package db diff --git a/internal/db/messages.sql.go b/internal/db/messages.sql.go index 0555b4330d79089c0d5a7127c311f55af567e604..2acfe18fdbc63312c49d65e9e3acb1bd24cf4d7e 100644 --- a/internal/db/messages.sql.go +++ b/internal/db/messages.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.27.0 +// sqlc v1.29.0 // source: messages.sql package db diff --git a/internal/db/migrations/20250515105448_add_summary_message_id.sql b/internal/db/migrations/20250515105448_add_summary_message_id.sql new file mode 100644 index 0000000000000000000000000000000000000000..138a0af21a2c4dec72d47eef40d3c9491b4e5314 --- /dev/null +++ b/internal/db/migrations/20250515105448_add_summary_message_id.sql @@ -0,0 +1,9 @@ +-- +goose Up +-- +goose StatementBegin +ALTER TABLE sessions ADD COLUMN summary_message_id TEXT; +-- +goose StatementEnd + +-- +goose Down +-- +goose StatementBegin +ALTER TABLE sessions DROP COLUMN summary_message_id; +-- +goose StatementEnd diff --git a/internal/db/models.go b/internal/db/models.go index f00cb6ad17ec5f9426502bb7612191dd6065f255..07549024a230dc357a7f57d69c42440336065a9a 100644 --- a/internal/db/models.go +++ b/internal/db/models.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.27.0 +// sqlc v1.29.0 package db @@ -39,4 +39,5 @@ type Session struct { Cost float64 `json:"cost"` UpdatedAt int64 `json:"updated_at"` CreatedAt int64 `json:"created_at"` + SummaryMessageID sql.NullString `json:"summary_message_id"` } diff --git a/internal/db/querier.go b/internal/db/querier.go index 704a97da26c7feaf022ff3d8fa228b918ab298b6..257012526e54fd08065df410e207bee2a126b9c0 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.27.0 +// sqlc v1.29.0 package db diff --git a/internal/db/sessions.sql.go b/internal/db/sessions.sql.go index 18d70c3dbdba50a5684f35126742a9095dff221f..76ef6480b8e435cff66f29f7a1912aa5db5b9e9d 100644 --- a/internal/db/sessions.sql.go +++ b/internal/db/sessions.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.27.0 +// sqlc v1.29.0 // source: sessions.sql package db @@ -19,6 +19,7 @@ INSERT INTO sessions ( prompt_tokens, completion_tokens, cost, + summary_message_id, updated_at, created_at ) VALUES ( @@ -29,9 +30,10 @@ INSERT INTO sessions ( ?, ?, ?, + null, strftime('%s', 'now'), strftime('%s', 'now') -) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at +) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id ` type CreateSessionParams struct { @@ -65,6 +67,7 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S &i.Cost, &i.UpdatedAt, &i.CreatedAt, + &i.SummaryMessageID, ) return i, err } @@ -80,7 +83,7 @@ func (q *Queries) DeleteSession(ctx context.Context, id string) error { } const getSessionByID = `-- name: GetSessionByID :one -SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at +SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id FROM sessions WHERE id = ? LIMIT 1 ` @@ -98,12 +101,13 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error &i.Cost, &i.UpdatedAt, &i.CreatedAt, + &i.SummaryMessageID, ) return i, err } const listSessions = `-- name: ListSessions :many -SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at +SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id FROM sessions WHERE parent_session_id is NULL ORDER BY created_at DESC @@ -128,6 +132,7 @@ func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) { &i.Cost, &i.UpdatedAt, &i.CreatedAt, + &i.SummaryMessageID, ); err != nil { return nil, err } @@ -148,17 +153,19 @@ SET title = ?, prompt_tokens = ?, completion_tokens = ?, + summary_message_id = ?, cost = ? WHERE id = ? -RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at +RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id ` type UpdateSessionParams struct { - Title string `json:"title"` - PromptTokens int64 `json:"prompt_tokens"` - CompletionTokens int64 `json:"completion_tokens"` - Cost float64 `json:"cost"` - ID string `json:"id"` + Title string `json:"title"` + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + SummaryMessageID sql.NullString `json:"summary_message_id"` + Cost float64 `json:"cost"` + ID string `json:"id"` } func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) { @@ -166,6 +173,7 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S arg.Title, arg.PromptTokens, arg.CompletionTokens, + arg.SummaryMessageID, arg.Cost, arg.ID, ) @@ -180,6 +188,7 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S &i.Cost, &i.UpdatedAt, &i.CreatedAt, + &i.SummaryMessageID, ) return i, err } diff --git a/internal/db/sql/sessions.sql b/internal/db/sql/sessions.sql index f065b5f5614563c1479a0636da25c583a29b2310..ebeab90d39f641c0aee72152c1f60ef455d5dff4 100644 --- a/internal/db/sql/sessions.sql +++ b/internal/db/sql/sessions.sql @@ -7,6 +7,7 @@ INSERT INTO sessions ( prompt_tokens, completion_tokens, cost, + summary_message_id, updated_at, created_at ) VALUES ( @@ -17,6 +18,7 @@ INSERT INTO sessions ( ?, ?, ?, + null, strftime('%s', 'now'), strftime('%s', 'now') ) RETURNING *; @@ -38,6 +40,7 @@ SET title = ?, prompt_tokens = ?, completion_tokens = ?, + summary_message_id = ?, cost = ? WHERE id = ? RETURNING *; diff --git a/internal/llm/agent/agent.go b/internal/llm/agent/agent.go index 03b2d59dd4359e91aa8b0fba2e99c02340f0b9f5..0ac7f65ff37f2cbab3fd45e5ea963542f431e75c 100644 --- a/internal/llm/agent/agent.go +++ b/internal/llm/agent/agent.go @@ -6,6 +6,7 @@ import ( "fmt" "strings" "sync" + "time" "github.com/opencode-ai/opencode/internal/config" "github.com/opencode-ai/opencode/internal/llm/models" @@ -245,6 +246,23 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string } }() } + session, err := a.sessions.Get(ctx, sessionID) + if err != nil { + return a.err(fmt.Errorf("failed to get session: %w", err)) + } + if session.SummaryMessageID != "" { + summaryMsgInex := -1 + for i, msg := range msgs { + if msg.ID == session.SummaryMessageID { + summaryMsgInex = i + break + } + } + if summaryMsgInex != -1 { + msgs = msgs[summaryMsgInex:] + msgs[0].Role = message.User + } + } userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts) if err != nil { @@ -614,37 +632,51 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error { a.Publish(pubsub.CreatedEvent, event) return } - // Create a new session with the summary - newSession, err := a.sessions.Create(summarizeCtx, oldSession.Title+" - Continuation") + // Create a message in the new session with the summary + msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{ + Role: message.Assistant, + Parts: []message.ContentPart{ + message.TextContent{Text: summary}, + message.Finish{ + Reason: message.FinishReasonEndTurn, + Time: time.Now().Unix(), + }, + }, + Model: a.summarizeProvider.Model().ID, + }) if err != nil { event = AgentEvent{ Type: AgentEventTypeError, - Error: fmt.Errorf("failed to create new session: %w", err), + Error: fmt.Errorf("failed to create summary message: %w", err), Done: true, } + a.Publish(pubsub.CreatedEvent, event) return } - - // Create a message in the new session with the summary - _, err = a.messages.Create(summarizeCtx, newSession.ID, message.CreateMessageParams{ - Role: message.Assistant, - Parts: []message.ContentPart{message.TextContent{Text: summary}}, - Model: a.summarizeProvider.Model().ID, - }) + oldSession.SummaryMessageID = msg.ID + oldSession.CompletionTokens = response.Usage.OutputTokens + oldSession.PromptTokens = 0 + model := a.summarizeProvider.Model() + usage := response.Usage + cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) + + model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) + + model.CostPer1MIn/1e6*float64(usage.InputTokens) + + model.CostPer1MOut/1e6*float64(usage.OutputTokens) + oldSession.Cost += cost + _, err = a.sessions.Save(summarizeCtx, oldSession) if err != nil { event = AgentEvent{ Type: AgentEventTypeError, - Error: fmt.Errorf("failed to create summary message: %w", err), + Error: fmt.Errorf("failed to save session: %w", err), Done: true, } - a.Publish(pubsub.CreatedEvent, event) - return } + event = AgentEvent{ Type: AgentEventTypeSummarize, - SessionID: newSession.ID, + SessionID: oldSession.ID, Progress: "Summary complete", Done: true, } diff --git a/internal/session/session.go b/internal/session/session.go index 682ea7768d6494afbd06367335a54ccedab3efbf..c6e7f60bfbfe52e54071183b0cc9f399363904d6 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -16,6 +16,7 @@ type Session struct { MessageCount int64 PromptTokens int64 CompletionTokens int64 + SummaryMessageID string Cost float64 CreatedAt int64 UpdatedAt int64 @@ -105,7 +106,11 @@ func (s *service) Save(ctx context.Context, session Session) (Session, error) { Title: session.Title, PromptTokens: session.PromptTokens, CompletionTokens: session.CompletionTokens, - Cost: session.Cost, + SummaryMessageID: sql.NullString{ + String: session.SummaryMessageID, + Valid: session.SummaryMessageID != "", + }, + Cost: session.Cost, }) if err != nil { return Session{}, err @@ -135,6 +140,7 @@ func (s service) fromDBItem(item db.Session) Session { MessageCount: item.MessageCount, PromptTokens: item.PromptTokens, CompletionTokens: item.CompletionTokens, + SummaryMessageID: item.SummaryMessageID.String, Cost: item.Cost, CreatedAt: item.CreatedAt, UpdatedAt: item.UpdatedAt, diff --git a/internal/tui/components/chat/list.go b/internal/tui/components/chat/list.go index df642907352b561c25d31a2eaa10dec44a516b2d..40d5b962876f09f60f44092f0c21b1f1ec9e4bb1 100644 --- a/internal/tui/components/chat/list.go +++ b/internal/tui/components/chat/list.go @@ -99,6 +99,14 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case renderFinishedMsg: m.rendering = false m.viewport.GotoBottom() + case pubsub.Event[session.Session]: + if msg.Type == pubsub.UpdatedEvent && msg.Payload.ID == m.session.ID { + m.session = msg.Payload + if m.session.SummaryMessageID == m.currentMsgID { + delete(m.cachedContent, m.currentMsgID) + m.renderView() + } + } case pubsub.Event[message.Message]: needsRerender := false if msg.Type == pubsub.CreatedEvent { @@ -208,12 +216,15 @@ func (m *messagesCmp) renderView() { m.uiMessages = append(m.uiMessages, cache.content...) continue } + isSummary := m.session.SummaryMessageID == msg.ID + assistantMessages := renderAssistantMessage( msg, inx, m.messages, m.app.Messages, m.currentMsgID, + isSummary, m.width, pos, ) diff --git a/internal/tui/components/chat/message.go b/internal/tui/components/chat/message.go index 4acbbef9ee597bbd30bad2d783ff3c37769a5f1e..0732366d94c01dc8d183f97dcebbe8a220554f74 100644 --- a/internal/tui/components/chat/message.go +++ b/internal/tui/components/chat/message.go @@ -120,6 +120,7 @@ func renderAssistantMessage( allMessages []message.Message, // we need this to get tool results and the user message messagesService message.Service, // We need this to get the task tool messages focusedUIMessageId string, + isSummary bool, width int, position int, ) []uiMessage { @@ -168,6 +169,9 @@ func renderAssistantMessage( if content == "" { content = "*Finished without output*" } + if isSummary { + info = append(info, baseStyle.Width(width-1).Foreground(t.TextMuted()).Render(" (summary)")) + } content = renderMessage(content, false, true, width, info...) messages = append(messages, uiMessage{ diff --git a/internal/tui/tui.go b/internal/tui/tui.go index 700dc04e83d81a92a409cf89d53f839f73b83d22..060b8c79c8572a0508ebcd95a148ff0743bc7009 100644 --- a/internal/tui/tui.go +++ b/internal/tui/tui.go @@ -331,30 +331,6 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) { if payload.Done && payload.Type == agent.AgentEventTypeSummarize { a.isCompacting = false - - if payload.SessionID != "" { - // Switch to the new session - return a, func() tea.Msg { - sessions, err := a.app.Sessions.List(context.Background()) - if err != nil { - return util.InfoMsg{ - Type: util.InfoTypeError, - Msg: "Failed to list sessions: " + err.Error(), - } - } - - for _, s := range sessions { - if s.ID == payload.SessionID { - return dialog.SessionSelectedMsg{Session: s} - } - } - - return util.InfoMsg{ - Type: util.InfoTypeError, - Msg: "Failed to find new session", - } - } - } return a, util.ReportInfo("Session summarization complete") } else if payload.Done && payload.Type == agent.AgentEventTypeResponse && a.selectedSession.ID != "" { model := a.app.CoderAgent.Model()