session.go

  1package session
  2
  3import (
  4	"context"
  5	"database/sql"
  6
  7	"github.com/charmbracelet/crush/internal/db"
  8	"github.com/charmbracelet/crush/internal/pubsub"
  9	"github.com/google/uuid"
 10)
 11
 12type Session struct {
 13	ID               string  `json:"id"`
 14	ParentSessionID  string  `json:"parent_session_id"`
 15	Title            string  `json:"title"`
 16	MessageCount     int64   `json:"message_count"`
 17	PromptTokens     int64   `json:"prompt_tokens"`
 18	CompletionTokens int64   `json:"completion_tokens"`
 19	SummaryMessageID string  `json:"summary_message_id"`
 20	Cost             float64 `json:"cost"`
 21	CreatedAt        int64   `json:"created_at"`
 22	UpdatedAt        int64   `json:"updated_at"`
 23}
 24
 25type Service interface {
 26	pubsub.Suscriber[Session]
 27	Create(ctx context.Context, title string) (Session, error)
 28	CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
 29	CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
 30	Get(ctx context.Context, id string) (Session, error)
 31	List(ctx context.Context) ([]Session, error)
 32	Save(ctx context.Context, session Session) (Session, error)
 33	Delete(ctx context.Context, id string) error
 34}
 35
 36type service struct {
 37	*pubsub.Broker[Session]
 38	q db.Querier
 39}
 40
 41func (s *service) Create(ctx context.Context, title string) (Session, error) {
 42	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 43		ID:    uuid.New().String(),
 44		Title: title,
 45	})
 46	if err != nil {
 47		return Session{}, err
 48	}
 49	session := s.fromDBItem(dbSession)
 50	s.Publish(pubsub.CreatedEvent, session)
 51	return session, nil
 52}
 53
 54func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
 55	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 56		ID:              toolCallID,
 57		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
 58		Title:           title,
 59	})
 60	if err != nil {
 61		return Session{}, err
 62	}
 63	session := s.fromDBItem(dbSession)
 64	s.Publish(pubsub.CreatedEvent, session)
 65	return session, nil
 66}
 67
 68func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
 69	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 70		ID:              "title-" + parentSessionID,
 71		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
 72		Title:           "Generate a title",
 73	})
 74	if err != nil {
 75		return Session{}, err
 76	}
 77	session := s.fromDBItem(dbSession)
 78	s.Publish(pubsub.CreatedEvent, session)
 79	return session, nil
 80}
 81
 82func (s *service) Delete(ctx context.Context, id string) error {
 83	session, err := s.Get(ctx, id)
 84	if err != nil {
 85		return err
 86	}
 87	err = s.q.DeleteSession(ctx, session.ID)
 88	if err != nil {
 89		return err
 90	}
 91	s.Publish(pubsub.DeletedEvent, session)
 92	return nil
 93}
 94
 95func (s *service) Get(ctx context.Context, id string) (Session, error) {
 96	dbSession, err := s.q.GetSessionByID(ctx, id)
 97	if err != nil {
 98		return Session{}, err
 99	}
100	return s.fromDBItem(dbSession), nil
101}
102
103func (s *service) Save(ctx context.Context, session Session) (Session, error) {
104	dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
105		ID:               session.ID,
106		Title:            session.Title,
107		PromptTokens:     session.PromptTokens,
108		CompletionTokens: session.CompletionTokens,
109		SummaryMessageID: sql.NullString{
110			String: session.SummaryMessageID,
111			Valid:  session.SummaryMessageID != "",
112		},
113		Cost: session.Cost,
114	})
115	if err != nil {
116		return Session{}, err
117	}
118	session = s.fromDBItem(dbSession)
119	s.Publish(pubsub.UpdatedEvent, session)
120	return session, nil
121}
122
123func (s *service) List(ctx context.Context) ([]Session, error) {
124	dbSessions, err := s.q.ListSessions(ctx)
125	if err != nil {
126		return nil, err
127	}
128	sessions := make([]Session, len(dbSessions))
129	for i, dbSession := range dbSessions {
130		sessions[i] = s.fromDBItem(dbSession)
131	}
132	return sessions, nil
133}
134
135func (s service) fromDBItem(item db.Session) Session {
136	return Session{
137		ID:               item.ID,
138		ParentSessionID:  item.ParentSessionID.String,
139		Title:            item.Title,
140		MessageCount:     item.MessageCount,
141		PromptTokens:     item.PromptTokens,
142		CompletionTokens: item.CompletionTokens,
143		SummaryMessageID: item.SummaryMessageID.String,
144		Cost:             item.Cost,
145		CreatedAt:        item.CreatedAt,
146		UpdatedAt:        item.UpdatedAt,
147	}
148}
149
150func NewService(q db.Querier) Service {
151	broker := pubsub.NewBroker[Session]()
152	return &service{
153		broker,
154		q,
155	}
156}