1package session
2
3import (
4 "context"
5 "database/sql"
6
7 "github.com/charmbracelet/crush/internal/db"
8 "github.com/charmbracelet/crush/internal/event"
9 "github.com/charmbracelet/crush/internal/pubsub"
10 "github.com/google/uuid"
11)
12
13type Session struct {
14 ID string
15 ParentSessionID string
16 Title string
17 MessageCount int64
18 PromptTokens int64
19 CompletionTokens int64
20 SummaryMessageID string
21 Cost float64
22 CreatedAt int64
23 UpdatedAt int64
24}
25
26type Service interface {
27 pubsub.Suscriber[Session]
28 Create(ctx context.Context, title string) (Session, error)
29 CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
30 CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
31 Get(ctx context.Context, id string) (Session, error)
32 List(ctx context.Context) ([]Session, error)
33 Save(ctx context.Context, session Session) (Session, error)
34 Delete(ctx context.Context, id string) error
35}
36
37type service struct {
38 *pubsub.Broker[Session]
39 q db.Querier
40}
41
42func (s *service) Create(ctx context.Context, title string) (Session, error) {
43 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
44 ID: uuid.New().String(),
45 Title: title,
46 })
47 if err != nil {
48 return Session{}, err
49 }
50 session := s.fromDBItem(dbSession)
51 s.Publish(pubsub.CreatedEvent, session)
52 event.SessionCreated()
53 return session, nil
54}
55
56func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
57 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
58 ID: toolCallID,
59 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
60 Title: title,
61 })
62 if err != nil {
63 return Session{}, err
64 }
65 session := s.fromDBItem(dbSession)
66 s.Publish(pubsub.CreatedEvent, session)
67 return session, nil
68}
69
70func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
71 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
72 ID: "title-" + parentSessionID,
73 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
74 Title: "Generate a title",
75 })
76 if err != nil {
77 return Session{}, err
78 }
79 session := s.fromDBItem(dbSession)
80 s.Publish(pubsub.CreatedEvent, session)
81 return session, nil
82}
83
84func (s *service) Delete(ctx context.Context, id string) error {
85 session, err := s.Get(ctx, id)
86 if err != nil {
87 return err
88 }
89 err = s.q.DeleteSession(ctx, session.ID)
90 if err != nil {
91 return err
92 }
93 s.Publish(pubsub.DeletedEvent, session)
94 event.SessionDeleted()
95 return nil
96}
97
98func (s *service) Get(ctx context.Context, id string) (Session, error) {
99 dbSession, err := s.q.GetSessionByID(ctx, id)
100 if err != nil {
101 return Session{}, err
102 }
103 return s.fromDBItem(dbSession), nil
104}
105
106func (s *service) Save(ctx context.Context, session Session) (Session, error) {
107 dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
108 ID: session.ID,
109 Title: session.Title,
110 PromptTokens: session.PromptTokens,
111 CompletionTokens: session.CompletionTokens,
112 SummaryMessageID: sql.NullString{
113 String: session.SummaryMessageID,
114 Valid: session.SummaryMessageID != "",
115 },
116 Cost: session.Cost,
117 })
118 if err != nil {
119 return Session{}, err
120 }
121 session = s.fromDBItem(dbSession)
122 s.Publish(pubsub.UpdatedEvent, session)
123 return session, nil
124}
125
126func (s *service) List(ctx context.Context) ([]Session, error) {
127 dbSessions, err := s.q.ListSessions(ctx)
128 if err != nil {
129 return nil, err
130 }
131 sessions := make([]Session, len(dbSessions))
132 for i, dbSession := range dbSessions {
133 sessions[i] = s.fromDBItem(dbSession)
134 }
135 return sessions, nil
136}
137
138func (s service) fromDBItem(item db.Session) Session {
139 return Session{
140 ID: item.ID,
141 ParentSessionID: item.ParentSessionID.String,
142 Title: item.Title,
143 MessageCount: item.MessageCount,
144 PromptTokens: item.PromptTokens,
145 CompletionTokens: item.CompletionTokens,
146 SummaryMessageID: item.SummaryMessageID.String,
147 Cost: item.Cost,
148 CreatedAt: item.CreatedAt,
149 UpdatedAt: item.UpdatedAt,
150 }
151}
152
153func NewService(q db.Querier) Service {
154 broker := pubsub.NewBroker[Session]()
155 return &service{
156 broker,
157 q,
158 }
159}