1package session
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "strings"
10 "sync"
11
12 "github.com/charmbracelet/crush/internal/db"
13 "github.com/charmbracelet/crush/internal/event"
14 "github.com/charmbracelet/crush/internal/pubsub"
15 "github.com/google/uuid"
16 "github.com/zeebo/xxh3"
17)
18
19type TodoStatus string
20
21const (
22 TodoStatusPending TodoStatus = "pending"
23 TodoStatusInProgress TodoStatus = "in_progress"
24 TodoStatusCompleted TodoStatus = "completed"
25)
26
27// HashID returns the XXH3 hash of a session ID (UUID) as a hex string.
28func HashID(id string) string {
29 h := xxh3.New()
30 h.WriteString(id)
31 return fmt.Sprintf("%x", h.Sum(nil))
32}
33
34type Todo struct {
35 Content string `json:"content"`
36 Status TodoStatus `json:"status"`
37 ActiveForm string `json:"active_form"`
38}
39
40// HasIncompleteTodos returns true if there are any non-completed todos.
41func HasIncompleteTodos(todos []Todo) bool {
42 for _, todo := range todos {
43 if todo.Status != TodoStatusCompleted {
44 return true
45 }
46 }
47 return false
48}
49
50type Session struct {
51 ID string
52 ParentSessionID string
53 Title string
54 MessageCount int64
55 PromptTokens int64
56 CompletionTokens int64
57 EstimatedUsage bool
58 SummaryMessageID string
59 Cost float64
60 Todos []Todo
61 CreatedAt int64
62 UpdatedAt int64
63}
64
65type Service interface {
66 pubsub.Subscriber[Session]
67 Create(ctx context.Context, title string) (Session, error)
68 CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
69 CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
70 Get(ctx context.Context, id string) (Session, error)
71 GetLast(ctx context.Context) (Session, error)
72 List(ctx context.Context) ([]Session, error)
73 Save(ctx context.Context, session Session) (Session, error)
74 UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
75 Rename(ctx context.Context, id string, title string) error
76 Delete(ctx context.Context, id string) error
77
78 // Agent tool session management
79 CreateAgentToolSessionID(messageID, toolCallID string) string
80 ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
81 IsAgentToolSession(sessionID string) bool
82}
83
84type service struct {
85 *pubsub.Broker[Session]
86 db *sql.DB
87 q *db.Queries
88
89 // Estimated usage stays in memory so fetch-modify-save paths (e.g.,
90 // updating todos or parent-session cost) do not rebuild a session from
91 // SQLite and incorrectly clear the UI "~" marker.
92 estimatedUsageMu sync.RWMutex
93 estimatedUsage map[string]bool
94}
95
96func (s *service) Create(ctx context.Context, title string) (Session, error) {
97 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
98 ID: uuid.New().String(),
99 Title: title,
100 })
101 if err != nil {
102 return Session{}, err
103 }
104 session := s.fromDBItem(dbSession)
105 s.Publish(pubsub.CreatedEvent, session)
106 event.SessionCreated()
107 return session, nil
108}
109
110func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
111 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
112 ID: toolCallID,
113 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
114 Title: title,
115 })
116 if err != nil {
117 return Session{}, err
118 }
119 session := s.fromDBItem(dbSession)
120 s.Publish(pubsub.CreatedEvent, session)
121 return session, nil
122}
123
124func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
125 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
126 ID: "title-" + parentSessionID,
127 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
128 Title: "Generate a title",
129 })
130 if err != nil {
131 return Session{}, err
132 }
133 session := s.fromDBItem(dbSession)
134 s.Publish(pubsub.CreatedEvent, session)
135 return session, nil
136}
137
138func (s *service) Delete(ctx context.Context, id string) error {
139 tx, err := s.db.BeginTx(ctx, nil)
140 if err != nil {
141 return fmt.Errorf("beginning transaction: %w", err)
142 }
143 defer tx.Rollback() //nolint:errcheck
144
145 qtx := s.q.WithTx(tx)
146
147 dbSession, err := qtx.GetSessionByID(ctx, id)
148 if err != nil {
149 return err
150 }
151 if err = qtx.DeleteSessionMessages(ctx, dbSession.ID); err != nil {
152 return fmt.Errorf("deleting session messages: %w", err)
153 }
154 if err = qtx.DeleteSessionFiles(ctx, dbSession.ID); err != nil {
155 return fmt.Errorf("deleting session files: %w", err)
156 }
157 if err = qtx.DeleteSession(ctx, dbSession.ID); err != nil {
158 return fmt.Errorf("deleting session: %w", err)
159 }
160 if err = tx.Commit(); err != nil {
161 return fmt.Errorf("committing transaction: %w", err)
162 }
163
164 session := s.fromDBItem(dbSession)
165 s.clearEstimatedUsageState(dbSession.ID)
166 s.Publish(pubsub.DeletedEvent, session)
167 event.SessionDeleted()
168 return nil
169}
170
171func (s *service) Get(ctx context.Context, id string) (Session, error) {
172 dbSession, err := s.q.GetSessionByID(ctx, id)
173 if err != nil {
174 return Session{}, err
175 }
176 session := s.fromDBItem(dbSession)
177 s.applyEstimatedUsageState(&session)
178 return session, nil
179}
180
181func (s *service) GetLast(ctx context.Context) (Session, error) {
182 dbSession, err := s.q.GetLastSession(ctx)
183 if err != nil {
184 return Session{}, err
185 }
186 session := s.fromDBItem(dbSession)
187 s.applyEstimatedUsageState(&session)
188 return session, nil
189}
190
191func (s *service) Save(ctx context.Context, session Session) (Session, error) {
192 todosJSON, err := marshalTodos(session.Todos)
193 if err != nil {
194 return Session{}, err
195 }
196
197 dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
198 ID: session.ID,
199 Title: session.Title,
200 PromptTokens: session.PromptTokens,
201 CompletionTokens: session.CompletionTokens,
202 SummaryMessageID: sql.NullString{
203 String: session.SummaryMessageID,
204 Valid: session.SummaryMessageID != "",
205 },
206 Cost: session.Cost,
207 Todos: sql.NullString{
208 String: todosJSON,
209 Valid: todosJSON != "",
210 },
211 })
212 if err != nil {
213 return Session{}, err
214 }
215 estimatedUsage := session.EstimatedUsage
216 s.setEstimatedUsageState(session.ID, estimatedUsage)
217 session = s.fromDBItem(dbSession)
218 session.EstimatedUsage = estimatedUsage
219 s.Publish(pubsub.UpdatedEvent, session)
220 return session, nil
221}
222
223// UpdateTitleAndUsage updates only the title and usage fields atomically.
224// This is safer than fetching, modifying, and saving the entire session.
225func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
226 return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
227 ID: sessionID,
228 Title: title,
229 PromptTokens: promptTokens,
230 CompletionTokens: completionTokens,
231 Cost: cost,
232 })
233}
234
235// Rename updates only the title of a session without touching updated_at or
236// usage fields.
237func (s *service) Rename(ctx context.Context, id string, title string) error {
238 return s.q.RenameSession(ctx, db.RenameSessionParams{
239 ID: id,
240 Title: title,
241 })
242}
243
244func (s *service) List(ctx context.Context) ([]Session, error) {
245 dbSessions, err := s.q.ListSessions(ctx)
246 if err != nil {
247 return nil, err
248 }
249 sessions := make([]Session, len(dbSessions))
250 for i, dbSession := range dbSessions {
251 sessions[i] = s.fromDBItem(dbSession)
252 s.applyEstimatedUsageState(&sessions[i])
253 }
254 return sessions, nil
255}
256
257func (s *service) applyEstimatedUsageState(session *Session) {
258 s.estimatedUsageMu.RLock()
259 session.EstimatedUsage = s.estimatedUsage[session.ID]
260 s.estimatedUsageMu.RUnlock()
261}
262
263func (s *service) setEstimatedUsageState(sessionID string, estimatedUsage bool) {
264 s.estimatedUsageMu.Lock()
265 defer s.estimatedUsageMu.Unlock()
266 if estimatedUsage {
267 s.estimatedUsage[sessionID] = true
268 return
269 }
270 delete(s.estimatedUsage, sessionID)
271}
272
273func (s *service) clearEstimatedUsageState(sessionID string) {
274 s.estimatedUsageMu.Lock()
275 delete(s.estimatedUsage, sessionID)
276 s.estimatedUsageMu.Unlock()
277}
278
279func (s *service) fromDBItem(item db.Session) Session {
280 todos, err := unmarshalTodos(item.Todos.String)
281 if err != nil {
282 slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
283 }
284 return Session{
285 ID: item.ID,
286 ParentSessionID: item.ParentSessionID.String,
287 Title: item.Title,
288 MessageCount: item.MessageCount,
289 PromptTokens: item.PromptTokens,
290 CompletionTokens: item.CompletionTokens,
291 SummaryMessageID: item.SummaryMessageID.String,
292 Cost: item.Cost,
293 Todos: todos,
294 CreatedAt: item.CreatedAt,
295 UpdatedAt: item.UpdatedAt,
296 }
297}
298
299func marshalTodos(todos []Todo) (string, error) {
300 if len(todos) == 0 {
301 return "", nil
302 }
303 data, err := json.Marshal(todos)
304 if err != nil {
305 return "", err
306 }
307 return string(data), nil
308}
309
310func unmarshalTodos(data string) ([]Todo, error) {
311 if data == "" {
312 return []Todo{}, nil
313 }
314 var todos []Todo
315 if err := json.Unmarshal([]byte(data), &todos); err != nil {
316 return []Todo{}, err
317 }
318 return todos, nil
319}
320
321func NewService(q *db.Queries, conn *sql.DB) Service {
322 broker := pubsub.NewBroker[Session]()
323 return &service{
324 Broker: broker,
325 db: conn,
326 q: q,
327 estimatedUsage: make(map[string]bool),
328 }
329}
330
331// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
332func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
333 return fmt.Sprintf("%s$$%s", messageID, toolCallID)
334}
335
336// ParseAgentToolSessionID parses an agent tool session ID into its components
337func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
338 parts := strings.Split(sessionID, "$$")
339 if len(parts) != 2 {
340 return "", "", false
341 }
342 return parts[0], parts[1], true
343}
344
345// IsAgentToolSession checks if a session ID follows the agent tool session format
346func (s *service) IsAgentToolSession(sessionID string) bool {
347 _, _, ok := s.ParseAgentToolSessionID(sessionID)
348 return ok
349}