session.go

  1package session
  2
  3import (
  4	"context"
  5	"database/sql"
  6
  7	"github.com/charmbracelet/crush/internal/db"
  8	"github.com/charmbracelet/crush/internal/proto"
  9	"github.com/charmbracelet/crush/internal/pubsub"
 10	"github.com/google/uuid"
 11)
 12
 13type Session = proto.Session
 14
 15type Service interface {
 16	pubsub.Suscriber[Session]
 17	Create(ctx context.Context, title string) (Session, error)
 18	CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
 19	CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
 20	Get(ctx context.Context, id string) (Session, error)
 21	List(ctx context.Context) ([]Session, error)
 22	Save(ctx context.Context, session Session) (Session, error)
 23	Delete(ctx context.Context, id string) error
 24}
 25
 26type service struct {
 27	*pubsub.Broker[Session]
 28	q db.Querier
 29}
 30
 31func (s *service) Create(ctx context.Context, title string) (Session, error) {
 32	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 33		ID:    uuid.New().String(),
 34		Title: title,
 35	})
 36	if err != nil {
 37		return Session{}, err
 38	}
 39	session := s.fromDBItem(dbSession)
 40	s.Publish(pubsub.CreatedEvent, session)
 41	return session, nil
 42}
 43
 44func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
 45	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 46		ID:              toolCallID,
 47		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
 48		Title:           title,
 49	})
 50	if err != nil {
 51		return Session{}, err
 52	}
 53	session := s.fromDBItem(dbSession)
 54	s.Publish(pubsub.CreatedEvent, session)
 55	return session, nil
 56}
 57
 58func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
 59	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 60		ID:              "title-" + parentSessionID,
 61		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
 62		Title:           "Generate a title",
 63	})
 64	if err != nil {
 65		return Session{}, err
 66	}
 67	session := s.fromDBItem(dbSession)
 68	s.Publish(pubsub.CreatedEvent, session)
 69	return session, nil
 70}
 71
 72func (s *service) Delete(ctx context.Context, id string) error {
 73	session, err := s.Get(ctx, id)
 74	if err != nil {
 75		return err
 76	}
 77	err = s.q.DeleteSession(ctx, session.ID)
 78	if err != nil {
 79		return err
 80	}
 81	s.Publish(pubsub.DeletedEvent, session)
 82	return nil
 83}
 84
 85func (s *service) Get(ctx context.Context, id string) (Session, error) {
 86	dbSession, err := s.q.GetSessionByID(ctx, id)
 87	if err != nil {
 88		return Session{}, err
 89	}
 90	return s.fromDBItem(dbSession), nil
 91}
 92
 93func (s *service) Save(ctx context.Context, session Session) (Session, error) {
 94	dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
 95		ID:               session.ID,
 96		Title:            session.Title,
 97		PromptTokens:     session.PromptTokens,
 98		CompletionTokens: session.CompletionTokens,
 99		SummaryMessageID: sql.NullString{
100			String: session.SummaryMessageID,
101			Valid:  session.SummaryMessageID != "",
102		},
103		Cost: session.Cost,
104	})
105	if err != nil {
106		return Session{}, err
107	}
108	session = s.fromDBItem(dbSession)
109	s.Publish(pubsub.UpdatedEvent, session)
110	return session, nil
111}
112
113func (s *service) List(ctx context.Context) ([]Session, error) {
114	dbSessions, err := s.q.ListSessions(ctx)
115	if err != nil {
116		return nil, err
117	}
118	sessions := make([]Session, len(dbSessions))
119	for i, dbSession := range dbSessions {
120		sessions[i] = s.fromDBItem(dbSession)
121	}
122	return sessions, nil
123}
124
125func (s service) fromDBItem(item db.Session) Session {
126	return Session{
127		ID:               item.ID,
128		ParentSessionID:  item.ParentSessionID.String,
129		Title:            item.Title,
130		MessageCount:     item.MessageCount,
131		PromptTokens:     item.PromptTokens,
132		CompletionTokens: item.CompletionTokens,
133		SummaryMessageID: item.SummaryMessageID.String,
134		Cost:             item.Cost,
135		CreatedAt:        item.CreatedAt,
136		UpdatedAt:        item.UpdatedAt,
137	}
138}
139
140func NewService(q db.Querier) Service {
141	broker := pubsub.NewBroker[Session]()
142	return &service{
143		broker,
144		q,
145	}
146}