1package session
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "strings"
10
11 "github.com/charmbracelet/crush/internal/db"
12 "github.com/charmbracelet/crush/internal/event"
13 "github.com/charmbracelet/crush/internal/pubsub"
14 "github.com/google/uuid"
15 "github.com/zeebo/xxh3"
16)
17
18type TodoStatus string
19
20const (
21 TodoStatusPending TodoStatus = "pending"
22 TodoStatusInProgress TodoStatus = "in_progress"
23 TodoStatusCompleted TodoStatus = "completed"
24)
25
26// HashID returns the XXH3 hash of a session ID (UUID) as a hex string.
27func HashID(id string) string {
28 h := xxh3.New()
29 h.WriteString(id)
30 return fmt.Sprintf("%x", h.Sum(nil))
31}
32
33type Todo struct {
34 Content string `json:"content"`
35 Status TodoStatus `json:"status"`
36 ActiveForm string `json:"active_form"`
37}
38
39// HasIncompleteTodos returns true if there are any non-completed todos.
40func HasIncompleteTodos(todos []Todo) bool {
41 for _, todo := range todos {
42 if todo.Status != TodoStatusCompleted {
43 return true
44 }
45 }
46 return false
47}
48
49type Session struct {
50 ID string
51 ParentSessionID string
52 Title string
53 MessageCount int64
54 PromptTokens int64
55 CompletionTokens int64
56 EstimatedUsage bool
57 SummaryMessageID string
58 Cost float64
59 Todos []Todo
60 CreatedAt int64
61 UpdatedAt int64
62}
63
64type Service interface {
65 pubsub.Subscriber[Session]
66 Create(ctx context.Context, title string) (Session, error)
67 CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
68 CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
69 Get(ctx context.Context, id string) (Session, error)
70 GetLast(ctx context.Context) (Session, error)
71 List(ctx context.Context) ([]Session, error)
72 Save(ctx context.Context, session Session) (Session, error)
73 UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
74 Rename(ctx context.Context, id string, title string) error
75 Delete(ctx context.Context, id string) error
76
77 // Agent tool session management
78 CreateAgentToolSessionID(messageID, toolCallID string) string
79 ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
80 IsAgentToolSession(sessionID string) bool
81}
82
83type service struct {
84 *pubsub.Broker[Session]
85 db *sql.DB
86 q *db.Queries
87}
88
89func (s *service) Create(ctx context.Context, title string) (Session, error) {
90 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
91 ID: uuid.New().String(),
92 Title: title,
93 })
94 if err != nil {
95 return Session{}, err
96 }
97 session := s.fromDBItem(dbSession)
98 s.Publish(pubsub.CreatedEvent, session)
99 event.SessionCreated()
100 return session, nil
101}
102
103func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
104 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
105 ID: toolCallID,
106 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
107 Title: title,
108 })
109 if err != nil {
110 return Session{}, err
111 }
112 session := s.fromDBItem(dbSession)
113 s.Publish(pubsub.CreatedEvent, session)
114 return session, nil
115}
116
117func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
118 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
119 ID: "title-" + parentSessionID,
120 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
121 Title: "Generate a title",
122 })
123 if err != nil {
124 return Session{}, err
125 }
126 session := s.fromDBItem(dbSession)
127 s.Publish(pubsub.CreatedEvent, session)
128 return session, nil
129}
130
131func (s *service) Delete(ctx context.Context, id string) error {
132 tx, err := s.db.BeginTx(ctx, nil)
133 if err != nil {
134 return fmt.Errorf("beginning transaction: %w", err)
135 }
136 defer tx.Rollback() //nolint:errcheck
137
138 qtx := s.q.WithTx(tx)
139
140 dbSession, err := qtx.GetSessionByID(ctx, id)
141 if err != nil {
142 return err
143 }
144 if err = qtx.DeleteSessionMessages(ctx, dbSession.ID); err != nil {
145 return fmt.Errorf("deleting session messages: %w", err)
146 }
147 if err = qtx.DeleteSessionFiles(ctx, dbSession.ID); err != nil {
148 return fmt.Errorf("deleting session files: %w", err)
149 }
150 if err = qtx.DeleteSession(ctx, dbSession.ID); err != nil {
151 return fmt.Errorf("deleting session: %w", err)
152 }
153 if err = tx.Commit(); err != nil {
154 return fmt.Errorf("committing transaction: %w", err)
155 }
156
157 session := s.fromDBItem(dbSession)
158 s.Publish(pubsub.DeletedEvent, session)
159 event.SessionDeleted()
160 return nil
161}
162
163func (s *service) Get(ctx context.Context, id string) (Session, error) {
164 dbSession, err := s.q.GetSessionByID(ctx, id)
165 if err != nil {
166 return Session{}, err
167 }
168 return s.fromDBItem(dbSession), nil
169}
170
171func (s *service) GetLast(ctx context.Context) (Session, error) {
172 dbSession, err := s.q.GetLastSession(ctx)
173 if err != nil {
174 return Session{}, err
175 }
176 return s.fromDBItem(dbSession), nil
177}
178
179func (s *service) Save(ctx context.Context, session Session) (Session, error) {
180 todosJSON, err := marshalTodos(session.Todos)
181 if err != nil {
182 return Session{}, err
183 }
184
185 dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
186 ID: session.ID,
187 Title: session.Title,
188 PromptTokens: session.PromptTokens,
189 CompletionTokens: session.CompletionTokens,
190 SummaryMessageID: sql.NullString{
191 String: session.SummaryMessageID,
192 Valid: session.SummaryMessageID != "",
193 },
194 Cost: session.Cost,
195 Todos: sql.NullString{
196 String: todosJSON,
197 Valid: todosJSON != "",
198 },
199 })
200 if err != nil {
201 return Session{}, err
202 }
203 estimatedUsage := session.EstimatedUsage
204 session = s.fromDBItem(dbSession)
205 session.EstimatedUsage = estimatedUsage
206 s.Publish(pubsub.UpdatedEvent, session)
207 return session, nil
208}
209
210// UpdateTitleAndUsage updates only the title and usage fields atomically.
211// This is safer than fetching, modifying, and saving the entire session.
212func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
213 return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
214 ID: sessionID,
215 Title: title,
216 PromptTokens: promptTokens,
217 CompletionTokens: completionTokens,
218 Cost: cost,
219 })
220}
221
222// Rename updates only the title of a session without touching updated_at or
223// usage fields.
224func (s *service) Rename(ctx context.Context, id string, title string) error {
225 return s.q.RenameSession(ctx, db.RenameSessionParams{
226 ID: id,
227 Title: title,
228 })
229}
230
231func (s *service) List(ctx context.Context) ([]Session, error) {
232 dbSessions, err := s.q.ListSessions(ctx)
233 if err != nil {
234 return nil, err
235 }
236 sessions := make([]Session, len(dbSessions))
237 for i, dbSession := range dbSessions {
238 sessions[i] = s.fromDBItem(dbSession)
239 }
240 return sessions, nil
241}
242
243func (s service) fromDBItem(item db.Session) Session {
244 todos, err := unmarshalTodos(item.Todos.String)
245 if err != nil {
246 slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
247 }
248 return Session{
249 ID: item.ID,
250 ParentSessionID: item.ParentSessionID.String,
251 Title: item.Title,
252 MessageCount: item.MessageCount,
253 PromptTokens: item.PromptTokens,
254 CompletionTokens: item.CompletionTokens,
255 SummaryMessageID: item.SummaryMessageID.String,
256 Cost: item.Cost,
257 Todos: todos,
258 CreatedAt: item.CreatedAt,
259 UpdatedAt: item.UpdatedAt,
260 }
261}
262
263func marshalTodos(todos []Todo) (string, error) {
264 if len(todos) == 0 {
265 return "", nil
266 }
267 data, err := json.Marshal(todos)
268 if err != nil {
269 return "", err
270 }
271 return string(data), nil
272}
273
274func unmarshalTodos(data string) ([]Todo, error) {
275 if data == "" {
276 return []Todo{}, nil
277 }
278 var todos []Todo
279 if err := json.Unmarshal([]byte(data), &todos); err != nil {
280 return []Todo{}, err
281 }
282 return todos, nil
283}
284
285func NewService(q *db.Queries, conn *sql.DB) Service {
286 broker := pubsub.NewBroker[Session]()
287 return &service{
288 Broker: broker,
289 db: conn,
290 q: q,
291 }
292}
293
294// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
295func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
296 return fmt.Sprintf("%s$$%s", messageID, toolCallID)
297}
298
299// ParseAgentToolSessionID parses an agent tool session ID into its components
300func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
301 parts := strings.Split(sessionID, "$$")
302 if len(parts) != 2 {
303 return "", "", false
304 }
305 return parts[0], parts[1], true
306}
307
308// IsAgentToolSession checks if a session ID follows the agent tool session format
309func (s *service) IsAgentToolSession(sessionID string) bool {
310 _, _, ok := s.ParseAgentToolSessionID(sessionID)
311 return ok
312}