1package session
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "strings"
10
11 "github.com/charmbracelet/crush/internal/db"
12 "github.com/charmbracelet/crush/internal/event"
13 "github.com/charmbracelet/crush/internal/message"
14 "github.com/charmbracelet/crush/internal/pubsub"
15 "github.com/google/uuid"
16)
17
18type TodoStatus string
19
20const (
21 TodoStatusPending TodoStatus = "pending"
22 TodoStatusInProgress TodoStatus = "in_progress"
23 TodoStatusCompleted TodoStatus = "completed"
24)
25
26type Todo struct {
27 Content string `json:"content"`
28 Status TodoStatus `json:"status"`
29 ActiveForm string `json:"active_form"`
30}
31
32type Session struct {
33 ID string
34 ParentSessionID string
35 Title string
36 MessageCount int64
37 PromptTokens int64
38 CompletionTokens int64
39 SummaryMessageID string
40 Cost float64
41 Todos []Todo
42 CreatedAt int64
43 UpdatedAt int64
44}
45
46type ForkMessageService interface {
47 List(ctx context.Context, sessionID string) ([]message.Message, error)
48 Copy(ctx context.Context, sessionID string, message message.Message) (message.Message, error)
49}
50
51type Service interface {
52 pubsub.Subscriber[Session]
53 Create(ctx context.Context, title string) (Session, error)
54 CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
55 CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
56 Get(ctx context.Context, id string) (Session, error)
57 List(ctx context.Context) ([]Session, error)
58 Save(ctx context.Context, session Session) (Session, error)
59 UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
60 Delete(ctx context.Context, id string) error
61 Fork(ctx context.Context, sourceSessionID, upToMessageID string, messageSvc ForkMessageService) (Session, error)
62
63 // Agent tool session management
64 CreateAgentToolSessionID(messageID, toolCallID string) string
65 ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
66 IsAgentToolSession(sessionID string) bool
67}
68
69type service struct {
70 *pubsub.Broker[Session]
71 db *sql.DB
72 q *db.Queries
73}
74
75func (s *service) Create(ctx context.Context, title string) (Session, error) {
76 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
77 ID: uuid.New().String(),
78 Title: title,
79 })
80 if err != nil {
81 return Session{}, err
82 }
83 session := s.fromDBItem(dbSession)
84 s.Publish(pubsub.CreatedEvent, session)
85 event.SessionCreated()
86 return session, nil
87}
88
89func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
90 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
91 ID: toolCallID,
92 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
93 Title: title,
94 })
95 if err != nil {
96 return Session{}, err
97 }
98 session := s.fromDBItem(dbSession)
99 s.Publish(pubsub.CreatedEvent, session)
100 return session, nil
101}
102
103func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
104 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
105 ID: "title-" + parentSessionID,
106 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
107 Title: "Generate a title",
108 })
109 if err != nil {
110 return Session{}, err
111 }
112 session := s.fromDBItem(dbSession)
113 s.Publish(pubsub.CreatedEvent, session)
114 return session, nil
115}
116
117func (s *service) Delete(ctx context.Context, id string) error {
118 tx, err := s.db.BeginTx(ctx, nil)
119 if err != nil {
120 return fmt.Errorf("beginning transaction: %w", err)
121 }
122 defer tx.Rollback() //nolint:errcheck
123
124 qtx := s.q.WithTx(tx)
125
126 dbSession, err := qtx.GetSessionByID(ctx, id)
127 if err != nil {
128 return err
129 }
130 if err = qtx.DeleteSessionMessages(ctx, dbSession.ID); err != nil {
131 return fmt.Errorf("deleting session messages: %w", err)
132 }
133 if err = qtx.DeleteSessionFiles(ctx, dbSession.ID); err != nil {
134 return fmt.Errorf("deleting session files: %w", err)
135 }
136 if err = qtx.DeleteSession(ctx, dbSession.ID); err != nil {
137 return fmt.Errorf("deleting session: %w", err)
138 }
139 if err = tx.Commit(); err != nil {
140 return fmt.Errorf("committing transaction: %w", err)
141 }
142
143 session := s.fromDBItem(dbSession)
144 s.Publish(pubsub.DeletedEvent, session)
145 event.SessionDeleted()
146 return nil
147}
148
149func (s *service) Fork(ctx context.Context, sourceSessionID, upToMessageID string, messageSvc ForkMessageService) (Session, error) {
150 messages, err := messageSvc.List(ctx, sourceSessionID)
151 if err != nil {
152 return Session{}, fmt.Errorf("listing messages: %w", err)
153 }
154
155 targetIndex := -1
156 for i, msg := range messages {
157 if msg.ID == upToMessageID {
158 targetIndex = i
159 break
160 }
161 }
162 if targetIndex == -1 {
163 return Session{}, fmt.Errorf("message not found: %s", upToMessageID)
164 }
165
166 sourceSession, err := s.Get(ctx, sourceSessionID)
167 if err != nil {
168 return Session{}, fmt.Errorf("getting source session: %w", err)
169 }
170
171 newSession, err := s.Create(ctx, "Forked: "+sourceSession.Title)
172 if err != nil {
173 return Session{}, fmt.Errorf("creating session: %w", err)
174 }
175
176 for i := 0; i < targetIndex; i++ {
177 _, err = messageSvc.Copy(ctx, newSession.ID, messages[i])
178 if err != nil {
179 _ = s.Delete(ctx, newSession.ID)
180 return Session{}, fmt.Errorf("copying message: %w", err)
181 }
182 }
183
184 return newSession, nil
185}
186
187func (s *service) Get(ctx context.Context, id string) (Session, error) {
188 dbSession, err := s.q.GetSessionByID(ctx, id)
189 if err != nil {
190 return Session{}, err
191 }
192 return s.fromDBItem(dbSession), nil
193}
194
195func (s *service) Save(ctx context.Context, session Session) (Session, error) {
196 todosJSON, err := marshalTodos(session.Todos)
197 if err != nil {
198 return Session{}, err
199 }
200
201 dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
202 ID: session.ID,
203 Title: session.Title,
204 PromptTokens: session.PromptTokens,
205 CompletionTokens: session.CompletionTokens,
206 SummaryMessageID: sql.NullString{
207 String: session.SummaryMessageID,
208 Valid: session.SummaryMessageID != "",
209 },
210 Cost: session.Cost,
211 Todos: sql.NullString{
212 String: todosJSON,
213 Valid: todosJSON != "",
214 },
215 })
216 if err != nil {
217 return Session{}, err
218 }
219 session = s.fromDBItem(dbSession)
220 s.Publish(pubsub.UpdatedEvent, session)
221 return session, nil
222}
223
224// UpdateTitleAndUsage updates only the title and usage fields atomically.
225// This is safer than fetching, modifying, and saving the entire session.
226func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
227 return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
228 ID: sessionID,
229 Title: title,
230 PromptTokens: promptTokens,
231 CompletionTokens: completionTokens,
232 Cost: cost,
233 })
234}
235
236func (s *service) List(ctx context.Context) ([]Session, error) {
237 dbSessions, err := s.q.ListSessions(ctx)
238 if err != nil {
239 return nil, err
240 }
241 sessions := make([]Session, len(dbSessions))
242 for i, dbSession := range dbSessions {
243 sessions[i] = s.fromDBItem(dbSession)
244 }
245 return sessions, nil
246}
247
248func (s service) fromDBItem(item db.Session) Session {
249 todos, err := unmarshalTodos(item.Todos.String)
250 if err != nil {
251 slog.Error("failed to unmarshal todos", "session_id", item.ID, "error", err)
252 }
253 return Session{
254 ID: item.ID,
255 ParentSessionID: item.ParentSessionID.String,
256 Title: item.Title,
257 MessageCount: item.MessageCount,
258 PromptTokens: item.PromptTokens,
259 CompletionTokens: item.CompletionTokens,
260 SummaryMessageID: item.SummaryMessageID.String,
261 Cost: item.Cost,
262 Todos: todos,
263 CreatedAt: item.CreatedAt,
264 UpdatedAt: item.UpdatedAt,
265 }
266}
267
268func marshalTodos(todos []Todo) (string, error) {
269 if len(todos) == 0 {
270 return "", nil
271 }
272 data, err := json.Marshal(todos)
273 if err != nil {
274 return "", err
275 }
276 return string(data), nil
277}
278
279func unmarshalTodos(data string) ([]Todo, error) {
280 if data == "" {
281 return []Todo{}, nil
282 }
283 var todos []Todo
284 if err := json.Unmarshal([]byte(data), &todos); err != nil {
285 return []Todo{}, err
286 }
287 return todos, nil
288}
289
290func NewService(q *db.Queries, conn *sql.DB) Service {
291 broker := pubsub.NewBroker[Session]()
292 return &service{
293 Broker: broker,
294 db: conn,
295 q: q,
296 }
297}
298
299// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
300func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
301 return fmt.Sprintf("%s$$%s", messageID, toolCallID)
302}
303
304// ParseAgentToolSessionID parses an agent tool session ID into its components
305func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
306 parts := strings.Split(sessionID, "$$")
307 if len(parts) != 2 {
308 return "", "", false
309 }
310 return parts[0], parts[1], true
311}
312
313// IsAgentToolSession checks if a session ID follows the agent tool session format
314func (s *service) IsAgentToolSession(sessionID string) bool {
315 _, _, ok := s.ParseAgentToolSessionID(sessionID)
316 return ok
317}