1package session
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "strings"
10
11 "github.com/charmbracelet/crush/internal/db"
12 "github.com/charmbracelet/crush/internal/event"
13 "github.com/charmbracelet/crush/internal/pubsub"
14 "github.com/google/uuid"
15)
16
17type TodoStatus string
18
19const (
20 TodoStatusPending TodoStatus = "pending"
21 TodoStatusInProgress TodoStatus = "in_progress"
22 TodoStatusCompleted TodoStatus = "completed"
23)
24
25type Todo struct {
26 Content string `json:"content"`
27 Status TodoStatus `json:"status"`
28 ActiveForm string `json:"active_form"`
29}
30
31type Session struct {
32 ID string
33 ParentSessionID string
34 Title string
35 MessageCount int64
36 PromptTokens int64
37 CompletionTokens int64
38 SummaryMessageID string
39 Cost float64
40 Todos []Todo
41 CreatedAt int64
42 UpdatedAt int64
43}
44
45type Service interface {
46 pubsub.Subscriber[Session]
47 Create(ctx context.Context, title string) (Session, error)
48 CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
49 CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
50 Get(ctx context.Context, id string) (Session, error)
51 List(ctx context.Context) ([]Session, error)
52 Save(ctx context.Context, session Session) (Session, error)
53 UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
54 Delete(ctx context.Context, id string) error
55
56 // Agent tool session management
57 CreateAgentToolSessionID(messageID, toolCallID string) string
58 ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
59 IsAgentToolSession(sessionID string) bool
60}
61
62type service struct {
63 *pubsub.Broker[Session]
64 q db.Querier
65}
66
67func (s *service) Create(ctx context.Context, title string) (Session, error) {
68 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
69 ID: uuid.New().String(),
70 Title: title,
71 })
72 if err != nil {
73 return Session{}, err
74 }
75 session := s.fromDBItem(dbSession)
76 s.Publish(pubsub.CreatedEvent, session)
77 event.SessionCreated()
78 return session, nil
79}
80
81func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
82 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
83 ID: toolCallID,
84 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
85 Title: title,
86 })
87 if err != nil {
88 return Session{}, err
89 }
90 session := s.fromDBItem(dbSession)
91 s.Publish(pubsub.CreatedEvent, session)
92 return session, nil
93}
94
95func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
96 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
97 ID: "title-" + parentSessionID,
98 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
99 Title: "Generate a title",
100 })
101 if err != nil {
102 return Session{}, err
103 }
104 session := s.fromDBItem(dbSession)
105 s.Publish(pubsub.CreatedEvent, session)
106 return session, nil
107}
108
109func (s *service) Delete(ctx context.Context, id string) error {
110 session, err := s.Get(ctx, id)
111 if err != nil {
112 return err
113 }
114 err = s.q.DeleteSession(ctx, session.ID)
115 if err != nil {
116 return err
117 }
118 s.Publish(pubsub.DeletedEvent, session)
119 event.SessionDeleted()
120 return nil
121}
122
123func (s *service) Get(ctx context.Context, id string) (Session, error) {
124 dbSession, err := s.q.GetSessionByID(ctx, id)
125 if err != nil {
126 return Session{}, err
127 }
128 return s.fromDBItem(dbSession), nil
129}
130
131func (s *service) Save(ctx context.Context, session Session) (Session, error) {
132 todosJSON, err := marshalTodos(session.Todos)
133 if err != nil {
134 return Session{}, err
135 }
136
137 dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
138 ID: session.ID,
139 Title: session.Title,
140 PromptTokens: session.PromptTokens,
141 CompletionTokens: session.CompletionTokens,
142 SummaryMessageID: sql.NullString{
143 String: session.SummaryMessageID,
144 Valid: session.SummaryMessageID != "",
145 },
146 Cost: session.Cost,
147 Todos: sql.NullString{
148 String: todosJSON,
149 Valid: todosJSON != "",
150 },
151 })
152 if err != nil {
153 return Session{}, err
154 }
155 session = s.fromDBItem(dbSession)
156 s.Publish(pubsub.UpdatedEvent, session)
157 return session, nil
158}
159
160// UpdateTitleAndUsage updates only the title and usage fields atomically.
161// This is safer than fetching, modifying, and saving the entire session.
162func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
163 return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
164 ID: sessionID,
165 Title: title,
166 PromptTokens: promptTokens,
167 CompletionTokens: completionTokens,
168 Cost: cost,
169 })
170}
171
172func (s *service) List(ctx context.Context) ([]Session, error) {
173 dbSessions, err := s.q.ListSessions(ctx)
174 if err != nil {
175 return nil, err
176 }
177 sessions := make([]Session, len(dbSessions))
178 for i, dbSession := range dbSessions {
179 sessions[i] = s.fromDBItem(dbSession)
180 }
181 return sessions, nil
182}
183
184func (s service) fromDBItem(item db.Session) Session {
185 todos, err := unmarshalTodos(item.Todos.String)
186 if err != nil {
187 slog.Error("failed to unmarshal todos", "session_id", item.ID, "error", err)
188 }
189 return Session{
190 ID: item.ID,
191 ParentSessionID: item.ParentSessionID.String,
192 Title: item.Title,
193 MessageCount: item.MessageCount,
194 PromptTokens: item.PromptTokens,
195 CompletionTokens: item.CompletionTokens,
196 SummaryMessageID: item.SummaryMessageID.String,
197 Cost: item.Cost,
198 Todos: todos,
199 CreatedAt: item.CreatedAt,
200 UpdatedAt: item.UpdatedAt,
201 }
202}
203
204func marshalTodos(todos []Todo) (string, error) {
205 if len(todos) == 0 {
206 return "", nil
207 }
208 data, err := json.Marshal(todos)
209 if err != nil {
210 return "", err
211 }
212 return string(data), nil
213}
214
215func unmarshalTodos(data string) ([]Todo, error) {
216 if data == "" {
217 return []Todo{}, nil
218 }
219 var todos []Todo
220 if err := json.Unmarshal([]byte(data), &todos); err != nil {
221 return []Todo{}, err
222 }
223 return todos, nil
224}
225
226func NewService(q db.Querier) Service {
227 broker := pubsub.NewBroker[Session]()
228 return &service{
229 broker,
230 q,
231 }
232}
233
234// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
235func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
236 return fmt.Sprintf("%s$$%s", messageID, toolCallID)
237}
238
239// ParseAgentToolSessionID parses an agent tool session ID into its components
240func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
241 parts := strings.Split(sessionID, "$$")
242 if len(parts) != 2 {
243 return "", "", false
244 }
245 return parts[0], parts[1], true
246}
247
248// IsAgentToolSession checks if a session ID follows the agent tool session format
249func (s *service) IsAgentToolSession(sessionID string) bool {
250 _, _, ok := s.ParseAgentToolSessionID(sessionID)
251 return ok
252}