session.go

  1package session
  2
  3import (
  4	"context"
  5	"database/sql"
  6
  7	"github.com/charmbracelet/crush/internal/db"
  8	"github.com/charmbracelet/crush/internal/pubsub"
  9	"github.com/google/uuid"
 10)
 11
 12type Session struct {
 13	ID               string
 14	ParentSessionID  string
 15	Title            string
 16	MessageCount     int64
 17	PromptTokens     int64
 18	CompletionTokens int64
 19	SummaryMessageID string
 20	Cost             float64
 21	CreatedAt        int64
 22	UpdatedAt        int64
 23}
 24
 25type Service interface {
 26	pubsub.Suscriber[Session]
 27	Create(ctx context.Context, title string) (Session, error)
 28	CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
 29	CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
 30	Get(ctx context.Context, id string) (Session, error)
 31	List(ctx context.Context) ([]Session, error)
 32	ListAll(ctx context.Context) ([]Session, error)
 33	ListChildren(ctx context.Context, parentSessionID string) ([]Session, error)
 34	Save(ctx context.Context, session Session) (Session, error)
 35	Delete(ctx context.Context, id string) error
 36	SearchByTitle(ctx context.Context, titlePattern string) ([]Session, error)
 37	SearchByText(ctx context.Context, textPattern string) ([]Session, error)
 38	SearchByTitleAndText(ctx context.Context, titlePattern, textPattern string) ([]Session, error)
 39}
 40
 41type service struct {
 42	*pubsub.Broker[Session]
 43	q db.Querier
 44}
 45
 46func (s *service) Create(ctx context.Context, title string) (Session, error) {
 47	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 48		ID:    uuid.New().String(),
 49		Title: title,
 50	})
 51	if err != nil {
 52		return Session{}, err
 53	}
 54	session := s.fromDBItem(dbSession)
 55	s.Publish(pubsub.CreatedEvent, session)
 56	return session, nil
 57}
 58
 59func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
 60	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 61		ID:              toolCallID,
 62		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
 63		Title:           title,
 64	})
 65	if err != nil {
 66		return Session{}, err
 67	}
 68	session := s.fromDBItem(dbSession)
 69	s.Publish(pubsub.CreatedEvent, session)
 70	return session, nil
 71}
 72
 73func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
 74	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 75		ID:              "title-" + parentSessionID,
 76		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
 77		Title:           "Generate a title",
 78	})
 79	if err != nil {
 80		return Session{}, err
 81	}
 82	session := s.fromDBItem(dbSession)
 83	s.Publish(pubsub.CreatedEvent, session)
 84	return session, nil
 85}
 86
 87func (s *service) Delete(ctx context.Context, id string) error {
 88	session, err := s.Get(ctx, id)
 89	if err != nil {
 90		return err
 91	}
 92	err = s.q.DeleteSession(ctx, session.ID)
 93	if err != nil {
 94		return err
 95	}
 96	s.Publish(pubsub.DeletedEvent, session)
 97	return nil
 98}
 99
100func (s *service) Get(ctx context.Context, id string) (Session, error) {
101	dbSession, err := s.q.GetSessionByID(ctx, id)
102	if err != nil {
103		return Session{}, err
104	}
105	return s.fromDBItem(dbSession), nil
106}
107
108func (s *service) Save(ctx context.Context, session Session) (Session, error) {
109	dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
110		ID:               session.ID,
111		Title:            session.Title,
112		PromptTokens:     session.PromptTokens,
113		CompletionTokens: session.CompletionTokens,
114		SummaryMessageID: sql.NullString{
115			String: session.SummaryMessageID,
116			Valid:  session.SummaryMessageID != "",
117		},
118		Cost: session.Cost,
119	})
120	if err != nil {
121		return Session{}, err
122	}
123	session = s.fromDBItem(dbSession)
124	s.Publish(pubsub.UpdatedEvent, session)
125	return session, nil
126}
127
128func (s *service) List(ctx context.Context) ([]Session, error) {
129	dbSessions, err := s.q.ListSessions(ctx)
130	if err != nil {
131		return nil, err
132	}
133	sessions := make([]Session, len(dbSessions))
134	for i, dbSession := range dbSessions {
135		sessions[i] = s.fromDBItem(dbSession)
136	}
137	return sessions, nil
138}
139
140func (s *service) ListAll(ctx context.Context) ([]Session, error) {
141	dbSessions, err := s.q.ListAllSessions(ctx)
142	if err != nil {
143		return nil, err
144	}
145	sessions := make([]Session, len(dbSessions))
146	for i, dbSession := range dbSessions {
147		sessions[i] = s.fromDBItem(dbSession)
148	}
149	return sessions, nil
150}
151
152func (s *service) ListChildren(ctx context.Context, parentSessionID string) ([]Session, error) {
153	dbSessions, err := s.q.ListChildSessions(ctx, sql.NullString{
154		String: parentSessionID,
155		Valid:  true,
156	})
157	if err != nil {
158		return nil, err
159	}
160	sessions := make([]Session, len(dbSessions))
161	for i, dbSession := range dbSessions {
162		sessions[i] = s.fromDBItem(dbSession)
163	}
164	return sessions, nil
165}
166
167func (s *service) SearchByTitle(ctx context.Context, titlePattern string) ([]Session, error) {
168	dbSessions, err := s.q.SearchSessionsByTitle(ctx, "%"+titlePattern+"%")
169	if err != nil {
170		return nil, err
171	}
172	sessions := make([]Session, len(dbSessions))
173	for i, dbSession := range dbSessions {
174		sessions[i] = s.fromDBItem(dbSession)
175	}
176	return sessions, nil
177}
178
179func (s *service) SearchByText(ctx context.Context, textPattern string) ([]Session, error) {
180	dbSessions, err := s.q.SearchSessionsByText(ctx, "%"+textPattern+"%")
181	if err != nil {
182		return nil, err
183	}
184	sessions := make([]Session, len(dbSessions))
185	for i, dbSession := range dbSessions {
186		sessions[i] = s.fromDBItem(dbSession)
187	}
188	return sessions, nil
189}
190
191func (s *service) SearchByTitleAndText(ctx context.Context, titlePattern, textPattern string) ([]Session, error) {
192	dbSessions, err := s.q.SearchSessionsByTitleAndText(ctx, db.SearchSessionsByTitleAndTextParams{
193		Title: "%" + titlePattern + "%",
194		Parts: "%" + textPattern + "%",
195	})
196	if err != nil {
197		return nil, err
198	}
199	sessions := make([]Session, len(dbSessions))
200	for i, dbSession := range dbSessions {
201		sessions[i] = s.fromDBItem(dbSession)
202	}
203	return sessions, nil
204}
205
206func (s service) fromDBItem(item db.Session) Session {
207	return Session{
208		ID:               item.ID,
209		ParentSessionID:  item.ParentSessionID.String,
210		Title:            item.Title,
211		MessageCount:     item.MessageCount,
212		PromptTokens:     item.PromptTokens,
213		CompletionTokens: item.CompletionTokens,
214		SummaryMessageID: item.SummaryMessageID.String,
215		Cost:             item.Cost,
216		CreatedAt:        item.CreatedAt,
217		UpdatedAt:        item.UpdatedAt,
218	}
219}
220
221func NewService(q db.Querier) Service {
222	broker := pubsub.NewBroker[Session]()
223	return &service{
224		broker,
225		q,
226	}
227}