1package session
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "log/slog"
9 "strings"
10
11 "github.com/charmbracelet/crush/internal/db"
12 "github.com/charmbracelet/crush/internal/event"
13 "github.com/charmbracelet/crush/internal/pubsub"
14 "github.com/google/uuid"
15)
16
17type TodoStatus string
18
19const (
20 TodoStatusPending TodoStatus = "pending"
21 TodoStatusInProgress TodoStatus = "in_progress"
22 TodoStatusCompleted TodoStatus = "completed"
23)
24
25type Todo struct {
26 Content string `json:"content"`
27 Status TodoStatus `json:"status"`
28 ActiveForm string `json:"active_form"`
29}
30
31type Session struct {
32 ID string
33 ParentSessionID string
34 Title string
35 MessageCount int64
36 PromptTokens int64
37 CompletionTokens int64
38 SummaryMessageID string
39 Cost float64
40 Todos []Todo
41 CreatedAt int64
42 UpdatedAt int64
43}
44
45type Service interface {
46 pubsub.Subscriber[Session]
47 Create(ctx context.Context, title string) (Session, error)
48 CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
49 CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
50 Get(ctx context.Context, id string) (Session, error)
51 List(ctx context.Context) ([]Session, error)
52 Save(ctx context.Context, session Session) (Session, error)
53 Delete(ctx context.Context, id string) error
54
55 // Agent tool session management
56 CreateAgentToolSessionID(messageID, toolCallID string) string
57 ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
58 IsAgentToolSession(sessionID string) bool
59}
60
61type service struct {
62 *pubsub.Broker[Session]
63 q db.Querier
64}
65
66func (s *service) Create(ctx context.Context, title string) (Session, error) {
67 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
68 ID: uuid.New().String(),
69 Title: title,
70 })
71 if err != nil {
72 return Session{}, err
73 }
74 session := s.fromDBItem(dbSession)
75 s.Publish(pubsub.CreatedEvent, session)
76 event.SessionCreated()
77 return session, nil
78}
79
80func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
81 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
82 ID: toolCallID,
83 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
84 Title: title,
85 })
86 if err != nil {
87 return Session{}, err
88 }
89 session := s.fromDBItem(dbSession)
90 s.Publish(pubsub.CreatedEvent, session)
91 return session, nil
92}
93
94func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
95 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
96 ID: "title-" + parentSessionID,
97 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
98 Title: "Generate a title",
99 })
100 if err != nil {
101 return Session{}, err
102 }
103 session := s.fromDBItem(dbSession)
104 s.Publish(pubsub.CreatedEvent, session)
105 return session, nil
106}
107
108func (s *service) Delete(ctx context.Context, id string) error {
109 session, err := s.Get(ctx, id)
110 if err != nil {
111 return err
112 }
113 err = s.q.DeleteSession(ctx, session.ID)
114 if err != nil {
115 return err
116 }
117 s.Publish(pubsub.DeletedEvent, session)
118 event.SessionDeleted()
119 return nil
120}
121
122func (s *service) Get(ctx context.Context, id string) (Session, error) {
123 dbSession, err := s.q.GetSessionByID(ctx, id)
124 if err != nil {
125 return Session{}, err
126 }
127 return s.fromDBItem(dbSession), nil
128}
129
130func (s *service) Save(ctx context.Context, session Session) (Session, error) {
131 todosJSON, err := marshalTodos(session.Todos)
132 if err != nil {
133 return Session{}, err
134 }
135
136 dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
137 ID: session.ID,
138 Title: session.Title,
139 PromptTokens: session.PromptTokens,
140 CompletionTokens: session.CompletionTokens,
141 SummaryMessageID: sql.NullString{
142 String: session.SummaryMessageID,
143 Valid: session.SummaryMessageID != "",
144 },
145 Cost: session.Cost,
146 Todos: sql.NullString{
147 String: todosJSON,
148 Valid: todosJSON != "",
149 },
150 })
151 if err != nil {
152 return Session{}, err
153 }
154 session = s.fromDBItem(dbSession)
155 s.Publish(pubsub.UpdatedEvent, session)
156 return session, nil
157}
158
159func (s *service) List(ctx context.Context) ([]Session, error) {
160 dbSessions, err := s.q.ListSessions(ctx)
161 if err != nil {
162 return nil, err
163 }
164 sessions := make([]Session, len(dbSessions))
165 for i, dbSession := range dbSessions {
166 sessions[i] = s.fromDBItem(dbSession)
167 }
168 return sessions, nil
169}
170
171func (s service) fromDBItem(item db.Session) Session {
172 todos, err := unmarshalTodos(item.Todos.String)
173 if err != nil {
174 slog.Error("failed to unmarshal todos", "session_id", item.ID, "error", err)
175 }
176 return Session{
177 ID: item.ID,
178 ParentSessionID: item.ParentSessionID.String,
179 Title: item.Title,
180 MessageCount: item.MessageCount,
181 PromptTokens: item.PromptTokens,
182 CompletionTokens: item.CompletionTokens,
183 SummaryMessageID: item.SummaryMessageID.String,
184 Cost: item.Cost,
185 Todos: todos,
186 CreatedAt: item.CreatedAt,
187 UpdatedAt: item.UpdatedAt,
188 }
189}
190
191func marshalTodos(todos []Todo) (string, error) {
192 if len(todos) == 0 {
193 return "", nil
194 }
195 data, err := json.Marshal(todos)
196 if err != nil {
197 return "", err
198 }
199 return string(data), nil
200}
201
202func unmarshalTodos(data string) ([]Todo, error) {
203 if data == "" {
204 return []Todo{}, nil
205 }
206 var todos []Todo
207 if err := json.Unmarshal([]byte(data), &todos); err != nil {
208 return []Todo{}, err
209 }
210 return todos, nil
211}
212
213func NewService(q db.Querier) Service {
214 broker := pubsub.NewBroker[Session]()
215 return &service{
216 broker,
217 q,
218 }
219}
220
221// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
222func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
223 return fmt.Sprintf("%s$$%s", messageID, toolCallID)
224}
225
226// ParseAgentToolSessionID parses an agent tool session ID into its components
227func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
228 parts := strings.Split(sessionID, "$$")
229 if len(parts) != 2 {
230 return "", "", false
231 }
232 return parts[0], parts[1], true
233}
234
235// IsAgentToolSession checks if a session ID follows the agent tool session format
236func (s *service) IsAgentToolSession(sessionID string) bool {
237 _, _, ok := s.ParseAgentToolSessionID(sessionID)
238 return ok
239}