testharness_test.go

  1package server
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"log/slog"
  7	"net/http"
  8	"net/http/httptest"
  9	"os"
 10	"strings"
 11	"testing"
 12	"time"
 13
 14	"shelley.exe.dev/claudetool"
 15	"shelley.exe.dev/db"
 16	"shelley.exe.dev/db/generated"
 17	"shelley.exe.dev/llm"
 18	"shelley.exe.dev/loop"
 19)
 20
 21// TestHarness provides a DSL-like interface for testing conversations.
 22type TestHarness struct {
 23	t              *testing.T
 24	db             *db.DB
 25	server         *Server
 26	cleanup        func()
 27	llm            *loop.PredictableService
 28	convID         string
 29	timeout        time.Duration
 30	responsesCount int // Number of agent responses seen so far
 31}
 32
 33// NewTestHarness creates a new test harness with a predictable LLM and bash tool.
 34func NewTestHarness(t *testing.T) *TestHarness {
 35	t.Helper()
 36
 37	database, cleanup := setupTestDB(t)
 38
 39	predictableService := loop.NewPredictableService()
 40	llmManager := &testLLMManager{service: predictableService}
 41	logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelWarn}))
 42
 43	toolSetConfig := claudetool.ToolSetConfig{EnableBrowser: false}
 44	server := NewServer(database, llmManager, toolSetConfig, logger, true, "", "predictable", "", nil)
 45
 46	return &TestHarness{
 47		t:       t,
 48		db:      database,
 49		server:  server,
 50		cleanup: cleanup,
 51		llm:     predictableService,
 52		timeout: 5 * time.Second,
 53	}
 54}
 55
 56// Close cleans up the test harness resources.
 57func (h *TestHarness) Close() {
 58	h.cleanup()
 59}
 60
 61// NewConversation starts a new conversation with the given message and options.
 62func (h *TestHarness) NewConversation(msg, cwd string) *TestHarness {
 63	h.t.Helper()
 64
 65	chatReq := ChatRequest{
 66		Message: msg,
 67		Model:   "predictable",
 68		Cwd:     cwd,
 69	}
 70	chatBody, _ := json.Marshal(chatReq)
 71
 72	req := httptest.NewRequest("POST", "/api/conversations/new", strings.NewReader(string(chatBody)))
 73	req.Header.Set("Content-Type", "application/json")
 74	w := httptest.NewRecorder()
 75
 76	h.server.handleNewConversation(w, req)
 77	if w.Code != http.StatusCreated {
 78		h.t.Fatalf("NewConversation: expected status 201, got %d: %s", w.Code, w.Body.String())
 79	}
 80
 81	var resp struct {
 82		ConversationID string `json:"conversation_id"`
 83	}
 84	if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
 85		h.t.Fatalf("NewConversation: failed to parse response: %v", err)
 86	}
 87	h.convID = resp.ConversationID
 88	h.responsesCount = 0 // Reset for new conversation
 89	return h
 90}
 91
 92// Chat sends a message to the current conversation.
 93func (h *TestHarness) Chat(msg string) *TestHarness {
 94	h.t.Helper()
 95
 96	if h.convID == "" {
 97		h.t.Fatal("Chat: no conversation started, call NewConversation first")
 98	}
 99
100	chatReq := ChatRequest{
101		Message: msg,
102		Model:   "predictable",
103	}
104	chatBody, _ := json.Marshal(chatReq)
105
106	req := httptest.NewRequest("POST", "/api/conversation/"+h.convID+"/chat", strings.NewReader(string(chatBody)))
107	req.Header.Set("Content-Type", "application/json")
108	w := httptest.NewRecorder()
109
110	h.server.handleChatConversation(w, req, h.convID)
111	if w.Code != http.StatusAccepted {
112		h.t.Fatalf("Chat: expected status 202, got %d: %s", w.Code, w.Body.String())
113	}
114	return h
115}
116
117// WaitToolResult waits for a tool result and returns its text content.
118func (h *TestHarness) WaitToolResult() string {
119	h.t.Helper()
120
121	if h.convID == "" {
122		h.t.Fatal("WaitToolResult: no conversation started")
123	}
124
125	deadline := time.Now().Add(h.timeout)
126	for time.Now().Before(deadline) {
127		var messages []generated.Message
128		err := h.db.Queries(context.Background(), func(q *generated.Queries) error {
129			var qerr error
130			messages, qerr = q.ListMessages(context.Background(), h.convID)
131			return qerr
132		})
133		if err != nil {
134			h.t.Fatalf("WaitToolResult: failed to get messages: %v", err)
135		}
136
137		for _, msg := range messages {
138			if msg.Type != string(db.MessageTypeUser) || msg.LlmData == nil {
139				continue
140			}
141
142			var llmMsg llm.Message
143			if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
144				continue
145			}
146
147			for _, content := range llmMsg.Content {
148				if content.Type == llm.ContentTypeToolResult {
149					for _, result := range content.ToolResult {
150						if result.Type == llm.ContentTypeText && result.Text != "" {
151							return result.Text
152						}
153					}
154				}
155			}
156		}
157
158		time.Sleep(100 * time.Millisecond)
159	}
160
161	h.t.Fatalf("WaitToolResult: timed out waiting for tool result")
162	return ""
163}
164
165// WaitResponse waits for the assistant's text response (end of turn).
166// It waits for a NEW response that hasn't been seen before.
167func (h *TestHarness) WaitResponse() string {
168	h.t.Helper()
169
170	if h.convID == "" {
171		h.t.Fatal("WaitResponse: no conversation started")
172	}
173
174	targetCount := h.responsesCount + 1
175
176	deadline := time.Now().Add(h.timeout)
177	for time.Now().Before(deadline) {
178		var messages []generated.Message
179		err := h.db.Queries(context.Background(), func(q *generated.Queries) error {
180			var qerr error
181			messages, qerr = q.ListMessages(context.Background(), h.convID)
182			return qerr
183		})
184		if err != nil {
185			h.t.Fatalf("WaitResponse: failed to get messages: %v", err)
186		}
187
188		// Count assistant messages with end_of_turn
189		count := 0
190		var lastText string
191		for _, msg := range messages {
192			if msg.Type != string(db.MessageTypeAgent) || msg.LlmData == nil {
193				continue
194			}
195
196			var llmMsg llm.Message
197			if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil {
198				continue
199			}
200
201			if llmMsg.EndOfTurn {
202				count++
203				for _, content := range llmMsg.Content {
204					if content.Type == llm.ContentTypeText {
205						lastText = content.Text
206						break
207					}
208				}
209			}
210		}
211
212		if count >= targetCount {
213			h.responsesCount = count
214			return lastText
215		}
216
217		time.Sleep(100 * time.Millisecond)
218	}
219
220	h.t.Fatalf("WaitResponse: timed out waiting for response (seen %d, need %d)", h.responsesCount, targetCount)
221	return ""
222}
223
224// ConversationID returns the current conversation ID.
225func (h *TestHarness) ConversationID() string {
226	return h.convID
227}
228
229// GetContextWindowSize retrieves the current context window size from the server.
230func (h *TestHarness) GetContextWindowSize() uint64 {
231	h.t.Helper()
232
233	if h.convID == "" {
234		h.t.Fatal("GetContextWindowSize: no conversation started")
235	}
236
237	// Use handleGetConversation (GET /conversation/<id>) instead of stream endpoint
238	req := httptest.NewRequest("GET", "/api/conversation/"+h.convID, nil)
239	w := httptest.NewRecorder()
240
241	h.server.handleGetConversation(w, req, h.convID)
242	if w.Code != http.StatusOK {
243		h.t.Fatalf("GetContextWindowSize: expected status 200, got %d: %s", w.Code, w.Body.String())
244	}
245
246	var resp StreamResponse
247	if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
248		h.t.Fatalf("GetContextWindowSize: failed to parse response: %v", err)
249	}
250
251	return resp.ContextWindowSize
252}