session.go

  1package session
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"strings"
 10
 11	"github.com/charmbracelet/crush/internal/db"
 12	"github.com/charmbracelet/crush/internal/event"
 13	"github.com/charmbracelet/crush/internal/pubsub"
 14	"github.com/google/uuid"
 15)
 16
 17type TodoStatus string
 18
 19const (
 20	TodoStatusPending    TodoStatus = "pending"
 21	TodoStatusInProgress TodoStatus = "in_progress"
 22	TodoStatusCompleted  TodoStatus = "completed"
 23)
 24
 25type Todo struct {
 26	Content    string     `json:"content"`
 27	Status     TodoStatus `json:"status"`
 28	ActiveForm string     `json:"active_form"`
 29}
 30
 31type Session struct {
 32	ID               string
 33	ParentSessionID  string
 34	Title            string
 35	MessageCount     int64
 36	PromptTokens     int64
 37	CompletionTokens int64
 38	SummaryMessageID string
 39	Cost             float64
 40	Todos            []Todo
 41	CreatedAt        int64
 42	UpdatedAt        int64
 43}
 44
 45type Service interface {
 46	pubsub.Subscriber[Session]
 47	Create(ctx context.Context, title string) (Session, error)
 48	CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
 49	CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
 50	Get(ctx context.Context, id string) (Session, error)
 51	List(ctx context.Context) ([]Session, error)
 52	Save(ctx context.Context, session Session) (Session, error)
 53	Delete(ctx context.Context, id string) error
 54
 55	// Agent tool session management
 56	CreateAgentToolSessionID(messageID, toolCallID string) string
 57	ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
 58	IsAgentToolSession(sessionID string) bool
 59}
 60
 61type service struct {
 62	*pubsub.Broker[Session]
 63	q db.Querier
 64}
 65
 66func (s *service) Create(ctx context.Context, title string) (Session, error) {
 67	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 68		ID:    uuid.New().String(),
 69		Title: title,
 70	})
 71	if err != nil {
 72		return Session{}, err
 73	}
 74	session := s.fromDBItem(dbSession)
 75	s.Publish(pubsub.CreatedEvent, session)
 76	event.SessionCreated()
 77	return session, nil
 78}
 79
 80func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
 81	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 82		ID:              toolCallID,
 83		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
 84		Title:           title,
 85	})
 86	if err != nil {
 87		return Session{}, err
 88	}
 89	session := s.fromDBItem(dbSession)
 90	s.Publish(pubsub.CreatedEvent, session)
 91	return session, nil
 92}
 93
 94func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
 95	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 96		ID:              "title-" + parentSessionID,
 97		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
 98		Title:           "Generate a title",
 99	})
100	if err != nil {
101		return Session{}, err
102	}
103	session := s.fromDBItem(dbSession)
104	s.Publish(pubsub.CreatedEvent, session)
105	return session, nil
106}
107
108func (s *service) Delete(ctx context.Context, id string) error {
109	session, err := s.Get(ctx, id)
110	if err != nil {
111		return err
112	}
113	err = s.q.DeleteSession(ctx, session.ID)
114	if err != nil {
115		return err
116	}
117	s.Publish(pubsub.DeletedEvent, session)
118	event.SessionDeleted()
119	return nil
120}
121
122func (s *service) Get(ctx context.Context, id string) (Session, error) {
123	dbSession, err := s.q.GetSessionByID(ctx, id)
124	if err != nil {
125		return Session{}, err
126	}
127	return s.fromDBItem(dbSession), nil
128}
129
130func (s *service) Save(ctx context.Context, session Session) (Session, error) {
131	todosJSON, err := marshalTodos(session.Todos)
132	if err != nil {
133		return Session{}, err
134	}
135
136	dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
137		ID:               session.ID,
138		Title:            session.Title,
139		PromptTokens:     session.PromptTokens,
140		CompletionTokens: session.CompletionTokens,
141		SummaryMessageID: sql.NullString{
142			String: session.SummaryMessageID,
143			Valid:  session.SummaryMessageID != "",
144		},
145		Cost: session.Cost,
146		Todos: sql.NullString{
147			String: todosJSON,
148			Valid:  todosJSON != "",
149		},
150	})
151	if err != nil {
152		return Session{}, err
153	}
154	session = s.fromDBItem(dbSession)
155	s.Publish(pubsub.UpdatedEvent, session)
156	return session, nil
157}
158
159func (s *service) List(ctx context.Context) ([]Session, error) {
160	dbSessions, err := s.q.ListSessions(ctx)
161	if err != nil {
162		return nil, err
163	}
164	sessions := make([]Session, len(dbSessions))
165	for i, dbSession := range dbSessions {
166		sessions[i] = s.fromDBItem(dbSession)
167	}
168	return sessions, nil
169}
170
171func (s service) fromDBItem(item db.Session) Session {
172	todos, err := unmarshalTodos(item.Todos.String)
173	if err != nil {
174		slog.Error("failed to unmarshal todos", "session_id", item.ID, "error", err)
175	}
176	return Session{
177		ID:               item.ID,
178		ParentSessionID:  item.ParentSessionID.String,
179		Title:            item.Title,
180		MessageCount:     item.MessageCount,
181		PromptTokens:     item.PromptTokens,
182		CompletionTokens: item.CompletionTokens,
183		SummaryMessageID: item.SummaryMessageID.String,
184		Cost:             item.Cost,
185		Todos:            todos,
186		CreatedAt:        item.CreatedAt,
187		UpdatedAt:        item.UpdatedAt,
188	}
189}
190
191func marshalTodos(todos []Todo) (string, error) {
192	if len(todos) == 0 {
193		return "", nil
194	}
195	data, err := json.Marshal(todos)
196	if err != nil {
197		return "", err
198	}
199	return string(data), nil
200}
201
202func unmarshalTodos(data string) ([]Todo, error) {
203	if data == "" {
204		return []Todo{}, nil
205	}
206	var todos []Todo
207	if err := json.Unmarshal([]byte(data), &todos); err != nil {
208		return []Todo{}, err
209	}
210	return todos, nil
211}
212
213func NewService(q db.Querier) Service {
214	broker := pubsub.NewBroker[Session]()
215	return &service{
216		broker,
217		q,
218	}
219}
220
221// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
222func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
223	return fmt.Sprintf("%s$$%s", messageID, toolCallID)
224}
225
226// ParseAgentToolSessionID parses an agent tool session ID into its components
227func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
228	parts := strings.Split(sessionID, "$$")
229	if len(parts) != 2 {
230		return "", "", false
231	}
232	return parts[0], parts[1], true
233}
234
235// IsAgentToolSession checks if a session ID follows the agent tool session format
236func (s *service) IsAgentToolSession(sessionID string) bool {
237	_, _, ok := s.ParseAgentToolSessionID(sessionID)
238	return ok
239}