fix: race condition (#1649)

Kujtim Hoxha created

Change summary

internal/agent/agent.go      | 33 +++++++++---
internal/db/db.go            | 98 ++++++++++++++++++++-----------------
internal/db/querier.go       |  1 
internal/db/sessions.sql.go  | 29 +++++++++++
internal/db/sql/sessions.sql |  9 +++
internal/message/content.go  |  9 +++
internal/message/message.go  | 12 +++-
internal/pubsub/broker.go    | 11 +---
internal/session/session.go  | 13 +++++
9 files changed, 152 insertions(+), 63 deletions(-)

Detailed changes

internal/agent/agent.go 🔗

@@ -171,10 +171,9 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
 	var wg sync.WaitGroup
 	// Generate title if first message.
 	if len(msgs) == 0 {
+		titleCtx := ctx // Copy to avoid race with ctx reassignment below.
 		wg.Go(func() {
-			sessionLock.Lock()
-			a.generateTitle(ctx, &currentSession, call.Prompt)
-			sessionLock.Unlock()
+			a.generateTitle(titleCtx, call.SessionID, call.Prompt)
 		})
 	}
 
@@ -723,7 +722,7 @@ func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.S
 	return msgs, nil
 }
 
-func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
+func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, prompt string) {
 	if prompt == "" {
 		return
 	}
@@ -768,8 +767,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Sessi
 		return
 	}
 
-	session.Title = title
-
+	// Calculate usage and cost.
 	var openrouterCost *float64
 	for _, step := range resp.Steps {
 		stepCost := a.openrouterCost(step.ProviderMetadata)
@@ -782,8 +780,27 @@ func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Sessi
 		}
 	}
 
-	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
-	_, saveErr := a.sessions.Save(ctx, *session)
+	modelConfig := a.smallModel.CatwalkCfg
+	cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
+		modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
+		modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
+		modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
+
+	if a.isClaudeCode() {
+		cost = 0
+	}
+
+	// Use override cost if available (e.g., from OpenRouter).
+	if openrouterCost != nil {
+		cost = *openrouterCost
+	}
+
+	promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
+	completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
+
+	// Atomically update only title and usage fields to avoid overriding other
+	// concurrent session updates.
+	saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
 	if saveErr != nil {
 		slog.Error("failed to save session title & usage", "error", saveErr)
 		return

internal/db/db.go 🔗

@@ -84,6 +84,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
 	if q.updateSessionStmt, err = db.PrepareContext(ctx, updateSession); err != nil {
 		return nil, fmt.Errorf("error preparing query UpdateSession: %w", err)
 	}
+	if q.updateSessionTitleAndUsageStmt, err = db.PrepareContext(ctx, updateSessionTitleAndUsage); err != nil {
+		return nil, fmt.Errorf("error preparing query UpdateSessionTitleAndUsage: %w", err)
+	}
 	return &q, nil
 }
 
@@ -189,6 +192,11 @@ func (q *Queries) Close() error {
 			err = fmt.Errorf("error closing updateSessionStmt: %w", cerr)
 		}
 	}
+	if q.updateSessionTitleAndUsageStmt != nil {
+		if cerr := q.updateSessionTitleAndUsageStmt.Close(); cerr != nil {
+			err = fmt.Errorf("error closing updateSessionTitleAndUsageStmt: %w", cerr)
+		}
+	}
 	return err
 }
 
@@ -226,53 +234,55 @@ func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, ar
 }
 
 type Queries struct {
-	db                          DBTX
-	tx                          *sql.Tx
-	createFileStmt              *sql.Stmt
-	createMessageStmt           *sql.Stmt
-	createSessionStmt           *sql.Stmt
-	deleteFileStmt              *sql.Stmt
-	deleteMessageStmt           *sql.Stmt
-	deleteSessionStmt           *sql.Stmt
-	deleteSessionFilesStmt      *sql.Stmt
-	deleteSessionMessagesStmt   *sql.Stmt
-	getFileStmt                 *sql.Stmt
-	getFileByPathAndSessionStmt *sql.Stmt
-	getMessageStmt              *sql.Stmt
-	getSessionByIDStmt          *sql.Stmt
-	listFilesByPathStmt         *sql.Stmt
-	listFilesBySessionStmt      *sql.Stmt
-	listLatestSessionFilesStmt  *sql.Stmt
-	listMessagesBySessionStmt   *sql.Stmt
-	listNewFilesStmt            *sql.Stmt
-	listSessionsStmt            *sql.Stmt
-	updateMessageStmt           *sql.Stmt
-	updateSessionStmt           *sql.Stmt
+	db                             DBTX
+	tx                             *sql.Tx
+	createFileStmt                 *sql.Stmt
+	createMessageStmt              *sql.Stmt
+	createSessionStmt              *sql.Stmt
+	deleteFileStmt                 *sql.Stmt
+	deleteMessageStmt              *sql.Stmt
+	deleteSessionStmt              *sql.Stmt
+	deleteSessionFilesStmt         *sql.Stmt
+	deleteSessionMessagesStmt      *sql.Stmt
+	getFileStmt                    *sql.Stmt
+	getFileByPathAndSessionStmt    *sql.Stmt
+	getMessageStmt                 *sql.Stmt
+	getSessionByIDStmt             *sql.Stmt
+	listFilesByPathStmt            *sql.Stmt
+	listFilesBySessionStmt         *sql.Stmt
+	listLatestSessionFilesStmt     *sql.Stmt
+	listMessagesBySessionStmt      *sql.Stmt
+	listNewFilesStmt               *sql.Stmt
+	listSessionsStmt               *sql.Stmt
+	updateMessageStmt              *sql.Stmt
+	updateSessionStmt              *sql.Stmt
+	updateSessionTitleAndUsageStmt *sql.Stmt
 }
 
 func (q *Queries) WithTx(tx *sql.Tx) *Queries {
 	return &Queries{
-		db:                          tx,
-		tx:                          tx,
-		createFileStmt:              q.createFileStmt,
-		createMessageStmt:           q.createMessageStmt,
-		createSessionStmt:           q.createSessionStmt,
-		deleteFileStmt:              q.deleteFileStmt,
-		deleteMessageStmt:           q.deleteMessageStmt,
-		deleteSessionStmt:           q.deleteSessionStmt,
-		deleteSessionFilesStmt:      q.deleteSessionFilesStmt,
-		deleteSessionMessagesStmt:   q.deleteSessionMessagesStmt,
-		getFileStmt:                 q.getFileStmt,
-		getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt,
-		getMessageStmt:              q.getMessageStmt,
-		getSessionByIDStmt:          q.getSessionByIDStmt,
-		listFilesByPathStmt:         q.listFilesByPathStmt,
-		listFilesBySessionStmt:      q.listFilesBySessionStmt,
-		listLatestSessionFilesStmt:  q.listLatestSessionFilesStmt,
-		listMessagesBySessionStmt:   q.listMessagesBySessionStmt,
-		listNewFilesStmt:            q.listNewFilesStmt,
-		listSessionsStmt:            q.listSessionsStmt,
-		updateMessageStmt:           q.updateMessageStmt,
-		updateSessionStmt:           q.updateSessionStmt,
+		db:                             tx,
+		tx:                             tx,
+		createFileStmt:                 q.createFileStmt,
+		createMessageStmt:              q.createMessageStmt,
+		createSessionStmt:              q.createSessionStmt,
+		deleteFileStmt:                 q.deleteFileStmt,
+		deleteMessageStmt:              q.deleteMessageStmt,
+		deleteSessionStmt:              q.deleteSessionStmt,
+		deleteSessionFilesStmt:         q.deleteSessionFilesStmt,
+		deleteSessionMessagesStmt:      q.deleteSessionMessagesStmt,
+		getFileStmt:                    q.getFileStmt,
+		getFileByPathAndSessionStmt:    q.getFileByPathAndSessionStmt,
+		getMessageStmt:                 q.getMessageStmt,
+		getSessionByIDStmt:             q.getSessionByIDStmt,
+		listFilesByPathStmt:            q.listFilesByPathStmt,
+		listFilesBySessionStmt:         q.listFilesBySessionStmt,
+		listLatestSessionFilesStmt:     q.listLatestSessionFilesStmt,
+		listMessagesBySessionStmt:      q.listMessagesBySessionStmt,
+		listNewFilesStmt:               q.listNewFilesStmt,
+		listSessionsStmt:               q.listSessionsStmt,
+		updateMessageStmt:              q.updateMessageStmt,
+		updateSessionStmt:              q.updateSessionStmt,
+		updateSessionTitleAndUsageStmt: q.updateSessionTitleAndUsageStmt,
 	}
 }

internal/db/querier.go 🔗

@@ -29,6 +29,7 @@ type Querier interface {
 	ListSessions(ctx context.Context) ([]Session, error)
 	UpdateMessage(ctx context.Context, arg UpdateMessageParams) error
 	UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
+	UpdateSessionTitleAndUsage(ctx context.Context, arg UpdateSessionTitleAndUsageParams) error
 }
 
 var _ Querier = (*Queries)(nil)

internal/db/sessions.sql.go 🔗

@@ -199,3 +199,32 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
 	)
 	return i, err
 }
+
+const updateSessionTitleAndUsage = `-- name: UpdateSessionTitleAndUsage :exec
+UPDATE sessions
+SET
+    title = ?,
+    prompt_tokens = prompt_tokens + ?,
+    completion_tokens = completion_tokens + ?,
+    cost = cost + ?
+WHERE id = ?
+`
+
+type UpdateSessionTitleAndUsageParams struct {
+	Title            string  `json:"title"`
+	PromptTokens     int64   `json:"prompt_tokens"`
+	CompletionTokens int64   `json:"completion_tokens"`
+	Cost             float64 `json:"cost"`
+	ID               string  `json:"id"`
+}
+
+func (q *Queries) UpdateSessionTitleAndUsage(ctx context.Context, arg UpdateSessionTitleAndUsageParams) error {
+	_, err := q.exec(ctx, q.updateSessionTitleAndUsageStmt, updateSessionTitleAndUsage,
+		arg.Title,
+		arg.PromptTokens,
+		arg.CompletionTokens,
+		arg.Cost,
+		arg.ID,
+	)
+	return err
+}

internal/db/sql/sessions.sql 🔗

@@ -46,6 +46,15 @@ SET
 WHERE id = ?
 RETURNING *;
 
+-- name: UpdateSessionTitleAndUsage :exec
+UPDATE sessions
+SET
+    title = ?,
+    prompt_tokens = prompt_tokens + ?,
+    completion_tokens = completion_tokens + ?,
+    cost = cost + ?
+WHERE id = ?;
+
 
 -- name: DeleteSession :exec
 DELETE FROM sessions

internal/message/content.go 🔗

@@ -407,6 +407,15 @@ func (m *Message) SetToolResults(tr []ToolResult) {
 	}
 }
 
+// Clone returns a deep copy of the message with an independent Parts slice.
+// This prevents race conditions when the message is modified concurrently.
+func (m *Message) Clone() Message {
+	clone := *m
+	clone.Parts = make([]ContentPart, len(m.Parts))
+	copy(clone.Parts, m.Parts)
+	return clone
+}
+
 func (m *Message) AddFinish(reason FinishReason, message, details string) {
 	// remove any existing finish part
 	for i, part := range m.Parts {

internal/message/message.go 🔗

@@ -51,7 +51,9 @@ func (s *service) Delete(ctx context.Context, id string) error {
 	if err != nil {
 		return err
 	}
-	s.Publish(pubsub.DeletedEvent, message)
+	// Clone the message before publishing to avoid race conditions with
+	// concurrent modifications to the Parts slice.
+	s.Publish(pubsub.DeletedEvent, message.Clone())
 	return nil
 }
 
@@ -85,7 +87,9 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes
 	if err != nil {
 		return Message{}, err
 	}
-	s.Publish(pubsub.CreatedEvent, message)
+	// Clone the message before publishing to avoid race conditions with
+	// concurrent modifications to the Parts slice.
+	s.Publish(pubsub.CreatedEvent, message.Clone())
 	return message, nil
 }
 
@@ -124,7 +128,9 @@ func (s *service) Update(ctx context.Context, message Message) error {
 		return err
 	}
 	message.UpdatedAt = time.Now().Unix()
-	s.Publish(pubsub.UpdatedEvent, message)
+	// Clone the message before publishing to avoid race conditions with
+	// concurrent modifications to the Parts slice.
+	s.Publish(pubsub.UpdatedEvent, message.Clone())
 	return nil
 }
 

internal/pubsub/broker.go 🔗

@@ -92,22 +92,17 @@ func (b *Broker[T]) GetSubscriberCount() int {
 
 func (b *Broker[T]) Publish(t EventType, payload T) {
 	b.mu.RLock()
+	defer b.mu.RUnlock()
+
 	select {
 	case <-b.done:
-		b.mu.RUnlock()
 		return
 	default:
 	}
 
-	subscribers := make([]chan Event[T], 0, len(b.subs))
-	for sub := range b.subs {
-		subscribers = append(subscribers, sub)
-	}
-	b.mu.RUnlock()
-
 	event := Event[T]{Type: t, Payload: payload}
 
-	for _, sub := range subscribers {
+	for sub := range b.subs {
 		select {
 		case sub <- event:
 		default:

internal/session/session.go 🔗

@@ -50,6 +50,7 @@ type Service interface {
 	Get(ctx context.Context, id string) (Session, error)
 	List(ctx context.Context) ([]Session, error)
 	Save(ctx context.Context, session Session) (Session, error)
+	UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
 	Delete(ctx context.Context, id string) error
 
 	// Agent tool session management
@@ -156,6 +157,18 @@ func (s *service) Save(ctx context.Context, session Session) (Session, error) {
 	return session, nil
 }
 
+// UpdateTitleAndUsage updates only the title and usage fields atomically.
+// This is safer than fetching, modifying, and saving the entire session.
+func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
+	return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
+		ID:               sessionID,
+		Title:            title,
+		PromptTokens:     promptTokens,
+		CompletionTokens: completionTokens,
+		Cost:             cost,
+	})
+}
+
 func (s *service) List(ctx context.Context) ([]Session, error) {
 	dbSessions, err := s.q.ListSessions(ctx)
 	if err != nil {