session.go

  1package session
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"strings"
 10
 11	"github.com/charmbracelet/crush/internal/db"
 12	"github.com/charmbracelet/crush/internal/event"
 13	"github.com/charmbracelet/crush/internal/message"
 14	"github.com/charmbracelet/crush/internal/pubsub"
 15	"github.com/google/uuid"
 16)
 17
 18type TodoStatus string
 19
 20const (
 21	TodoStatusPending    TodoStatus = "pending"
 22	TodoStatusInProgress TodoStatus = "in_progress"
 23	TodoStatusCompleted  TodoStatus = "completed"
 24)
 25
 26type Todo struct {
 27	Content    string     `json:"content"`
 28	Status     TodoStatus `json:"status"`
 29	ActiveForm string     `json:"active_form"`
 30}
 31
 32type Session struct {
 33	ID               string
 34	ParentSessionID  string
 35	Title            string
 36	MessageCount     int64
 37	PromptTokens     int64
 38	CompletionTokens int64
 39	SummaryMessageID string
 40	Cost             float64
 41	Todos            []Todo
 42	CreatedAt        int64
 43	UpdatedAt        int64
 44}
 45
 46type ForkMessageService interface {
 47	List(ctx context.Context, sessionID string) ([]message.Message, error)
 48	Copy(ctx context.Context, sessionID string, message message.Message) (message.Message, error)
 49}
 50
 51type Service interface {
 52	pubsub.Subscriber[Session]
 53	Create(ctx context.Context, title string) (Session, error)
 54	CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
 55	CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
 56	Get(ctx context.Context, id string) (Session, error)
 57	List(ctx context.Context) ([]Session, error)
 58	Save(ctx context.Context, session Session) (Session, error)
 59	UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
 60	Delete(ctx context.Context, id string) error
 61	Fork(ctx context.Context, sourceSessionID, upToMessageID string, messageSvc ForkMessageService) (Session, error)
 62
 63	// Agent tool session management
 64	CreateAgentToolSessionID(messageID, toolCallID string) string
 65	ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
 66	IsAgentToolSession(sessionID string) bool
 67}
 68
 69type service struct {
 70	*pubsub.Broker[Session]
 71	db *sql.DB
 72	q  *db.Queries
 73}
 74
 75func (s *service) Create(ctx context.Context, title string) (Session, error) {
 76	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 77		ID:    uuid.New().String(),
 78		Title: title,
 79	})
 80	if err != nil {
 81		return Session{}, err
 82	}
 83	session := s.fromDBItem(dbSession)
 84	s.Publish(pubsub.CreatedEvent, session)
 85	event.SessionCreated()
 86	return session, nil
 87}
 88
 89func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
 90	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 91		ID:              toolCallID,
 92		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
 93		Title:           title,
 94	})
 95	if err != nil {
 96		return Session{}, err
 97	}
 98	session := s.fromDBItem(dbSession)
 99	s.Publish(pubsub.CreatedEvent, session)
100	return session, nil
101}
102
103func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
104	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
105		ID:              "title-" + parentSessionID,
106		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
107		Title:           "Generate a title",
108	})
109	if err != nil {
110		return Session{}, err
111	}
112	session := s.fromDBItem(dbSession)
113	s.Publish(pubsub.CreatedEvent, session)
114	return session, nil
115}
116
117func (s *service) Delete(ctx context.Context, id string) error {
118	tx, err := s.db.BeginTx(ctx, nil)
119	if err != nil {
120		return fmt.Errorf("beginning transaction: %w", err)
121	}
122	defer tx.Rollback() //nolint:errcheck
123
124	qtx := s.q.WithTx(tx)
125
126	dbSession, err := qtx.GetSessionByID(ctx, id)
127	if err != nil {
128		return err
129	}
130	if err = qtx.DeleteSessionMessages(ctx, dbSession.ID); err != nil {
131		return fmt.Errorf("deleting session messages: %w", err)
132	}
133	if err = qtx.DeleteSessionFiles(ctx, dbSession.ID); err != nil {
134		return fmt.Errorf("deleting session files: %w", err)
135	}
136	if err = qtx.DeleteSession(ctx, dbSession.ID); err != nil {
137		return fmt.Errorf("deleting session: %w", err)
138	}
139	if err = tx.Commit(); err != nil {
140		return fmt.Errorf("committing transaction: %w", err)
141	}
142
143	session := s.fromDBItem(dbSession)
144	s.Publish(pubsub.DeletedEvent, session)
145	event.SessionDeleted()
146	return nil
147}
148
149func (s *service) Fork(ctx context.Context, sourceSessionID, upToMessageID string, messageSvc ForkMessageService) (Session, error) {
150	messages, err := messageSvc.List(ctx, sourceSessionID)
151	if err != nil {
152		return Session{}, fmt.Errorf("listing messages: %w", err)
153	}
154
155	targetIndex := -1
156	for i, msg := range messages {
157		if msg.ID == upToMessageID {
158			targetIndex = i
159			break
160		}
161	}
162	if targetIndex == -1 {
163		return Session{}, fmt.Errorf("message not found: %s", upToMessageID)
164	}
165
166	sourceSession, err := s.Get(ctx, sourceSessionID)
167	if err != nil {
168		return Session{}, fmt.Errorf("getting source session: %w", err)
169	}
170
171	newSession, err := s.Create(ctx, "Forked: "+sourceSession.Title)
172	if err != nil {
173		return Session{}, fmt.Errorf("creating session: %w", err)
174	}
175
176	for i := 0; i < targetIndex; i++ {
177		_, err = messageSvc.Copy(ctx, newSession.ID, messages[i])
178		if err != nil {
179			_ = s.Delete(ctx, newSession.ID)
180			return Session{}, fmt.Errorf("copying message: %w", err)
181		}
182	}
183
184	return newSession, nil
185}
186
187func (s *service) Get(ctx context.Context, id string) (Session, error) {
188	dbSession, err := s.q.GetSessionByID(ctx, id)
189	if err != nil {
190		return Session{}, err
191	}
192	return s.fromDBItem(dbSession), nil
193}
194
195func (s *service) Save(ctx context.Context, session Session) (Session, error) {
196	todosJSON, err := marshalTodos(session.Todos)
197	if err != nil {
198		return Session{}, err
199	}
200
201	dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
202		ID:               session.ID,
203		Title:            session.Title,
204		PromptTokens:     session.PromptTokens,
205		CompletionTokens: session.CompletionTokens,
206		SummaryMessageID: sql.NullString{
207			String: session.SummaryMessageID,
208			Valid:  session.SummaryMessageID != "",
209		},
210		Cost: session.Cost,
211		Todos: sql.NullString{
212			String: todosJSON,
213			Valid:  todosJSON != "",
214		},
215	})
216	if err != nil {
217		return Session{}, err
218	}
219	session = s.fromDBItem(dbSession)
220	s.Publish(pubsub.UpdatedEvent, session)
221	return session, nil
222}
223
224// UpdateTitleAndUsage updates only the title and usage fields atomically.
225// This is safer than fetching, modifying, and saving the entire session.
226func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
227	return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
228		ID:               sessionID,
229		Title:            title,
230		PromptTokens:     promptTokens,
231		CompletionTokens: completionTokens,
232		Cost:             cost,
233	})
234}
235
236func (s *service) List(ctx context.Context) ([]Session, error) {
237	dbSessions, err := s.q.ListSessions(ctx)
238	if err != nil {
239		return nil, err
240	}
241	sessions := make([]Session, len(dbSessions))
242	for i, dbSession := range dbSessions {
243		sessions[i] = s.fromDBItem(dbSession)
244	}
245	return sessions, nil
246}
247
248func (s service) fromDBItem(item db.Session) Session {
249	todos, err := unmarshalTodos(item.Todos.String)
250	if err != nil {
251		slog.Error("failed to unmarshal todos", "session_id", item.ID, "error", err)
252	}
253	return Session{
254		ID:               item.ID,
255		ParentSessionID:  item.ParentSessionID.String,
256		Title:            item.Title,
257		MessageCount:     item.MessageCount,
258		PromptTokens:     item.PromptTokens,
259		CompletionTokens: item.CompletionTokens,
260		SummaryMessageID: item.SummaryMessageID.String,
261		Cost:             item.Cost,
262		Todos:            todos,
263		CreatedAt:        item.CreatedAt,
264		UpdatedAt:        item.UpdatedAt,
265	}
266}
267
268func marshalTodos(todos []Todo) (string, error) {
269	if len(todos) == 0 {
270		return "", nil
271	}
272	data, err := json.Marshal(todos)
273	if err != nil {
274		return "", err
275	}
276	return string(data), nil
277}
278
279func unmarshalTodos(data string) ([]Todo, error) {
280	if data == "" {
281		return []Todo{}, nil
282	}
283	var todos []Todo
284	if err := json.Unmarshal([]byte(data), &todos); err != nil {
285		return []Todo{}, err
286	}
287	return todos, nil
288}
289
290func NewService(q *db.Queries, conn *sql.DB) Service {
291	broker := pubsub.NewBroker[Session]()
292	return &service{
293		Broker: broker,
294		db:     conn,
295		q:      q,
296	}
297}
298
299// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
300func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
301	return fmt.Sprintf("%s$$%s", messageID, toolCallID)
302}
303
304// ParseAgentToolSessionID parses an agent tool session ID into its components
305func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
306	parts := strings.Split(sessionID, "$$")
307	if len(parts) != 2 {
308		return "", "", false
309	}
310	return parts[0], parts[1], true
311}
312
313// IsAgentToolSession checks if a session ID follows the agent tool session format
314func (s *service) IsAgentToolSession(sessionID string) bool {
315	_, _, ok := s.ParseAgentToolSessionID(sessionID)
316	return ok
317}