Improve summary to keep context (#159)

Kujtim Hoxha created

* improve summary to keep context

* improve loop

* remove debug msg

Change summary

internal/db/db.go                                                |  2 
internal/db/files.sql.go                                         |  2 
internal/db/messages.sql.go                                      |  2 
internal/db/migrations/20250515105448_add_summary_message_id.sql |  9 
internal/db/models.go                                            |  3 
internal/db/querier.go                                           |  2 
internal/db/sessions.sql.go                                      | 29 
internal/db/sql/sessions.sql                                     |  3 
internal/llm/agent/agent.go                                      | 60 +
internal/session/session.go                                      |  8 
internal/tui/components/chat/list.go                             | 11 
internal/tui/components/chat/message.go                          |  4 
internal/tui/tui.go                                              | 24 
13 files changed, 105 insertions(+), 54 deletions(-)

Detailed changes

internal/db/db.go 🔗

@@ -1,6 +1,6 @@
 // Code generated by sqlc. DO NOT EDIT.
 // versions:
-//   sqlc v1.27.0
+//   sqlc v1.29.0
 
 package db
 

internal/db/files.sql.go 🔗

@@ -1,6 +1,6 @@
 // Code generated by sqlc. DO NOT EDIT.
 // versions:
-//   sqlc v1.27.0
+//   sqlc v1.29.0
 // source: files.sql
 
 package db

internal/db/messages.sql.go 🔗

@@ -1,6 +1,6 @@
 // Code generated by sqlc. DO NOT EDIT.
 // versions:
-//   sqlc v1.27.0
+//   sqlc v1.29.0
 // source: messages.sql
 
 package db

internal/db/models.go 🔗

@@ -1,6 +1,6 @@
 // Code generated by sqlc. DO NOT EDIT.
 // versions:
-//   sqlc v1.27.0
+//   sqlc v1.29.0
 
 package db
 
@@ -39,4 +39,5 @@ type Session struct {
 	Cost             float64        `json:"cost"`
 	UpdatedAt        int64          `json:"updated_at"`
 	CreatedAt        int64          `json:"created_at"`
+	SummaryMessageID sql.NullString `json:"summary_message_id"`
 }

internal/db/querier.go 🔗

@@ -1,6 +1,6 @@
 // Code generated by sqlc. DO NOT EDIT.
 // versions:
-//   sqlc v1.27.0
+//   sqlc v1.29.0
 
 package db
 

internal/db/sessions.sql.go 🔗

@@ -1,6 +1,6 @@
 // Code generated by sqlc. DO NOT EDIT.
 // versions:
-//   sqlc v1.27.0
+//   sqlc v1.29.0
 // source: sessions.sql
 
 package db
@@ -19,6 +19,7 @@ INSERT INTO sessions (
     prompt_tokens,
     completion_tokens,
     cost,
+    summary_message_id,
     updated_at,
     created_at
 ) VALUES (
@@ -29,9 +30,10 @@ INSERT INTO sessions (
     ?,
     ?,
     ?,
+    null,
     strftime('%s', 'now'),
     strftime('%s', 'now')
-) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
+) RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id
 `
 
 type CreateSessionParams struct {
@@ -65,6 +67,7 @@ func (q *Queries) CreateSession(ctx context.Context, arg CreateSessionParams) (S
 		&i.Cost,
 		&i.UpdatedAt,
 		&i.CreatedAt,
+		&i.SummaryMessageID,
 	)
 	return i, err
 }
@@ -80,7 +83,7 @@ func (q *Queries) DeleteSession(ctx context.Context, id string) error {
 }
 
 const getSessionByID = `-- name: GetSessionByID :one
-SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
+SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id
 FROM sessions
 WHERE id = ? LIMIT 1
 `
@@ -98,12 +101,13 @@ func (q *Queries) GetSessionByID(ctx context.Context, id string) (Session, error
 		&i.Cost,
 		&i.UpdatedAt,
 		&i.CreatedAt,
+		&i.SummaryMessageID,
 	)
 	return i, err
 }
 
 const listSessions = `-- name: ListSessions :many
-SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
+SELECT id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id
 FROM sessions
 WHERE parent_session_id is NULL
 ORDER BY created_at DESC
@@ -128,6 +132,7 @@ func (q *Queries) ListSessions(ctx context.Context) ([]Session, error) {
 			&i.Cost,
 			&i.UpdatedAt,
 			&i.CreatedAt,
+			&i.SummaryMessageID,
 		); err != nil {
 			return nil, err
 		}
@@ -148,17 +153,19 @@ SET
     title = ?,
     prompt_tokens = ?,
     completion_tokens = ?,
+    summary_message_id = ?,
     cost = ?
 WHERE id = ?
-RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at
+RETURNING id, parent_session_id, title, message_count, prompt_tokens, completion_tokens, cost, updated_at, created_at, summary_message_id
 `
 
 type UpdateSessionParams struct {
-	Title            string  `json:"title"`
-	PromptTokens     int64   `json:"prompt_tokens"`
-	CompletionTokens int64   `json:"completion_tokens"`
-	Cost             float64 `json:"cost"`
-	ID               string  `json:"id"`
+	Title            string         `json:"title"`
+	PromptTokens     int64          `json:"prompt_tokens"`
+	CompletionTokens int64          `json:"completion_tokens"`
+	SummaryMessageID sql.NullString `json:"summary_message_id"`
+	Cost             float64        `json:"cost"`
+	ID               string         `json:"id"`
 }
 
 func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) {
@@ -166,6 +173,7 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
 		arg.Title,
 		arg.PromptTokens,
 		arg.CompletionTokens,
+		arg.SummaryMessageID,
 		arg.Cost,
 		arg.ID,
 	)
@@ -180,6 +188,7 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
 		&i.Cost,
 		&i.UpdatedAt,
 		&i.CreatedAt,
+		&i.SummaryMessageID,
 	)
 	return i, err
 }

internal/db/sql/sessions.sql 🔗

@@ -7,6 +7,7 @@ INSERT INTO sessions (
     prompt_tokens,
     completion_tokens,
     cost,
+    summary_message_id,
     updated_at,
     created_at
 ) VALUES (
@@ -17,6 +18,7 @@ INSERT INTO sessions (
     ?,
     ?,
     ?,
+    null,
     strftime('%s', 'now'),
     strftime('%s', 'now')
 ) RETURNING *;
@@ -38,6 +40,7 @@ SET
     title = ?,
     prompt_tokens = ?,
     completion_tokens = ?,
+    summary_message_id = ?,
     cost = ?
 WHERE id = ?
 RETURNING *;

internal/llm/agent/agent.go 🔗

@@ -6,6 +6,7 @@ import (
 	"fmt"
 	"strings"
 	"sync"
+	"time"
 
 	"github.com/opencode-ai/opencode/internal/config"
 	"github.com/opencode-ai/opencode/internal/llm/models"
@@ -245,6 +246,23 @@ func (a *agent) processGeneration(ctx context.Context, sessionID, content string
 			}
 		}()
 	}
+	session, err := a.sessions.Get(ctx, sessionID)
+	if err != nil {
+		return a.err(fmt.Errorf("failed to get session: %w", err))
+	}
+	if session.SummaryMessageID != "" {
+		summaryMsgInex := -1
+		for i, msg := range msgs {
+			if msg.ID == session.SummaryMessageID {
+				summaryMsgInex = i
+				break
+			}
+		}
+		if summaryMsgInex != -1 {
+			msgs = msgs[summaryMsgInex:]
+			msgs[0].Role = message.User
+		}
+	}
 
 	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 	if err != nil {
@@ -614,37 +632,51 @@ func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 			a.Publish(pubsub.CreatedEvent, event)
 			return
 		}
-		// Create a new session with the summary
-		newSession, err := a.sessions.Create(summarizeCtx, oldSession.Title+" - Continuation")
+		// Create a message in the new session with the summary
+		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
+			Role: message.Assistant,
+			Parts: []message.ContentPart{
+				message.TextContent{Text: summary},
+				message.Finish{
+					Reason: message.FinishReasonEndTurn,
+					Time:   time.Now().Unix(),
+				},
+			},
+			Model: a.summarizeProvider.Model().ID,
+		})
 		if err != nil {
 			event = AgentEvent{
 				Type:  AgentEventTypeError,
-				Error: fmt.Errorf("failed to create new session: %w", err),
+				Error: fmt.Errorf("failed to create summary message: %w", err),
 				Done:  true,
 			}
+
 			a.Publish(pubsub.CreatedEvent, event)
 			return
 		}
-
-		// Create a message in the new session with the summary
-		_, err = a.messages.Create(summarizeCtx, newSession.ID, message.CreateMessageParams{
-			Role:  message.Assistant,
-			Parts: []message.ContentPart{message.TextContent{Text: summary}},
-			Model: a.summarizeProvider.Model().ID,
-		})
+		oldSession.SummaryMessageID = msg.ID
+		oldSession.CompletionTokens = response.Usage.OutputTokens
+		oldSession.PromptTokens = 0
+		model := a.summarizeProvider.Model()
+		usage := response.Usage
+		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
+			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
+			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
+			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
+		oldSession.Cost += cost
+		_, err = a.sessions.Save(summarizeCtx, oldSession)
 		if err != nil {
 			event = AgentEvent{
 				Type:  AgentEventTypeError,
-				Error: fmt.Errorf("failed to create summary message: %w", err),
+				Error: fmt.Errorf("failed to save session: %w", err),
 				Done:  true,
 			}
-
 			a.Publish(pubsub.CreatedEvent, event)
-			return
 		}
+
 		event = AgentEvent{
 			Type:      AgentEventTypeSummarize,
-			SessionID: newSession.ID,
+			SessionID: oldSession.ID,
 			Progress:  "Summary complete",
 			Done:      true,
 		}

internal/session/session.go 🔗

@@ -16,6 +16,7 @@ type Session struct {
 	MessageCount     int64
 	PromptTokens     int64
 	CompletionTokens int64
+	SummaryMessageID string
 	Cost             float64
 	CreatedAt        int64
 	UpdatedAt        int64
@@ -105,7 +106,11 @@ func (s *service) Save(ctx context.Context, session Session) (Session, error) {
 		Title:            session.Title,
 		PromptTokens:     session.PromptTokens,
 		CompletionTokens: session.CompletionTokens,
-		Cost:             session.Cost,
+		SummaryMessageID: sql.NullString{
+			String: session.SummaryMessageID,
+			Valid:  session.SummaryMessageID != "",
+		},
+		Cost: session.Cost,
 	})
 	if err != nil {
 		return Session{}, err
@@ -135,6 +140,7 @@ func (s service) fromDBItem(item db.Session) Session {
 		MessageCount:     item.MessageCount,
 		PromptTokens:     item.PromptTokens,
 		CompletionTokens: item.CompletionTokens,
+		SummaryMessageID: item.SummaryMessageID.String,
 		Cost:             item.Cost,
 		CreatedAt:        item.CreatedAt,
 		UpdatedAt:        item.UpdatedAt,

internal/tui/components/chat/list.go 🔗

@@ -99,6 +99,14 @@ func (m *messagesCmp) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 	case renderFinishedMsg:
 		m.rendering = false
 		m.viewport.GotoBottom()
+	case pubsub.Event[session.Session]:
+		if msg.Type == pubsub.UpdatedEvent && msg.Payload.ID == m.session.ID {
+			m.session = msg.Payload
+			if m.session.SummaryMessageID == m.currentMsgID {
+				delete(m.cachedContent, m.currentMsgID)
+				m.renderView()
+			}
+		}
 	case pubsub.Event[message.Message]:
 		needsRerender := false
 		if msg.Type == pubsub.CreatedEvent {
@@ -208,12 +216,15 @@ func (m *messagesCmp) renderView() {
 				m.uiMessages = append(m.uiMessages, cache.content...)
 				continue
 			}
+			isSummary := m.session.SummaryMessageID == msg.ID
+
 			assistantMessages := renderAssistantMessage(
 				msg,
 				inx,
 				m.messages,
 				m.app.Messages,
 				m.currentMsgID,
+				isSummary,
 				m.width,
 				pos,
 			)

internal/tui/components/chat/message.go 🔗

@@ -120,6 +120,7 @@ func renderAssistantMessage(
 	allMessages []message.Message, // we need this to get tool results and the user message
 	messagesService message.Service, // We need this to get the task tool messages
 	focusedUIMessageId string,
+	isSummary bool,
 	width int,
 	position int,
 ) []uiMessage {
@@ -168,6 +169,9 @@ func renderAssistantMessage(
 		if content == "" {
 			content = "*Finished without output*"
 		}
+		if isSummary {
+			info = append(info, baseStyle.Width(width-1).Foreground(t.TextMuted()).Render(" (summary)"))
+		}
 
 		content = renderMessage(content, false, true, width, info...)
 		messages = append(messages, uiMessage{

internal/tui/tui.go 🔗

@@ -331,30 +331,6 @@ func (a appModel) Update(msg tea.Msg) (tea.Model, tea.Cmd) {
 
 		if payload.Done && payload.Type == agent.AgentEventTypeSummarize {
 			a.isCompacting = false
-
-			if payload.SessionID != "" {
-				// Switch to the new session
-				return a, func() tea.Msg {
-					sessions, err := a.app.Sessions.List(context.Background())
-					if err != nil {
-						return util.InfoMsg{
-							Type: util.InfoTypeError,
-							Msg:  "Failed to list sessions: " + err.Error(),
-						}
-					}
-
-					for _, s := range sessions {
-						if s.ID == payload.SessionID {
-							return dialog.SessionSelectedMsg{Session: s}
-						}
-					}
-
-					return util.InfoMsg{
-						Type: util.InfoTypeError,
-						Msg:  "Failed to find new session",
-					}
-				}
-			}
 			return a, util.ReportInfo("Session summarization complete")
 		} else if payload.Done && payload.Type == agent.AgentEventTypeResponse && a.selectedSession.ID != "" {
 			model := a.app.CoderAgent.Model()