1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "net/http"
7 "net/http/httptest"
8 "os"
9 "testing"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/message"
15 "github.com/openai/openai-go"
16 "github.com/openai/openai-go/option"
17)
18
19func TestMain(m *testing.M) {
20 _, err := config.Init(".", "", true, os.Environ())
21 if err != nil {
22 panic("Failed to initialize config: " + err.Error())
23 }
24
25 os.Exit(m.Run())
26}
27
28func TestOpenAIClientStreamChoices(t *testing.T) {
29 // Create a mock server that returns Server-Sent Events with empty choices
30 // This simulates the 🤡 behavior when a server returns 200 instead of 404
31 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
32 w.Header().Set("Content-Type", "text/event-stream")
33 w.Header().Set("Cache-Control", "no-cache")
34 w.Header().Set("Connection", "keep-alive")
35 w.WriteHeader(http.StatusOK)
36
37 emptyChoicesChunk := map[string]any{
38 "id": "chat-completion-test",
39 "object": "chat.completion.chunk",
40 "created": time.Now().Unix(),
41 "model": "test-model",
42 "choices": []any{}, // Empty choices array that causes panic
43 }
44
45 jsonData, _ := json.Marshal(emptyChoicesChunk)
46 w.Write([]byte("data: " + string(jsonData) + "\n\n"))
47 w.Write([]byte("data: [DONE]\n\n"))
48 }))
49 defer server.Close()
50
51 // Create OpenAI client pointing to our mock server
52 client := &openaiClient{
53 providerOptions: providerClientOptions{
54 modelType: config.SelectedModelTypeLarge,
55 apiKey: "test-key",
56 systemMessage: "test",
57 model: func(config.SelectedModelType) catwalk.Model {
58 return catwalk.Model{
59 ID: "test-model",
60 Name: "test-model",
61 }
62 },
63 },
64 client: openai.NewClient(
65 option.WithAPIKey("test-key"),
66 option.WithBaseURL(server.URL),
67 ),
68 }
69
70 // Create test messages
71 messages := []message.Message{
72 {
73 Role: message.User,
74 Parts: []message.ContentPart{message.TextContent{Text: "Hello"}},
75 },
76 }
77
78 ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
79 defer cancel()
80
81 eventsChan := client.stream(ctx, messages, nil)
82
83 // Collect events - this will panic without the bounds check
84 for event := range eventsChan {
85 t.Logf("Received event: %+v", event)
86 if event.Type == EventError || event.Type == EventComplete {
87 break
88 }
89 }
90}