Detailed changes
@@ -171,10 +171,9 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy
var wg sync.WaitGroup
// Generate title if first message.
if len(msgs) == 0 {
+ titleCtx := ctx // Copy to avoid race with ctx reassignment below.
wg.Go(func() {
- sessionLock.Lock()
- a.generateTitle(ctx, ¤tSession, call.Prompt)
- sessionLock.Unlock()
+ a.generateTitle(titleCtx, call.SessionID, call.Prompt)
})
}
@@ -723,7 +722,7 @@ func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.S
return msgs, nil
}
-func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
+func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, prompt string) {
if prompt == "" {
return
}
@@ -768,8 +767,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Sessi
return
}
- session.Title = title
-
+ // Calculate usage and cost.
var openrouterCost *float64
for _, step := range resp.Steps {
stepCost := a.openrouterCost(step.ProviderMetadata)
@@ -782,8 +780,27 @@ func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Sessi
}
}
- a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
- _, saveErr := a.sessions.Save(ctx, *session)
+ modelConfig := a.smallModel.CatwalkCfg
+ cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
+ modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
+ modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
+ modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
+
+ if a.isClaudeCode() {
+ cost = 0
+ }
+
+ // Use override cost if available (e.g., from OpenRouter).
+ if openrouterCost != nil {
+ cost = *openrouterCost
+ }
+
+ promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
+ completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens
+
+ // Atomically update only title and usage fields to avoid overriding other
+ // concurrent session updates.
+ saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
if saveErr != nil {
slog.Error("failed to save session title & usage", "error", saveErr)
return
@@ -84,6 +84,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) {
if q.updateSessionStmt, err = db.PrepareContext(ctx, updateSession); err != nil {
return nil, fmt.Errorf("error preparing query UpdateSession: %w", err)
}
+ if q.updateSessionTitleAndUsageStmt, err = db.PrepareContext(ctx, updateSessionTitleAndUsage); err != nil {
+ return nil, fmt.Errorf("error preparing query UpdateSessionTitleAndUsage: %w", err)
+ }
return &q, nil
}
@@ -189,6 +192,11 @@ func (q *Queries) Close() error {
err = fmt.Errorf("error closing updateSessionStmt: %w", cerr)
}
}
+ if q.updateSessionTitleAndUsageStmt != nil {
+ if cerr := q.updateSessionTitleAndUsageStmt.Close(); cerr != nil {
+ err = fmt.Errorf("error closing updateSessionTitleAndUsageStmt: %w", cerr)
+ }
+ }
return err
}
@@ -226,53 +234,55 @@ func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, ar
}
type Queries struct {
- db DBTX
- tx *sql.Tx
- createFileStmt *sql.Stmt
- createMessageStmt *sql.Stmt
- createSessionStmt *sql.Stmt
- deleteFileStmt *sql.Stmt
- deleteMessageStmt *sql.Stmt
- deleteSessionStmt *sql.Stmt
- deleteSessionFilesStmt *sql.Stmt
- deleteSessionMessagesStmt *sql.Stmt
- getFileStmt *sql.Stmt
- getFileByPathAndSessionStmt *sql.Stmt
- getMessageStmt *sql.Stmt
- getSessionByIDStmt *sql.Stmt
- listFilesByPathStmt *sql.Stmt
- listFilesBySessionStmt *sql.Stmt
- listLatestSessionFilesStmt *sql.Stmt
- listMessagesBySessionStmt *sql.Stmt
- listNewFilesStmt *sql.Stmt
- listSessionsStmt *sql.Stmt
- updateMessageStmt *sql.Stmt
- updateSessionStmt *sql.Stmt
+ db DBTX
+ tx *sql.Tx
+ createFileStmt *sql.Stmt
+ createMessageStmt *sql.Stmt
+ createSessionStmt *sql.Stmt
+ deleteFileStmt *sql.Stmt
+ deleteMessageStmt *sql.Stmt
+ deleteSessionStmt *sql.Stmt
+ deleteSessionFilesStmt *sql.Stmt
+ deleteSessionMessagesStmt *sql.Stmt
+ getFileStmt *sql.Stmt
+ getFileByPathAndSessionStmt *sql.Stmt
+ getMessageStmt *sql.Stmt
+ getSessionByIDStmt *sql.Stmt
+ listFilesByPathStmt *sql.Stmt
+ listFilesBySessionStmt *sql.Stmt
+ listLatestSessionFilesStmt *sql.Stmt
+ listMessagesBySessionStmt *sql.Stmt
+ listNewFilesStmt *sql.Stmt
+ listSessionsStmt *sql.Stmt
+ updateMessageStmt *sql.Stmt
+ updateSessionStmt *sql.Stmt
+ updateSessionTitleAndUsageStmt *sql.Stmt
}
func (q *Queries) WithTx(tx *sql.Tx) *Queries {
return &Queries{
- db: tx,
- tx: tx,
- createFileStmt: q.createFileStmt,
- createMessageStmt: q.createMessageStmt,
- createSessionStmt: q.createSessionStmt,
- deleteFileStmt: q.deleteFileStmt,
- deleteMessageStmt: q.deleteMessageStmt,
- deleteSessionStmt: q.deleteSessionStmt,
- deleteSessionFilesStmt: q.deleteSessionFilesStmt,
- deleteSessionMessagesStmt: q.deleteSessionMessagesStmt,
- getFileStmt: q.getFileStmt,
- getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt,
- getMessageStmt: q.getMessageStmt,
- getSessionByIDStmt: q.getSessionByIDStmt,
- listFilesByPathStmt: q.listFilesByPathStmt,
- listFilesBySessionStmt: q.listFilesBySessionStmt,
- listLatestSessionFilesStmt: q.listLatestSessionFilesStmt,
- listMessagesBySessionStmt: q.listMessagesBySessionStmt,
- listNewFilesStmt: q.listNewFilesStmt,
- listSessionsStmt: q.listSessionsStmt,
- updateMessageStmt: q.updateMessageStmt,
- updateSessionStmt: q.updateSessionStmt,
+ db: tx,
+ tx: tx,
+ createFileStmt: q.createFileStmt,
+ createMessageStmt: q.createMessageStmt,
+ createSessionStmt: q.createSessionStmt,
+ deleteFileStmt: q.deleteFileStmt,
+ deleteMessageStmt: q.deleteMessageStmt,
+ deleteSessionStmt: q.deleteSessionStmt,
+ deleteSessionFilesStmt: q.deleteSessionFilesStmt,
+ deleteSessionMessagesStmt: q.deleteSessionMessagesStmt,
+ getFileStmt: q.getFileStmt,
+ getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt,
+ getMessageStmt: q.getMessageStmt,
+ getSessionByIDStmt: q.getSessionByIDStmt,
+ listFilesByPathStmt: q.listFilesByPathStmt,
+ listFilesBySessionStmt: q.listFilesBySessionStmt,
+ listLatestSessionFilesStmt: q.listLatestSessionFilesStmt,
+ listMessagesBySessionStmt: q.listMessagesBySessionStmt,
+ listNewFilesStmt: q.listNewFilesStmt,
+ listSessionsStmt: q.listSessionsStmt,
+ updateMessageStmt: q.updateMessageStmt,
+ updateSessionStmt: q.updateSessionStmt,
+ updateSessionTitleAndUsageStmt: q.updateSessionTitleAndUsageStmt,
}
}
@@ -29,6 +29,7 @@ type Querier interface {
ListSessions(ctx context.Context) ([]Session, error)
UpdateMessage(ctx context.Context, arg UpdateMessageParams) error
UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error)
+ UpdateSessionTitleAndUsage(ctx context.Context, arg UpdateSessionTitleAndUsageParams) error
}
var _ Querier = (*Queries)(nil)
@@ -199,3 +199,32 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S
)
return i, err
}
+
+const updateSessionTitleAndUsage = `-- name: UpdateSessionTitleAndUsage :exec
+UPDATE sessions
+SET
+ title = ?,
+ prompt_tokens = prompt_tokens + ?,
+ completion_tokens = completion_tokens + ?,
+ cost = cost + ?
+WHERE id = ?
+`
+
+type UpdateSessionTitleAndUsageParams struct {
+ Title string `json:"title"`
+ PromptTokens int64 `json:"prompt_tokens"`
+ CompletionTokens int64 `json:"completion_tokens"`
+ Cost float64 `json:"cost"`
+ ID string `json:"id"`
+}
+
+func (q *Queries) UpdateSessionTitleAndUsage(ctx context.Context, arg UpdateSessionTitleAndUsageParams) error {
+ _, err := q.exec(ctx, q.updateSessionTitleAndUsageStmt, updateSessionTitleAndUsage,
+ arg.Title,
+ arg.PromptTokens,
+ arg.CompletionTokens,
+ arg.Cost,
+ arg.ID,
+ )
+ return err
+}
@@ -46,6 +46,15 @@ SET
WHERE id = ?
RETURNING *;
+-- name: UpdateSessionTitleAndUsage :exec
+UPDATE sessions
+SET
+ title = ?,
+ prompt_tokens = prompt_tokens + ?,
+ completion_tokens = completion_tokens + ?,
+ cost = cost + ?
+WHERE id = ?;
+
-- name: DeleteSession :exec
DELETE FROM sessions
@@ -407,6 +407,15 @@ func (m *Message) SetToolResults(tr []ToolResult) {
}
}
+// Clone returns a deep copy of the message with an independent Parts slice.
+// This prevents race conditions when the message is modified concurrently.
+func (m *Message) Clone() Message {
+ clone := *m
+ clone.Parts = make([]ContentPart, len(m.Parts))
+ copy(clone.Parts, m.Parts)
+ return clone
+}
+
func (m *Message) AddFinish(reason FinishReason, message, details string) {
// remove any existing finish part
for i, part := range m.Parts {
@@ -51,7 +51,9 @@ func (s *service) Delete(ctx context.Context, id string) error {
if err != nil {
return err
}
- s.Publish(pubsub.DeletedEvent, message)
+ // Clone the message before publishing to avoid race conditions with
+ // concurrent modifications to the Parts slice.
+ s.Publish(pubsub.DeletedEvent, message.Clone())
return nil
}
@@ -85,7 +87,9 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes
if err != nil {
return Message{}, err
}
- s.Publish(pubsub.CreatedEvent, message)
+ // Clone the message before publishing to avoid race conditions with
+ // concurrent modifications to the Parts slice.
+ s.Publish(pubsub.CreatedEvent, message.Clone())
return message, nil
}
@@ -124,7 +128,9 @@ func (s *service) Update(ctx context.Context, message Message) error {
return err
}
message.UpdatedAt = time.Now().Unix()
- s.Publish(pubsub.UpdatedEvent, message)
+ // Clone the message before publishing to avoid race conditions with
+ // concurrent modifications to the Parts slice.
+ s.Publish(pubsub.UpdatedEvent, message.Clone())
return nil
}
@@ -92,22 +92,17 @@ func (b *Broker[T]) GetSubscriberCount() int {
func (b *Broker[T]) Publish(t EventType, payload T) {
b.mu.RLock()
+ defer b.mu.RUnlock()
+
select {
case <-b.done:
- b.mu.RUnlock()
return
default:
}
- subscribers := make([]chan Event[T], 0, len(b.subs))
- for sub := range b.subs {
- subscribers = append(subscribers, sub)
- }
- b.mu.RUnlock()
-
event := Event[T]{Type: t, Payload: payload}
- for _, sub := range subscribers {
+ for sub := range b.subs {
select {
case sub <- event:
default:
@@ -50,6 +50,7 @@ type Service interface {
Get(ctx context.Context, id string) (Session, error)
List(ctx context.Context) ([]Session, error)
Save(ctx context.Context, session Session) (Session, error)
+ UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
Delete(ctx context.Context, id string) error
// Agent tool session management
@@ -156,6 +157,18 @@ func (s *service) Save(ctx context.Context, session Session) (Session, error) {
return session, nil
}
+// UpdateTitleAndUsage updates only the title and usage fields atomically.
+// This is safer than fetching, modifying, and saving the entire session.
+func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
+ return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
+ ID: sessionID,
+ Title: title,
+ PromptTokens: promptTokens,
+ CompletionTokens: completionTokens,
+ Cost: cost,
+ })
+}
+
func (s *service) List(ctx context.Context) ([]Session, error) {
dbSessions, err := s.q.ListSessions(ctx)
if err != nil {