1package session
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "strings"
10
11 "github.com/charmbracelet/crush/internal/db"
12 "github.com/charmbracelet/crush/internal/event"
13 "github.com/charmbracelet/crush/internal/pubsub"
14 "github.com/google/uuid"
15 "github.com/zeebo/xxh3"
16)
17
18type TodoStatus string
19
20const (
21 TodoStatusPending TodoStatus = "pending"
22 TodoStatusInProgress TodoStatus = "in_progress"
23 TodoStatusCompleted TodoStatus = "completed"
24)
25
26// HashID returns the XXH3 hash of a session ID (UUID) as a hex string.
27func HashID(id string) string {
28 h := xxh3.New()
29 h.WriteString(id)
30 return fmt.Sprintf("%x", h.Sum(nil))
31}
32
33type Todo struct {
34 Content string `json:"content"`
35 Status TodoStatus `json:"status"`
36 ActiveForm string `json:"active_form"`
37}
38
39// HasIncompleteTodos returns true if there are any non-completed todos.
40func HasIncompleteTodos(todos []Todo) bool {
41 for _, todo := range todos {
42 if todo.Status != TodoStatusCompleted {
43 return true
44 }
45 }
46 return false
47}
48
49type Session struct {
50 ID string
51 ParentSessionID string
52 Title string
53 MessageCount int64
54 PromptTokens int64
55 CompletionTokens int64
56 SummaryMessageID string
57 Cost float64
58 Todos []Todo
59 CreatedAt int64
60 UpdatedAt int64
61}
62
63type Service interface {
64 pubsub.Subscriber[Session]
65 Create(ctx context.Context, title string) (Session, error)
66 CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
67 CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
68 Get(ctx context.Context, id string) (Session, error)
69 List(ctx context.Context) ([]Session, error)
70 Save(ctx context.Context, session Session) (Session, error)
71 UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
72 Rename(ctx context.Context, id string, title string) error
73 Delete(ctx context.Context, id string) error
74
75 // Agent tool session management
76 CreateAgentToolSessionID(messageID, toolCallID string) string
77 ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
78 IsAgentToolSession(sessionID string) bool
79}
80
81type service struct {
82 *pubsub.Broker[Session]
83 db *sql.DB
84 q *db.Queries
85}
86
87func (s *service) Create(ctx context.Context, title string) (Session, error) {
88 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
89 ID: uuid.New().String(),
90 Title: title,
91 })
92 if err != nil {
93 return Session{}, err
94 }
95 session := s.fromDBItem(dbSession)
96 s.Publish(pubsub.CreatedEvent, session)
97 event.SessionCreated()
98 return session, nil
99}
100
101func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
102 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
103 ID: toolCallID,
104 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
105 Title: title,
106 })
107 if err != nil {
108 return Session{}, err
109 }
110 session := s.fromDBItem(dbSession)
111 s.Publish(pubsub.CreatedEvent, session)
112 return session, nil
113}
114
115func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
116 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
117 ID: "title-" + parentSessionID,
118 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
119 Title: "Generate a title",
120 })
121 if err != nil {
122 return Session{}, err
123 }
124 session := s.fromDBItem(dbSession)
125 s.Publish(pubsub.CreatedEvent, session)
126 return session, nil
127}
128
129func (s *service) Delete(ctx context.Context, id string) error {
130 tx, err := s.db.BeginTx(ctx, nil)
131 if err != nil {
132 return fmt.Errorf("beginning transaction: %w", err)
133 }
134 defer tx.Rollback() //nolint:errcheck
135
136 qtx := s.q.WithTx(tx)
137
138 dbSession, err := qtx.GetSessionByID(ctx, id)
139 if err != nil {
140 return err
141 }
142 if err = qtx.DeleteSessionMessages(ctx, dbSession.ID); err != nil {
143 return fmt.Errorf("deleting session messages: %w", err)
144 }
145 if err = qtx.DeleteSessionFiles(ctx, dbSession.ID); err != nil {
146 return fmt.Errorf("deleting session files: %w", err)
147 }
148 if err = qtx.DeleteSession(ctx, dbSession.ID); err != nil {
149 return fmt.Errorf("deleting session: %w", err)
150 }
151 if err = tx.Commit(); err != nil {
152 return fmt.Errorf("committing transaction: %w", err)
153 }
154
155 session := s.fromDBItem(dbSession)
156 s.Publish(pubsub.DeletedEvent, session)
157 event.SessionDeleted()
158 return nil
159}
160
161func (s *service) Get(ctx context.Context, id string) (Session, error) {
162 dbSession, err := s.q.GetSessionByID(ctx, id)
163 if err != nil {
164 return Session{}, err
165 }
166 return s.fromDBItem(dbSession), nil
167}
168
169func (s *service) Save(ctx context.Context, session Session) (Session, error) {
170 todosJSON, err := marshalTodos(session.Todos)
171 if err != nil {
172 return Session{}, err
173 }
174
175 dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
176 ID: session.ID,
177 Title: session.Title,
178 PromptTokens: session.PromptTokens,
179 CompletionTokens: session.CompletionTokens,
180 SummaryMessageID: sql.NullString{
181 String: session.SummaryMessageID,
182 Valid: session.SummaryMessageID != "",
183 },
184 Cost: session.Cost,
185 Todos: sql.NullString{
186 String: todosJSON,
187 Valid: todosJSON != "",
188 },
189 })
190 if err != nil {
191 return Session{}, err
192 }
193 session = s.fromDBItem(dbSession)
194 s.Publish(pubsub.UpdatedEvent, session)
195 return session, nil
196}
197
198// UpdateTitleAndUsage updates only the title and usage fields atomically.
199// This is safer than fetching, modifying, and saving the entire session.
200func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
201 return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
202 ID: sessionID,
203 Title: title,
204 PromptTokens: promptTokens,
205 CompletionTokens: completionTokens,
206 Cost: cost,
207 })
208}
209
210// Rename updates only the title of a session without touching updated_at or
211// usage fields.
212func (s *service) Rename(ctx context.Context, id string, title string) error {
213 return s.q.RenameSession(ctx, db.RenameSessionParams{
214 ID: id,
215 Title: title,
216 })
217}
218
219func (s *service) List(ctx context.Context) ([]Session, error) {
220 dbSessions, err := s.q.ListSessions(ctx)
221 if err != nil {
222 return nil, err
223 }
224 sessions := make([]Session, len(dbSessions))
225 for i, dbSession := range dbSessions {
226 sessions[i] = s.fromDBItem(dbSession)
227 }
228 return sessions, nil
229}
230
231func (s service) fromDBItem(item db.Session) Session {
232 todos, err := unmarshalTodos(item.Todos.String)
233 if err != nil {
234 slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
235 }
236 return Session{
237 ID: item.ID,
238 ParentSessionID: item.ParentSessionID.String,
239 Title: item.Title,
240 MessageCount: item.MessageCount,
241 PromptTokens: item.PromptTokens,
242 CompletionTokens: item.CompletionTokens,
243 SummaryMessageID: item.SummaryMessageID.String,
244 Cost: item.Cost,
245 Todos: todos,
246 CreatedAt: item.CreatedAt,
247 UpdatedAt: item.UpdatedAt,
248 }
249}
250
251func marshalTodos(todos []Todo) (string, error) {
252 if len(todos) == 0 {
253 return "", nil
254 }
255 data, err := json.Marshal(todos)
256 if err != nil {
257 return "", err
258 }
259 return string(data), nil
260}
261
262func unmarshalTodos(data string) ([]Todo, error) {
263 if data == "" {
264 return []Todo{}, nil
265 }
266 var todos []Todo
267 if err := json.Unmarshal([]byte(data), &todos); err != nil {
268 return []Todo{}, err
269 }
270 return todos, nil
271}
272
273func NewService(q *db.Queries, conn *sql.DB) Service {
274 broker := pubsub.NewBroker[Session]()
275 return &service{
276 Broker: broker,
277 db: conn,
278 q: q,
279 }
280}
281
282// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
283func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
284 return fmt.Sprintf("%s$$%s", messageID, toolCallID)
285}
286
287// ParseAgentToolSessionID parses an agent tool session ID into its components
288func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
289 parts := strings.Split(sessionID, "$$")
290 if len(parts) != 2 {
291 return "", "", false
292 }
293 return parts[0], parts[1], true
294}
295
296// IsAgentToolSession checks if a session ID follows the agent tool session format
297func (s *service) IsAgentToolSession(sessionID string) bool {
298 _, _, ok := s.ParseAgentToolSessionID(sessionID)
299 return ok
300}