1package session
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"fmt"
  7	"strings"
  8
  9	"github.com/charmbracelet/crush/internal/db"
 10	"github.com/charmbracelet/crush/internal/event"
 11	"github.com/charmbracelet/crush/internal/pubsub"
 12	"github.com/google/uuid"
 13)
 14
 15type Session struct {
 16	ID               string
 17	ParentSessionID  string
 18	Title            string
 19	MessageCount     int64
 20	PromptTokens     int64
 21	CompletionTokens int64
 22	SummaryMessageID string
 23	Cost             float64
 24	CreatedAt        int64
 25	UpdatedAt        int64
 26}
 27
 28type Service interface {
 29	pubsub.Suscriber[Session]
 30	Create(ctx context.Context, title string) (Session, error)
 31	CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
 32	CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
 33	Get(ctx context.Context, id string) (Session, error)
 34	List(ctx context.Context) ([]Session, error)
 35	Save(ctx context.Context, session Session) (Session, error)
 36	Delete(ctx context.Context, id string) error
 37
 38	// Agent tool session management
 39	CreateAgentToolSessionID(messageID, toolCallID string) string
 40	ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
 41	IsAgentToolSession(sessionID string) bool
 42}
 43
 44type service struct {
 45	*pubsub.Broker[Session]
 46	q db.Querier
 47}
 48
 49func (s *service) Create(ctx context.Context, title string) (Session, error) {
 50	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 51		ID:    uuid.New().String(),
 52		Title: title,
 53	})
 54	if err != nil {
 55		return Session{}, err
 56	}
 57	session := s.fromDBItem(dbSession)
 58	s.Publish(pubsub.CreatedEvent, session)
 59	event.SessionCreated()
 60	return session, nil
 61}
 62
 63func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
 64	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 65		ID:              toolCallID,
 66		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
 67		Title:           title,
 68	})
 69	if err != nil {
 70		return Session{}, err
 71	}
 72	session := s.fromDBItem(dbSession)
 73	s.Publish(pubsub.CreatedEvent, session)
 74	return session, nil
 75}
 76
 77func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
 78	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 79		ID:              "title-" + parentSessionID,
 80		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
 81		Title:           "Generate a title",
 82	})
 83	if err != nil {
 84		return Session{}, err
 85	}
 86	session := s.fromDBItem(dbSession)
 87	s.Publish(pubsub.CreatedEvent, session)
 88	return session, nil
 89}
 90
 91func (s *service) Delete(ctx context.Context, id string) error {
 92	session, err := s.Get(ctx, id)
 93	if err != nil {
 94		return err
 95	}
 96	err = s.q.DeleteSession(ctx, session.ID)
 97	if err != nil {
 98		return err
 99	}
100	s.Publish(pubsub.DeletedEvent, session)
101	event.SessionDeleted()
102	return nil
103}
104
105func (s *service) Get(ctx context.Context, id string) (Session, error) {
106	dbSession, err := s.q.GetSessionByID(ctx, id)
107	if err != nil {
108		return Session{}, err
109	}
110	return s.fromDBItem(dbSession), nil
111}
112
113func (s *service) Save(ctx context.Context, session Session) (Session, error) {
114	dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
115		ID:               session.ID,
116		Title:            session.Title,
117		PromptTokens:     session.PromptTokens,
118		CompletionTokens: session.CompletionTokens,
119		SummaryMessageID: sql.NullString{
120			String: session.SummaryMessageID,
121			Valid:  session.SummaryMessageID != "",
122		},
123		Cost: session.Cost,
124	})
125	if err != nil {
126		return Session{}, err
127	}
128	session = s.fromDBItem(dbSession)
129	s.Publish(pubsub.UpdatedEvent, session)
130	return session, nil
131}
132
133func (s *service) List(ctx context.Context) ([]Session, error) {
134	dbSessions, err := s.q.ListSessions(ctx)
135	if err != nil {
136		return nil, err
137	}
138	sessions := make([]Session, len(dbSessions))
139	for i, dbSession := range dbSessions {
140		sessions[i] = s.fromDBItem(dbSession)
141	}
142	return sessions, nil
143}
144
145func (s service) fromDBItem(item db.Session) Session {
146	return Session{
147		ID:               item.ID,
148		ParentSessionID:  item.ParentSessionID.String,
149		Title:            item.Title,
150		MessageCount:     item.MessageCount,
151		PromptTokens:     item.PromptTokens,
152		CompletionTokens: item.CompletionTokens,
153		SummaryMessageID: item.SummaryMessageID.String,
154		Cost:             item.Cost,
155		CreatedAt:        item.CreatedAt,
156		UpdatedAt:        item.UpdatedAt,
157	}
158}
159
160func NewService(q db.Querier) Service {
161	broker := pubsub.NewBroker[Session]()
162	return &service{
163		broker,
164		q,
165	}
166}
167
168// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
169func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
170	return fmt.Sprintf("%s$$%s", messageID, toolCallID)
171}
172
173// ParseAgentToolSessionID parses an agent tool session ID into its components
174func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
175	parts := strings.Split(sessionID, "$$")
176	if len(parts) != 2 {
177		return "", "", false
178	}
179	return parts[0], parts[1], true
180}
181
182// IsAgentToolSession checks if a session ID follows the agent tool session format
183func (s *service) IsAgentToolSession(sessionID string) bool {
184	_, _, ok := s.ParseAgentToolSessionID(sessionID)
185	return ok
186}