resolve_session_test.go

  1package app
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"fmt"
  7	"strings"
  8	"testing"
  9
 10	"github.com/charmbracelet/crush/internal/config"
 11	"github.com/charmbracelet/crush/internal/pubsub"
 12	"github.com/charmbracelet/crush/internal/session"
 13	"github.com/stretchr/testify/require"
 14)
 15
 16// mockSessionService is a minimal mock of session.Service for testing resolveSession.
 17type mockSessionService struct {
 18	sessions []session.Session
 19	created  []session.Session
 20}
 21
 22func (m *mockSessionService) Subscribe(context.Context) <-chan pubsub.Event[session.Session] {
 23	return make(chan pubsub.Event[session.Session])
 24}
 25
 26func (m *mockSessionService) Create(_ context.Context, title string) (session.Session, error) {
 27	s := session.Session{ID: "new-session-id", Title: title}
 28	m.created = append(m.created, s)
 29	return s, nil
 30}
 31
 32func (m *mockSessionService) CreateTitleSession(context.Context, string) (session.Session, error) {
 33	return session.Session{}, nil
 34}
 35
 36func (m *mockSessionService) CreateTaskSession(context.Context, string, string, string) (session.Session, error) {
 37	return session.Session{}, nil
 38}
 39
 40func (m *mockSessionService) Get(_ context.Context, id string) (session.Session, error) {
 41	for _, s := range m.sessions {
 42		if s.ID == id {
 43			return s, nil
 44		}
 45	}
 46	return session.Session{}, sql.ErrNoRows
 47}
 48
 49func (m *mockSessionService) GetLast(_ context.Context) (session.Session, error) {
 50	if len(m.sessions) > 0 {
 51		return m.sessions[0], nil
 52	}
 53	return session.Session{}, sql.ErrNoRows
 54}
 55
 56func (m *mockSessionService) List(context.Context) ([]session.Session, error) {
 57	return m.sessions, nil
 58}
 59
 60func (m *mockSessionService) Save(_ context.Context, s session.Session) (session.Session, error) {
 61	return s, nil
 62}
 63
 64func (m *mockSessionService) SaveWithModels(_ context.Context, s session.Session, _ map[config.SelectedModelType]config.SelectedModel) (session.Session, error) {
 65	return s, nil
 66}
 67
 68func (m *mockSessionService) UpdateSessionModels(context.Context, string, map[config.SelectedModelType]config.SelectedModel) error {
 69	return nil
 70}
 71
 72func (m *mockSessionService) UpdateTitleAndUsage(context.Context, string, string, int64, int64, float64) error {
 73	return nil
 74}
 75
 76func (m *mockSessionService) Rename(context.Context, string, string) error {
 77	return nil
 78}
 79
 80func (m *mockSessionService) Delete(context.Context, string) error {
 81	return nil
 82}
 83
 84func (m *mockSessionService) CreateAgentToolSessionID(messageID, toolCallID string) string {
 85	return fmt.Sprintf("%s$$%s", messageID, toolCallID)
 86}
 87
 88func (m *mockSessionService) ParseAgentToolSessionID(sessionID string) (string, string, bool) {
 89	parts := strings.Split(sessionID, "$$")
 90	if len(parts) != 2 {
 91		return "", "", false
 92	}
 93	return parts[0], parts[1], true
 94}
 95
 96func (m *mockSessionService) IsAgentToolSession(sessionID string) bool {
 97	_, _, ok := m.ParseAgentToolSessionID(sessionID)
 98	return ok
 99}
100
101func newTestApp(sessions session.Service) *App {
102	return &App{Sessions: sessions}
103}
104
105func TestResolveSession_NewSession(t *testing.T) {
106	mock := &mockSessionService{}
107	app := newTestApp(mock)
108
109	sess, err := app.resolveSession(t.Context(), "", false)
110	require.NoError(t, err)
111	require.Equal(t, "new-session-id", sess.ID)
112	require.Len(t, mock.created, 1)
113}
114
115func TestResolveSession_ContinueByID(t *testing.T) {
116	mock := &mockSessionService{
117		sessions: []session.Session{
118			{ID: "existing-id", Title: "Old session"},
119		},
120	}
121	app := newTestApp(mock)
122
123	sess, err := app.resolveSession(t.Context(), "existing-id", false)
124	require.NoError(t, err)
125	require.Equal(t, "existing-id", sess.ID)
126	require.Equal(t, "Old session", sess.Title)
127	require.Empty(t, mock.created)
128}
129
130func TestResolveSession_ContinueByID_NotFound(t *testing.T) {
131	mock := &mockSessionService{}
132	app := newTestApp(mock)
133
134	_, err := app.resolveSession(t.Context(), "nonexistent", false)
135	require.Error(t, err)
136	require.Contains(t, err.Error(), "session not found")
137}
138
139func TestResolveSession_ContinueByID_ChildSession(t *testing.T) {
140	mock := &mockSessionService{
141		sessions: []session.Session{
142			{ID: "child-id", ParentSessionID: "parent-id", Title: "Child session"},
143		},
144	}
145	app := newTestApp(mock)
146
147	_, err := app.resolveSession(t.Context(), "child-id", false)
148	require.Error(t, err)
149	require.Contains(t, err.Error(), "cannot continue a child session")
150}
151
152func TestResolveSession_ContinueByID_AgentToolSession(t *testing.T) {
153	mock := &mockSessionService{}
154	app := newTestApp(mock)
155
156	_, err := app.resolveSession(t.Context(), "msg123$$tool456", false)
157	require.Error(t, err)
158	require.Contains(t, err.Error(), "cannot continue an agent tool session")
159}
160
161func TestResolveSession_Last(t *testing.T) {
162	mock := &mockSessionService{
163		sessions: []session.Session{
164			{ID: "most-recent", Title: "Latest session"},
165			{ID: "older", Title: "Older session"},
166		},
167	}
168	app := newTestApp(mock)
169
170	sess, err := app.resolveSession(t.Context(), "", true)
171	require.NoError(t, err)
172	require.Equal(t, "most-recent", sess.ID)
173	require.Empty(t, mock.created)
174}
175
176func TestResolveSession_Last_NoSessions(t *testing.T) {
177	mock := &mockSessionService{}
178	app := newTestApp(mock)
179
180	_, err := app.resolveSession(t.Context(), "", true)
181	require.Error(t, err)
182	require.Contains(t, err.Error(), "no sessions found")
183}