1package session
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "strings"
10
11 "github.com/charmbracelet/crush/internal/db"
12 "github.com/charmbracelet/crush/internal/event"
13 "github.com/charmbracelet/crush/internal/pubsub"
14 "github.com/google/uuid"
15)
16
17type TodoStatus string
18
19const (
20 TodoStatusPending TodoStatus = "pending"
21 TodoStatusInProgress TodoStatus = "in_progress"
22 TodoStatusCompleted TodoStatus = "completed"
23)
24
25type Todo struct {
26 Content string `json:"content"`
27 Status TodoStatus `json:"status"`
28 ActiveForm string `json:"active_form"`
29}
30
31// HasIncompleteTodos returns true if there are any non-completed todos.
32func HasIncompleteTodos(todos []Todo) bool {
33 for _, todo := range todos {
34 if todo.Status != TodoStatusCompleted {
35 return true
36 }
37 }
38 return false
39}
40
41type Session struct {
42 ID string
43 ParentSessionID string
44 Title string
45 MessageCount int64
46 PromptTokens int64
47 CompletionTokens int64
48 SummaryMessageID string
49 Cost float64
50 Todos []Todo
51 CreatedAt int64
52 UpdatedAt int64
53}
54
55type Service interface {
56 pubsub.Subscriber[Session]
57 Create(ctx context.Context, title string) (Session, error)
58 CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
59 CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
60 Get(ctx context.Context, id string) (Session, error)
61 List(ctx context.Context) ([]Session, error)
62 Save(ctx context.Context, session Session) (Session, error)
63 UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
64 Delete(ctx context.Context, id string) error
65
66 // Agent tool session management
67 CreateAgentToolSessionID(messageID, toolCallID string) string
68 ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
69 IsAgentToolSession(sessionID string) bool
70}
71
72type service struct {
73 *pubsub.Broker[Session]
74 db *sql.DB
75 q *db.Queries
76}
77
78func (s *service) Create(ctx context.Context, title string) (Session, error) {
79 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
80 ID: uuid.New().String(),
81 Title: title,
82 })
83 if err != nil {
84 return Session{}, err
85 }
86 session := s.fromDBItem(dbSession)
87 s.Publish(pubsub.CreatedEvent, session)
88 event.SessionCreated()
89 return session, nil
90}
91
92func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
93 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
94 ID: toolCallID,
95 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
96 Title: title,
97 })
98 if err != nil {
99 return Session{}, err
100 }
101 session := s.fromDBItem(dbSession)
102 s.Publish(pubsub.CreatedEvent, session)
103 return session, nil
104}
105
106func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
107 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
108 ID: "title-" + parentSessionID,
109 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
110 Title: "Generate a title",
111 })
112 if err != nil {
113 return Session{}, err
114 }
115 session := s.fromDBItem(dbSession)
116 s.Publish(pubsub.CreatedEvent, session)
117 return session, nil
118}
119
120func (s *service) Delete(ctx context.Context, id string) error {
121 tx, err := s.db.BeginTx(ctx, nil)
122 if err != nil {
123 return fmt.Errorf("beginning transaction: %w", err)
124 }
125 defer tx.Rollback() //nolint:errcheck
126
127 qtx := s.q.WithTx(tx)
128
129 dbSession, err := qtx.GetSessionByID(ctx, id)
130 if err != nil {
131 return err
132 }
133 if err = qtx.DeleteSessionMessages(ctx, dbSession.ID); err != nil {
134 return fmt.Errorf("deleting session messages: %w", err)
135 }
136 if err = qtx.DeleteSessionFiles(ctx, dbSession.ID); err != nil {
137 return fmt.Errorf("deleting session files: %w", err)
138 }
139 if err = qtx.DeleteSession(ctx, dbSession.ID); err != nil {
140 return fmt.Errorf("deleting session: %w", err)
141 }
142 if err = tx.Commit(); err != nil {
143 return fmt.Errorf("committing transaction: %w", err)
144 }
145
146 session := s.fromDBItem(dbSession)
147 s.Publish(pubsub.DeletedEvent, session)
148 event.SessionDeleted()
149 return nil
150}
151
152func (s *service) Get(ctx context.Context, id string) (Session, error) {
153 dbSession, err := s.q.GetSessionByID(ctx, id)
154 if err != nil {
155 return Session{}, err
156 }
157 return s.fromDBItem(dbSession), nil
158}
159
160func (s *service) Save(ctx context.Context, session Session) (Session, error) {
161 todosJSON, err := marshalTodos(session.Todos)
162 if err != nil {
163 return Session{}, err
164 }
165
166 dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
167 ID: session.ID,
168 Title: session.Title,
169 PromptTokens: session.PromptTokens,
170 CompletionTokens: session.CompletionTokens,
171 SummaryMessageID: sql.NullString{
172 String: session.SummaryMessageID,
173 Valid: session.SummaryMessageID != "",
174 },
175 Cost: session.Cost,
176 Todos: sql.NullString{
177 String: todosJSON,
178 Valid: todosJSON != "",
179 },
180 })
181 if err != nil {
182 return Session{}, err
183 }
184 session = s.fromDBItem(dbSession)
185 s.Publish(pubsub.UpdatedEvent, session)
186 return session, nil
187}
188
189// UpdateTitleAndUsage updates only the title and usage fields atomically.
190// This is safer than fetching, modifying, and saving the entire session.
191func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
192 return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
193 ID: sessionID,
194 Title: title,
195 PromptTokens: promptTokens,
196 CompletionTokens: completionTokens,
197 Cost: cost,
198 })
199}
200
201func (s *service) List(ctx context.Context) ([]Session, error) {
202 dbSessions, err := s.q.ListSessions(ctx)
203 if err != nil {
204 return nil, err
205 }
206 sessions := make([]Session, len(dbSessions))
207 for i, dbSession := range dbSessions {
208 sessions[i] = s.fromDBItem(dbSession)
209 }
210 return sessions, nil
211}
212
213func (s service) fromDBItem(item db.Session) Session {
214 todos, err := unmarshalTodos(item.Todos.String)
215 if err != nil {
216 slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
217 }
218 return Session{
219 ID: item.ID,
220 ParentSessionID: item.ParentSessionID.String,
221 Title: item.Title,
222 MessageCount: item.MessageCount,
223 PromptTokens: item.PromptTokens,
224 CompletionTokens: item.CompletionTokens,
225 SummaryMessageID: item.SummaryMessageID.String,
226 Cost: item.Cost,
227 Todos: todos,
228 CreatedAt: item.CreatedAt,
229 UpdatedAt: item.UpdatedAt,
230 }
231}
232
233func marshalTodos(todos []Todo) (string, error) {
234 if len(todos) == 0 {
235 return "", nil
236 }
237 data, err := json.Marshal(todos)
238 if err != nil {
239 return "", err
240 }
241 return string(data), nil
242}
243
244func unmarshalTodos(data string) ([]Todo, error) {
245 if data == "" {
246 return []Todo{}, nil
247 }
248 var todos []Todo
249 if err := json.Unmarshal([]byte(data), &todos); err != nil {
250 return []Todo{}, err
251 }
252 return todos, nil
253}
254
255func NewService(q *db.Queries, conn *sql.DB) Service {
256 broker := pubsub.NewBroker[Session]()
257 return &service{
258 Broker: broker,
259 db: conn,
260 q: q,
261 }
262}
263
264// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
265func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
266 return fmt.Sprintf("%s$$%s", messageID, toolCallID)
267}
268
269// ParseAgentToolSessionID parses an agent tool session ID into its components
270func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
271 parts := strings.Split(sessionID, "$$")
272 if len(parts) != 2 {
273 return "", "", false
274 }
275 return parts[0], parts[1], true
276}
277
278// IsAgentToolSession checks if a session ID follows the agent tool session format
279func (s *service) IsAgentToolSession(sessionID string) bool {
280 _, _, ok := s.ParseAgentToolSessionID(sessionID)
281 return ok
282}