1package app
2
3import (
4 "context"
5 "database/sql"
6 "fmt"
7 "strings"
8 "testing"
9
10 "github.com/charmbracelet/crush/internal/pubsub"
11 "github.com/charmbracelet/crush/internal/session"
12 "github.com/stretchr/testify/require"
13)
14
15// mockSessionService is a minimal mock of session.Service for testing resolveSession.
16type mockSessionService struct {
17 sessions []session.Session
18 created []session.Session
19}
20
21func (m *mockSessionService) Subscribe(context.Context) <-chan pubsub.Event[session.Session] {
22 return make(chan pubsub.Event[session.Session])
23}
24
25func (m *mockSessionService) Create(_ context.Context, title string) (session.Session, error) {
26 s := session.Session{ID: "new-session-id", Title: title}
27 m.created = append(m.created, s)
28 return s, nil
29}
30
31func (m *mockSessionService) CreateTitleSession(context.Context, string) (session.Session, error) {
32 return session.Session{}, nil
33}
34
35func (m *mockSessionService) CreateTaskSession(context.Context, string, string, string) (session.Session, error) {
36 return session.Session{}, nil
37}
38
39func (m *mockSessionService) Get(_ context.Context, id string) (session.Session, error) {
40 for _, s := range m.sessions {
41 if s.ID == id {
42 return s, nil
43 }
44 }
45 return session.Session{}, sql.ErrNoRows
46}
47
48func (m *mockSessionService) GetLast(_ context.Context) (session.Session, error) {
49 if len(m.sessions) > 0 {
50 return m.sessions[0], nil
51 }
52 return session.Session{}, sql.ErrNoRows
53}
54
55func (m *mockSessionService) List(context.Context) ([]session.Session, error) {
56 return m.sessions, nil
57}
58
59func (m *mockSessionService) Save(_ context.Context, s session.Session) (session.Session, error) {
60 return s, nil
61}
62
63func (m *mockSessionService) UpdateTitleAndUsage(context.Context, string, string, int64, int64, float64) error {
64 return nil
65}
66
67func (m *mockSessionService) Rename(context.Context, string, string) error {
68 return nil
69}
70
71func (m *mockSessionService) Delete(context.Context, string) error {
72 return nil
73}
74
75func (m *mockSessionService) CreateAgentToolSessionID(messageID, toolCallID string) string {
76 return fmt.Sprintf("%s$$%s", messageID, toolCallID)
77}
78
79func (m *mockSessionService) ParseAgentToolSessionID(sessionID string) (string, string, bool) {
80 parts := strings.Split(sessionID, "$$")
81 if len(parts) != 2 {
82 return "", "", false
83 }
84 return parts[0], parts[1], true
85}
86
87func (m *mockSessionService) IsAgentToolSession(sessionID string) bool {
88 _, _, ok := m.ParseAgentToolSessionID(sessionID)
89 return ok
90}
91
92func newTestApp(sessions session.Service) *App {
93 return &App{Sessions: sessions}
94}
95
96func TestResolveSession_NewSession(t *testing.T) {
97 mock := &mockSessionService{}
98 app := newTestApp(mock)
99
100 sess, err := app.resolveSession(t.Context(), "", false)
101 require.NoError(t, err)
102 require.Equal(t, "new-session-id", sess.ID)
103 require.Len(t, mock.created, 1)
104}
105
106func TestResolveSession_ContinueByID(t *testing.T) {
107 mock := &mockSessionService{
108 sessions: []session.Session{
109 {ID: "existing-id", Title: "Old session"},
110 },
111 }
112 app := newTestApp(mock)
113
114 sess, err := app.resolveSession(t.Context(), "existing-id", false)
115 require.NoError(t, err)
116 require.Equal(t, "existing-id", sess.ID)
117 require.Equal(t, "Old session", sess.Title)
118 require.Empty(t, mock.created)
119}
120
121func TestResolveSession_ContinueByID_NotFound(t *testing.T) {
122 mock := &mockSessionService{}
123 app := newTestApp(mock)
124
125 _, err := app.resolveSession(t.Context(), "nonexistent", false)
126 require.Error(t, err)
127 require.Contains(t, err.Error(), "session not found")
128}
129
130func TestResolveSession_ContinueByID_ChildSession(t *testing.T) {
131 mock := &mockSessionService{
132 sessions: []session.Session{
133 {ID: "child-id", ParentSessionID: "parent-id", Title: "Child session"},
134 },
135 }
136 app := newTestApp(mock)
137
138 _, err := app.resolveSession(t.Context(), "child-id", false)
139 require.Error(t, err)
140 require.Contains(t, err.Error(), "cannot continue a child session")
141}
142
143func TestResolveSession_ContinueByID_AgentToolSession(t *testing.T) {
144 mock := &mockSessionService{}
145 app := newTestApp(mock)
146
147 _, err := app.resolveSession(t.Context(), "msg123$$tool456", false)
148 require.Error(t, err)
149 require.Contains(t, err.Error(), "cannot continue an agent tool session")
150}
151
152func TestResolveSession_Last(t *testing.T) {
153 mock := &mockSessionService{
154 sessions: []session.Session{
155 {ID: "most-recent", Title: "Latest session"},
156 {ID: "older", Title: "Older session"},
157 },
158 }
159 app := newTestApp(mock)
160
161 sess, err := app.resolveSession(t.Context(), "", true)
162 require.NoError(t, err)
163 require.Equal(t, "most-recent", sess.ID)
164 require.Empty(t, mock.created)
165}
166
167func TestResolveSession_Last_NoSessions(t *testing.T) {
168 mock := &mockSessionService{}
169 app := newTestApp(mock)
170
171 _, err := app.resolveSession(t.Context(), "", true)
172 require.Error(t, err)
173 require.Contains(t, err.Error(), "no sessions found")
174}