session.go

  1package session
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"strings"
 10
 11	"github.com/charmbracelet/crush/internal/db"
 12	"github.com/charmbracelet/crush/internal/event"
 13	"github.com/charmbracelet/crush/internal/pubsub"
 14	"github.com/google/uuid"
 15)
 16
 17type TodoStatus string
 18
 19const (
 20	TodoStatusPending    TodoStatus = "pending"
 21	TodoStatusInProgress TodoStatus = "in_progress"
 22	TodoStatusCompleted  TodoStatus = "completed"
 23)
 24
 25type Todo struct {
 26	Content    string     `json:"content"`
 27	Status     TodoStatus `json:"status"`
 28	ActiveForm string     `json:"active_form"`
 29}
 30
 31// HasIncompleteTodos returns true if there are any non-completed todos.
 32func HasIncompleteTodos(todos []Todo) bool {
 33	for _, todo := range todos {
 34		if todo.Status != TodoStatusCompleted {
 35			return true
 36		}
 37	}
 38	return false
 39}
 40
 41type Session struct {
 42	ID               string
 43	ParentSessionID  string
 44	Title            string
 45	MessageCount     int64
 46	PromptTokens     int64
 47	CompletionTokens int64
 48	SummaryMessageID string
 49	Cost             float64
 50	Todos            []Todo
 51	CreatedAt        int64
 52	UpdatedAt        int64
 53}
 54
 55type Service interface {
 56	pubsub.Subscriber[Session]
 57	Create(ctx context.Context, title string) (Session, error)
 58	CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
 59	CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
 60	Get(ctx context.Context, id string) (Session, error)
 61	List(ctx context.Context) ([]Session, error)
 62	Save(ctx context.Context, session Session) (Session, error)
 63	UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
 64	Delete(ctx context.Context, id string) error
 65
 66	// Agent tool session management
 67	CreateAgentToolSessionID(messageID, toolCallID string) string
 68	ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
 69	IsAgentToolSession(sessionID string) bool
 70}
 71
 72type service struct {
 73	*pubsub.Broker[Session]
 74	db *sql.DB
 75	q  *db.Queries
 76}
 77
 78func (s *service) Create(ctx context.Context, title string) (Session, error) {
 79	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 80		ID:    uuid.New().String(),
 81		Title: title,
 82	})
 83	if err != nil {
 84		return Session{}, err
 85	}
 86	session := s.fromDBItem(dbSession)
 87	s.Publish(pubsub.CreatedEvent, session)
 88	event.SessionCreated()
 89	return session, nil
 90}
 91
 92func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
 93	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 94		ID:              toolCallID,
 95		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
 96		Title:           title,
 97	})
 98	if err != nil {
 99		return Session{}, err
100	}
101	session := s.fromDBItem(dbSession)
102	s.Publish(pubsub.CreatedEvent, session)
103	return session, nil
104}
105
106func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
107	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
108		ID:              "title-" + parentSessionID,
109		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
110		Title:           "Generate a title",
111	})
112	if err != nil {
113		return Session{}, err
114	}
115	session := s.fromDBItem(dbSession)
116	s.Publish(pubsub.CreatedEvent, session)
117	return session, nil
118}
119
120func (s *service) Delete(ctx context.Context, id string) error {
121	tx, err := s.db.BeginTx(ctx, nil)
122	if err != nil {
123		return fmt.Errorf("beginning transaction: %w", err)
124	}
125	defer tx.Rollback() //nolint:errcheck
126
127	qtx := s.q.WithTx(tx)
128
129	dbSession, err := qtx.GetSessionByID(ctx, id)
130	if err != nil {
131		return err
132	}
133	if err = qtx.DeleteSessionMessages(ctx, dbSession.ID); err != nil {
134		return fmt.Errorf("deleting session messages: %w", err)
135	}
136	if err = qtx.DeleteSessionFiles(ctx, dbSession.ID); err != nil {
137		return fmt.Errorf("deleting session files: %w", err)
138	}
139	if err = qtx.DeleteSession(ctx, dbSession.ID); err != nil {
140		return fmt.Errorf("deleting session: %w", err)
141	}
142	if err = tx.Commit(); err != nil {
143		return fmt.Errorf("committing transaction: %w", err)
144	}
145
146	session := s.fromDBItem(dbSession)
147	s.Publish(pubsub.DeletedEvent, session)
148	event.SessionDeleted()
149	return nil
150}
151
152func (s *service) Get(ctx context.Context, id string) (Session, error) {
153	dbSession, err := s.q.GetSessionByID(ctx, id)
154	if err != nil {
155		return Session{}, err
156	}
157	return s.fromDBItem(dbSession), nil
158}
159
160func (s *service) Save(ctx context.Context, session Session) (Session, error) {
161	todosJSON, err := marshalTodos(session.Todos)
162	if err != nil {
163		return Session{}, err
164	}
165
166	dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
167		ID:               session.ID,
168		Title:            session.Title,
169		PromptTokens:     session.PromptTokens,
170		CompletionTokens: session.CompletionTokens,
171		SummaryMessageID: sql.NullString{
172			String: session.SummaryMessageID,
173			Valid:  session.SummaryMessageID != "",
174		},
175		Cost: session.Cost,
176		Todos: sql.NullString{
177			String: todosJSON,
178			Valid:  todosJSON != "",
179		},
180	})
181	if err != nil {
182		return Session{}, err
183	}
184	session = s.fromDBItem(dbSession)
185	s.Publish(pubsub.UpdatedEvent, session)
186	return session, nil
187}
188
189// UpdateTitleAndUsage updates only the title and usage fields atomically.
190// This is safer than fetching, modifying, and saving the entire session.
191func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
192	return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
193		ID:               sessionID,
194		Title:            title,
195		PromptTokens:     promptTokens,
196		CompletionTokens: completionTokens,
197		Cost:             cost,
198	})
199}
200
201func (s *service) List(ctx context.Context) ([]Session, error) {
202	dbSessions, err := s.q.ListSessions(ctx)
203	if err != nil {
204		return nil, err
205	}
206	sessions := make([]Session, len(dbSessions))
207	for i, dbSession := range dbSessions {
208		sessions[i] = s.fromDBItem(dbSession)
209	}
210	return sessions, nil
211}
212
213func (s service) fromDBItem(item db.Session) Session {
214	todos, err := unmarshalTodos(item.Todos.String)
215	if err != nil {
216		slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
217	}
218	return Session{
219		ID:               item.ID,
220		ParentSessionID:  item.ParentSessionID.String,
221		Title:            item.Title,
222		MessageCount:     item.MessageCount,
223		PromptTokens:     item.PromptTokens,
224		CompletionTokens: item.CompletionTokens,
225		SummaryMessageID: item.SummaryMessageID.String,
226		Cost:             item.Cost,
227		Todos:            todos,
228		CreatedAt:        item.CreatedAt,
229		UpdatedAt:        item.UpdatedAt,
230	}
231}
232
233func marshalTodos(todos []Todo) (string, error) {
234	if len(todos) == 0 {
235		return "", nil
236	}
237	data, err := json.Marshal(todos)
238	if err != nil {
239		return "", err
240	}
241	return string(data), nil
242}
243
244func unmarshalTodos(data string) ([]Todo, error) {
245	if data == "" {
246		return []Todo{}, nil
247	}
248	var todos []Todo
249	if err := json.Unmarshal([]byte(data), &todos); err != nil {
250		return []Todo{}, err
251	}
252	return todos, nil
253}
254
255func NewService(q *db.Queries, conn *sql.DB) Service {
256	broker := pubsub.NewBroker[Session]()
257	return &service{
258		Broker: broker,
259		db:     conn,
260		q:      q,
261	}
262}
263
264// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
265func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
266	return fmt.Sprintf("%s$$%s", messageID, toolCallID)
267}
268
269// ParseAgentToolSessionID parses an agent tool session ID into its components
270func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
271	parts := strings.Split(sessionID, "$$")
272	if len(parts) != 2 {
273		return "", "", false
274	}
275	return parts[0], parts[1], true
276}
277
278// IsAgentToolSession checks if a session ID follows the agent tool session format
279func (s *service) IsAgentToolSession(sessionID string) bool {
280	_, _, ok := s.ParseAgentToolSessionID(sessionID)
281	return ok
282}