1package session
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "strings"
10
11 "github.com/charmbracelet/crush/internal/config"
12 "github.com/charmbracelet/crush/internal/db"
13 "github.com/charmbracelet/crush/internal/event"
14 "github.com/charmbracelet/crush/internal/pubsub"
15 "github.com/google/uuid"
16 "github.com/zeebo/xxh3"
17)
18
19type TodoStatus string
20
21const (
22 TodoStatusPending TodoStatus = "pending"
23 TodoStatusInProgress TodoStatus = "in_progress"
24 TodoStatusCompleted TodoStatus = "completed"
25)
26
27// HashID returns the XXH3 hash of a session ID (UUID) as a hex string.
28func HashID(id string) string {
29 h := xxh3.New()
30 h.WriteString(id)
31 return fmt.Sprintf("%x", h.Sum(nil))
32}
33
34type Todo struct {
35 Content string `json:"content"`
36 Status TodoStatus `json:"status"`
37 ActiveForm string `json:"active_form"`
38}
39
40// HasIncompleteTodos returns true if there are any non-completed todos.
41func HasIncompleteTodos(todos []Todo) bool {
42 for _, todo := range todos {
43 if todo.Status != TodoStatusCompleted {
44 return true
45 }
46 }
47 return false
48}
49
50type Session struct {
51 ID string
52 ParentSessionID string
53 Title string
54 MessageCount int64
55 PromptTokens int64
56 CompletionTokens int64
57 SummaryMessageID string
58 Cost float64
59 Todos []Todo
60 Models map[config.SelectedModelType]config.SelectedModel
61 CreatedAt int64
62 UpdatedAt int64
63}
64
65type Service interface {
66 pubsub.Subscriber[Session]
67 Create(ctx context.Context, title string) (Session, error)
68 CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
69 CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
70 Get(ctx context.Context, id string) (Session, error)
71 GetLast(ctx context.Context) (Session, error)
72 List(ctx context.Context) ([]Session, error)
73 Save(ctx context.Context, session Session) (Session, error)
74 SaveWithModels(ctx context.Context, session Session, models map[config.SelectedModelType]config.SelectedModel) (Session, error)
75 UpdateSessionModels(ctx context.Context, id string, models map[config.SelectedModelType]config.SelectedModel) error
76 UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
77 Rename(ctx context.Context, id string, title string) error
78 Delete(ctx context.Context, id string) error
79
80 // Agent tool session management
81 CreateAgentToolSessionID(messageID, toolCallID string) string
82 ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
83 IsAgentToolSession(sessionID string) bool
84}
85
86type service struct {
87 *pubsub.Broker[Session]
88 db *sql.DB
89 q *db.Queries
90}
91
92func (s *service) Create(ctx context.Context, title string) (Session, error) {
93 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
94 ID: uuid.New().String(),
95 Title: title,
96 })
97 if err != nil {
98 return Session{}, err
99 }
100 session := s.fromDBItem(dbSession)
101 s.Publish(pubsub.CreatedEvent, session)
102 event.SessionCreated()
103 return session, nil
104}
105
106func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
107 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
108 ID: toolCallID,
109 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
110 Title: title,
111 })
112 if err != nil {
113 return Session{}, err
114 }
115 session := s.fromDBItem(dbSession)
116 s.Publish(pubsub.CreatedEvent, session)
117 return session, nil
118}
119
120func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
121 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
122 ID: "title-" + parentSessionID,
123 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
124 Title: "Generate a title",
125 })
126 if err != nil {
127 return Session{}, err
128 }
129 session := s.fromDBItem(dbSession)
130 s.Publish(pubsub.CreatedEvent, session)
131 return session, nil
132}
133
134func (s *service) Delete(ctx context.Context, id string) error {
135 tx, err := s.db.BeginTx(ctx, nil)
136 if err != nil {
137 return fmt.Errorf("beginning transaction: %w", err)
138 }
139 defer tx.Rollback() //nolint:errcheck
140
141 qtx := s.q.WithTx(tx)
142
143 dbSession, err := qtx.GetSessionByID(ctx, id)
144 if err != nil {
145 return err
146 }
147 if err = qtx.DeleteSessionMessages(ctx, dbSession.ID); err != nil {
148 return fmt.Errorf("deleting session messages: %w", err)
149 }
150 if err = qtx.DeleteSessionFiles(ctx, dbSession.ID); err != nil {
151 return fmt.Errorf("deleting session files: %w", err)
152 }
153 if err = qtx.DeleteSession(ctx, dbSession.ID); err != nil {
154 return fmt.Errorf("deleting session: %w", err)
155 }
156 if err = tx.Commit(); err != nil {
157 return fmt.Errorf("committing transaction: %w", err)
158 }
159
160 session := s.fromDBItem(dbSession)
161 s.Publish(pubsub.DeletedEvent, session)
162 event.SessionDeleted()
163 return nil
164}
165
166func (s *service) Get(ctx context.Context, id string) (Session, error) {
167 dbSession, err := s.q.GetSessionByID(ctx, id)
168 if err != nil {
169 return Session{}, err
170 }
171 return s.fromDBItem(dbSession), nil
172}
173
174func (s *service) GetLast(ctx context.Context) (Session, error) {
175 dbSession, err := s.q.GetLastSession(ctx)
176 if err != nil {
177 return Session{}, err
178 }
179 return s.fromDBItem(dbSession), nil
180}
181
182func (s *service) Save(ctx context.Context, session Session) (Session, error) {
183 todosJSON, err := marshalTodos(session.Todos)
184 if err != nil {
185 return Session{}, err
186 }
187
188 dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
189 ID: session.ID,
190 Title: session.Title,
191 PromptTokens: session.PromptTokens,
192 CompletionTokens: session.CompletionTokens,
193 SummaryMessageID: sql.NullString{
194 String: session.SummaryMessageID,
195 Valid: session.SummaryMessageID != "",
196 },
197 Cost: session.Cost,
198 Todos: sql.NullString{
199 String: todosJSON,
200 Valid: todosJSON != "",
201 },
202 })
203 if err != nil {
204 return Session{}, err
205 }
206 session = s.fromDBItem(dbSession)
207 s.Publish(pubsub.UpdatedEvent, session)
208 return session, nil
209}
210
211// UpdateTitleAndUsage updates only the title and usage fields atomically.
212// This is safer than fetching, modifying, and saving the entire session.
213func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
214 return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
215 ID: sessionID,
216 Title: title,
217 PromptTokens: promptTokens,
218 CompletionTokens: completionTokens,
219 Cost: cost,
220 })
221}
222
223// Rename updates only the title of a session without touching updated_at or
224// usage fields.
225func (s *service) Rename(ctx context.Context, id string, title string) error {
226 return s.q.RenameSession(ctx, db.RenameSessionParams{
227 ID: id,
228 Title: title,
229 })
230}
231
232func (s *service) List(ctx context.Context) ([]Session, error) {
233 dbSessions, err := s.q.ListSessions(ctx)
234 if err != nil {
235 return nil, err
236 }
237 sessions := make([]Session, len(dbSessions))
238 for i, dbSession := range dbSessions {
239 sessions[i] = s.fromDBItem(dbSession)
240 }
241 return sessions, nil
242}
243
244func (s service) fromDBItem(item db.Session) Session {
245 todos, err := unmarshalTodos(item.Todos.String)
246 if err != nil {
247 slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
248 }
249 models, err := unmarshalModels(item.Models.String)
250 if err != nil {
251 slog.Error("Failed to unmarshal models", "session_id", item.ID, "error", err)
252 }
253 return Session{
254 ID: item.ID,
255 ParentSessionID: item.ParentSessionID.String,
256 Title: item.Title,
257 MessageCount: item.MessageCount,
258 PromptTokens: item.PromptTokens,
259 CompletionTokens: item.CompletionTokens,
260 SummaryMessageID: item.SummaryMessageID.String,
261 Cost: item.Cost,
262 Todos: todos,
263 Models: models,
264 CreatedAt: item.CreatedAt,
265 UpdatedAt: item.UpdatedAt,
266 }
267}
268
269func marshalTodos(todos []Todo) (string, error) {
270 if len(todos) == 0 {
271 return "", nil
272 }
273 data, err := json.Marshal(todos)
274 if err != nil {
275 return "", err
276 }
277 return string(data), nil
278}
279
280func unmarshalTodos(data string) ([]Todo, error) {
281 if data == "" {
282 return []Todo{}, nil
283 }
284 var todos []Todo
285 if err := json.Unmarshal([]byte(data), &todos); err != nil {
286 return []Todo{}, err
287 }
288 return todos, nil
289}
290
291func marshalModels(models map[config.SelectedModelType]config.SelectedModel) (string, error) {
292 if len(models) == 0 {
293 return "", nil
294 }
295 data, err := json.Marshal(models)
296 if err != nil {
297 return "", err
298 }
299 return string(data), nil
300}
301
302func unmarshalModels(data string) (map[config.SelectedModelType]config.SelectedModel, error) {
303 if data == "" {
304 return nil, nil
305 }
306 var models map[config.SelectedModelType]config.SelectedModel
307 if err := json.Unmarshal([]byte(data), &models); err != nil {
308 return nil, err
309 }
310 return models, nil
311}
312
313func (s *service) UpdateSessionModels(ctx context.Context, id string, models map[config.SelectedModelType]config.SelectedModel) error {
314 modelsJSON, err := marshalModels(models)
315 if err != nil {
316 return fmt.Errorf("failed to marshal models: %w", err)
317 }
318 _, err = s.q.UpdateSessionModels(ctx, db.UpdateSessionModelsParams{
319 Models: sql.NullString{String: modelsJSON, Valid: modelsJSON != ""},
320 ID: id,
321 })
322 return err
323}
324
325// SaveWithModels saves the session and then persists the models column as a
326// second operation. This is intentionally non-atomic: if the models update
327// fails, the session fields are still saved (which is equivalent to the
328// pre-feature behavior where models were never persisted). The next agent turn
329// will retry the models write, so transient failures are self-healing.
330func (s *service) SaveWithModels(ctx context.Context, session Session, models map[config.SelectedModelType]config.SelectedModel) (Session, error) {
331 saved, err := s.Save(ctx, session)
332 if err != nil {
333 return Session{}, err
334 }
335 if err := s.UpdateSessionModels(ctx, session.ID, models); err != nil {
336 return Session{}, fmt.Errorf("failed to persist models: %w", err)
337 }
338 return saved, nil
339}
340
341func NewService(q *db.Queries, conn *sql.DB) Service {
342 broker := pubsub.NewBroker[Session]()
343 return &service{
344 Broker: broker,
345 db: conn,
346 q: q,
347 }
348}
349
350// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
351func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
352 return fmt.Sprintf("%s$$%s", messageID, toolCallID)
353}
354
355// ParseAgentToolSessionID parses an agent tool session ID into its components
356func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
357 parts := strings.Split(sessionID, "$$")
358 if len(parts) != 2 {
359 return "", "", false
360 }
361 return parts[0], parts[1], true
362}
363
364// IsAgentToolSession checks if a session ID follows the agent tool session format
365func (s *service) IsAgentToolSession(sessionID string) bool {
366 _, _, ok := s.ParseAgentToolSessionID(sessionID)
367 return ok
368}