diff --git a/internal/agent/agent.go b/internal/agent/agent.go index 62025b1943af245e94da6da744036e8040029c65..30539e5dba5a5b7aa66c2d2ffc35336dbe3d9774 100644 --- a/internal/agent/agent.go +++ b/internal/agent/agent.go @@ -171,10 +171,9 @@ func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy var wg sync.WaitGroup // Generate title if first message. if len(msgs) == 0 { + titleCtx := ctx // Copy to avoid race with ctx reassignment below. wg.Go(func() { - sessionLock.Lock() - a.generateTitle(ctx, ¤tSession, call.Prompt) - sessionLock.Unlock() + a.generateTitle(titleCtx, call.SessionID, call.Prompt) }) } @@ -723,7 +722,7 @@ func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.S return msgs, nil } -func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) { +func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, prompt string) { if prompt == "" { return } @@ -768,8 +767,7 @@ func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Sessi return } - session.Title = title - + // Calculate usage and cost. var openrouterCost *float64 for _, step := range resp.Steps { stepCost := a.openrouterCost(step.ProviderMetadata) @@ -782,8 +780,27 @@ func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Sessi } } - a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost) - _, saveErr := a.sessions.Save(ctx, *session) + modelConfig := a.smallModel.CatwalkCfg + cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) + + modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) + + modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) + + modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens) + + if a.isClaudeCode() { + cost = 0 + } + + // Use override cost if available (e.g., from OpenRouter). + if openrouterCost != nil { + cost = *openrouterCost + } + + promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens + completionTokens := resp.TotalUsage.OutputTokens + resp.TotalUsage.CacheReadTokens + + // Atomically update only title and usage fields to avoid overriding other + // concurrent session updates. + saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost) if saveErr != nil { slog.Error("failed to save session title & usage", "error", saveErr) return diff --git a/internal/db/db.go b/internal/db/db.go index 6f57f2c2c6c7c2854e93fa6246cad6dbfcfa569c..7fa2e6528743dcb5485c0de9b4a3f2b46eb39376 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -84,6 +84,9 @@ func Prepare(ctx context.Context, db DBTX) (*Queries, error) { if q.updateSessionStmt, err = db.PrepareContext(ctx, updateSession); err != nil { return nil, fmt.Errorf("error preparing query UpdateSession: %w", err) } + if q.updateSessionTitleAndUsageStmt, err = db.PrepareContext(ctx, updateSessionTitleAndUsage); err != nil { + return nil, fmt.Errorf("error preparing query UpdateSessionTitleAndUsage: %w", err) + } return &q, nil } @@ -189,6 +192,11 @@ func (q *Queries) Close() error { err = fmt.Errorf("error closing updateSessionStmt: %w", cerr) } } + if q.updateSessionTitleAndUsageStmt != nil { + if cerr := q.updateSessionTitleAndUsageStmt.Close(); cerr != nil { + err = fmt.Errorf("error closing updateSessionTitleAndUsageStmt: %w", cerr) + } + } return err } @@ -226,53 +234,55 @@ func (q *Queries) queryRow(ctx context.Context, stmt *sql.Stmt, query string, ar } type Queries struct { - db DBTX - tx *sql.Tx - createFileStmt *sql.Stmt - createMessageStmt *sql.Stmt - createSessionStmt *sql.Stmt - deleteFileStmt *sql.Stmt - deleteMessageStmt *sql.Stmt - deleteSessionStmt *sql.Stmt - deleteSessionFilesStmt *sql.Stmt - deleteSessionMessagesStmt *sql.Stmt - getFileStmt *sql.Stmt - getFileByPathAndSessionStmt *sql.Stmt - getMessageStmt *sql.Stmt - getSessionByIDStmt *sql.Stmt - listFilesByPathStmt *sql.Stmt - listFilesBySessionStmt *sql.Stmt - listLatestSessionFilesStmt *sql.Stmt - listMessagesBySessionStmt *sql.Stmt - listNewFilesStmt *sql.Stmt - listSessionsStmt *sql.Stmt - updateMessageStmt *sql.Stmt - updateSessionStmt *sql.Stmt + db DBTX + tx *sql.Tx + createFileStmt *sql.Stmt + createMessageStmt *sql.Stmt + createSessionStmt *sql.Stmt + deleteFileStmt *sql.Stmt + deleteMessageStmt *sql.Stmt + deleteSessionStmt *sql.Stmt + deleteSessionFilesStmt *sql.Stmt + deleteSessionMessagesStmt *sql.Stmt + getFileStmt *sql.Stmt + getFileByPathAndSessionStmt *sql.Stmt + getMessageStmt *sql.Stmt + getSessionByIDStmt *sql.Stmt + listFilesByPathStmt *sql.Stmt + listFilesBySessionStmt *sql.Stmt + listLatestSessionFilesStmt *sql.Stmt + listMessagesBySessionStmt *sql.Stmt + listNewFilesStmt *sql.Stmt + listSessionsStmt *sql.Stmt + updateMessageStmt *sql.Stmt + updateSessionStmt *sql.Stmt + updateSessionTitleAndUsageStmt *sql.Stmt } func (q *Queries) WithTx(tx *sql.Tx) *Queries { return &Queries{ - db: tx, - tx: tx, - createFileStmt: q.createFileStmt, - createMessageStmt: q.createMessageStmt, - createSessionStmt: q.createSessionStmt, - deleteFileStmt: q.deleteFileStmt, - deleteMessageStmt: q.deleteMessageStmt, - deleteSessionStmt: q.deleteSessionStmt, - deleteSessionFilesStmt: q.deleteSessionFilesStmt, - deleteSessionMessagesStmt: q.deleteSessionMessagesStmt, - getFileStmt: q.getFileStmt, - getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt, - getMessageStmt: q.getMessageStmt, - getSessionByIDStmt: q.getSessionByIDStmt, - listFilesByPathStmt: q.listFilesByPathStmt, - listFilesBySessionStmt: q.listFilesBySessionStmt, - listLatestSessionFilesStmt: q.listLatestSessionFilesStmt, - listMessagesBySessionStmt: q.listMessagesBySessionStmt, - listNewFilesStmt: q.listNewFilesStmt, - listSessionsStmt: q.listSessionsStmt, - updateMessageStmt: q.updateMessageStmt, - updateSessionStmt: q.updateSessionStmt, + db: tx, + tx: tx, + createFileStmt: q.createFileStmt, + createMessageStmt: q.createMessageStmt, + createSessionStmt: q.createSessionStmt, + deleteFileStmt: q.deleteFileStmt, + deleteMessageStmt: q.deleteMessageStmt, + deleteSessionStmt: q.deleteSessionStmt, + deleteSessionFilesStmt: q.deleteSessionFilesStmt, + deleteSessionMessagesStmt: q.deleteSessionMessagesStmt, + getFileStmt: q.getFileStmt, + getFileByPathAndSessionStmt: q.getFileByPathAndSessionStmt, + getMessageStmt: q.getMessageStmt, + getSessionByIDStmt: q.getSessionByIDStmt, + listFilesByPathStmt: q.listFilesByPathStmt, + listFilesBySessionStmt: q.listFilesBySessionStmt, + listLatestSessionFilesStmt: q.listLatestSessionFilesStmt, + listMessagesBySessionStmt: q.listMessagesBySessionStmt, + listNewFilesStmt: q.listNewFilesStmt, + listSessionsStmt: q.listSessionsStmt, + updateMessageStmt: q.updateMessageStmt, + updateSessionStmt: q.updateSessionStmt, + updateSessionTitleAndUsageStmt: q.updateSessionTitleAndUsageStmt, } } diff --git a/internal/db/querier.go b/internal/db/querier.go index 0978eb2c6e4c7b1aa80888530bb5169a1d2bcec3..dfa6d722535b4265f3f54331d1904523a648f562 100644 --- a/internal/db/querier.go +++ b/internal/db/querier.go @@ -29,6 +29,7 @@ type Querier interface { ListSessions(ctx context.Context) ([]Session, error) UpdateMessage(ctx context.Context, arg UpdateMessageParams) error UpdateSession(ctx context.Context, arg UpdateSessionParams) (Session, error) + UpdateSessionTitleAndUsage(ctx context.Context, arg UpdateSessionTitleAndUsageParams) error } var _ Querier = (*Queries)(nil) diff --git a/internal/db/sessions.sql.go b/internal/db/sessions.sql.go index 012e70b40e825fce3b5941420f992715c2bfd6c7..3b1ecbfecb3c5d947e84b1ec07f7a3f72b8d6139 100644 --- a/internal/db/sessions.sql.go +++ b/internal/db/sessions.sql.go @@ -199,3 +199,32 @@ func (q *Queries) UpdateSession(ctx context.Context, arg UpdateSessionParams) (S ) return i, err } + +const updateSessionTitleAndUsage = `-- name: UpdateSessionTitleAndUsage :exec +UPDATE sessions +SET + title = ?, + prompt_tokens = prompt_tokens + ?, + completion_tokens = completion_tokens + ?, + cost = cost + ? +WHERE id = ? +` + +type UpdateSessionTitleAndUsageParams struct { + Title string `json:"title"` + PromptTokens int64 `json:"prompt_tokens"` + CompletionTokens int64 `json:"completion_tokens"` + Cost float64 `json:"cost"` + ID string `json:"id"` +} + +func (q *Queries) UpdateSessionTitleAndUsage(ctx context.Context, arg UpdateSessionTitleAndUsageParams) error { + _, err := q.exec(ctx, q.updateSessionTitleAndUsageStmt, updateSessionTitleAndUsage, + arg.Title, + arg.PromptTokens, + arg.CompletionTokens, + arg.Cost, + arg.ID, + ) + return err +} diff --git a/internal/db/sql/sessions.sql b/internal/db/sql/sessions.sql index 9f67305981445d187a454f5b4ce5084c82f00a8a..54bc072a0dcd7462d805f30cf832714e1f7d7705 100644 --- a/internal/db/sql/sessions.sql +++ b/internal/db/sql/sessions.sql @@ -46,6 +46,15 @@ SET WHERE id = ? RETURNING *; +-- name: UpdateSessionTitleAndUsage :exec +UPDATE sessions +SET + title = ?, + prompt_tokens = prompt_tokens + ?, + completion_tokens = completion_tokens + ?, + cost = cost + ? +WHERE id = ?; + -- name: DeleteSession :exec DELETE FROM sessions diff --git a/internal/message/content.go b/internal/message/content.go index 358ad120d8f87109ea8888984ad236b155388788..7333f738c0aa685833c57cc97086e61928d3f51e 100644 --- a/internal/message/content.go +++ b/internal/message/content.go @@ -407,6 +407,15 @@ func (m *Message) SetToolResults(tr []ToolResult) { } } +// Clone returns a deep copy of the message with an independent Parts slice. +// This prevents race conditions when the message is modified concurrently. +func (m *Message) Clone() Message { + clone := *m + clone.Parts = make([]ContentPart, len(m.Parts)) + copy(clone.Parts, m.Parts) + return clone +} + func (m *Message) AddFinish(reason FinishReason, message, details string) { // remove any existing finish part for i, part := range m.Parts { diff --git a/internal/message/message.go b/internal/message/message.go index 97db7fe34d5a5169b1f6ef5686a3d27760b44909..a09d0acbf590e840541a7d5e057fb89513cc0618 100644 --- a/internal/message/message.go +++ b/internal/message/message.go @@ -51,7 +51,9 @@ func (s *service) Delete(ctx context.Context, id string) error { if err != nil { return err } - s.Publish(pubsub.DeletedEvent, message) + // Clone the message before publishing to avoid race conditions with + // concurrent modifications to the Parts slice. + s.Publish(pubsub.DeletedEvent, message.Clone()) return nil } @@ -85,7 +87,9 @@ func (s *service) Create(ctx context.Context, sessionID string, params CreateMes if err != nil { return Message{}, err } - s.Publish(pubsub.CreatedEvent, message) + // Clone the message before publishing to avoid race conditions with + // concurrent modifications to the Parts slice. + s.Publish(pubsub.CreatedEvent, message.Clone()) return message, nil } @@ -124,7 +128,9 @@ func (s *service) Update(ctx context.Context, message Message) error { return err } message.UpdatedAt = time.Now().Unix() - s.Publish(pubsub.UpdatedEvent, message) + // Clone the message before publishing to avoid race conditions with + // concurrent modifications to the Parts slice. + s.Publish(pubsub.UpdatedEvent, message.Clone()) return nil } diff --git a/internal/pubsub/broker.go b/internal/pubsub/broker.go index 80948d3d515a4fb5dad0d4dc36adbbff4e502993..ed14cbfed6c8fd44355501e16457e0dd92a494bc 100644 --- a/internal/pubsub/broker.go +++ b/internal/pubsub/broker.go @@ -92,22 +92,17 @@ func (b *Broker[T]) GetSubscriberCount() int { func (b *Broker[T]) Publish(t EventType, payload T) { b.mu.RLock() + defer b.mu.RUnlock() + select { case <-b.done: - b.mu.RUnlock() return default: } - subscribers := make([]chan Event[T], 0, len(b.subs)) - for sub := range b.subs { - subscribers = append(subscribers, sub) - } - b.mu.RUnlock() - event := Event[T]{Type: t, Payload: payload} - for _, sub := range subscribers { + for sub := range b.subs { select { case sub <- event: default: diff --git a/internal/session/session.go b/internal/session/session.go index 6d5e9a437c3c5a996973823e8903f47ea1cee514..3792cc1d576cdd7ebd0dbf0b64670c746718da9c 100644 --- a/internal/session/session.go +++ b/internal/session/session.go @@ -50,6 +50,7 @@ type Service interface { Get(ctx context.Context, id string) (Session, error) List(ctx context.Context) ([]Session, error) Save(ctx context.Context, session Session) (Session, error) + UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error Delete(ctx context.Context, id string) error // Agent tool session management @@ -156,6 +157,18 @@ func (s *service) Save(ctx context.Context, session Session) (Session, error) { return session, nil } +// UpdateTitleAndUsage updates only the title and usage fields atomically. +// This is safer than fetching, modifying, and saving the entire session. +func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error { + return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{ + ID: sessionID, + Title: title, + PromptTokens: promptTokens, + CompletionTokens: completionTokens, + Cost: cost, + }) +} + func (s *service) List(ctx context.Context) ([]Session, error) { dbSessions, err := s.q.ListSessions(ctx) if err != nil {