session.go

  1package session
  2
  3import (
  4	"context"
  5	"database/sql"
  6
  7	"github.com/charmbracelet/crush/internal/db"
  8	"github.com/charmbracelet/crush/internal/pubsub"
  9	"github.com/google/uuid"
 10)
 11
 12type Session struct {
 13	ID               string
 14	ParentSessionID  string
 15	Title            string
 16	MessageCount     int64
 17	PromptTokens     int64
 18	CompletionTokens int64
 19	SummaryMessageID string
 20	Cost             float64
 21	CreatedAt        int64
 22	UpdatedAt        int64
 23}
 24
 25type Service interface {
 26	pubsub.Suscriber[Session]
 27	Create(ctx context.Context, title string) (Session, error)
 28	CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
 29	CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
 30	Get(ctx context.Context, id string) (Session, error)
 31	List(ctx context.Context) ([]Session, error)
 32	ListAll(ctx context.Context) ([]Session, error)
 33	ListChildren(ctx context.Context, parentSessionID string) ([]Session, error)
 34	Save(ctx context.Context, session Session) (Session, error)
 35	Delete(ctx context.Context, id string) error
 36}
 37
 38type service struct {
 39	*pubsub.Broker[Session]
 40	q db.Querier
 41}
 42
 43func (s *service) Create(ctx context.Context, title string) (Session, error) {
 44	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 45		ID:    uuid.New().String(),
 46		Title: title,
 47	})
 48	if err != nil {
 49		return Session{}, err
 50	}
 51	session := s.fromDBItem(dbSession)
 52	s.Publish(pubsub.CreatedEvent, session)
 53	return session, nil
 54}
 55
 56func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
 57	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 58		ID:              toolCallID,
 59		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
 60		Title:           title,
 61	})
 62	if err != nil {
 63		return Session{}, err
 64	}
 65	session := s.fromDBItem(dbSession)
 66	s.Publish(pubsub.CreatedEvent, session)
 67	return session, nil
 68}
 69
 70func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
 71	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 72		ID:              "title-" + parentSessionID,
 73		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
 74		Title:           "Generate a title",
 75	})
 76	if err != nil {
 77		return Session{}, err
 78	}
 79	session := s.fromDBItem(dbSession)
 80	s.Publish(pubsub.CreatedEvent, session)
 81	return session, nil
 82}
 83
 84func (s *service) Delete(ctx context.Context, id string) error {
 85	session, err := s.Get(ctx, id)
 86	if err != nil {
 87		return err
 88	}
 89	err = s.q.DeleteSession(ctx, session.ID)
 90	if err != nil {
 91		return err
 92	}
 93	s.Publish(pubsub.DeletedEvent, session)
 94	return nil
 95}
 96
 97func (s *service) Get(ctx context.Context, id string) (Session, error) {
 98	dbSession, err := s.q.GetSessionByID(ctx, id)
 99	if err != nil {
100		return Session{}, err
101	}
102	return s.fromDBItem(dbSession), nil
103}
104
105func (s *service) Save(ctx context.Context, session Session) (Session, error) {
106	dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
107		ID:               session.ID,
108		Title:            session.Title,
109		PromptTokens:     session.PromptTokens,
110		CompletionTokens: session.CompletionTokens,
111		SummaryMessageID: sql.NullString{
112			String: session.SummaryMessageID,
113			Valid:  session.SummaryMessageID != "",
114		},
115		Cost: session.Cost,
116	})
117	if err != nil {
118		return Session{}, err
119	}
120	session = s.fromDBItem(dbSession)
121	s.Publish(pubsub.UpdatedEvent, session)
122	return session, nil
123}
124
125func (s *service) List(ctx context.Context) ([]Session, error) {
126	dbSessions, err := s.q.ListSessions(ctx)
127	if err != nil {
128		return nil, err
129	}
130	sessions := make([]Session, len(dbSessions))
131	for i, dbSession := range dbSessions {
132		sessions[i] = s.fromDBItem(dbSession)
133	}
134	return sessions, nil
135}
136
137func (s *service) ListAll(ctx context.Context) ([]Session, error) {
138	dbSessions, err := s.q.ListAllSessions(ctx)
139	if err != nil {
140		return nil, err
141	}
142	sessions := make([]Session, len(dbSessions))
143	for i, dbSession := range dbSessions {
144		sessions[i] = s.fromDBItem(dbSession)
145	}
146	return sessions, nil
147}
148
149func (s *service) ListChildren(ctx context.Context, parentSessionID string) ([]Session, error) {
150	dbSessions, err := s.q.ListChildSessions(ctx, sql.NullString{
151		String: parentSessionID,
152		Valid:  true,
153	})
154	if err != nil {
155		return nil, err
156	}
157	sessions := make([]Session, len(dbSessions))
158	for i, dbSession := range dbSessions {
159		sessions[i] = s.fromDBItem(dbSession)
160	}
161	return sessions, nil
162}
163
164func (s service) fromDBItem(item db.Session) Session {
165	return Session{
166		ID:               item.ID,
167		ParentSessionID:  item.ParentSessionID.String,
168		Title:            item.Title,
169		MessageCount:     item.MessageCount,
170		PromptTokens:     item.PromptTokens,
171		CompletionTokens: item.CompletionTokens,
172		SummaryMessageID: item.SummaryMessageID.String,
173		Cost:             item.Cost,
174		CreatedAt:        item.CreatedAt,
175		UpdatedAt:        item.UpdatedAt,
176	}
177}
178
179func NewService(q db.Querier) Service {
180	broker := pubsub.NewBroker[Session]()
181	return &service{
182		broker,
183		q,
184	}
185}