openai_test.go

 1package provider
 2
 3import (
 4	"context"
 5	"encoding/json"
 6	"net/http"
 7	"net/http/httptest"
 8	"os"
 9	"testing"
10	"time"
11
12	"github.com/charmbracelet/catwalk/pkg/catwalk"
13	"github.com/charmbracelet/crush/internal/config"
14	"github.com/charmbracelet/crush/internal/message"
15	"github.com/openai/openai-go"
16	"github.com/openai/openai-go/option"
17)
18
19func TestMain(m *testing.M) {
20	_, err := config.Init(".", "", true, os.Environ())
21	if err != nil {
22		panic("Failed to initialize config: " + err.Error())
23	}
24
25	os.Exit(m.Run())
26}
27
28func TestOpenAIClientStreamChoices(t *testing.T) {
29	// Create a mock server that returns Server-Sent Events with empty choices
30	// This simulates the 🤡 behavior when a server returns 200 instead of 404
31	server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
32		w.Header().Set("Content-Type", "text/event-stream")
33		w.Header().Set("Cache-Control", "no-cache")
34		w.Header().Set("Connection", "keep-alive")
35		w.WriteHeader(http.StatusOK)
36
37		emptyChoicesChunk := map[string]any{
38			"id":      "chat-completion-test",
39			"object":  "chat.completion.chunk",
40			"created": time.Now().Unix(),
41			"model":   "test-model",
42			"choices": []any{}, // Empty choices array that causes panic
43		}
44
45		jsonData, _ := json.Marshal(emptyChoicesChunk)
46		w.Write([]byte("data: " + string(jsonData) + "\n\n"))
47		w.Write([]byte("data: [DONE]\n\n"))
48	}))
49	defer server.Close()
50
51	// Create OpenAI client pointing to our mock server
52	client := &openaiClient{
53		providerOptions: providerClientOptions{
54			modelType:     config.SelectedModelTypeLarge,
55			apiKey:        "test-key",
56			systemMessage: "test",
57			model: func(config.SelectedModelType) catwalk.Model {
58				return catwalk.Model{
59					ID:   "test-model",
60					Name: "test-model",
61				}
62			},
63		},
64		client: openai.NewClient(
65			option.WithAPIKey("test-key"),
66			option.WithBaseURL(server.URL),
67		),
68	}
69
70	// Create test messages
71	messages := []message.Message{
72		{
73			Role:  message.User,
74			Parts: []message.ContentPart{message.TextContent{Text: "Hello"}},
75		},
76	}
77
78	ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second)
79	defer cancel()
80
81	eventsChan := client.stream(ctx, messages, nil)
82
83	// Collect events - this will panic without the bounds check
84	for event := range eventsChan {
85		t.Logf("Received event: %+v", event)
86		if event.Type == EventError || event.Type == EventComplete {
87			break
88		}
89	}
90}