1package app
2
3import (
4 "context"
5 "database/sql"
6 "fmt"
7 "strings"
8 "testing"
9
10 "github.com/charmbracelet/crush/internal/config"
11 "github.com/charmbracelet/crush/internal/pubsub"
12 "github.com/charmbracelet/crush/internal/session"
13 "github.com/stretchr/testify/require"
14)
15
16// mockSessionService is a minimal mock of session.Service for testing resolveSession.
17type mockSessionService struct {
18 sessions []session.Session
19 created []session.Session
20}
21
22func (m *mockSessionService) Subscribe(context.Context) <-chan pubsub.Event[session.Session] {
23 return make(chan pubsub.Event[session.Session])
24}
25
26func (m *mockSessionService) Create(_ context.Context, title string) (session.Session, error) {
27 s := session.Session{ID: "new-session-id", Title: title}
28 m.created = append(m.created, s)
29 return s, nil
30}
31
32func (m *mockSessionService) CreateTitleSession(context.Context, string) (session.Session, error) {
33 return session.Session{}, nil
34}
35
36func (m *mockSessionService) CreateTaskSession(context.Context, string, string, string) (session.Session, error) {
37 return session.Session{}, nil
38}
39
40func (m *mockSessionService) Get(_ context.Context, id string) (session.Session, error) {
41 for _, s := range m.sessions {
42 if s.ID == id {
43 return s, nil
44 }
45 }
46 return session.Session{}, sql.ErrNoRows
47}
48
49func (m *mockSessionService) GetLast(_ context.Context) (session.Session, error) {
50 if len(m.sessions) > 0 {
51 return m.sessions[0], nil
52 }
53 return session.Session{}, sql.ErrNoRows
54}
55
56func (m *mockSessionService) List(context.Context) ([]session.Session, error) {
57 return m.sessions, nil
58}
59
60func (m *mockSessionService) Save(_ context.Context, s session.Session) (session.Session, error) {
61 return s, nil
62}
63
64func (m *mockSessionService) SaveWithModels(_ context.Context, s session.Session, _ map[config.SelectedModelType]config.SelectedModel) (session.Session, error) {
65 return s, nil
66}
67
68func (m *mockSessionService) UpdateSessionModels(context.Context, string, map[config.SelectedModelType]config.SelectedModel) error {
69 return nil
70}
71
72func (m *mockSessionService) UpdateTitleAndUsage(context.Context, string, string, int64, int64, float64) error {
73 return nil
74}
75
76func (m *mockSessionService) Rename(context.Context, string, string) error {
77 return nil
78}
79
80func (m *mockSessionService) Delete(context.Context, string) error {
81 return nil
82}
83
84func (m *mockSessionService) CreateAgentToolSessionID(messageID, toolCallID string) string {
85 return fmt.Sprintf("%s$$%s", messageID, toolCallID)
86}
87
88func (m *mockSessionService) ParseAgentToolSessionID(sessionID string) (string, string, bool) {
89 parts := strings.Split(sessionID, "$$")
90 if len(parts) != 2 {
91 return "", "", false
92 }
93 return parts[0], parts[1], true
94}
95
96func (m *mockSessionService) IsAgentToolSession(sessionID string) bool {
97 _, _, ok := m.ParseAgentToolSessionID(sessionID)
98 return ok
99}
100
101func newTestApp(sessions session.Service) *App {
102 return &App{Sessions: sessions}
103}
104
105func TestResolveSession_NewSession(t *testing.T) {
106 mock := &mockSessionService{}
107 app := newTestApp(mock)
108
109 sess, err := app.resolveSession(t.Context(), "", false)
110 require.NoError(t, err)
111 require.Equal(t, "new-session-id", sess.ID)
112 require.Len(t, mock.created, 1)
113}
114
115func TestResolveSession_ContinueByID(t *testing.T) {
116 mock := &mockSessionService{
117 sessions: []session.Session{
118 {ID: "existing-id", Title: "Old session"},
119 },
120 }
121 app := newTestApp(mock)
122
123 sess, err := app.resolveSession(t.Context(), "existing-id", false)
124 require.NoError(t, err)
125 require.Equal(t, "existing-id", sess.ID)
126 require.Equal(t, "Old session", sess.Title)
127 require.Empty(t, mock.created)
128}
129
130func TestResolveSession_ContinueByID_NotFound(t *testing.T) {
131 mock := &mockSessionService{}
132 app := newTestApp(mock)
133
134 _, err := app.resolveSession(t.Context(), "nonexistent", false)
135 require.Error(t, err)
136 require.Contains(t, err.Error(), "session not found")
137}
138
139func TestResolveSession_ContinueByID_ChildSession(t *testing.T) {
140 mock := &mockSessionService{
141 sessions: []session.Session{
142 {ID: "child-id", ParentSessionID: "parent-id", Title: "Child session"},
143 },
144 }
145 app := newTestApp(mock)
146
147 _, err := app.resolveSession(t.Context(), "child-id", false)
148 require.Error(t, err)
149 require.Contains(t, err.Error(), "cannot continue a child session")
150}
151
152func TestResolveSession_ContinueByID_AgentToolSession(t *testing.T) {
153 mock := &mockSessionService{}
154 app := newTestApp(mock)
155
156 _, err := app.resolveSession(t.Context(), "msg123$$tool456", false)
157 require.Error(t, err)
158 require.Contains(t, err.Error(), "cannot continue an agent tool session")
159}
160
161func TestResolveSession_Last(t *testing.T) {
162 mock := &mockSessionService{
163 sessions: []session.Session{
164 {ID: "most-recent", Title: "Latest session"},
165 {ID: "older", Title: "Older session"},
166 },
167 }
168 app := newTestApp(mock)
169
170 sess, err := app.resolveSession(t.Context(), "", true)
171 require.NoError(t, err)
172 require.Equal(t, "most-recent", sess.ID)
173 require.Empty(t, mock.created)
174}
175
176func TestResolveSession_Last_NoSessions(t *testing.T) {
177 mock := &mockSessionService{}
178 app := newTestApp(mock)
179
180 _, err := app.resolveSession(t.Context(), "", true)
181 require.Error(t, err)
182 require.Contains(t, err.Error(), "no sessions found")
183}