1package session
2
3import (
4 "context"
5 "database/sql"
6
7 "github.com/charmbracelet/crush/internal/db"
8 "github.com/charmbracelet/crush/internal/event"
9 "github.com/charmbracelet/crush/internal/proto"
10 "github.com/charmbracelet/crush/internal/pubsub"
11 "github.com/google/uuid"
12)
13
14type Session = proto.Session
15
16type Service interface {
17 pubsub.Suscriber[Session]
18 Create(ctx context.Context, title string) (Session, error)
19 CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
20 CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
21 Get(ctx context.Context, id string) (Session, error)
22 List(ctx context.Context) ([]Session, error)
23 Save(ctx context.Context, session Session) (Session, error)
24 Delete(ctx context.Context, id string) error
25}
26
27type service struct {
28 *pubsub.Broker[Session]
29 q db.Querier
30}
31
32func (s *service) Create(ctx context.Context, title string) (Session, error) {
33 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
34 ID: uuid.New().String(),
35 Title: title,
36 })
37 if err != nil {
38 return Session{}, err
39 }
40 session := s.fromDBItem(dbSession)
41 s.Publish(pubsub.CreatedEvent, session)
42 event.SessionCreated()
43 return session, nil
44}
45
46func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
47 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
48 ID: toolCallID,
49 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
50 Title: title,
51 })
52 if err != nil {
53 return Session{}, err
54 }
55 session := s.fromDBItem(dbSession)
56 s.Publish(pubsub.CreatedEvent, session)
57 return session, nil
58}
59
60func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
61 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
62 ID: "title-" + parentSessionID,
63 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
64 Title: "Generate a title",
65 })
66 if err != nil {
67 return Session{}, err
68 }
69 session := s.fromDBItem(dbSession)
70 s.Publish(pubsub.CreatedEvent, session)
71 return session, nil
72}
73
74func (s *service) Delete(ctx context.Context, id string) error {
75 session, err := s.Get(ctx, id)
76 if err != nil {
77 return err
78 }
79 err = s.q.DeleteSession(ctx, session.ID)
80 if err != nil {
81 return err
82 }
83 s.Publish(pubsub.DeletedEvent, session)
84 event.SessionDeleted()
85 return nil
86}
87
88func (s *service) Get(ctx context.Context, id string) (Session, error) {
89 dbSession, err := s.q.GetSessionByID(ctx, id)
90 if err != nil {
91 return Session{}, err
92 }
93 return s.fromDBItem(dbSession), nil
94}
95
96func (s *service) Save(ctx context.Context, session Session) (Session, error) {
97 dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
98 ID: session.ID,
99 Title: session.Title,
100 PromptTokens: session.PromptTokens,
101 CompletionTokens: session.CompletionTokens,
102 SummaryMessageID: sql.NullString{
103 String: session.SummaryMessageID,
104 Valid: session.SummaryMessageID != "",
105 },
106 Cost: session.Cost,
107 })
108 if err != nil {
109 return Session{}, err
110 }
111 session = s.fromDBItem(dbSession)
112 s.Publish(pubsub.UpdatedEvent, session)
113 return session, nil
114}
115
116func (s *service) List(ctx context.Context) ([]Session, error) {
117 dbSessions, err := s.q.ListSessions(ctx)
118 if err != nil {
119 return nil, err
120 }
121 sessions := make([]Session, len(dbSessions))
122 for i, dbSession := range dbSessions {
123 sessions[i] = s.fromDBItem(dbSession)
124 }
125 return sessions, nil
126}
127
128func (s service) fromDBItem(item db.Session) Session {
129 return Session{
130 ID: item.ID,
131 ParentSessionID: item.ParentSessionID.String,
132 Title: item.Title,
133 MessageCount: item.MessageCount,
134 PromptTokens: item.PromptTokens,
135 CompletionTokens: item.CompletionTokens,
136 SummaryMessageID: item.SummaryMessageID.String,
137 Cost: item.Cost,
138 CreatedAt: item.CreatedAt,
139 UpdatedAt: item.UpdatedAt,
140 }
141}
142
143func NewService(q db.Querier) Service {
144 broker := pubsub.NewBroker[Session]()
145 return &service{
146 broker,
147 q,
148 }
149}