1package session
2
3import (
4 "context"
5 "database/sql"
6 "fmt"
7 "strings"
8
9 "github.com/charmbracelet/crush/internal/db"
10 "github.com/charmbracelet/crush/internal/event"
11 "github.com/charmbracelet/crush/internal/pubsub"
12 "github.com/google/uuid"
13)
14
15type Session struct {
16 ID string
17 ParentSessionID string
18 Title string
19 MessageCount int64
20 PromptTokens int64
21 CompletionTokens int64
22 SummaryMessageID string
23 Cost float64
24 CreatedAt int64
25 UpdatedAt int64
26}
27
28type Service interface {
29 pubsub.Suscriber[Session]
30 Create(ctx context.Context, title string) (Session, error)
31 CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
32 CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
33 Get(ctx context.Context, id string) (Session, error)
34 List(ctx context.Context) ([]Session, error)
35 Save(ctx context.Context, session Session) (Session, error)
36 Delete(ctx context.Context, id string) error
37
38 // Agent tool session management
39 CreateAgentToolSessionID(messageID, toolCallID string) string
40 ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
41 IsAgentToolSession(sessionID string) bool
42}
43
44type service struct {
45 *pubsub.Broker[Session]
46 q db.Querier
47}
48
49func (s *service) Create(ctx context.Context, title string) (Session, error) {
50 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
51 ID: uuid.New().String(),
52 Title: title,
53 })
54 if err != nil {
55 return Session{}, err
56 }
57 session := s.fromDBItem(dbSession)
58 s.Publish(pubsub.CreatedEvent, session)
59 event.SessionCreated()
60 return session, nil
61}
62
63func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
64 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
65 ID: toolCallID,
66 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
67 Title: title,
68 })
69 if err != nil {
70 return Session{}, err
71 }
72 session := s.fromDBItem(dbSession)
73 s.Publish(pubsub.CreatedEvent, session)
74 return session, nil
75}
76
77func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
78 dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
79 ID: "title-" + parentSessionID,
80 ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
81 Title: "Generate a title",
82 })
83 if err != nil {
84 return Session{}, err
85 }
86 session := s.fromDBItem(dbSession)
87 s.Publish(pubsub.CreatedEvent, session)
88 return session, nil
89}
90
91func (s *service) Delete(ctx context.Context, id string) error {
92 session, err := s.Get(ctx, id)
93 if err != nil {
94 return err
95 }
96 err = s.q.DeleteSession(ctx, session.ID)
97 if err != nil {
98 return err
99 }
100 s.Publish(pubsub.DeletedEvent, session)
101 event.SessionDeleted()
102 return nil
103}
104
105func (s *service) Get(ctx context.Context, id string) (Session, error) {
106 dbSession, err := s.q.GetSessionByID(ctx, id)
107 if err != nil {
108 return Session{}, err
109 }
110 return s.fromDBItem(dbSession), nil
111}
112
113func (s *service) Save(ctx context.Context, session Session) (Session, error) {
114 dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
115 ID: session.ID,
116 Title: session.Title,
117 PromptTokens: session.PromptTokens,
118 CompletionTokens: session.CompletionTokens,
119 SummaryMessageID: sql.NullString{
120 String: session.SummaryMessageID,
121 Valid: session.SummaryMessageID != "",
122 },
123 Cost: session.Cost,
124 })
125 if err != nil {
126 return Session{}, err
127 }
128 session = s.fromDBItem(dbSession)
129 s.Publish(pubsub.UpdatedEvent, session)
130 return session, nil
131}
132
133func (s *service) List(ctx context.Context) ([]Session, error) {
134 dbSessions, err := s.q.ListSessions(ctx)
135 if err != nil {
136 return nil, err
137 }
138 sessions := make([]Session, len(dbSessions))
139 for i, dbSession := range dbSessions {
140 sessions[i] = s.fromDBItem(dbSession)
141 }
142 return sessions, nil
143}
144
145func (s service) fromDBItem(item db.Session) Session {
146 return Session{
147 ID: item.ID,
148 ParentSessionID: item.ParentSessionID.String,
149 Title: item.Title,
150 MessageCount: item.MessageCount,
151 PromptTokens: item.PromptTokens,
152 CompletionTokens: item.CompletionTokens,
153 SummaryMessageID: item.SummaryMessageID.String,
154 Cost: item.Cost,
155 CreatedAt: item.CreatedAt,
156 UpdatedAt: item.UpdatedAt,
157 }
158}
159
160func NewService(q db.Querier) Service {
161 broker := pubsub.NewBroker[Session]()
162 return &service{
163 broker,
164 q,
165 }
166}
167
168// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
169func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
170 return fmt.Sprintf("%s$$%s", messageID, toolCallID)
171}
172
173// ParseAgentToolSessionID parses an agent tool session ID into its components
174func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
175 parts := strings.Split(sessionID, "$$")
176 if len(parts) != 2 {
177 return "", "", false
178 }
179 return parts[0], parts[1], true
180}
181
182// IsAgentToolSession checks if a session ID follows the agent tool session format
183func (s *service) IsAgentToolSession(sessionID string) bool {
184 _, _, ok := s.ParseAgentToolSessionID(sessionID)
185 return ok
186}