session.go

  1package session
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"encoding/json"
  7	"fmt"
  8	"log/slog"
  9	"strings"
 10
 11	"github.com/charmbracelet/crush/internal/db"
 12	"github.com/charmbracelet/crush/internal/event"
 13	"github.com/charmbracelet/crush/internal/pubsub"
 14	"github.com/google/uuid"
 15)
 16
 17type TodoStatus string
 18
 19const (
 20	TodoStatusPending    TodoStatus = "pending"
 21	TodoStatusInProgress TodoStatus = "in_progress"
 22	TodoStatusCompleted  TodoStatus = "completed"
 23)
 24
 25type Todo struct {
 26	Content    string     `json:"content"`
 27	Status     TodoStatus `json:"status"`
 28	ActiveForm string     `json:"active_form"`
 29}
 30
 31type Session struct {
 32	ID               string
 33	ParentSessionID  string
 34	Title            string
 35	MessageCount     int64
 36	PromptTokens     int64
 37	CompletionTokens int64
 38	SummaryMessageID string
 39	Cost             float64
 40	Todos            []Todo
 41	CreatedAt        int64
 42	UpdatedAt        int64
 43}
 44
 45type Service interface {
 46	pubsub.Subscriber[Session]
 47	Create(ctx context.Context, title string) (Session, error)
 48	CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error)
 49	CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error)
 50	Get(ctx context.Context, id string) (Session, error)
 51	List(ctx context.Context) ([]Session, error)
 52	Save(ctx context.Context, session Session) (Session, error)
 53	UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error
 54	Delete(ctx context.Context, id string) error
 55
 56	// Agent tool session management
 57	CreateAgentToolSessionID(messageID, toolCallID string) string
 58	ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool)
 59	IsAgentToolSession(sessionID string) bool
 60}
 61
 62type service struct {
 63	*pubsub.Broker[Session]
 64	q db.Querier
 65}
 66
 67func (s *service) Create(ctx context.Context, title string) (Session, error) {
 68	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 69		ID:    uuid.New().String(),
 70		Title: title,
 71	})
 72	if err != nil {
 73		return Session{}, err
 74	}
 75	session := s.fromDBItem(dbSession)
 76	s.Publish(pubsub.CreatedEvent, session)
 77	event.SessionCreated()
 78	return session, nil
 79}
 80
 81func (s *service) CreateTaskSession(ctx context.Context, toolCallID, parentSessionID, title string) (Session, error) {
 82	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 83		ID:              toolCallID,
 84		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
 85		Title:           title,
 86	})
 87	if err != nil {
 88		return Session{}, err
 89	}
 90	session := s.fromDBItem(dbSession)
 91	s.Publish(pubsub.CreatedEvent, session)
 92	return session, nil
 93}
 94
 95func (s *service) CreateTitleSession(ctx context.Context, parentSessionID string) (Session, error) {
 96	dbSession, err := s.q.CreateSession(ctx, db.CreateSessionParams{
 97		ID:              "title-" + parentSessionID,
 98		ParentSessionID: sql.NullString{String: parentSessionID, Valid: true},
 99		Title:           "Generate a title",
100	})
101	if err != nil {
102		return Session{}, err
103	}
104	session := s.fromDBItem(dbSession)
105	s.Publish(pubsub.CreatedEvent, session)
106	return session, nil
107}
108
109func (s *service) Delete(ctx context.Context, id string) error {
110	session, err := s.Get(ctx, id)
111	if err != nil {
112		return err
113	}
114	err = s.q.DeleteSession(ctx, session.ID)
115	if err != nil {
116		return err
117	}
118	s.Publish(pubsub.DeletedEvent, session)
119	event.SessionDeleted()
120	return nil
121}
122
123func (s *service) Get(ctx context.Context, id string) (Session, error) {
124	dbSession, err := s.q.GetSessionByID(ctx, id)
125	if err != nil {
126		return Session{}, err
127	}
128	return s.fromDBItem(dbSession), nil
129}
130
131func (s *service) Save(ctx context.Context, session Session) (Session, error) {
132	todosJSON, err := marshalTodos(session.Todos)
133	if err != nil {
134		return Session{}, err
135	}
136
137	dbSession, err := s.q.UpdateSession(ctx, db.UpdateSessionParams{
138		ID:               session.ID,
139		Title:            session.Title,
140		PromptTokens:     session.PromptTokens,
141		CompletionTokens: session.CompletionTokens,
142		SummaryMessageID: sql.NullString{
143			String: session.SummaryMessageID,
144			Valid:  session.SummaryMessageID != "",
145		},
146		Cost: session.Cost,
147		Todos: sql.NullString{
148			String: todosJSON,
149			Valid:  todosJSON != "",
150		},
151	})
152	if err != nil {
153		return Session{}, err
154	}
155	session = s.fromDBItem(dbSession)
156	s.Publish(pubsub.UpdatedEvent, session)
157	return session, nil
158}
159
160// UpdateTitleAndUsage updates only the title and usage fields atomically.
161// This is safer than fetching, modifying, and saving the entire session.
162func (s *service) UpdateTitleAndUsage(ctx context.Context, sessionID, title string, promptTokens, completionTokens int64, cost float64) error {
163	return s.q.UpdateSessionTitleAndUsage(ctx, db.UpdateSessionTitleAndUsageParams{
164		ID:               sessionID,
165		Title:            title,
166		PromptTokens:     promptTokens,
167		CompletionTokens: completionTokens,
168		Cost:             cost,
169	})
170}
171
172func (s *service) List(ctx context.Context) ([]Session, error) {
173	dbSessions, err := s.q.ListSessions(ctx)
174	if err != nil {
175		return nil, err
176	}
177	sessions := make([]Session, len(dbSessions))
178	for i, dbSession := range dbSessions {
179		sessions[i] = s.fromDBItem(dbSession)
180	}
181	return sessions, nil
182}
183
184func (s service) fromDBItem(item db.Session) Session {
185	todos, err := unmarshalTodos(item.Todos.String)
186	if err != nil {
187		slog.Error("failed to unmarshal todos", "session_id", item.ID, "error", err)
188	}
189	return Session{
190		ID:               item.ID,
191		ParentSessionID:  item.ParentSessionID.String,
192		Title:            item.Title,
193		MessageCount:     item.MessageCount,
194		PromptTokens:     item.PromptTokens,
195		CompletionTokens: item.CompletionTokens,
196		SummaryMessageID: item.SummaryMessageID.String,
197		Cost:             item.Cost,
198		Todos:            todos,
199		CreatedAt:        item.CreatedAt,
200		UpdatedAt:        item.UpdatedAt,
201	}
202}
203
204func marshalTodos(todos []Todo) (string, error) {
205	if len(todos) == 0 {
206		return "", nil
207	}
208	data, err := json.Marshal(todos)
209	if err != nil {
210		return "", err
211	}
212	return string(data), nil
213}
214
215func unmarshalTodos(data string) ([]Todo, error) {
216	if data == "" {
217		return []Todo{}, nil
218	}
219	var todos []Todo
220	if err := json.Unmarshal([]byte(data), &todos); err != nil {
221		return []Todo{}, err
222	}
223	return todos, nil
224}
225
226func NewService(q db.Querier) Service {
227	broker := pubsub.NewBroker[Session]()
228	return &service{
229		broker,
230		q,
231	}
232}
233
234// CreateAgentToolSessionID creates a session ID for agent tool sessions using the format "messageID$$toolCallID"
235func (s *service) CreateAgentToolSessionID(messageID, toolCallID string) string {
236	return fmt.Sprintf("%s$$%s", messageID, toolCallID)
237}
238
239// ParseAgentToolSessionID parses an agent tool session ID into its components
240func (s *service) ParseAgentToolSessionID(sessionID string) (messageID string, toolCallID string, ok bool) {
241	parts := strings.Split(sessionID, "$$")
242	if len(parts) != 2 {
243		return "", "", false
244	}
245	return parts[0], parts[1], true
246}
247
248// IsAgentToolSession checks if a session ID follows the agent tool session format
249func (s *service) IsAgentToolSession(sessionID string) bool {
250	_, _, ok := s.ParseAgentToolSessionID(sessionID)
251	return ok
252}