1package session
2
3import (
4 "context"
5 "database/sql"
6
7 "github.com/charmbracelet/crush/internal/db"
8 "github.com/charmbracelet/crush/internal/pubsub"
9 "github.com/google/uuid"
10)
11
12type Session struct {
13 ID string
14 ParentSessionID string
15 Title string
16 MessageCount int64
17 PromptTokens int64
18 CompletionTokens int64
19 SummaryMessageID string
20 Cost float64
21 CreatedAt int64
22 UpdatedAt int64
23}
24
25type Service interface {
26 pubsub.Suscriber[Session]
27 Create(ctx context.Context, title string) (Session, error)
28 CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
29 CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
30 Get(ctx context.Context, id string) (Session, error)
31 List(ctx context.Context) ([]Session, error)
32 Save(ctx context.Context, session Session) (Session, error)
33 Delete(ctx context.Context, id string) error
34}
35
36type service struct {
37 *pubsub.Broker[Session]
38 q db.Querier
39}
40
41func (s *service) Create(ctx context.Context, title string) (Session, error) {
42 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
43 ID: uuid.New().String(),
44 Title: title,
45 })
46 if err != nil {
47 return Session{}, err
48 }
49 session := s.fromDBItem(dbSession)
50 s.Publish(pubsub.CreatedEvent, session)
51 return session, nil
52}
53
54func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
55 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
56 ID: toolCallID,
57 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
58 Title: title,
59 })
60 if err != nil {
61 return Session{}, err
62 }
63 session := s.fromDBItem(dbSession)
64 s.Publish(pubsub.CreatedEvent, session)
65 return session, nil
66}
67
68func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
69 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
70 ID: "title-" + parentSessionID,
71 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
72 Title: "Generate a title",
73 })
74 if err != nil {
75 return Session{}, err
76 }
77 session := s.fromDBItem(dbSession)
78 s.Publish(pubsub.CreatedEvent, session)
79 return session, nil
80}
81
82func (s *service) Delete(ctx context.Context, id string) error {
83 session, err := s.Get(ctx, id)
84 if err != nil {
85 return err
86 }
87 err = s.q.DeleteSession(ctx, session.ID)
88 if err != nil {
89 return err
90 }
91 s.Publish(pubsub.DeletedEvent, session)
92 return nil
93}
94
95func (s *service) Get(ctx context.Context, id string) (Session, error) {
96 dbSession, err := s.q.GetSessionByID(ctx, id)
97 if err != nil {
98 return Session{}, err
99 }
100 return s.fromDBItem(dbSession), nil
101}
102
103func (s *service) Save(ctx context.Context, session Session) (Session, error) {
104 dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
105 ID: session.ID,
106 Title: session.Title,
107 PromptTokens: session.PromptTokens,
108 CompletionTokens: session.CompletionTokens,
109 SummaryMessageID: sql.NullString{
110 String: session.SummaryMessageID,
111 Valid: session.SummaryMessageID != "",
112 },
113 Cost: session.Cost,
114 })
115 if err != nil {
116 return Session{}, err
117 }
118 session = s.fromDBItem(dbSession)
119 s.Publish(pubsub.UpdatedEvent, session)
120 return session, nil
121}
122
123func (s *service) List(ctx context.Context) ([]Session, error) {
124 dbSessions, err := s.q.ListSessions(ctx)
125 if err != nil {
126 return nil, err
127 }
128 sessions := make([]Session, len(dbSessions))
129 for i, dbSession := range dbSessions {
130 sessions[i] = s.fromDBItem(dbSession)
131 }
132 return sessions, nil
133}
134
135func (s service) fromDBItem(item db.Session) Session {
136 return Session{
137 ID: item.ID,
138 ParentSessionID: item.ParentSessionID.String,
139 Title: item.Title,
140 MessageCount: item.MessageCount,
141 PromptTokens: item.PromptTokens,
142 CompletionTokens: item.CompletionTokens,
143 SummaryMessageID: item.SummaryMessageID.String,
144 Cost: item.Cost,
145 CreatedAt: item.CreatedAt,
146 UpdatedAt: item.UpdatedAt,
147 }
148}
149
150func NewService(q db.Querier) Service {
151 broker := pubsub.NewBroker[Session]()
152 return &service{
153 broker,
154 q,
155 }
156}