1package provider
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"net/http"
  7	"net/http/httptest"
  8	"os"
  9	"strings"
 10	"testing"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/message"
 16	"github.com/openai/openai-go"
 17	"github.com/openai/openai-go/option"
 18)
 19
 20func TestMain(m *testing.M) {
 21	_, err := config.Init(".", "", true)
 22	if err != nil {
 23		panic("Failed to initialize config: " + err.Error())
 24	}
 25
 26	os.Exit(m.Run())
 27}
 28
 29func TestOpenAIClientStreamChoices(t *testing.T) {
 30	// Create a mock server that returns Server-Sent Events with empty choices
 31	// This simulates the 🤡 behavior when a server returns 200 instead of 404
 32	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
 33		w.Header().Set("Content-Type", "text/event-stream")
 34		w.Header().Set("Cache-Control", "no-cache")
 35		w.Header().Set("Connection", "keep-alive")
 36		w.WriteHeader(http.StatusOK)
 37
 38		emptyChoicesChunk := map[string]any{
 39			"id":      "chat-completion-test",
 40			"object":  "chat.completion.chunk",
 41			"created": time.Now().Unix(),
 42			"model":   "test-model",
 43			"choices": []any{}, // Empty choices array that causes panic
 44		}
 45
 46		jsonData, _ := json.Marshal(emptyChoicesChunk)
 47		w.Write([]byte("data: " + string(jsonData) + "\n\n"))
 48		w.Write([]byte("data: [DONE]\n\n"))
 49	}))
 50	defer server.Close()
 51
 52	// Create OpenAI client pointing to our mock server
 53	client := &openaiClient{
 54		providerOptions: providerClientOptions{
 55			modelType:     config.SelectedModelTypeLarge,
 56			apiKey:        "test-key",
 57			systemMessage: "test",
 58			model: func(config.SelectedModelType) catwalk.Model {
 59				return catwalk.Model{
 60					ID:   "test-model",
 61					Name: "test-model",
 62				}
 63			},
 64		},
 65		client: openai.NewClient(
 66			option.WithAPIKey("test-key"),
 67			option.WithBaseURL(server.URL),
 68		),
 69	}
 70
 71	// Create test messages
 72	messages := []message.Message{
 73		{
 74			Role:  message.User,
 75			Parts: []message.ContentPart{message.TextContent{Text: "Hello"}},
 76		},
 77	}
 78
 79	ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
 80	defer cancel()
 81
 82	eventsChan := client.stream(ctx, messages, nil)
 83
 84	// Collect events - this will panic without the bounds check
 85	for event := range eventsChan {
 86		t.Logf("Received event: %+v", event)
 87		if event.Type == EventError || event.Type == EventComplete {
 88			break
 89		}
 90	}
 91}
 92
 93func TestOpenAIClient429InsufficientQuotaError(t *testing.T) {
 94	client := &openaiClient{
 95		providerOptions: providerClientOptions{
 96			modelType:     config.SelectedModelTypeLarge,
 97			apiKey:        "test-key",
 98			systemMessage: "test",
 99			config: config.ProviderConfig{
100				ID:     "test-openai",
101				APIKey: "test-key",
102			},
103			model: func(config.SelectedModelType) catwalk.Model {
104				return catwalk.Model{
105					ID:   "test-model",
106					Name: "test-model",
107				}
108			},
109		},
110	}
111
112	// Test insufficient_quota error should not retry
113	apiErr := &openai.Error{
114		StatusCode: 429,
115		Message:    "You exceeded your current quota, please check your plan and billing details. For more information on this error, read the docs: https://platform.openai.com/docs/guides/error-codes/api-errors.",
116		Type:       "insufficient_quota",
117		Code:       "insufficient_quota",
118	}
119
120	retry, _, err := client.shouldRetry(1, apiErr)
121	if retry {
122		t.Error("Expected shouldRetry to return false for insufficient_quota error, but got true")
123	}
124	if err == nil {
125		t.Error("Expected shouldRetry to return an error for insufficient_quota, but got nil")
126	}
127	if err != nil && !strings.Contains(err.Error(), "quota") {
128		t.Errorf("Expected error message to mention quota, got: %v", err)
129	}
130}
131
132func TestOpenAIClient429RateLimitError(t *testing.T) {
133	client := &openaiClient{
134		providerOptions: providerClientOptions{
135			modelType:     config.SelectedModelTypeLarge,
136			apiKey:        "test-key",
137			systemMessage: "test",
138			config: config.ProviderConfig{
139				ID:     "test-openai",
140				APIKey: "test-key",
141			},
142			model: func(config.SelectedModelType) catwalk.Model {
143				return catwalk.Model{
144					ID:   "test-model",
145					Name: "test-model",
146				}
147			},
148		},
149	}
150
151	// Test regular rate limit error should retry
152	apiErr := &openai.Error{
153		StatusCode: 429,
154		Message:    "Rate limit reached for requests",
155		Type:       "rate_limit_exceeded",
156		Code:       "rate_limit_exceeded",
157	}
158
159	retry, _, err := client.shouldRetry(1, apiErr)
160	if !retry {
161		t.Error("Expected shouldRetry to return true for rate_limit_exceeded error, but got false")
162	}
163	if err != nil {
164		t.Errorf("Expected shouldRetry to return nil error for rate_limit_exceeded, but got: %v", err)
165	}
166}