1package session
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "strings"
10
11 "github.com/charmbracelet/crush/internal/db"
12 "github.com/charmbracelet/crush/internal/event"
13 "github.com/charmbracelet/crush/internal/pubsub"
14 "github.com/google/uuid"
15)
16
17type TodoStatus string
18
19const (
20 TodoStatusPending TodoStatus = "pending"
21 TodoStatusInProgress TodoStatus = "in_progress"
22 TodoStatusCompleted TodoStatus = "completed"
23)
24
25type Todo struct {
26 Content string `json:"content"`
27 Status TodoStatus `json:"status"`
28 ActiveForm string `json:"active_form"`
29}
30
31type Session struct {
32 ID string
33 ParentSessionID string
34 Title string
35 MessageCount int64
36 PromptTokens int64
37 CompletionTokens int64
38 SummaryMessageID string
39 Cost float64
40 Todos []Todo
41 CreatedAt int64
42 UpdatedAt int64
43}
44
45type Service interface {
46 pubsub.Subscriber[Session]
47 Create(ctx context.Context, title string) (Session, error)
48 CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
49 CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
50 Get(ctx context.Context, id string) (Session, error)
51 List(ctx context.Context) ([]Session, error)
52 Save(ctx context.Context, session Session) (Session, error)
53 UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
54 Delete(ctx context.Context, id string) error
55
56 // Agent tool session management
57 CreateAgentToolSessionID(messageID, toolCallID string) string
58 ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
59 IsAgentToolSession(sessionID string) bool
60}
61
62type service struct {
63 *pubsub.Broker[Session]
64 db *sql.DB
65 q *db.Queries
66}
67
68func (s *service) Create(ctx context.Context, title string) (Session, error) {
69 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
70 ID: uuid.New().String(),
71 Title: title,
72 })
73 if err != nil {
74 return Session{}, err
75 }
76 session := s.fromDBItem(dbSession)
77 s.Publish(pubsub.CreatedEvent, session)
78 event.SessionCreated()
79 return session, nil
80}
81
82func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
83 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
84 ID: toolCallID,
85 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
86 Title: title,
87 })
88 if err != nil {
89 return Session{}, err
90 }
91 session := s.fromDBItem(dbSession)
92 s.Publish(pubsub.CreatedEvent, session)
93 return session, nil
94}
95
96func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
97 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
98 ID: "title-" + parentSessionID,
99 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
100 Title: "Generate a title",
101 })
102 if err != nil {
103 return Session{}, err
104 }
105 session := s.fromDBItem(dbSession)
106 s.Publish(pubsub.CreatedEvent, session)
107 return session, nil
108}
109
110func (s *service) Delete(ctx context.Context, id string) error {
111 tx, err := s.db.BeginTx(ctx, nil)
112 if err != nil {
113 return fmt.Errorf("beginning transaction: %w", err)
114 }
115 defer tx.Rollback() //nolint:errcheck
116
117 qtx := s.q.WithTx(tx)
118
119 dbSession, err := qtx.GetSessionByID(ctx, id)
120 if err != nil {
121 return err
122 }
123 if err = qtx.DeleteSessionMessages(ctx, dbSession.ID); err != nil {
124 return fmt.Errorf("deleting session messages: %w", err)
125 }
126 if err = qtx.DeleteSessionFiles(ctx, dbSession.ID); err != nil {
127 return fmt.Errorf("deleting session files: %w", err)
128 }
129 if err = qtx.DeleteSession(ctx, dbSession.ID); err != nil {
130 return fmt.Errorf("deleting session: %w", err)
131 }
132 if err = tx.Commit(); err != nil {
133 return fmt.Errorf("committing transaction: %w", err)
134 }
135
136 session := s.fromDBItem(dbSession)
137 s.Publish(pubsub.DeletedEvent, session)
138 event.SessionDeleted()
139 return nil
140}
141
142func (s *service) Get(ctx context.Context, id string) (Session, error) {
143 dbSession, err := s.q.GetSessionByID(ctx, id)
144 if err != nil {
145 return Session{}, err
146 }
147 return s.fromDBItem(dbSession), nil
148}
149
150func (s *service) Save(ctx context.Context, session Session) (Session, error) {
151 todosJSON, err := marshalTodos(session.Todos)
152 if err != nil {
153 return Session{}, err
154 }
155
156 dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
157 ID: session.ID,
158 Title: session.Title,
159 PromptTokens: session.PromptTokens,
160 CompletionTokens: session.CompletionTokens,
161 SummaryMessageID: sql.NullString{
162 String: session.SummaryMessageID,
163 Valid: session.SummaryMessageID != "",
164 },
165 Cost: session.Cost,
166 Todos: sql.NullString{
167 String: todosJSON,
168 Valid: todosJSON != "",
169 },
170 })
171 if err != nil {
172 return Session{}, err
173 }
174 session = s.fromDBItem(dbSession)
175 s.Publish(pubsub.UpdatedEvent, session)
176 return session, nil
177}
178
179// UpdateTitleAndUsage updates only the title and usage fields atomically.
180// This is safer than fetching, modifying, and saving the entire session.
181func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
182 return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
183 ID: sessionID,
184 Title: title,
185 PromptTokens: promptTokens,
186 CompletionTokens: completionTokens,
187 Cost: cost,
188 })
189}
190
191func (s *service) List(ctx context.Context) ([]Session, error) {
192 dbSessions, err := s.q.ListSessions(ctx)
193 if err != nil {
194 return nil, err
195 }
196 sessions := make([]Session, len(dbSessions))
197 for i, dbSession := range dbSessions {
198 sessions[i] = s.fromDBItem(dbSession)
199 }
200 return sessions, nil
201}
202
203func (s service) fromDBItem(item db.Session) Session {
204 todos, err := unmarshalTodos(item.Todos.String)
205 if err != nil {
206 slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
207 }
208 return Session{
209 ID: item.ID,
210 ParentSessionID: item.ParentSessionID.String,
211 Title: item.Title,
212 MessageCount: item.MessageCount,
213 PromptTokens: item.PromptTokens,
214 CompletionTokens: item.CompletionTokens,
215 SummaryMessageID: item.SummaryMessageID.String,
216 Cost: item.Cost,
217 Todos: todos,
218 CreatedAt: item.CreatedAt,
219 UpdatedAt: item.UpdatedAt,
220 }
221}
222
223func marshalTodos(todos []Todo) (string, error) {
224 if len(todos) == 0 {
225 return "", nil
226 }
227 data, err := json.Marshal(todos)
228 if err != nil {
229 return "", err
230 }
231 return string(data), nil
232}
233
234func unmarshalTodos(data string) ([]Todo, error) {
235 if data == "" {
236 return []Todo{}, nil
237 }
238 var todos []Todo
239 if err := json.Unmarshal([]byte(data), &todos); err != nil {
240 return []Todo{}, err
241 }
242 return todos, nil
243}
244
245func NewService(q *db.Queries, conn *sql.DB) Service {
246 broker := pubsub.NewBroker[Session]()
247 return &service{
248 Broker: broker,
249 db: conn,
250 q: q,
251 }
252}
253
254// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
255func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
256 return fmt.Sprintf("%s$$%s", messageID, toolCallID)
257}
258
259// ParseAgentToolSessionID parses an agent tool session ID into its components
260func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
261 parts := strings.Split(sessionID, "$$")
262 if len(parts) != 2 {
263 return "", "", false
264 }
265 return parts[0], parts[1], true
266}
267
268// IsAgentToolSession checks if a session ID follows the agent tool session format
269func (s *service) IsAgentToolSession(sessionID string) bool {
270 _, _, ok := s.ParseAgentToolSessionID(sessionID)
271 return ok
272}