diff --git a/server/cancel_test.go b/server/cancel_test.go index 4c660bdf98ee062e0040cbe0c4a45105d488d49f..d25f03f4bec33f98aa7c30d87319679f9c75af00 100644 --- a/server/cancel_test.go +++ b/server/cancel_test.go @@ -39,6 +39,19 @@ func setupTestDB(t *testing.T) (*db.DB, func()) { } } +// waitFor polls a condition until it returns true or the timeout is reached. +func waitFor(t *testing.T, timeout time.Duration, condition func() bool) { + t.Helper() + deadline := time.Now().Add(timeout) + for time.Now().Before(deadline) { + if condition() { + return + } + time.Sleep(10 * time.Millisecond) + } + t.Fatal("timed out waiting for condition") +} + // TestCancelWithPredictableModel tests cancellation with the predictable model func TestCancelWithPredictableModel(t *testing.T) { // Create test database @@ -77,24 +90,34 @@ func TestCancelWithPredictableModel(t *testing.T) { t.Fatalf("expected status 202, got %d: %s", w.Code, w.Body.String()) } - // Wait for the tool to start executing - time.Sleep(300 * time.Millisecond) - - // Verify agent is working - var messages []generated.Message - err = database.Queries(context.Background(), func(q *generated.Queries) error { - var qerr error - messages, qerr = q.ListMessages(context.Background(), conversationID) - return qerr + // Wait for agent to record an assistant message with tool use + waitFor(t, 5*time.Second, func() bool { + var messages []generated.Message + err := database.Queries(context.Background(), func(q *generated.Queries) error { + var qerr error + messages, qerr = q.ListMessages(context.Background(), conversationID) + return qerr + }) + if err != nil || len(messages) < 2 { + return false + } + // Check for assistant message with tool use + for _, msg := range messages { + if msg.Type != string(db.MessageTypeAgent) || msg.LlmData == nil { + continue + } + var llmMsg llm.Message + if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err != nil { + continue + } + for _, content := range llmMsg.Content { + if content.Type == llm.ContentTypeToolUse { + return true + } + } + } + return false }) - if err != nil { - t.Fatalf("failed to get messages: %v", err) - } - - // Should have user message and assistant message with tool use - if len(messages) < 2 { - t.Fatalf("expected at least 2 messages, got %d", len(messages)) - } // Cancel the conversation cancelReq := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/cancel", nil) @@ -115,10 +138,13 @@ func TestCancelWithPredictableModel(t *testing.T) { t.Errorf("expected status 'cancelled', got '%s'", cancelResp["status"]) } - // Wait for cancellation to complete and cancelled message to be recorded - time.Sleep(300 * time.Millisecond) + // Wait for agent to stop working (cancellation complete) + waitFor(t, 5*time.Second, func() bool { + return !server.IsAgentWorking(conversationID) + }) // Verify that a cancelled tool result was recorded + var messages []generated.Message err = database.Queries(context.Background(), func(q *generated.Queries) error { var qerr error messages, qerr = q.ListMessages(context.Background(), conversationID) @@ -195,8 +221,10 @@ func TestCancelWithPredictableModel(t *testing.T) { t.Fatalf("expected status 202 for resume, got %d: %s", resumeW.Code, resumeW.Body.String()) } - // Wait for the response - time.Sleep(300 * time.Millisecond) + // Wait for agent to finish processing the resumed conversation + waitFor(t, 5*time.Second, func() bool { + return !server.IsAgentWorking(conversationID) + }) // Verify conversation continued err = database.Queries(context.Background(), func(q *generated.Queries) error { @@ -312,8 +340,10 @@ func TestCancelDuringTextGeneration(t *testing.T) { t.Fatalf("expected status 202, got %d: %s", w.Code, w.Body.String()) } - // Wait briefly for processing to start - time.Sleep(100 * time.Millisecond) + // Wait for agent to start working + waitFor(t, 5*time.Second, func() bool { + return server.IsAgentWorking(conversationID) + }) // Cancel during text generation cancelReq := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/cancel", nil) @@ -325,8 +355,10 @@ func TestCancelDuringTextGeneration(t *testing.T) { t.Fatalf("expected status 200, got %d: %s", cancelW.Code, cancelW.Body.String()) } - // Wait for cancellation - time.Sleep(200 * time.Millisecond) + // Wait for agent to stop working (cancellation complete) + waitFor(t, 5*time.Second, func() bool { + return !server.IsAgentWorking(conversationID) + }) // Verify that no cancelled tool result was added (since there was no tool call) var messages []generated.Message diff --git a/server/duplicate_tool_result_test.go b/server/duplicate_tool_result_test.go index 6b4556dc02ffc206e8730ea984018de6b8c03c77..83c449e3b42586c004c15bc74d510c4f2db1db24 100644 --- a/server/duplicate_tool_result_test.go +++ b/server/duplicate_tool_result_test.go @@ -108,9 +108,6 @@ func TestCancelAfterToolCompletesCreatesDuplicateToolResult(t *testing.T) { t.Fatal("tool result was not found - tool didn't complete") } - // Give a tiny bit more time for the loop to stabilize - time.Sleep(100 * time.Millisecond) - // Now cancel the conversation AFTER the tool has completed // This should NOT create a new tool_result because the tool already finished cancelReq := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/cancel", nil) @@ -121,8 +118,14 @@ func TestCancelAfterToolCompletesCreatesDuplicateToolResult(t *testing.T) { t.Fatalf("cancel: expected status 200, got %d: %s", cancelW.Code, cancelW.Body.String()) } - // Wait for cancel to process - time.Sleep(200 * time.Millisecond) + // Wait for agent to stop working (cancel to process) + deadline = time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + if !server.IsAgentWorking(conversationID) { + break + } + time.Sleep(10 * time.Millisecond) + } // Check the messages to see if there are duplicate tool_results for the same tool_use_id var messages []generated.Message @@ -179,8 +182,14 @@ func TestCancelAfterToolCompletesCreatesDuplicateToolResult(t *testing.T) { t.Fatalf("resume: expected status 202, got %d: %s", resumeW.Code, resumeW.Body.String()) } - // Wait for the request to be processed - time.Sleep(300 * time.Millisecond) + // Wait for agent to stop working + deadline = time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + if !server.IsAgentWorking(conversationID) { + break + } + time.Sleep(10 * time.Millisecond) + } // Check the last request sent to the LLM for duplicate tool_results lastRequest := predictableService.GetLastRequest() diff --git a/server/server.go b/server/server.go index df26fa0805e7cc9771614c1b85743c91ffd7a7c1..4a27754bba6b9a24b82f827d6db6484afaa3b97f 100644 --- a/server/server.go +++ b/server/server.go @@ -857,6 +857,18 @@ func (s *Server) getWorkingConversations() map[string]bool { return working } +// IsAgentWorking returns whether the agent is currently working on the given conversation. +// Returns false if the conversation doesn't have an active manager. +func (s *Server) IsAgentWorking(conversationID string) bool { + s.mu.Lock() + manager, exists := s.activeConversations[conversationID] + s.mu.Unlock() + if !exists { + return false + } + return manager.IsAgentWorking() +} + // Cleanup removes inactive conversation managers func (s *Server) Cleanup() { s.mu.Lock()