session.go

  1package session
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"strings"
 10
 11	"github.com/charmbracelet/crush/internal/db"
 12	"github.com/charmbracelet/crush/internal/event"
 13	"github.com/charmbracelet/crush/internal/pubsub"
 14	"github.com/google/uuid"
 15)
 16
 17type TodoStatus string
 18
 19const (
 20	TodoStatusPending    TodoStatus = "pending"
 21	TodoStatusInProgress TodoStatus = "in_progress"
 22	TodoStatusCompleted  TodoStatus = "completed"
 23)
 24
 25type Todo struct {
 26	Content    string     `json:"content"`
 27	Status     TodoStatus `json:"status"`
 28	ActiveForm string     `json:"active_form"`
 29}
 30
 31type Session struct {
 32	ID               string
 33	ParentSessionID  string
 34	Title            string
 35	MessageCount     int64
 36	PromptTokens     int64
 37	CompletionTokens int64
 38	SummaryMessageID string
 39	Cost             float64
 40	Todos            []Todo
 41	CreatedAt        int64
 42	UpdatedAt        int64
 43}
 44
 45type Service interface {
 46	pubsub.Subscriber[Session]
 47	Create(ctx context.Context, title string) (Session, error)
 48	CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
 49	CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
 50	Get(ctx context.Context, id string) (Session, error)
 51	List(ctx context.Context) ([]Session, error)
 52	Save(ctx context.Context, session Session) (Session, error)
 53	UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
 54	Delete(ctx context.Context, id string) error
 55
 56	// Agent tool session management
 57	CreateAgentToolSessionID(messageID, toolCallID string) string
 58	ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
 59	IsAgentToolSession(sessionID string) bool
 60}
 61
 62type service struct {
 63	*pubsub.Broker[Session]
 64	db *sql.DB
 65	q  *db.Queries
 66}
 67
 68func (s *service) Create(ctx context.Context, title string) (Session, error) {
 69	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 70		ID:    uuid.New().String(),
 71		Title: title,
 72	})
 73	if err != nil {
 74		return Session{}, err
 75	}
 76	session := s.fromDBItem(dbSession)
 77	s.Publish(pubsub.CreatedEvent, session)
 78	event.SessionCreated()
 79	return session, nil
 80}
 81
 82func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
 83	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 84		ID:              toolCallID,
 85		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
 86		Title:           title,
 87	})
 88	if err != nil {
 89		return Session{}, err
 90	}
 91	session := s.fromDBItem(dbSession)
 92	s.Publish(pubsub.CreatedEvent, session)
 93	return session, nil
 94}
 95
 96func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
 97	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 98		ID:              "title-" + parentSessionID,
 99		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
100		Title:           "Generate a title",
101	})
102	if err != nil {
103		return Session{}, err
104	}
105	session := s.fromDBItem(dbSession)
106	s.Publish(pubsub.CreatedEvent, session)
107	return session, nil
108}
109
110func (s *service) Delete(ctx context.Context, id string) error {
111	tx, err := s.db.BeginTx(ctx, nil)
112	if err != nil {
113		return fmt.Errorf("beginning transaction: %w", err)
114	}
115	defer tx.Rollback() //nolint:errcheck
116
117	qtx := s.q.WithTx(tx)
118
119	dbSession, err := qtx.GetSessionByID(ctx, id)
120	if err != nil {
121		return err
122	}
123	if err = qtx.DeleteSessionMessages(ctx, dbSession.ID); err != nil {
124		return fmt.Errorf("deleting session messages: %w", err)
125	}
126	if err = qtx.DeleteSessionFiles(ctx, dbSession.ID); err != nil {
127		return fmt.Errorf("deleting session files: %w", err)
128	}
129	if err = qtx.DeleteSession(ctx, dbSession.ID); err != nil {
130		return fmt.Errorf("deleting session: %w", err)
131	}
132	if err = tx.Commit(); err != nil {
133		return fmt.Errorf("committing transaction: %w", err)
134	}
135
136	session := s.fromDBItem(dbSession)
137	s.Publish(pubsub.DeletedEvent, session)
138	event.SessionDeleted()
139	return nil
140}
141
142func (s *service) Get(ctx context.Context, id string) (Session, error) {
143	dbSession, err := s.q.GetSessionByID(ctx, id)
144	if err != nil {
145		return Session{}, err
146	}
147	return s.fromDBItem(dbSession), nil
148}
149
150func (s *service) Save(ctx context.Context, session Session) (Session, error) {
151	todosJSON, err := marshalTodos(session.Todos)
152	if err != nil {
153		return Session{}, err
154	}
155
156	dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
157		ID:               session.ID,
158		Title:            session.Title,
159		PromptTokens:     session.PromptTokens,
160		CompletionTokens: session.CompletionTokens,
161		SummaryMessageID: sql.NullString{
162			String: session.SummaryMessageID,
163			Valid:  session.SummaryMessageID != "",
164		},
165		Cost: session.Cost,
166		Todos: sql.NullString{
167			String: todosJSON,
168			Valid:  todosJSON != "",
169		},
170	})
171	if err != nil {
172		return Session{}, err
173	}
174	session = s.fromDBItem(dbSession)
175	s.Publish(pubsub.UpdatedEvent, session)
176	return session, nil
177}
178
179// UpdateTitleAndUsage updates only the title and usage fields atomically.
180// This is safer than fetching, modifying, and saving the entire session.
181func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
182	return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
183		ID:               sessionID,
184		Title:            title,
185		PromptTokens:     promptTokens,
186		CompletionTokens: completionTokens,
187		Cost:             cost,
188	})
189}
190
191func (s *service) List(ctx context.Context) ([]Session, error) {
192	dbSessions, err := s.q.ListSessions(ctx)
193	if err != nil {
194		return nil, err
195	}
196	sessions := make([]Session, len(dbSessions))
197	for i, dbSession := range dbSessions {
198		sessions[i] = s.fromDBItem(dbSession)
199	}
200	return sessions, nil
201}
202
203func (s service) fromDBItem(item db.Session) Session {
204	todos, err := unmarshalTodos(item.Todos.String)
205	if err != nil {
206		slog.Error("Failed to unmarshal todos", "session_id", item.ID, "error", err)
207	}
208	return Session{
209		ID:               item.ID,
210		ParentSessionID:  item.ParentSessionID.String,
211		Title:            item.Title,
212		MessageCount:     item.MessageCount,
213		PromptTokens:     item.PromptTokens,
214		CompletionTokens: item.CompletionTokens,
215		SummaryMessageID: item.SummaryMessageID.String,
216		Cost:             item.Cost,
217		Todos:            todos,
218		CreatedAt:        item.CreatedAt,
219		UpdatedAt:        item.UpdatedAt,
220	}
221}
222
223func marshalTodos(todos []Todo) (string, error) {
224	if len(todos) == 0 {
225		return "", nil
226	}
227	data, err := json.Marshal(todos)
228	if err != nil {
229		return "", err
230	}
231	return string(data), nil
232}
233
234func unmarshalTodos(data string) ([]Todo, error) {
235	if data == "" {
236		return []Todo{}, nil
237	}
238	var todos []Todo
239	if err := json.Unmarshal([]byte(data), &todos); err != nil {
240		return []Todo{}, err
241	}
242	return todos, nil
243}
244
245func NewService(q *db.Queries, conn *sql.DB) Service {
246	broker := pubsub.NewBroker[Session]()
247	return &service{
248		Broker: broker,
249		db:     conn,
250		q:      q,
251	}
252}
253
254// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
255func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
256	return fmt.Sprintf("%s$$%s", messageID, toolCallID)
257}
258
259// ParseAgentToolSessionID parses an agent tool session ID into its components
260func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
261	parts := strings.Split(sessionID, "$$")
262	if len(parts) != 2 {
263		return "", "", false
264	}
265	return parts[0], parts[1], true
266}
267
268// IsAgentToolSession checks if a session ID follows the agent tool session format
269func (s *service) IsAgentToolSession(sessionID string) bool {
270	_, _, ok := s.ParseAgentToolSessionID(sessionID)
271	return ok
272}