1package session
2
3import (
4 "context"
5 "database/sql"
6
7 "github.com/charmbracelet/crush/internal/db"
8 "github.com/charmbracelet/crush/internal/pubsub"
9 "github.com/google/uuid"
10)
11
12type Session struct {
13 ID string
14 ParentSessionID string
15 Title string
16 MessageCount int64
17 PromptTokens int64
18 CompletionTokens int64
19 SummaryMessageID string
20 Cost float64
21 CreatedAt int64
22 UpdatedAt int64
23}
24
25type Service interface {
26 pubsub.Suscriber[Session]
27 Create(ctx context.Context, title string) (Session, error)
28 CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
29 CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
30 Get(ctx context.Context, id string) (Session, error)
31 List(ctx context.Context) ([]Session, error)
32 ListAll(ctx context.Context) ([]Session, error)
33 ListChildren(ctx context.Context, parentSessionID string) ([]Session, error)
34 Save(ctx context.Context, session Session) (Session, error)
35 Delete(ctx context.Context, id string) error
36 SearchByTitle(ctx context.Context, titlePattern string) ([]Session, error)
37 SearchByText(ctx context.Context, textPattern string) ([]Session, error)
38 SearchByTitleAndText(ctx context.Context, titlePattern, textPattern string) ([]Session, error)
39}
40
41type service struct {
42 *pubsub.Broker[Session]
43 q db.Querier
44}
45
46func (s *service) Create(ctx context.Context, title string) (Session, error) {
47 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
48 ID: uuid.New().String(),
49 Title: title,
50 })
51 if err != nil {
52 return Session{}, err
53 }
54 session := s.fromDBItem(dbSession)
55 s.Publish(pubsub.CreatedEvent, session)
56 return session, nil
57}
58
59func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
60 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
61 ID: toolCallID,
62 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
63 Title: title,
64 })
65 if err != nil {
66 return Session{}, err
67 }
68 session := s.fromDBItem(dbSession)
69 s.Publish(pubsub.CreatedEvent, session)
70 return session, nil
71}
72
73func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
74 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
75 ID: "title-" + parentSessionID,
76 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
77 Title: "Generate a title",
78 })
79 if err != nil {
80 return Session{}, err
81 }
82 session := s.fromDBItem(dbSession)
83 s.Publish(pubsub.CreatedEvent, session)
84 return session, nil
85}
86
87func (s *service) Delete(ctx context.Context, id string) error {
88 session, err := s.Get(ctx, id)
89 if err != nil {
90 return err
91 }
92 err = s.q.DeleteSession(ctx, session.ID)
93 if err != nil {
94 return err
95 }
96 s.Publish(pubsub.DeletedEvent, session)
97 return nil
98}
99
100func (s *service) Get(ctx context.Context, id string) (Session, error) {
101 dbSession, err := s.q.GetSessionByID(ctx, id)
102 if err != nil {
103 return Session{}, err
104 }
105 return s.fromDBItem(dbSession), nil
106}
107
108func (s *service) Save(ctx context.Context, session Session) (Session, error) {
109 dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
110 ID: session.ID,
111 Title: session.Title,
112 PromptTokens: session.PromptTokens,
113 CompletionTokens: session.CompletionTokens,
114 SummaryMessageID: sql.NullString{
115 String: session.SummaryMessageID,
116 Valid: session.SummaryMessageID != "",
117 },
118 Cost: session.Cost,
119 })
120 if err != nil {
121 return Session{}, err
122 }
123 session = s.fromDBItem(dbSession)
124 s.Publish(pubsub.UpdatedEvent, session)
125 return session, nil
126}
127
128func (s *service) List(ctx context.Context) ([]Session, error) {
129 dbSessions, err := s.q.ListSessions(ctx)
130 if err != nil {
131 return nil, err
132 }
133 sessions := make([]Session, len(dbSessions))
134 for i, dbSession := range dbSessions {
135 sessions[i] = s.fromDBItem(dbSession)
136 }
137 return sessions, nil
138}
139
140func (s *service) ListAll(ctx context.Context) ([]Session, error) {
141 dbSessions, err := s.q.ListAllSessions(ctx)
142 if err != nil {
143 return nil, err
144 }
145 sessions := make([]Session, len(dbSessions))
146 for i, dbSession := range dbSessions {
147 sessions[i] = s.fromDBItem(dbSession)
148 }
149 return sessions, nil
150}
151
152func (s *service) ListChildren(ctx context.Context, parentSessionID string) ([]Session, error) {
153 dbSessions, err := s.q.ListChildSessions(ctx, sql.NullString{
154 String: parentSessionID,
155 Valid: true,
156 })
157 if err != nil {
158 return nil, err
159 }
160 sessions := make([]Session, len(dbSessions))
161 for i, dbSession := range dbSessions {
162 sessions[i] = s.fromDBItem(dbSession)
163 }
164 return sessions, nil
165}
166
167func (s *service) SearchByTitle(ctx context.Context, titlePattern string) ([]Session, error) {
168 dbSessions, err := s.q.SearchSessionsByTitle(ctx, "%"+titlePattern+"%")
169 if err != nil {
170 return nil, err
171 }
172 sessions := make([]Session, len(dbSessions))
173 for i, dbSession := range dbSessions {
174 sessions[i] = s.fromDBItem(dbSession)
175 }
176 return sessions, nil
177}
178
179func (s *service) SearchByText(ctx context.Context, textPattern string) ([]Session, error) {
180 dbSessions, err := s.q.SearchSessionsByText(ctx, "%"+textPattern+"%")
181 if err != nil {
182 return nil, err
183 }
184 sessions := make([]Session, len(dbSessions))
185 for i, dbSession := range dbSessions {
186 sessions[i] = s.fromDBItem(dbSession)
187 }
188 return sessions, nil
189}
190
191func (s *service) SearchByTitleAndText(ctx context.Context, titlePattern, textPattern string) ([]Session, error) {
192 dbSessions, err := s.q.SearchSessionsByTitleAndText(ctx, db.SearchSessionsByTitleAndTextParams{
193 Title: "%" + titlePattern + "%",
194 Parts: "%" + textPattern + "%",
195 })
196 if err != nil {
197 return nil, err
198 }
199 sessions := make([]Session, len(dbSessions))
200 for i, dbSession := range dbSessions {
201 sessions[i] = s.fromDBItem(dbSession)
202 }
203 return sessions, nil
204}
205
206func (s service) fromDBItem(item db.Session) Session {
207 return Session{
208 ID: item.ID,
209 ParentSessionID: item.ParentSessionID.String,
210 Title: item.Title,
211 MessageCount: item.MessageCount,
212 PromptTokens: item.PromptTokens,
213 CompletionTokens: item.CompletionTokens,
214 SummaryMessageID: item.SummaryMessageID.String,
215 Cost: item.Cost,
216 CreatedAt: item.CreatedAt,
217 UpdatedAt: item.UpdatedAt,
218 }
219}
220
221func NewService(q db.Querier) Service {
222 broker := pubsub.NewBroker[Session]()
223 return &service{
224 broker,
225 q,
226 }
227}