1package session
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "strings"
10
11 "github.com/charmbracelet/crush/internal/db"
12 "github.com/charmbracelet/crush/internal/event"
13 "github.com/charmbracelet/crush/internal/pubsub"
14 "github.com/google/uuid"
15 "github.com/zeebo/xxh3"
16)
17
18type TodoStatus string
19
20const (
21 TodoStatusPending TodoStatus = "pending"
22 TodoStatusInProgress TodoStatus = "in_progress"
23 TodoStatusCompleted TodoStatus = "completed"
24)
25
26// HashID returns the XXH3 hash of a session ID (UUID) as a hex string.
27func HashID(id string) string {
28 h := xxh3.New()
29 h.WriteString(id)
30 return fmt.Sprintf("%x", h.Sum(nil))
31}
32
33type Todo struct {
34 Content string `json:"content"`
35 Status TodoStatus `json:"status"`
36 ActiveForm string `json:"active_form"`
37}
38
39// HasIncompleteTodos returns true if there are any non-completed todos.
40func HasIncompleteTodos(todos []Todo) bool {
41 for _, todo := range todos {
42 if todo.Status != TodoStatusCompleted {
43 return true
44 }
45 }
46 return false
47}
48
49type Session struct {
50 ID string
51 ParentSessionID string
52 Title string
53 MessageCount int64
54 PromptTokens int64
55 CompletionTokens int64
56 SummaryMessageID string
57 Cost float64
58 Todos []Todo
59 CreatedAt int64
60 UpdatedAt int64
61}
62
63type Service interface {
64 pubsub.Subscriber[Session]
65 Create(ctx context.Context, title string) (Session, error)
66 CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
67 CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
68 Get(ctx context.Context, id string) (Session, error)
69 GetLast(ctx context.Context) (Session, error)
70 List(ctx context.Context) ([]Session, error)
71 Save(ctx context.Context, session Session) (Session, error)
72 UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
73 Rename(ctx context.Context, id string, title string) error
74 Delete(ctx context.Context, id string) error
75
76 // Agent tool session management
77 CreateAgentToolSessionID(messageID, toolCallID string) string
78 ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
79 IsAgentToolSession(sessionID string) bool
80}
81
82type service struct {
83 *pubsub.Broker[Session]
84 db *sql.DB
85 q *db.Queries
86}
87
88func (s *service) Create(ctx context.Context, title string) (Session, error) {
89 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
90 ID: uuid.New().String(),
91 Title: title,
92 })
93 if err != nil {
94 return Session{}, err
95 }
96 session := s.fromDBItem(dbSession)
97 s.Publish(pubsub.CreatedEvent, session)
98 event.SessionCreated()
99 return session, nil
100}
101
102func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
103 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
104 ID: toolCallID,
105 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
106 Title: title,
107 })
108 if err != nil {
109 return Session{}, err
110 }
111 session := s.fromDBItem(dbSession)
112 s.Publish(pubsub.CreatedEvent, session)
113 return session, nil
114}
115
116func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
117 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
118 ID: "title-" + parentSessionID,
119 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
120 Title: "Generate a title",
121 })
122 if err != nil {
123 return Session{}, err
124 }
125 session := s.fromDBItem(dbSession)
126 s.Publish(pubsub.CreatedEvent, session)
127 return session, nil
128}
129
130func (s *service) Delete(ctx context.Context, id string) error {
131 tx, err := s.db.BeginTx(ctx, nil)
132 if err != nil {
133 return fmt.Errorf("beginning transaction: %w", err)
134 }
135 defer tx.Rollback() //nolint:errcheck
136
137 qtx := s.q.WithTx(tx)
138
139 dbSession, err := qtx.GetSessionByID(ctx, id)
140 if err != nil {
141 return err
142 }
143 if err = qtx.DeleteSessionMessages(ctx, dbSession.ID); err != nil {
144 return fmt.Errorf("deleting session messages: %w", err)
145 }
146 if err = qtx.DeleteSessionFiles(ctx, dbSession.ID); err != nil {
147 return fmt.Errorf("deleting session files: %w", err)
148 }
149 if err = qtx.DeleteSession(ctx, dbSession.ID); err != nil {
150 return fmt.Errorf("deleting session: %w", err)
151 }
152 if err = tx.Commit(); err != nil {
153 return fmt.Errorf("committing transaction: %w", err)
154 }
155
156 session := s.fromDBItem(dbSession)
157 s.Publish(pubsub.DeletedEvent, session)
158 event.SessionDeleted()
159 return nil
160}
161
162func (s *service) Get(ctx context.Context, id string) (Session, error) {
163 dbSession, err := s.q.GetSessionByID(ctx, id)
164 if err != nil {
165 return Session{}, err
166 }
167 return s.fromDBItem(dbSession), nil
168}
169
170func (s *service) GetLast(ctx context.Context) (Session, error) {
171 dbSession, err := s.q.GetLastSession(ctx)
172 if err != nil {
173 return Session{}, err
174 }
175 return s.fromDBItem(dbSession), nil
176}
177
178func (s *service) Save(ctx context.Context, session Session) (Session, error) {
179 todosJSON, err := marshalTodos(session.Todos)
180 if err != nil {
181 return Session{}, err
182 }
183
184 dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
185 ID: session.ID,
186 Title: session.Title,
187 PromptTokens: session.PromptTokens,
188 CompletionTokens: session.CompletionTokens,
189 SummaryMessageID: sql.NullString{
190 String: session.SummaryMessageID,
191 Valid: session.SummaryMessageID != "",
192 },
193 Cost: session.Cost,
194 Todos: sql.NullString{
195 String: todosJSON,
196 Valid: todosJSON != "",
197 },
198 })
199 if err != nil {
200 return Session{}, err
201 }
202 session = s.fromDBItem(dbSession)
203 s.Publish(pubsub.UpdatedEvent, session)
204 return session, nil
205}
206
207// UpdateTitleAndUsage updates only the title and usage fields atomically.
208// This is safer than fetching, modifying, and saving the entire session.
209func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
210 return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
211 ID: sessionID,
212 Title: title,
213 PromptTokens: promptTokens,
214 CompletionTokens: completionTokens,
215 Cost: cost,
216 })
217}
218
219// Rename updates only the title of a session without touching updated_at or
220// usage fields.
221func (s *service) Rename(ctx context.Context, id string, title string) error {
222 return s.q.RenameSession(ctx, db.RenameSessionParams{
223 ID: id,
224 Title: title,
225 })
226}
227
228func (s *service) List(ctx context.Context) ([]Session, error) {
229 dbSessions, err := s.q.ListSessions(ctx)
230 if err != nil {
231 return nil, err
232 }
233 sessions := make([]Session, len(dbSessions))
234 for i, dbSession := range dbSessions {
235 sessions[i] = s.fromDBItem(dbSession)
236 }
237 return sessions, nil
238}
239
240func (s service) fromDBItem(item db.Session) Session {
241 todos, err := unmarshalTodos(item.Todos.String)
242 if err != nil {
243 slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
244 }
245 return Session{
246 ID: item.ID,
247 ParentSessionID: item.ParentSessionID.String,
248 Title: item.Title,
249 MessageCount: item.MessageCount,
250 PromptTokens: item.PromptTokens,
251 CompletionTokens: item.CompletionTokens,
252 SummaryMessageID: item.SummaryMessageID.String,
253 Cost: item.Cost,
254 Todos: todos,
255 CreatedAt: item.CreatedAt,
256 UpdatedAt: item.UpdatedAt,
257 }
258}
259
260func marshalTodos(todos []Todo) (string, error) {
261 if len(todos) == 0 {
262 return "", nil
263 }
264 data, err := json.Marshal(todos)
265 if err != nil {
266 return "", err
267 }
268 return string(data), nil
269}
270
271func unmarshalTodos(data string) ([]Todo, error) {
272 if data == "" {
273 return []Todo{}, nil
274 }
275 var todos []Todo
276 if err := json.Unmarshal([]byte(data), &todos); err != nil {
277 return []Todo{}, err
278 }
279 return todos, nil
280}
281
282func NewService(q *db.Queries, conn *sql.DB) Service {
283 broker := pubsub.NewBroker[Session]()
284 return &service{
285 Broker: broker,
286 db: conn,
287 q: q,
288 }
289}
290
291// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
292func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
293 return fmt.Sprintf("%s$$%s", messageID, toolCallID)
294}
295
296// ParseAgentToolSessionID parses an agent tool session ID into its components
297func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
298 parts := strings.Split(sessionID, "$$")
299 if len(parts) != 2 {
300 return "", "", false
301 }
302 return parts[0], parts[1], true
303}
304
305// IsAgentToolSession checks if a session ID follows the agent tool session format
306func (s *service) IsAgentToolSession(sessionID string) bool {
307 _, _, ok := s.ParseAgentToolSessionID(sessionID)
308 return ok
309}