1package server
2
3import (
4 "context"
5 "encoding/json"
6 "log/slog"
7 "net/http"
8 "net/http/httptest"
9 "os"
10 "strings"
11 "testing"
12 "time"
13
14 "shelley.exe.dev/claudetool"
15 "shelley.exe.dev/db"
16 "shelley.exe.dev/db/generated"
17 "shelley.exe.dev/llm"
18 "shelley.exe.dev/loop"
19)
20
21// TestHarness provides a DSL-like interface for testing conversations.
22type TestHarness struct {
23 t *testing.T
24 db *db.DB
25 server *Server
26 cleanup func()
27 llm *loop.PredictableService
28 convID string
29 timeout time.Duration
30 responsesCount int // Number of agent responses seen so far
31}
32
33// NewTestHarness creates a new test harness with a predictable LLM and bash tool.
34func NewTestHarness(t *testing.T) *TestHarness {
35 t.Helper()
36
37 database, cleanup := setupTestDB(t)
38
39 predictableService := loop.NewPredictableService()
40 llmManager := &testLLMManager{service: predictableService}
41 logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
42
43 toolSetConfig := claudetool.ToolSetConfig{EnableBrowser: false}
44 server := NewServer(database, llmManager, toolSetConfig, logger, true, "", "predictable", "", nil)
45
46 return &TestHarness{
47 t: t,
48 db: database,
49 server: server,
50 cleanup: cleanup,
51 llm: predictableService,
52 timeout: 5 * time.Second,
53 }
54}
55
56// Close cleans up the test harness resources.
57func (h *TestHarness) Close() {
58 h.cleanup()
59}
60
61// NewConversation starts a new conversation with the given message and options.
62func (h *TestHarness) NewConversation(msg, cwd string) *TestHarness {
63 h.t.Helper()
64
65 chatReq := ChatRequest{
66 Message: msg,
67 Model: "predictable",
68 Cwd: cwd,
69 }
70 chatBody, _ := json.Marshal(chatReq)
71
72 req := httptest.NewRequest("POST", "/api/conversations/new", strings.NewReader(string(chatBody)))
73 req.Header.Set("Content-Type", "application/json")
74 w := httptest.NewRecorder()
75
76 h.server.handleNewConversation(w, req)
77 if w.Code != http.StatusCreated {
78 h.t.Fatalf("NewConversation: expected status 201, got %d: %s", w.Code, w.Body.String())
79 }
80
81 var resp struct {
82 ConversationID string `json:"conversation_id"`
83 }
84 if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
85 h.t.Fatalf("NewConversation: failed to parse response: %v", err)
86 }
87 h.convID = resp.ConversationID
88 h.responsesCount = 0 // Reset for new conversation
89 return h
90}
91
92// Chat sends a message to the current conversation.
93func (h *TestHarness) Chat(msg string) *TestHarness {
94 h.t.Helper()
95
96 if h.convID == "" {
97 h.t.Fatal("Chat: no conversation started, call NewConversation first")
98 }
99
100 chatReq := ChatRequest{
101 Message: msg,
102 Model: "predictable",
103 }
104 chatBody, _ := json.Marshal(chatReq)
105
106 req := httptest.NewRequest("POST", "/api/conversation/"+h.convID+"/chat", strings.NewReader(string(chatBody)))
107 req.Header.Set("Content-Type", "application/json")
108 w := httptest.NewRecorder()
109
110 h.server.handleChatConversation(w, req, h.convID)
111 if w.Code != http.StatusAccepted {
112 h.t.Fatalf("Chat: expected status 202, got %d: %s", w.Code, w.Body.String())
113 }
114 return h
115}
116
117// WaitToolResult waits for a tool result and returns its text content.
118func (h *TestHarness) WaitToolResult() string {
119 h.t.Helper()
120
121 if h.convID == "" {
122 h.t.Fatal("WaitToolResult: no conversation started")
123 }
124
125 deadline := time.Now().Add(h.timeout)
126 for time.Now().Before(deadline) {
127 var messages []generated.Message
128 err := h.db.Queries(context.Background(), func(q *generated.Queries) error {
129 var qerr error
130 messages, qerr = q.ListMessages(context.Background(), h.convID)
131 return qerr
132 })
133 if err != nil {
134 h.t.Fatalf("WaitToolResult: failed to get messages: %v", err)
135 }
136
137 for _, msg := range messages {
138 if msg.Type != string(db.MessageTypeUser) || msg.LlmData == nil {
139 continue
140 }
141
142 var llmMsg llm.Message
143 if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
144 continue
145 }
146
147 for _, content := range llmMsg.Content {
148 if content.Type == llm.ContentTypeToolResult {
149 for _, result := range content.ToolResult {
150 if result.Type == llm.ContentTypeText && result.Text != "" {
151 return result.Text
152 }
153 }
154 }
155 }
156 }
157
158 time.Sleep(100 * time.Millisecond)
159 }
160
161 h.t.Fatalf("WaitToolResult: timed out waiting for tool result")
162 return ""
163}
164
165// WaitResponse waits for the assistant's text response (end of turn).
166// It waits for a NEW response that hasn't been seen before.
167func (h *TestHarness) WaitResponse() string {
168 h.t.Helper()
169
170 if h.convID == "" {
171 h.t.Fatal("WaitResponse: no conversation started")
172 }
173
174 targetCount := h.responsesCount + 1
175
176 deadline := time.Now().Add(h.timeout)
177 for time.Now().Before(deadline) {
178 var messages []generated.Message
179 err := h.db.Queries(context.Background(), func(q *generated.Queries) error {
180 var qerr error
181 messages, qerr = q.ListMessages(context.Background(), h.convID)
182 return qerr
183 })
184 if err != nil {
185 h.t.Fatalf("WaitResponse: failed to get messages: %v", err)
186 }
187
188 // Count assistant messages with end_of_turn
189 count := 0
190 var lastText string
191 for _, msg := range messages {
192 if msg.Type != string(db.MessageTypeAgent) || msg.LlmData == nil {
193 continue
194 }
195
196 var llmMsg llm.Message
197 if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
198 continue
199 }
200
201 if llmMsg.EndOfTurn {
202 count++
203 for _, content := range llmMsg.Content {
204 if content.Type == llm.ContentTypeText {
205 lastText = content.Text
206 break
207 }
208 }
209 }
210 }
211
212 if count >= targetCount {
213 h.responsesCount = count
214 return lastText
215 }
216
217 time.Sleep(100 * time.Millisecond)
218 }
219
220 h.t.Fatalf("WaitResponse: timed out waiting for response (seen %d, need %d)", h.responsesCount, targetCount)
221 return ""
222}
223
224// ConversationID returns the current conversation ID.
225func (h *TestHarness) ConversationID() string {
226 return h.convID
227}
228
229// GetContextWindowSize retrieves the current context window size from the server.
230func (h *TestHarness) GetContextWindowSize() uint64 {
231 h.t.Helper()
232
233 if h.convID == "" {
234 h.t.Fatal("GetContextWindowSize: no conversation started")
235 }
236
237 // Use handleGetConversation (GET /conversation/<id>) instead of stream endpoint
238 req := httptest.NewRequest("GET", "/api/conversation/"+h.convID, nil)
239 w := httptest.NewRecorder()
240
241 h.server.handleGetConversation(w, req, h.convID)
242 if w.Code != http.StatusOK {
243 h.t.Fatalf("GetContextWindowSize: expected status 200, got %d: %s", w.Code, w.Body.String())
244 }
245
246 var resp StreamResponse
247 if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
248 h.t.Fatalf("GetContextWindowSize: failed to parse response: %v", err)
249 }
250
251 return resp.ContextWindowSize
252}