1package session
2
3import (
4 "context"
5 "database/sql"
6
7 "github.com/charmbracelet/crush/internal/db"
8 "github.com/charmbracelet/crush/internal/pubsub"
9 "github.com/google/uuid"
10)
11
12type Session struct {
13 ID string
14 ParentSessionID string
15 Title string
16 MessageCount int64
17 PromptTokens int64
18 CompletionTokens int64
19 SummaryMessageID string
20 Cost float64
21 CreatedAt int64
22 UpdatedAt int64
23}
24
25type Service interface {
26 pubsub.Suscriber[Session]
27 Create(ctx context.Context, title string) (Session, error)
28 CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
29 CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
30 Get(ctx context.Context, id string) (Session, error)
31 List(ctx context.Context) ([]Session, error)
32 ListAll(ctx context.Context) ([]Session, error)
33 ListChildren(ctx context.Context, parentSessionID string) ([]Session, error)
34 Save(ctx context.Context, session Session) (Session, error)
35 Delete(ctx context.Context, id string) error
36}
37
38type service struct {
39 *pubsub.Broker[Session]
40 q db.Querier
41}
42
43func (s *service) Create(ctx context.Context, title string) (Session, error) {
44 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
45 ID: uuid.New().String(),
46 Title: title,
47 })
48 if err != nil {
49 return Session{}, err
50 }
51 session := s.fromDBItem(dbSession)
52 s.Publish(pubsub.CreatedEvent, session)
53 return session, nil
54}
55
56func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
57 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
58 ID: toolCallID,
59 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
60 Title: title,
61 })
62 if err != nil {
63 return Session{}, err
64 }
65 session := s.fromDBItem(dbSession)
66 s.Publish(pubsub.CreatedEvent, session)
67 return session, nil
68}
69
70func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
71 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
72 ID: "title-" + parentSessionID,
73 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
74 Title: "Generate a title",
75 })
76 if err != nil {
77 return Session{}, err
78 }
79 session := s.fromDBItem(dbSession)
80 s.Publish(pubsub.CreatedEvent, session)
81 return session, nil
82}
83
84func (s *service) Delete(ctx context.Context, id string) error {
85 session, err := s.Get(ctx, id)
86 if err != nil {
87 return err
88 }
89 err = s.q.DeleteSession(ctx, session.ID)
90 if err != nil {
91 return err
92 }
93 s.Publish(pubsub.DeletedEvent, session)
94 return nil
95}
96
97func (s *service) Get(ctx context.Context, id string) (Session, error) {
98 dbSession, err := s.q.GetSessionByID(ctx, id)
99 if err != nil {
100 return Session{}, err
101 }
102 return s.fromDBItem(dbSession), nil
103}
104
105func (s *service) Save(ctx context.Context, session Session) (Session, error) {
106 dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
107 ID: session.ID,
108 Title: session.Title,
109 PromptTokens: session.PromptTokens,
110 CompletionTokens: session.CompletionTokens,
111 SummaryMessageID: sql.NullString{
112 String: session.SummaryMessageID,
113 Valid: session.SummaryMessageID != "",
114 },
115 Cost: session.Cost,
116 })
117 if err != nil {
118 return Session{}, err
119 }
120 session = s.fromDBItem(dbSession)
121 s.Publish(pubsub.UpdatedEvent, session)
122 return session, nil
123}
124
125func (s *service) List(ctx context.Context) ([]Session, error) {
126 dbSessions, err := s.q.ListSessions(ctx)
127 if err != nil {
128 return nil, err
129 }
130 sessions := make([]Session, len(dbSessions))
131 for i, dbSession := range dbSessions {
132 sessions[i] = s.fromDBItem(dbSession)
133 }
134 return sessions, nil
135}
136
137func (s *service) ListAll(ctx context.Context) ([]Session, error) {
138 dbSessions, err := s.q.ListAllSessions(ctx)
139 if err != nil {
140 return nil, err
141 }
142 sessions := make([]Session, len(dbSessions))
143 for i, dbSession := range dbSessions {
144 sessions[i] = s.fromDBItem(dbSession)
145 }
146 return sessions, nil
147}
148
149func (s *service) ListChildren(ctx context.Context, parentSessionID string) ([]Session, error) {
150 dbSessions, err := s.q.ListChildSessions(ctx, sql.NullString{
151 String: parentSessionID,
152 Valid: true,
153 })
154 if err != nil {
155 return nil, err
156 }
157 sessions := make([]Session, len(dbSessions))
158 for i, dbSession := range dbSessions {
159 sessions[i] = s.fromDBItem(dbSession)
160 }
161 return sessions, nil
162}
163
164func (s service) fromDBItem(item db.Session) Session {
165 return Session{
166 ID: item.ID,
167 ParentSessionID: item.ParentSessionID.String,
168 Title: item.Title,
169 MessageCount: item.MessageCount,
170 PromptTokens: item.PromptTokens,
171 CompletionTokens: item.CompletionTokens,
172 SummaryMessageID: item.SummaryMessageID.String,
173 Cost: item.Cost,
174 CreatedAt: item.CreatedAt,
175 UpdatedAt: item.UpdatedAt,
176 }
177}
178
179func NewService(q db.Querier) Service {
180 broker := pubsub.NewBroker[Session]()
181 return &service{
182 broker,
183 q,
184 }
185}