1package provider
2
3import (
4 "context"
5 "encoding/json"
6 "net/http"
7 "net/http/httptest"
8 "os"
9 "strings"
10 "testing"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/message"
16 "github.com/openai/openai-go"
17 "github.com/openai/openai-go/option"
18)
19
20func TestMain(m *testing.M) {
21 _, err := config.Init(".", "", true)
22 if err != nil {
23 panic("Failed to initialize config: " + err.Error())
24 }
25
26 os.Exit(m.Run())
27}
28
29func TestOpenAIClientStreamChoices(t *testing.T) {
30 // Create a mock server that returns Server-Sent Events with empty choices
31 // This simulates the 🤡 behavior when a server returns 200 instead of 404
32 server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
33 w.Header().Set("Content-Type", "text/event-stream")
34 w.Header().Set("Cache-Control", "no-cache")
35 w.Header().Set("Connection", "keep-alive")
36 w.WriteHeader(http.StatusOK)
37
38 emptyChoicesChunk := map[string]any{
39 "id": "chat-completion-test",
40 "object": "chat.completion.chunk",
41 "created": time.Now().Unix(),
42 "model": "test-model",
43 "choices": []any{}, // Empty choices array that causes panic
44 }
45
46 jsonData, _ := json.Marshal(emptyChoicesChunk)
47 w.Write([]byte("data: " + string(jsonData) + "\n\n"))
48 w.Write([]byte("data: [DONE]\n\n"))
49 }))
50 defer server.Close()
51
52 // Create OpenAI client pointing to our mock server
53 client := &openaiClient{
54 providerOptions: providerClientOptions{
55 modelType: config.SelectedModelTypeLarge,
56 apiKey: "test-key",
57 systemMessage: "test",
58 model: func(config.SelectedModelType) catwalk.Model {
59 return catwalk.Model{
60 ID: "test-model",
61 Name: "test-model",
62 }
63 },
64 },
65 client: openai.NewClient(
66 option.WithAPIKey("test-key"),
67 option.WithBaseURL(server.URL),
68 ),
69 }
70
71 // Create test messages
72 messages := []message.Message{
73 {
74 Role: message.User,
75 Parts: []message.ContentPart{message.TextContent{Text: "Hello"}},
76 },
77 }
78
79 ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
80 defer cancel()
81
82 eventsChan := client.stream(ctx, messages, nil)
83
84 // Collect events - this will panic without the bounds check
85 for event := range eventsChan {
86 t.Logf("Received event: %+v", event)
87 if event.Type == EventError || event.Type == EventComplete {
88 break
89 }
90 }
91}
92
93func TestOpenAIClient429InsufficientQuotaError(t *testing.T) {
94 client := &openaiClient{
95 providerOptions: providerClientOptions{
96 modelType: config.SelectedModelTypeLarge,
97 apiKey: "test-key",
98 systemMessage: "test",
99 config: config.ProviderConfig{
100 ID: "test-openai",
101 APIKey: "test-key",
102 },
103 model: func(config.SelectedModelType) catwalk.Model {
104 return catwalk.Model{
105 ID: "test-model",
106 Name: "test-model",
107 }
108 },
109 },
110 }
111
112 // Test insufficient_quota error should not retry
113 apiErr := &openai.Error{
114 StatusCode: 429,
115 Message: "You exceeded your current quota, please check your plan and billing details. For more information on this error, read the docs: https://platform.openai.com/docs/guides/error-codes/api-errors.",
116 Type: "insufficient_quota",
117 Code: "insufficient_quota",
118 }
119
120 retry, _, err := client.shouldRetry(1, apiErr)
121 if retry {
122 t.Error("Expected shouldRetry to return false for insufficient_quota error, but got true")
123 }
124 if err == nil {
125 t.Error("Expected shouldRetry to return an error for insufficient_quota, but got nil")
126 }
127 if err != nil && !strings.Contains(err.Error(), "quota") {
128 t.Errorf("Expected error message to mention quota, got: %v", err)
129 }
130}
131
132func TestOpenAIClient429RateLimitError(t *testing.T) {
133 client := &openaiClient{
134 providerOptions: providerClientOptions{
135 modelType: config.SelectedModelTypeLarge,
136 apiKey: "test-key",
137 systemMessage: "test",
138 config: config.ProviderConfig{
139 ID: "test-openai",
140 APIKey: "test-key",
141 },
142 model: func(config.SelectedModelType) catwalk.Model {
143 return catwalk.Model{
144 ID: "test-model",
145 Name: "test-model",
146 }
147 },
148 },
149 }
150
151 // Test regular rate limit error should retry
152 apiErr := &openai.Error{
153 StatusCode: 429,
154 Message: "Rate limit reached for requests",
155 Type: "rate_limit_exceeded",
156 Code: "rate_limit_exceeded",
157 }
158
159 retry, _, err := client.shouldRetry(1, apiErr)
160 if !retry {
161 t.Error("Expected shouldRetry to return true for rate_limit_exceeded error, but got false")
162 }
163 if err != nil {
164 t.Errorf("Expected shouldRetry to return nil error for rate_limit_exceeded, but got: %v", err)
165 }
166}