1package session
  2
  3import (
  4	"context"
  5	"database/sql"
  6
  7	"github.com/charmbracelet/crush/internal/db"
  8	"github.com/charmbracelet/crush/internal/event"
  9	"github.com/charmbracelet/crush/internal/proto"
 10	"github.com/charmbracelet/crush/internal/pubsub"
 11	"github.com/google/uuid"
 12)
 13
 14type Session = proto.Session
 15
 16type Service interface {
 17	pubsub.Suscriber[Session]
 18	Create(ctx context.Context, title string) (Session, error)
 19	CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
 20	CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
 21	Get(ctx context.Context, id string) (Session, error)
 22	List(ctx context.Context) ([]Session, error)
 23	Save(ctx context.Context, session Session) (Session, error)
 24	Delete(ctx context.Context, id string) error
 25}
 26
 27type service struct {
 28	*pubsub.Broker[Session]
 29	q db.Querier
 30}
 31
 32func (s *service) Create(ctx context.Context, title string) (Session, error) {
 33	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 34		ID:    uuid.New().String(),
 35		Title: title,
 36	})
 37	if err != nil {
 38		return Session{}, err
 39	}
 40	session := s.fromDBItem(dbSession)
 41	s.Publish(pubsub.CreatedEvent, session)
 42	event.SessionCreated()
 43	return session, nil
 44}
 45
 46func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
 47	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 48		ID:              toolCallID,
 49		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
 50		Title:           title,
 51	})
 52	if err != nil {
 53		return Session{}, err
 54	}
 55	session := s.fromDBItem(dbSession)
 56	s.Publish(pubsub.CreatedEvent, session)
 57	return session, nil
 58}
 59
 60func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
 61	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 62		ID:              "title-" + parentSessionID,
 63		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
 64		Title:           "Generate a title",
 65	})
 66	if err != nil {
 67		return Session{}, err
 68	}
 69	session := s.fromDBItem(dbSession)
 70	s.Publish(pubsub.CreatedEvent, session)
 71	return session, nil
 72}
 73
 74func (s *service) Delete(ctx context.Context, id string) error {
 75	session, err := s.Get(ctx, id)
 76	if err != nil {
 77		return err
 78	}
 79	err = s.q.DeleteSession(ctx, session.ID)
 80	if err != nil {
 81		return err
 82	}
 83	s.Publish(pubsub.DeletedEvent, session)
 84	event.SessionDeleted()
 85	return nil
 86}
 87
 88func (s *service) Get(ctx context.Context, id string) (Session, error) {
 89	dbSession, err := s.q.GetSessionByID(ctx, id)
 90	if err != nil {
 91		return Session{}, err
 92	}
 93	return s.fromDBItem(dbSession), nil
 94}
 95
 96func (s *service) Save(ctx context.Context, session Session) (Session, error) {
 97	dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
 98		ID:               session.ID,
 99		Title:            session.Title,
100		PromptTokens:     session.PromptTokens,
101		CompletionTokens: session.CompletionTokens,
102		SummaryMessageID: sql.NullString{
103			String: session.SummaryMessageID,
104			Valid:  session.SummaryMessageID != "",
105		},
106		Cost: session.Cost,
107	})
108	if err != nil {
109		return Session{}, err
110	}
111	session = s.fromDBItem(dbSession)
112	s.Publish(pubsub.UpdatedEvent, session)
113	return session, nil
114}
115
116func (s *service) List(ctx context.Context) ([]Session, error) {
117	dbSessions, err := s.q.ListSessions(ctx)
118	if err != nil {
119		return nil, err
120	}
121	sessions := make([]Session, len(dbSessions))
122	for i, dbSession := range dbSessions {
123		sessions[i] = s.fromDBItem(dbSession)
124	}
125	return sessions, nil
126}
127
128func (s service) fromDBItem(item db.Session) Session {
129	return Session{
130		ID:               item.ID,
131		ParentSessionID:  item.ParentSessionID.String,
132		Title:            item.Title,
133		MessageCount:     item.MessageCount,
134		PromptTokens:     item.PromptTokens,
135		CompletionTokens: item.CompletionTokens,
136		SummaryMessageID: item.SummaryMessageID.String,
137		Cost:             item.Cost,
138		CreatedAt:        item.CreatedAt,
139		UpdatedAt:        item.UpdatedAt,
140	}
141}
142
143func NewService(q db.Querier) Service {
144	broker := pubsub.NewBroker[Session]()
145	return &service{
146		broker,
147		q,
148	}
149}