resolve_session_test.go

  1package app
  2
  3import (
  4	"context"
  5	"database/sql"
  6	"fmt"
  7	"strings"
  8	"testing"
  9
 10	"github.com/charmbracelet/crush/internal/pubsub"
 11	"github.com/charmbracelet/crush/internal/session"
 12	"github.com/stretchr/testify/require"
 13)
 14
 15// mockSessionService is a minimal mock of session.Service for testing resolveSession.
 16type mockSessionService struct {
 17	sessions []session.Session
 18	created  []session.Session
 19}
 20
 21func (m *mockSessionService) Subscribe(context.Context) <-chan pubsub.Event[session.Session] {
 22	return make(chan pubsub.Event[session.Session])
 23}
 24
 25func (m *mockSessionService) Create(_ context.Context, title string) (session.Session, error) {
 26	s := session.Session{ID: "new-session-id", Title: title}
 27	m.created = append(m.created, s)
 28	return s, nil
 29}
 30
 31func (m *mockSessionService) CreateTitleSession(context.Context, string) (session.Session, error) {
 32	return session.Session{}, nil
 33}
 34
 35func (m *mockSessionService) CreateTaskSession(context.Context, string, string, string) (session.Session, error) {
 36	return session.Session{}, nil
 37}
 38
 39func (m *mockSessionService) Get(_ context.Context, id string) (session.Session, error) {
 40	for _, s := range m.sessions {
 41		if s.ID == id {
 42			return s, nil
 43		}
 44	}
 45	return session.Session{}, sql.ErrNoRows
 46}
 47
 48func (m *mockSessionService) GetLast(_ context.Context) (session.Session, error) {
 49	if len(m.sessions) > 0 {
 50		return m.sessions[0], nil
 51	}
 52	return session.Session{}, sql.ErrNoRows
 53}
 54
 55func (m *mockSessionService) List(context.Context) ([]session.Session, error) {
 56	return m.sessions, nil
 57}
 58
 59func (m *mockSessionService) Save(_ context.Context, s session.Session) (session.Session, error) {
 60	return s, nil
 61}
 62
 63func (m *mockSessionService) UpdateTitleAndUsage(context.Context, string, string, int64, int64, float64) error {
 64	return nil
 65}
 66
 67func (m *mockSessionService) Rename(context.Context, string, string) error {
 68	return nil
 69}
 70
 71func (m *mockSessionService) Delete(context.Context, string) error {
 72	return nil
 73}
 74
 75func (m *mockSessionService) CreateAgentToolSessionID(messageID, toolCallID string) string {
 76	return fmt.Sprintf("%s$$%s", messageID, toolCallID)
 77}
 78
 79func (m *mockSessionService) ParseAgentToolSessionID(sessionID string) (string, string, bool) {
 80	parts := strings.Split(sessionID, "$$")
 81	if len(parts) != 2 {
 82		return "", "", false
 83	}
 84	return parts[0], parts[1], true
 85}
 86
 87func (m *mockSessionService) IsAgentToolSession(sessionID string) bool {
 88	_, _, ok := m.ParseAgentToolSessionID(sessionID)
 89	return ok
 90}
 91
 92func newTestApp(sessions session.Service) *App {
 93	return &App{Sessions: sessions}
 94}
 95
 96func TestResolveSession_NewSession(t *testing.T) {
 97	mock := &mockSessionService{}
 98	app := newTestApp(mock)
 99
100	sess, err := app.resolveSession(t.Context(), "", false)
101	require.NoError(t, err)
102	require.Equal(t, "new-session-id", sess.ID)
103	require.Len(t, mock.created, 1)
104}
105
106func TestResolveSession_ContinueByID(t *testing.T) {
107	mock := &mockSessionService{
108		sessions: []session.Session{
109			{ID: "existing-id", Title: "Old session"},
110		},
111	}
112	app := newTestApp(mock)
113
114	sess, err := app.resolveSession(t.Context(), "existing-id", false)
115	require.NoError(t, err)
116	require.Equal(t, "existing-id", sess.ID)
117	require.Equal(t, "Old session", sess.Title)
118	require.Empty(t, mock.created)
119}
120
121func TestResolveSession_ContinueByID_NotFound(t *testing.T) {
122	mock := &mockSessionService{}
123	app := newTestApp(mock)
124
125	_, err := app.resolveSession(t.Context(), "nonexistent", false)
126	require.Error(t, err)
127	require.Contains(t, err.Error(), "session not found")
128}
129
130func TestResolveSession_ContinueByID_ChildSession(t *testing.T) {
131	mock := &mockSessionService{
132		sessions: []session.Session{
133			{ID: "child-id", ParentSessionID: "parent-id", Title: "Child session"},
134		},
135	}
136	app := newTestApp(mock)
137
138	_, err := app.resolveSession(t.Context(), "child-id", false)
139	require.Error(t, err)
140	require.Contains(t, err.Error(), "cannot continue a child session")
141}
142
143func TestResolveSession_ContinueByID_AgentToolSession(t *testing.T) {
144	mock := &mockSessionService{}
145	app := newTestApp(mock)
146
147	_, err := app.resolveSession(t.Context(), "msg123$$tool456", false)
148	require.Error(t, err)
149	require.Contains(t, err.Error(), "cannot continue an agent tool session")
150}
151
152func TestResolveSession_Last(t *testing.T) {
153	mock := &mockSessionService{
154		sessions: []session.Session{
155			{ID: "most-recent", Title: "Latest session"},
156			{ID: "older", Title: "Older session"},
157		},
158	}
159	app := newTestApp(mock)
160
161	sess, err := app.resolveSession(t.Context(), "", true)
162	require.NoError(t, err)
163	require.Equal(t, "most-recent", sess.ID)
164	require.Empty(t, mock.created)
165}
166
167func TestResolveSession_Last_NoSessions(t *testing.T) {
168	mock := &mockSessionService{}
169	app := newTestApp(mock)
170
171	_, err := app.resolveSession(t.Context(), "", true)
172	require.Error(t, err)
173	require.Contains(t, err.Error(), "no sessions found")
174}