1package providers
2
3import (
4 "context"
5 "encoding/base64"
6 "encoding/json"
7 "net/http"
8 "net/http/httptest"
9 "strings"
10 "testing"
11
12 "github.com/charmbracelet/crush/internal/ai"
13 "github.com/openai/openai-go/v2/packages/param"
14 "github.com/stretchr/testify/require"
15)
16
17func TestToOpenAIPrompt_SystemMessages(t *testing.T) {
18 t.Parallel()
19
20 t.Run("should forward system messages", func(t *testing.T) {
21 t.Parallel()
22
23 prompt := ai.Prompt{
24 {
25 Role: ai.MessageRoleSystem,
26 Content: []ai.MessagePart{
27 ai.TextPart{Text: "You are a helpful assistant."},
28 },
29 },
30 }
31
32 messages, warnings := toOpenAIPrompt(prompt)
33
34 require.Empty(t, warnings)
35 require.Len(t, messages, 1)
36
37 systemMsg := messages[0].OfSystem
38 require.NotNil(t, systemMsg)
39 require.Equal(t, "You are a helpful assistant.", systemMsg.Content.OfString.Value)
40 })
41
42 t.Run("should handle empty system messages", func(t *testing.T) {
43 t.Parallel()
44
45 prompt := ai.Prompt{
46 {
47 Role: ai.MessageRoleSystem,
48 Content: []ai.MessagePart{},
49 },
50 }
51
52 messages, warnings := toOpenAIPrompt(prompt)
53
54 require.Len(t, warnings, 1)
55 require.Contains(t, warnings[0].Message, "system prompt has no text parts")
56 require.Empty(t, messages)
57 })
58
59 t.Run("should join multiple system text parts", func(t *testing.T) {
60 t.Parallel()
61
62 prompt := ai.Prompt{
63 {
64 Role: ai.MessageRoleSystem,
65 Content: []ai.MessagePart{
66 ai.TextPart{Text: "You are a helpful assistant."},
67 ai.TextPart{Text: "Be concise."},
68 },
69 },
70 }
71
72 messages, warnings := toOpenAIPrompt(prompt)
73
74 require.Empty(t, warnings)
75 require.Len(t, messages, 1)
76
77 systemMsg := messages[0].OfSystem
78 require.NotNil(t, systemMsg)
79 require.Equal(t, "You are a helpful assistant.\nBe concise.", systemMsg.Content.OfString.Value)
80 })
81}
82
83func TestToOpenAIPrompt_UserMessages(t *testing.T) {
84 t.Parallel()
85
86 t.Run("should convert messages with only a text part to a string content", func(t *testing.T) {
87 t.Parallel()
88
89 prompt := ai.Prompt{
90 {
91 Role: ai.MessageRoleUser,
92 Content: []ai.MessagePart{
93 ai.TextPart{Text: "Hello"},
94 },
95 },
96 }
97
98 messages, warnings := toOpenAIPrompt(prompt)
99
100 require.Empty(t, warnings)
101 require.Len(t, messages, 1)
102
103 userMsg := messages[0].OfUser
104 require.NotNil(t, userMsg)
105 require.Equal(t, "Hello", userMsg.Content.OfString.Value)
106 })
107
108 t.Run("should convert messages with image parts", func(t *testing.T) {
109 t.Parallel()
110
111 imageData := []byte{0, 1, 2, 3}
112 prompt := ai.Prompt{
113 {
114 Role: ai.MessageRoleUser,
115 Content: []ai.MessagePart{
116 ai.TextPart{Text: "Hello"},
117 ai.FilePart{
118 MediaType: "image/png",
119 Data: imageData,
120 },
121 },
122 },
123 }
124
125 messages, warnings := toOpenAIPrompt(prompt)
126
127 require.Empty(t, warnings)
128 require.Len(t, messages, 1)
129
130 userMsg := messages[0].OfUser
131 require.NotNil(t, userMsg)
132
133 content := userMsg.Content.OfArrayOfContentParts
134 require.Len(t, content, 2)
135
136 // Check text part
137 textPart := content[0].OfText
138 require.NotNil(t, textPart)
139 require.Equal(t, "Hello", textPart.Text)
140
141 // Check image part
142 imagePart := content[1].OfImageURL
143 require.NotNil(t, imagePart)
144 expectedURL := "data:image/png;base64," + base64.StdEncoding.EncodeToString(imageData)
145 require.Equal(t, expectedURL, imagePart.ImageURL.URL)
146 })
147
148 t.Run("should add image detail when specified through provider options", func(t *testing.T) {
149 t.Parallel()
150
151 imageData := []byte{0, 1, 2, 3}
152 prompt := ai.Prompt{
153 {
154 Role: ai.MessageRoleUser,
155 Content: []ai.MessagePart{
156 ai.FilePart{
157 MediaType: "image/png",
158 Data: imageData,
159 ProviderOptions: ai.ProviderOptions{
160 "openai": map[string]any{
161 "imageDetail": "low",
162 },
163 },
164 },
165 },
166 },
167 }
168
169 messages, warnings := toOpenAIPrompt(prompt)
170
171 require.Empty(t, warnings)
172 require.Len(t, messages, 1)
173
174 userMsg := messages[0].OfUser
175 require.NotNil(t, userMsg)
176
177 content := userMsg.Content.OfArrayOfContentParts
178 require.Len(t, content, 1)
179
180 imagePart := content[0].OfImageURL
181 require.NotNil(t, imagePart)
182 require.Equal(t, "low", imagePart.ImageURL.Detail)
183 })
184}
185
186func TestToOpenAIPrompt_FileParts(t *testing.T) {
187 t.Parallel()
188
189 t.Run("should throw for unsupported mime types", func(t *testing.T) {
190 t.Parallel()
191
192 prompt := ai.Prompt{
193 {
194 Role: ai.MessageRoleUser,
195 Content: []ai.MessagePart{
196 ai.FilePart{
197 MediaType: "application/something",
198 Data: []byte("test"),
199 },
200 },
201 },
202 }
203
204 messages, warnings := toOpenAIPrompt(prompt)
205
206 require.Len(t, warnings, 1)
207 require.Contains(t, warnings[0].Message, "file part media type application/something not supported")
208 require.Len(t, messages, 1) // Message is still created but with empty content array
209 })
210
211 t.Run("should add audio content for audio/wav file parts", func(t *testing.T) {
212 t.Parallel()
213
214 audioData := []byte{0, 1, 2, 3}
215 prompt := ai.Prompt{
216 {
217 Role: ai.MessageRoleUser,
218 Content: []ai.MessagePart{
219 ai.FilePart{
220 MediaType: "audio/wav",
221 Data: audioData,
222 },
223 },
224 },
225 }
226
227 messages, warnings := toOpenAIPrompt(prompt)
228
229 require.Empty(t, warnings)
230 require.Len(t, messages, 1)
231
232 userMsg := messages[0].OfUser
233 require.NotNil(t, userMsg)
234
235 content := userMsg.Content.OfArrayOfContentParts
236 require.Len(t, content, 1)
237
238 audioPart := content[0].OfInputAudio
239 require.NotNil(t, audioPart)
240 require.Equal(t, base64.StdEncoding.EncodeToString(audioData), audioPart.InputAudio.Data)
241 require.Equal(t, "wav", audioPart.InputAudio.Format)
242 })
243
244 t.Run("should add audio content for audio/mpeg file parts", func(t *testing.T) {
245 t.Parallel()
246
247 audioData := []byte{0, 1, 2, 3}
248 prompt := ai.Prompt{
249 {
250 Role: ai.MessageRoleUser,
251 Content: []ai.MessagePart{
252 ai.FilePart{
253 MediaType: "audio/mpeg",
254 Data: audioData,
255 },
256 },
257 },
258 }
259
260 messages, warnings := toOpenAIPrompt(prompt)
261
262 require.Empty(t, warnings)
263 require.Len(t, messages, 1)
264
265 userMsg := messages[0].OfUser
266 content := userMsg.Content.OfArrayOfContentParts
267 audioPart := content[0].OfInputAudio
268 require.NotNil(t, audioPart)
269 require.Equal(t, "mp3", audioPart.InputAudio.Format)
270 })
271
272 t.Run("should add audio content for audio/mp3 file parts", func(t *testing.T) {
273 t.Parallel()
274
275 audioData := []byte{0, 1, 2, 3}
276 prompt := ai.Prompt{
277 {
278 Role: ai.MessageRoleUser,
279 Content: []ai.MessagePart{
280 ai.FilePart{
281 MediaType: "audio/mp3",
282 Data: audioData,
283 },
284 },
285 },
286 }
287
288 messages, warnings := toOpenAIPrompt(prompt)
289
290 require.Empty(t, warnings)
291 require.Len(t, messages, 1)
292
293 userMsg := messages[0].OfUser
294 content := userMsg.Content.OfArrayOfContentParts
295 audioPart := content[0].OfInputAudio
296 require.NotNil(t, audioPart)
297 require.Equal(t, "mp3", audioPart.InputAudio.Format)
298 })
299
300 t.Run("should convert messages with PDF file parts", func(t *testing.T) {
301 t.Parallel()
302
303 pdfData := []byte{1, 2, 3, 4, 5}
304 prompt := ai.Prompt{
305 {
306 Role: ai.MessageRoleUser,
307 Content: []ai.MessagePart{
308 ai.FilePart{
309 MediaType: "application/pdf",
310 Data: pdfData,
311 Filename: "document.pdf",
312 },
313 },
314 },
315 }
316
317 messages, warnings := toOpenAIPrompt(prompt)
318
319 require.Empty(t, warnings)
320 require.Len(t, messages, 1)
321
322 userMsg := messages[0].OfUser
323 content := userMsg.Content.OfArrayOfContentParts
324 require.Len(t, content, 1)
325
326 filePart := content[0].OfFile
327 require.NotNil(t, filePart)
328 require.Equal(t, "document.pdf", filePart.File.Filename.Value)
329
330 expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
331 require.Equal(t, expectedData, filePart.File.FileData.Value)
332 })
333
334 t.Run("should convert messages with binary PDF file parts", func(t *testing.T) {
335 t.Parallel()
336
337 pdfData := []byte{1, 2, 3, 4, 5}
338 prompt := ai.Prompt{
339 {
340 Role: ai.MessageRoleUser,
341 Content: []ai.MessagePart{
342 ai.FilePart{
343 MediaType: "application/pdf",
344 Data: pdfData,
345 Filename: "document.pdf",
346 },
347 },
348 },
349 }
350
351 messages, warnings := toOpenAIPrompt(prompt)
352
353 require.Empty(t, warnings)
354 require.Len(t, messages, 1)
355
356 userMsg := messages[0].OfUser
357 content := userMsg.Content.OfArrayOfContentParts
358 filePart := content[0].OfFile
359 require.NotNil(t, filePart)
360
361 expectedData := "data:application/pdf;base64," + base64.StdEncoding.EncodeToString(pdfData)
362 require.Equal(t, expectedData, filePart.File.FileData.Value)
363 })
364
365 t.Run("should convert messages with PDF file parts using file_id", func(t *testing.T) {
366 t.Parallel()
367
368 prompt := ai.Prompt{
369 {
370 Role: ai.MessageRoleUser,
371 Content: []ai.MessagePart{
372 ai.FilePart{
373 MediaType: "application/pdf",
374 Data: []byte("file-pdf-12345"),
375 },
376 },
377 },
378 }
379
380 messages, warnings := toOpenAIPrompt(prompt)
381
382 require.Empty(t, warnings)
383 require.Len(t, messages, 1)
384
385 userMsg := messages[0].OfUser
386 content := userMsg.Content.OfArrayOfContentParts
387 filePart := content[0].OfFile
388 require.NotNil(t, filePart)
389 require.Equal(t, "file-pdf-12345", filePart.File.FileID.Value)
390 require.True(t, param.IsOmitted(filePart.File.FileData))
391 require.True(t, param.IsOmitted(filePart.File.Filename))
392 })
393
394 t.Run("should use default filename for PDF file parts when not provided", func(t *testing.T) {
395 t.Parallel()
396
397 pdfData := []byte{1, 2, 3, 4, 5}
398 prompt := ai.Prompt{
399 {
400 Role: ai.MessageRoleUser,
401 Content: []ai.MessagePart{
402 ai.FilePart{
403 MediaType: "application/pdf",
404 Data: pdfData,
405 },
406 },
407 },
408 }
409
410 messages, warnings := toOpenAIPrompt(prompt)
411
412 require.Empty(t, warnings)
413 require.Len(t, messages, 1)
414
415 userMsg := messages[0].OfUser
416 content := userMsg.Content.OfArrayOfContentParts
417 filePart := content[0].OfFile
418 require.NotNil(t, filePart)
419 require.Equal(t, "part-0.pdf", filePart.File.Filename.Value)
420 })
421}
422
423func TestToOpenAIPrompt_ToolCalls(t *testing.T) {
424 t.Parallel()
425
426 t.Run("should stringify arguments to tool calls", func(t *testing.T) {
427 t.Parallel()
428
429 inputArgs := map[string]any{"foo": "bar123"}
430 inputJSON, _ := json.Marshal(inputArgs)
431
432 outputResult := map[string]any{"oof": "321rab"}
433 outputJSON, _ := json.Marshal(outputResult)
434
435 prompt := ai.Prompt{
436 {
437 Role: ai.MessageRoleAssistant,
438 Content: []ai.MessagePart{
439 ai.ToolCallPart{
440 ToolCallID: "quux",
441 ToolName: "thwomp",
442 Input: string(inputJSON),
443 },
444 },
445 },
446 {
447 Role: ai.MessageRoleTool,
448 Content: []ai.MessagePart{
449 ai.ToolResultPart{
450 ToolCallID: "quux",
451 Output: ai.ToolResultOutputContentText{
452 Text: string(outputJSON),
453 },
454 },
455 },
456 },
457 }
458
459 messages, warnings := toOpenAIPrompt(prompt)
460
461 require.Empty(t, warnings)
462 require.Len(t, messages, 2)
463
464 // Check assistant message with tool call
465 assistantMsg := messages[0].OfAssistant
466 require.NotNil(t, assistantMsg)
467 require.Equal(t, "", assistantMsg.Content.OfString.Value)
468 require.Len(t, assistantMsg.ToolCalls, 1)
469
470 toolCall := assistantMsg.ToolCalls[0].OfFunction
471 require.NotNil(t, toolCall)
472 require.Equal(t, "quux", toolCall.ID)
473 require.Equal(t, "thwomp", toolCall.Function.Name)
474 require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
475
476 // Check tool message
477 toolMsg := messages[1].OfTool
478 require.NotNil(t, toolMsg)
479 require.Equal(t, string(outputJSON), toolMsg.Content.OfString.Value)
480 require.Equal(t, "quux", toolMsg.ToolCallID)
481 })
482
483 t.Run("should handle different tool output types", func(t *testing.T) {
484 t.Parallel()
485
486 prompt := ai.Prompt{
487 {
488 Role: ai.MessageRoleTool,
489 Content: []ai.MessagePart{
490 ai.ToolResultPart{
491 ToolCallID: "text-tool",
492 Output: ai.ToolResultOutputContentText{
493 Text: "Hello world",
494 },
495 },
496 ai.ToolResultPart{
497 ToolCallID: "error-tool",
498 Output: ai.ToolResultOutputContentError{
499 Error: "Something went wrong",
500 },
501 },
502 },
503 },
504 }
505
506 messages, warnings := toOpenAIPrompt(prompt)
507
508 require.Empty(t, warnings)
509 require.Len(t, messages, 2)
510
511 // Check first tool message (text)
512 textToolMsg := messages[0].OfTool
513 require.NotNil(t, textToolMsg)
514 require.Equal(t, "Hello world", textToolMsg.Content.OfString.Value)
515 require.Equal(t, "text-tool", textToolMsg.ToolCallID)
516
517 // Check second tool message (error)
518 errorToolMsg := messages[1].OfTool
519 require.NotNil(t, errorToolMsg)
520 require.Equal(t, "Something went wrong", errorToolMsg.Content.OfString.Value)
521 require.Equal(t, "error-tool", errorToolMsg.ToolCallID)
522 })
523}
524
525func TestToOpenAIPrompt_AssistantMessages(t *testing.T) {
526 t.Parallel()
527
528 t.Run("should handle simple text assistant messages", func(t *testing.T) {
529 t.Parallel()
530
531 prompt := ai.Prompt{
532 {
533 Role: ai.MessageRoleAssistant,
534 Content: []ai.MessagePart{
535 ai.TextPart{Text: "Hello, how can I help you?"},
536 },
537 },
538 }
539
540 messages, warnings := toOpenAIPrompt(prompt)
541
542 require.Empty(t, warnings)
543 require.Len(t, messages, 1)
544
545 assistantMsg := messages[0].OfAssistant
546 require.NotNil(t, assistantMsg)
547 require.Equal(t, "Hello, how can I help you?", assistantMsg.Content.OfString.Value)
548 })
549
550 t.Run("should handle assistant messages with mixed content", func(t *testing.T) {
551 t.Parallel()
552
553 inputArgs := map[string]any{"query": "test"}
554 inputJSON, _ := json.Marshal(inputArgs)
555
556 prompt := ai.Prompt{
557 {
558 Role: ai.MessageRoleAssistant,
559 Content: []ai.MessagePart{
560 ai.TextPart{Text: "Let me search for that."},
561 ai.ToolCallPart{
562 ToolCallID: "call-123",
563 ToolName: "search",
564 Input: string(inputJSON),
565 },
566 },
567 },
568 }
569
570 messages, warnings := toOpenAIPrompt(prompt)
571
572 require.Empty(t, warnings)
573 require.Len(t, messages, 1)
574
575 assistantMsg := messages[0].OfAssistant
576 require.NotNil(t, assistantMsg)
577 require.Equal(t, "Let me search for that.", assistantMsg.Content.OfString.Value)
578 require.Len(t, assistantMsg.ToolCalls, 1)
579
580 toolCall := assistantMsg.ToolCalls[0].OfFunction
581 require.Equal(t, "call-123", toolCall.ID)
582 require.Equal(t, "search", toolCall.Function.Name)
583 require.Equal(t, string(inputJSON), toolCall.Function.Arguments)
584 })
585}
586
587var testPrompt = ai.Prompt{
588 {
589 Role: ai.MessageRoleUser,
590 Content: []ai.MessagePart{
591 ai.TextPart{Text: "Hello"},
592 },
593 },
594}
595
596var testLogprobs = map[string]any{
597 "content": []map[string]any{
598 {
599 "token": "Hello",
600 "logprob": -0.0009994634,
601 "top_logprobs": []map[string]any{
602 {
603 "token": "Hello",
604 "logprob": -0.0009994634,
605 },
606 },
607 },
608 {
609 "token": "!",
610 "logprob": -0.13410144,
611 "top_logprobs": []map[string]any{
612 {
613 "token": "!",
614 "logprob": -0.13410144,
615 },
616 },
617 },
618 {
619 "token": " How",
620 "logprob": -0.0009250381,
621 "top_logprobs": []map[string]any{
622 {
623 "token": " How",
624 "logprob": -0.0009250381,
625 },
626 },
627 },
628 {
629 "token": " can",
630 "logprob": -0.047709424,
631 "top_logprobs": []map[string]any{
632 {
633 "token": " can",
634 "logprob": -0.047709424,
635 },
636 },
637 },
638 {
639 "token": " I",
640 "logprob": -0.000009014684,
641 "top_logprobs": []map[string]any{
642 {
643 "token": " I",
644 "logprob": -0.000009014684,
645 },
646 },
647 },
648 {
649 "token": " assist",
650 "logprob": -0.009125131,
651 "top_logprobs": []map[string]any{
652 {
653 "token": " assist",
654 "logprob": -0.009125131,
655 },
656 },
657 },
658 {
659 "token": " you",
660 "logprob": -0.0000066306106,
661 "top_logprobs": []map[string]any{
662 {
663 "token": " you",
664 "logprob": -0.0000066306106,
665 },
666 },
667 },
668 {
669 "token": " today",
670 "logprob": -0.00011093382,
671 "top_logprobs": []map[string]any{
672 {
673 "token": " today",
674 "logprob": -0.00011093382,
675 },
676 },
677 },
678 {
679 "token": "?",
680 "logprob": -0.00004596782,
681 "top_logprobs": []map[string]any{
682 {
683 "token": "?",
684 "logprob": -0.00004596782,
685 },
686 },
687 },
688 },
689}
690
691type mockServer struct {
692 server *httptest.Server
693 response map[string]any
694 calls []mockCall
695}
696
697type mockCall struct {
698 method string
699 path string
700 headers map[string]string
701 body map[string]any
702}
703
704func newMockServer() *mockServer {
705 ms := &mockServer{
706 calls: make([]mockCall, 0),
707 }
708
709 ms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
710 // Record the call
711 call := mockCall{
712 method: r.Method,
713 path: r.URL.Path,
714 headers: make(map[string]string),
715 }
716
717 for k, v := range r.Header {
718 if len(v) > 0 {
719 call.headers[k] = v[0]
720 }
721 }
722
723 // Parse request body
724 if r.Body != nil {
725 var body map[string]any
726 json.NewDecoder(r.Body).Decode(&body)
727 call.body = body
728 }
729
730 ms.calls = append(ms.calls, call)
731
732 // Return mock response
733 w.Header().Set("Content-Type", "application/json")
734 json.NewEncoder(w).Encode(ms.response)
735 }))
736
737 return ms
738}
739
740func (ms *mockServer) close() {
741 ms.server.Close()
742}
743
744func (ms *mockServer) prepareJSONResponse(opts map[string]any) {
745 // Default values
746 response := map[string]any{
747 "id": "chatcmpl-95ZTZkhr0mHNKqerQfiwkuox3PHAd",
748 "object": "chat.completion",
749 "created": 1711115037,
750 "model": "gpt-3.5-turbo-0125",
751 "choices": []map[string]any{
752 {
753 "index": 0,
754 "message": map[string]any{
755 "role": "assistant",
756 "content": "",
757 },
758 "finish_reason": "stop",
759 },
760 },
761 "usage": map[string]any{
762 "prompt_tokens": 4,
763 "total_tokens": 34,
764 "completion_tokens": 30,
765 },
766 "system_fingerprint": "fp_3bc1b5746c",
767 }
768
769 // Override with provided options
770 for k, v := range opts {
771 switch k {
772 case "content":
773 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["content"] = v
774 case "tool_calls":
775 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["tool_calls"] = v
776 case "function_call":
777 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["function_call"] = v
778 case "annotations":
779 response["choices"].([]map[string]any)[0]["message"].(map[string]any)["annotations"] = v
780 case "usage":
781 response["usage"] = v
782 case "finish_reason":
783 response["choices"].([]map[string]any)[0]["finish_reason"] = v
784 case "id":
785 response["id"] = v
786 case "created":
787 response["created"] = v
788 case "model":
789 response["model"] = v
790 case "logprobs":
791 if v != nil {
792 response["choices"].([]map[string]any)[0]["logprobs"] = v
793 }
794 }
795 }
796
797 ms.response = response
798}
799
800func TestDoGenerate(t *testing.T) {
801 t.Parallel()
802
803 t.Run("should extract text response", func(t *testing.T) {
804 t.Parallel()
805
806 server := newMockServer()
807 defer server.close()
808
809 server.prepareJSONResponse(map[string]any{
810 "content": "Hello, World!",
811 })
812
813 provider := NewOpenAIProvider(
814 WithOpenAIApiKey("test-api-key"),
815 WithOpenAIBaseURL(server.server.URL),
816 )
817 model := provider.LanguageModel("gpt-3.5-turbo")
818
819 result, err := model.Generate(context.Background(), ai.Call{
820 Prompt: testPrompt,
821 })
822
823 require.NoError(t, err)
824 require.Len(t, result.Content, 1)
825
826 textContent, ok := result.Content[0].(ai.TextContent)
827 require.True(t, ok)
828 require.Equal(t, "Hello, World!", textContent.Text)
829 })
830
831 t.Run("should extract usage", func(t *testing.T) {
832 t.Parallel()
833
834 server := newMockServer()
835 defer server.close()
836
837 server.prepareJSONResponse(map[string]any{
838 "usage": map[string]any{
839 "prompt_tokens": 20,
840 "total_tokens": 25,
841 "completion_tokens": 5,
842 },
843 })
844
845 provider := NewOpenAIProvider(
846 WithOpenAIApiKey("test-api-key"),
847 WithOpenAIBaseURL(server.server.URL),
848 )
849 model := provider.LanguageModel("gpt-3.5-turbo")
850
851 result, err := model.Generate(context.Background(), ai.Call{
852 Prompt: testPrompt,
853 })
854
855 require.NoError(t, err)
856 require.Equal(t, int64(20), result.Usage.InputTokens)
857 require.Equal(t, int64(5), result.Usage.OutputTokens)
858 require.Equal(t, int64(25), result.Usage.TotalTokens)
859 })
860
861 t.Run("should send request body", func(t *testing.T) {
862 t.Parallel()
863
864 server := newMockServer()
865 defer server.close()
866
867 server.prepareJSONResponse(map[string]any{})
868
869 provider := NewOpenAIProvider(
870 WithOpenAIApiKey("test-api-key"),
871 WithOpenAIBaseURL(server.server.URL),
872 )
873 model := provider.LanguageModel("gpt-3.5-turbo")
874
875 _, err := model.Generate(context.Background(), ai.Call{
876 Prompt: testPrompt,
877 })
878
879 require.NoError(t, err)
880 require.Len(t, server.calls, 1)
881
882 call := server.calls[0]
883 require.Equal(t, "POST", call.method)
884 require.Equal(t, "/chat/completions", call.path)
885 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
886
887 messages, ok := call.body["messages"].([]any)
888 require.True(t, ok)
889 require.Len(t, messages, 1)
890
891 message := messages[0].(map[string]any)
892 require.Equal(t, "user", message["role"])
893 require.Equal(t, "Hello", message["content"])
894 })
895
896 t.Run("should support partial usage", func(t *testing.T) {
897 t.Parallel()
898
899 server := newMockServer()
900 defer server.close()
901
902 server.prepareJSONResponse(map[string]any{
903 "usage": map[string]any{
904 "prompt_tokens": 20,
905 "total_tokens": 20,
906 },
907 })
908
909 provider := NewOpenAIProvider(
910 WithOpenAIApiKey("test-api-key"),
911 WithOpenAIBaseURL(server.server.URL),
912 )
913 model := provider.LanguageModel("gpt-3.5-turbo")
914
915 result, err := model.Generate(context.Background(), ai.Call{
916 Prompt: testPrompt,
917 })
918
919 require.NoError(t, err)
920 require.Equal(t, int64(20), result.Usage.InputTokens)
921 require.Equal(t, int64(0), result.Usage.OutputTokens)
922 require.Equal(t, int64(20), result.Usage.TotalTokens)
923 })
924
925 t.Run("should extract logprobs", func(t *testing.T) {
926 t.Parallel()
927
928 server := newMockServer()
929 defer server.close()
930
931 server.prepareJSONResponse(map[string]any{
932 "logprobs": testLogprobs,
933 })
934
935 provider := NewOpenAIProvider(
936 WithOpenAIApiKey("test-api-key"),
937 WithOpenAIBaseURL(server.server.URL),
938 )
939 model := provider.LanguageModel("gpt-3.5-turbo")
940
941 result, err := model.Generate(context.Background(), ai.Call{
942 Prompt: testPrompt,
943 ProviderOptions: ai.ProviderOptions{
944 "openai": map[string]any{
945 "logProbs": true,
946 },
947 },
948 })
949
950 require.NoError(t, err)
951 require.NotNil(t, result.ProviderMetadata)
952
953 openaiMeta, ok := result.ProviderMetadata["openai"]
954 require.True(t, ok)
955
956 logprobs, ok := openaiMeta["logprobs"]
957 require.True(t, ok)
958 require.NotNil(t, logprobs)
959 })
960
961 t.Run("should extract finish reason", func(t *testing.T) {
962 t.Parallel()
963
964 server := newMockServer()
965 defer server.close()
966
967 server.prepareJSONResponse(map[string]any{
968 "finish_reason": "stop",
969 })
970
971 provider := NewOpenAIProvider(
972 WithOpenAIApiKey("test-api-key"),
973 WithOpenAIBaseURL(server.server.URL),
974 )
975 model := provider.LanguageModel("gpt-3.5-turbo")
976
977 result, err := model.Generate(context.Background(), ai.Call{
978 Prompt: testPrompt,
979 })
980
981 require.NoError(t, err)
982 require.Equal(t, ai.FinishReasonStop, result.FinishReason)
983 })
984
985 t.Run("should support unknown finish reason", func(t *testing.T) {
986 t.Parallel()
987
988 server := newMockServer()
989 defer server.close()
990
991 server.prepareJSONResponse(map[string]any{
992 "finish_reason": "eos",
993 })
994
995 provider := NewOpenAIProvider(
996 WithOpenAIApiKey("test-api-key"),
997 WithOpenAIBaseURL(server.server.URL),
998 )
999 model := provider.LanguageModel("gpt-3.5-turbo")
1000
1001 result, err := model.Generate(context.Background(), ai.Call{
1002 Prompt: testPrompt,
1003 })
1004
1005 require.NoError(t, err)
1006 require.Equal(t, ai.FinishReasonUnknown, result.FinishReason)
1007 })
1008
1009 t.Run("should pass the model and the messages", func(t *testing.T) {
1010 t.Parallel()
1011
1012 server := newMockServer()
1013 defer server.close()
1014
1015 server.prepareJSONResponse(map[string]any{
1016 "content": "",
1017 })
1018
1019 provider := NewOpenAIProvider(
1020 WithOpenAIApiKey("test-api-key"),
1021 WithOpenAIBaseURL(server.server.URL),
1022 )
1023 model := provider.LanguageModel("gpt-3.5-turbo")
1024
1025 _, err := model.Generate(context.Background(), ai.Call{
1026 Prompt: testPrompt,
1027 })
1028
1029 require.NoError(t, err)
1030 require.Len(t, server.calls, 1)
1031
1032 call := server.calls[0]
1033 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1034
1035 messages := call.body["messages"].([]any)
1036 require.Len(t, messages, 1)
1037
1038 message := messages[0].(map[string]any)
1039 require.Equal(t, "user", message["role"])
1040 require.Equal(t, "Hello", message["content"])
1041 })
1042
1043 t.Run("should pass settings", func(t *testing.T) {
1044 t.Parallel()
1045
1046 server := newMockServer()
1047 defer server.close()
1048
1049 server.prepareJSONResponse(map[string]any{})
1050
1051 provider := NewOpenAIProvider(
1052 WithOpenAIApiKey("test-api-key"),
1053 WithOpenAIBaseURL(server.server.URL),
1054 )
1055 model := provider.LanguageModel("gpt-3.5-turbo")
1056
1057 _, err := model.Generate(context.Background(), ai.Call{
1058 Prompt: testPrompt,
1059 ProviderOptions: ai.ProviderOptions{
1060 "openai": map[string]any{
1061 "logitBias": map[string]int64{
1062 "50256": -100,
1063 },
1064 "parallelToolCalls": false,
1065 "user": "test-user-id",
1066 },
1067 },
1068 })
1069
1070 require.NoError(t, err)
1071 require.Len(t, server.calls, 1)
1072
1073 call := server.calls[0]
1074 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1075
1076 messages := call.body["messages"].([]any)
1077 require.Len(t, messages, 1)
1078
1079 logitBias := call.body["logit_bias"].(map[string]any)
1080 require.Equal(t, float64(-100), logitBias["50256"])
1081 require.Equal(t, false, call.body["parallel_tool_calls"])
1082 require.Equal(t, "test-user-id", call.body["user"])
1083 })
1084
1085 t.Run("should pass reasoningEffort setting", func(t *testing.T) {
1086 t.Parallel()
1087
1088 server := newMockServer()
1089 defer server.close()
1090
1091 server.prepareJSONResponse(map[string]any{
1092 "content": "",
1093 })
1094
1095 provider := NewOpenAIProvider(
1096 WithOpenAIApiKey("test-api-key"),
1097 WithOpenAIBaseURL(server.server.URL),
1098 )
1099 model := provider.LanguageModel("o1-mini")
1100
1101 _, err := model.Generate(context.Background(), ai.Call{
1102 Prompt: testPrompt,
1103 ProviderOptions: ai.ProviderOptions{
1104 "openai": map[string]any{
1105 "reasoningEffort": "low",
1106 },
1107 },
1108 })
1109
1110 require.NoError(t, err)
1111 require.Len(t, server.calls, 1)
1112
1113 call := server.calls[0]
1114 require.Equal(t, "o1-mini", call.body["model"])
1115 require.Equal(t, "low", call.body["reasoning_effort"])
1116
1117 messages := call.body["messages"].([]any)
1118 require.Len(t, messages, 1)
1119
1120 message := messages[0].(map[string]any)
1121 require.Equal(t, "user", message["role"])
1122 require.Equal(t, "Hello", message["content"])
1123 })
1124
1125 t.Run("should pass textVerbosity setting", func(t *testing.T) {
1126 t.Parallel()
1127
1128 server := newMockServer()
1129 defer server.close()
1130
1131 server.prepareJSONResponse(map[string]any{
1132 "content": "",
1133 })
1134
1135 provider := NewOpenAIProvider(
1136 WithOpenAIApiKey("test-api-key"),
1137 WithOpenAIBaseURL(server.server.URL),
1138 )
1139 model := provider.LanguageModel("gpt-4o")
1140
1141 _, err := model.Generate(context.Background(), ai.Call{
1142 Prompt: testPrompt,
1143 ProviderOptions: ai.ProviderOptions{
1144 "openai": map[string]any{
1145 "textVerbosity": "low",
1146 },
1147 },
1148 })
1149
1150 require.NoError(t, err)
1151 require.Len(t, server.calls, 1)
1152
1153 call := server.calls[0]
1154 require.Equal(t, "gpt-4o", call.body["model"])
1155 require.Equal(t, "low", call.body["verbosity"])
1156
1157 messages := call.body["messages"].([]any)
1158 require.Len(t, messages, 1)
1159
1160 message := messages[0].(map[string]any)
1161 require.Equal(t, "user", message["role"])
1162 require.Equal(t, "Hello", message["content"])
1163 })
1164
1165 t.Run("should pass tools and toolChoice", func(t *testing.T) {
1166 t.Parallel()
1167
1168 server := newMockServer()
1169 defer server.close()
1170
1171 server.prepareJSONResponse(map[string]any{
1172 "content": "",
1173 })
1174
1175 provider := NewOpenAIProvider(
1176 WithOpenAIApiKey("test-api-key"),
1177 WithOpenAIBaseURL(server.server.URL),
1178 )
1179 model := provider.LanguageModel("gpt-3.5-turbo")
1180
1181 _, err := model.Generate(context.Background(), ai.Call{
1182 Prompt: testPrompt,
1183 Tools: []ai.Tool{
1184 ai.FunctionTool{
1185 Name: "test-tool",
1186 InputSchema: map[string]any{
1187 "type": "object",
1188 "properties": map[string]any{
1189 "value": map[string]any{
1190 "type": "string",
1191 },
1192 },
1193 "required": []string{"value"},
1194 "additionalProperties": false,
1195 "$schema": "http://json-schema.org/draft-07/schema#",
1196 },
1197 },
1198 },
1199 ToolChoice: &[]ai.ToolChoice{ai.ToolChoice("test-tool")}[0],
1200 })
1201
1202 require.NoError(t, err)
1203 require.Len(t, server.calls, 1)
1204
1205 call := server.calls[0]
1206 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1207
1208 messages := call.body["messages"].([]any)
1209 require.Len(t, messages, 1)
1210
1211 tools := call.body["tools"].([]any)
1212 require.Len(t, tools, 1)
1213
1214 tool := tools[0].(map[string]any)
1215 require.Equal(t, "function", tool["type"])
1216
1217 function := tool["function"].(map[string]any)
1218 require.Equal(t, "test-tool", function["name"])
1219 require.Equal(t, false, function["strict"])
1220
1221 toolChoice := call.body["tool_choice"].(map[string]any)
1222 require.Equal(t, "function", toolChoice["type"])
1223
1224 toolChoiceFunction := toolChoice["function"].(map[string]any)
1225 require.Equal(t, "test-tool", toolChoiceFunction["name"])
1226 })
1227
1228 t.Run("should parse tool results", func(t *testing.T) {
1229 t.Parallel()
1230
1231 server := newMockServer()
1232 defer server.close()
1233
1234 server.prepareJSONResponse(map[string]any{
1235 "tool_calls": []map[string]any{
1236 {
1237 "id": "call_O17Uplv4lJvD6DVdIvFFeRMw",
1238 "type": "function",
1239 "function": map[string]any{
1240 "name": "test-tool",
1241 "arguments": `{"value":"Spark"}`,
1242 },
1243 },
1244 },
1245 })
1246
1247 provider := NewOpenAIProvider(
1248 WithOpenAIApiKey("test-api-key"),
1249 WithOpenAIBaseURL(server.server.URL),
1250 )
1251 model := provider.LanguageModel("gpt-3.5-turbo")
1252
1253 result, err := model.Generate(context.Background(), ai.Call{
1254 Prompt: testPrompt,
1255 Tools: []ai.Tool{
1256 ai.FunctionTool{
1257 Name: "test-tool",
1258 InputSchema: map[string]any{
1259 "type": "object",
1260 "properties": map[string]any{
1261 "value": map[string]any{
1262 "type": "string",
1263 },
1264 },
1265 "required": []string{"value"},
1266 "additionalProperties": false,
1267 "$schema": "http://json-schema.org/draft-07/schema#",
1268 },
1269 },
1270 },
1271 ToolChoice: &[]ai.ToolChoice{ai.ToolChoice("test-tool")}[0],
1272 })
1273
1274 require.NoError(t, err)
1275 require.Len(t, result.Content, 1)
1276
1277 toolCall, ok := result.Content[0].(ai.ToolCallContent)
1278 require.True(t, ok)
1279 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", toolCall.ToolCallID)
1280 require.Equal(t, "test-tool", toolCall.ToolName)
1281 require.Equal(t, `{"value":"Spark"}`, toolCall.Input)
1282 })
1283
1284 t.Run("should parse annotations/citations", func(t *testing.T) {
1285 t.Parallel()
1286
1287 server := newMockServer()
1288 defer server.close()
1289
1290 server.prepareJSONResponse(map[string]any{
1291 "content": "Based on the search results [doc1], I found information.",
1292 "annotations": []map[string]any{
1293 {
1294 "type": "url_citation",
1295 "url_citation": map[string]any{
1296 "start_index": 24,
1297 "end_index": 29,
1298 "url": "https://example.com/doc1.pdf",
1299 "title": "Document 1",
1300 },
1301 },
1302 },
1303 })
1304
1305 provider := NewOpenAIProvider(
1306 WithOpenAIApiKey("test-api-key"),
1307 WithOpenAIBaseURL(server.server.URL),
1308 )
1309 model := provider.LanguageModel("gpt-3.5-turbo")
1310
1311 result, err := model.Generate(context.Background(), ai.Call{
1312 Prompt: testPrompt,
1313 })
1314
1315 require.NoError(t, err)
1316 require.Len(t, result.Content, 2)
1317
1318 textContent, ok := result.Content[0].(ai.TextContent)
1319 require.True(t, ok)
1320 require.Equal(t, "Based on the search results [doc1], I found information.", textContent.Text)
1321
1322 sourceContent, ok := result.Content[1].(ai.SourceContent)
1323 require.True(t, ok)
1324 require.Equal(t, ai.SourceTypeURL, sourceContent.SourceType)
1325 require.Equal(t, "https://example.com/doc1.pdf", sourceContent.URL)
1326 require.Equal(t, "Document 1", sourceContent.Title)
1327 require.NotEmpty(t, sourceContent.ID)
1328 })
1329
1330 t.Run("should return cached_tokens in prompt_details_tokens", func(t *testing.T) {
1331 t.Parallel()
1332
1333 server := newMockServer()
1334 defer server.close()
1335
1336 server.prepareJSONResponse(map[string]any{
1337 "usage": map[string]any{
1338 "prompt_tokens": 15,
1339 "completion_tokens": 20,
1340 "total_tokens": 35,
1341 "prompt_tokens_details": map[string]any{
1342 "cached_tokens": 1152,
1343 },
1344 },
1345 })
1346
1347 provider := NewOpenAIProvider(
1348 WithOpenAIApiKey("test-api-key"),
1349 WithOpenAIBaseURL(server.server.URL),
1350 )
1351 model := provider.LanguageModel("gpt-4o-mini")
1352
1353 result, err := model.Generate(context.Background(), ai.Call{
1354 Prompt: testPrompt,
1355 })
1356
1357 require.NoError(t, err)
1358 require.Equal(t, int64(1152), result.Usage.CacheReadTokens)
1359 require.Equal(t, int64(15), result.Usage.InputTokens)
1360 require.Equal(t, int64(20), result.Usage.OutputTokens)
1361 require.Equal(t, int64(35), result.Usage.TotalTokens)
1362 })
1363
1364 t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
1365 t.Parallel()
1366
1367 server := newMockServer()
1368 defer server.close()
1369
1370 server.prepareJSONResponse(map[string]any{
1371 "usage": map[string]any{
1372 "prompt_tokens": 15,
1373 "completion_tokens": 20,
1374 "total_tokens": 35,
1375 "completion_tokens_details": map[string]any{
1376 "accepted_prediction_tokens": 123,
1377 "rejected_prediction_tokens": 456,
1378 },
1379 },
1380 })
1381
1382 provider := NewOpenAIProvider(
1383 WithOpenAIApiKey("test-api-key"),
1384 WithOpenAIBaseURL(server.server.URL),
1385 )
1386 model := provider.LanguageModel("gpt-4o-mini")
1387
1388 result, err := model.Generate(context.Background(), ai.Call{
1389 Prompt: testPrompt,
1390 })
1391
1392 require.NoError(t, err)
1393 require.NotNil(t, result.ProviderMetadata)
1394
1395 openaiMeta, ok := result.ProviderMetadata["openai"]
1396 require.True(t, ok)
1397 require.Equal(t, int64(123), openaiMeta["acceptedPredictionTokens"])
1398 require.Equal(t, int64(456), openaiMeta["rejectedPredictionTokens"])
1399 })
1400
1401 t.Run("should clear out temperature, top_p, frequency_penalty, presence_penalty for reasoning models", func(t *testing.T) {
1402 t.Parallel()
1403
1404 server := newMockServer()
1405 defer server.close()
1406
1407 server.prepareJSONResponse(map[string]any{})
1408
1409 provider := NewOpenAIProvider(
1410 WithOpenAIApiKey("test-api-key"),
1411 WithOpenAIBaseURL(server.server.URL),
1412 )
1413 model := provider.LanguageModel("o1-preview")
1414
1415 result, err := model.Generate(context.Background(), ai.Call{
1416 Prompt: testPrompt,
1417 Temperature: &[]float64{0.5}[0],
1418 TopP: &[]float64{0.7}[0],
1419 FrequencyPenalty: &[]float64{0.2}[0],
1420 PresencePenalty: &[]float64{0.3}[0],
1421 })
1422
1423 require.NoError(t, err)
1424 require.Len(t, server.calls, 1)
1425
1426 call := server.calls[0]
1427 require.Equal(t, "o1-preview", call.body["model"])
1428
1429 messages := call.body["messages"].([]any)
1430 require.Len(t, messages, 1)
1431
1432 message := messages[0].(map[string]any)
1433 require.Equal(t, "user", message["role"])
1434 require.Equal(t, "Hello", message["content"])
1435
1436 // These should not be present
1437 require.Nil(t, call.body["temperature"])
1438 require.Nil(t, call.body["top_p"])
1439 require.Nil(t, call.body["frequency_penalty"])
1440 require.Nil(t, call.body["presence_penalty"])
1441
1442 // Should have warnings
1443 require.Len(t, result.Warnings, 4)
1444 require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1445 require.Equal(t, "temperature", result.Warnings[0].Setting)
1446 require.Contains(t, result.Warnings[0].Details, "temperature is not supported for reasoning models")
1447 })
1448
1449 t.Run("should convert maxOutputTokens to max_completion_tokens for reasoning models", func(t *testing.T) {
1450 t.Parallel()
1451
1452 server := newMockServer()
1453 defer server.close()
1454
1455 server.prepareJSONResponse(map[string]any{})
1456
1457 provider := NewOpenAIProvider(
1458 WithOpenAIApiKey("test-api-key"),
1459 WithOpenAIBaseURL(server.server.URL),
1460 )
1461 model := provider.LanguageModel("o1-preview")
1462
1463 _, err := model.Generate(context.Background(), ai.Call{
1464 Prompt: testPrompt,
1465 MaxOutputTokens: &[]int64{1000}[0],
1466 })
1467
1468 require.NoError(t, err)
1469 require.Len(t, server.calls, 1)
1470
1471 call := server.calls[0]
1472 require.Equal(t, "o1-preview", call.body["model"])
1473 require.Equal(t, float64(1000), call.body["max_completion_tokens"])
1474 require.Nil(t, call.body["max_tokens"])
1475
1476 messages := call.body["messages"].([]any)
1477 require.Len(t, messages, 1)
1478
1479 message := messages[0].(map[string]any)
1480 require.Equal(t, "user", message["role"])
1481 require.Equal(t, "Hello", message["content"])
1482 })
1483
1484 t.Run("should return reasoning tokens", func(t *testing.T) {
1485 t.Parallel()
1486
1487 server := newMockServer()
1488 defer server.close()
1489
1490 server.prepareJSONResponse(map[string]any{
1491 "usage": map[string]any{
1492 "prompt_tokens": 15,
1493 "completion_tokens": 20,
1494 "total_tokens": 35,
1495 "completion_tokens_details": map[string]any{
1496 "reasoning_tokens": 10,
1497 },
1498 },
1499 })
1500
1501 provider := NewOpenAIProvider(
1502 WithOpenAIApiKey("test-api-key"),
1503 WithOpenAIBaseURL(server.server.URL),
1504 )
1505 model := provider.LanguageModel("o1-preview")
1506
1507 result, err := model.Generate(context.Background(), ai.Call{
1508 Prompt: testPrompt,
1509 })
1510
1511 require.NoError(t, err)
1512 require.Equal(t, int64(15), result.Usage.InputTokens)
1513 require.Equal(t, int64(20), result.Usage.OutputTokens)
1514 require.Equal(t, int64(35), result.Usage.TotalTokens)
1515 require.Equal(t, int64(10), result.Usage.ReasoningTokens)
1516 })
1517
1518 t.Run("should send max_completion_tokens extension setting", func(t *testing.T) {
1519 t.Parallel()
1520
1521 server := newMockServer()
1522 defer server.close()
1523
1524 server.prepareJSONResponse(map[string]any{
1525 "model": "o1-preview",
1526 })
1527
1528 provider := NewOpenAIProvider(
1529 WithOpenAIApiKey("test-api-key"),
1530 WithOpenAIBaseURL(server.server.URL),
1531 )
1532 model := provider.LanguageModel("o1-preview")
1533
1534 _, err := model.Generate(context.Background(), ai.Call{
1535 Prompt: testPrompt,
1536 ProviderOptions: ai.ProviderOptions{
1537 "openai": map[string]any{
1538 "maxCompletionTokens": 255,
1539 },
1540 },
1541 })
1542
1543 require.NoError(t, err)
1544 require.Len(t, server.calls, 1)
1545
1546 call := server.calls[0]
1547 require.Equal(t, "o1-preview", call.body["model"])
1548 require.Equal(t, float64(255), call.body["max_completion_tokens"])
1549
1550 messages := call.body["messages"].([]any)
1551 require.Len(t, messages, 1)
1552
1553 message := messages[0].(map[string]any)
1554 require.Equal(t, "user", message["role"])
1555 require.Equal(t, "Hello", message["content"])
1556 })
1557
1558 t.Run("should send prediction extension setting", func(t *testing.T) {
1559 t.Parallel()
1560
1561 server := newMockServer()
1562 defer server.close()
1563
1564 server.prepareJSONResponse(map[string]any{
1565 "content": "",
1566 })
1567
1568 provider := NewOpenAIProvider(
1569 WithOpenAIApiKey("test-api-key"),
1570 WithOpenAIBaseURL(server.server.URL),
1571 )
1572 model := provider.LanguageModel("gpt-3.5-turbo")
1573
1574 _, err := model.Generate(context.Background(), ai.Call{
1575 Prompt: testPrompt,
1576 ProviderOptions: ai.ProviderOptions{
1577 "openai": map[string]any{
1578 "prediction": map[string]any{
1579 "type": "content",
1580 "content": "Hello, World!",
1581 },
1582 },
1583 },
1584 })
1585
1586 require.NoError(t, err)
1587 require.Len(t, server.calls, 1)
1588
1589 call := server.calls[0]
1590 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1591
1592 prediction := call.body["prediction"].(map[string]any)
1593 require.Equal(t, "content", prediction["type"])
1594 require.Equal(t, "Hello, World!", prediction["content"])
1595
1596 messages := call.body["messages"].([]any)
1597 require.Len(t, messages, 1)
1598
1599 message := messages[0].(map[string]any)
1600 require.Equal(t, "user", message["role"])
1601 require.Equal(t, "Hello", message["content"])
1602 })
1603
1604 t.Run("should send store extension setting", func(t *testing.T) {
1605 t.Parallel()
1606
1607 server := newMockServer()
1608 defer server.close()
1609
1610 server.prepareJSONResponse(map[string]any{
1611 "content": "",
1612 })
1613
1614 provider := NewOpenAIProvider(
1615 WithOpenAIApiKey("test-api-key"),
1616 WithOpenAIBaseURL(server.server.URL),
1617 )
1618 model := provider.LanguageModel("gpt-3.5-turbo")
1619
1620 _, err := model.Generate(context.Background(), ai.Call{
1621 Prompt: testPrompt,
1622 ProviderOptions: ai.ProviderOptions{
1623 "openai": map[string]any{
1624 "store": true,
1625 },
1626 },
1627 })
1628
1629 require.NoError(t, err)
1630 require.Len(t, server.calls, 1)
1631
1632 call := server.calls[0]
1633 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1634 require.Equal(t, true, call.body["store"])
1635
1636 messages := call.body["messages"].([]any)
1637 require.Len(t, messages, 1)
1638
1639 message := messages[0].(map[string]any)
1640 require.Equal(t, "user", message["role"])
1641 require.Equal(t, "Hello", message["content"])
1642 })
1643
1644 t.Run("should send metadata extension values", func(t *testing.T) {
1645 t.Parallel()
1646
1647 server := newMockServer()
1648 defer server.close()
1649
1650 server.prepareJSONResponse(map[string]any{
1651 "content": "",
1652 })
1653
1654 provider := NewOpenAIProvider(
1655 WithOpenAIApiKey("test-api-key"),
1656 WithOpenAIBaseURL(server.server.URL),
1657 )
1658 model := provider.LanguageModel("gpt-3.5-turbo")
1659
1660 _, err := model.Generate(context.Background(), ai.Call{
1661 Prompt: testPrompt,
1662 ProviderOptions: ai.ProviderOptions{
1663 "openai": map[string]any{
1664 "metadata": map[string]any{
1665 "custom": "value",
1666 },
1667 },
1668 },
1669 })
1670
1671 require.NoError(t, err)
1672 require.Len(t, server.calls, 1)
1673
1674 call := server.calls[0]
1675 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1676
1677 metadata := call.body["metadata"].(map[string]any)
1678 require.Equal(t, "value", metadata["custom"])
1679
1680 messages := call.body["messages"].([]any)
1681 require.Len(t, messages, 1)
1682
1683 message := messages[0].(map[string]any)
1684 require.Equal(t, "user", message["role"])
1685 require.Equal(t, "Hello", message["content"])
1686 })
1687
1688 t.Run("should send promptCacheKey extension value", func(t *testing.T) {
1689 t.Parallel()
1690
1691 server := newMockServer()
1692 defer server.close()
1693
1694 server.prepareJSONResponse(map[string]any{
1695 "content": "",
1696 })
1697
1698 provider := NewOpenAIProvider(
1699 WithOpenAIApiKey("test-api-key"),
1700 WithOpenAIBaseURL(server.server.URL),
1701 )
1702 model := provider.LanguageModel("gpt-3.5-turbo")
1703
1704 _, err := model.Generate(context.Background(), ai.Call{
1705 Prompt: testPrompt,
1706 ProviderOptions: ai.ProviderOptions{
1707 "openai": map[string]any{
1708 "promptCacheKey": "test-cache-key-123",
1709 },
1710 },
1711 })
1712
1713 require.NoError(t, err)
1714 require.Len(t, server.calls, 1)
1715
1716 call := server.calls[0]
1717 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1718 require.Equal(t, "test-cache-key-123", call.body["prompt_cache_key"])
1719
1720 messages := call.body["messages"].([]any)
1721 require.Len(t, messages, 1)
1722
1723 message := messages[0].(map[string]any)
1724 require.Equal(t, "user", message["role"])
1725 require.Equal(t, "Hello", message["content"])
1726 })
1727
1728 t.Run("should send safetyIdentifier extension value", func(t *testing.T) {
1729 t.Parallel()
1730
1731 server := newMockServer()
1732 defer server.close()
1733
1734 server.prepareJSONResponse(map[string]any{
1735 "content": "",
1736 })
1737
1738 provider := NewOpenAIProvider(
1739 WithOpenAIApiKey("test-api-key"),
1740 WithOpenAIBaseURL(server.server.URL),
1741 )
1742 model := provider.LanguageModel("gpt-3.5-turbo")
1743
1744 _, err := model.Generate(context.Background(), ai.Call{
1745 Prompt: testPrompt,
1746 ProviderOptions: ai.ProviderOptions{
1747 "openai": map[string]any{
1748 "safetyIdentifier": "test-safety-identifier-123",
1749 },
1750 },
1751 })
1752
1753 require.NoError(t, err)
1754 require.Len(t, server.calls, 1)
1755
1756 call := server.calls[0]
1757 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
1758 require.Equal(t, "test-safety-identifier-123", call.body["safety_identifier"])
1759
1760 messages := call.body["messages"].([]any)
1761 require.Len(t, messages, 1)
1762
1763 message := messages[0].(map[string]any)
1764 require.Equal(t, "user", message["role"])
1765 require.Equal(t, "Hello", message["content"])
1766 })
1767
1768 t.Run("should remove temperature setting for search preview models", func(t *testing.T) {
1769 t.Parallel()
1770
1771 server := newMockServer()
1772 defer server.close()
1773
1774 server.prepareJSONResponse(map[string]any{})
1775
1776 provider := NewOpenAIProvider(
1777 WithOpenAIApiKey("test-api-key"),
1778 WithOpenAIBaseURL(server.server.URL),
1779 )
1780 model := provider.LanguageModel("gpt-4o-search-preview")
1781
1782 result, err := model.Generate(context.Background(), ai.Call{
1783 Prompt: testPrompt,
1784 Temperature: &[]float64{0.7}[0],
1785 })
1786
1787 require.NoError(t, err)
1788 require.Len(t, server.calls, 1)
1789
1790 call := server.calls[0]
1791 require.Equal(t, "gpt-4o-search-preview", call.body["model"])
1792 require.Nil(t, call.body["temperature"])
1793
1794 require.Len(t, result.Warnings, 1)
1795 require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1796 require.Equal(t, "temperature", result.Warnings[0].Setting)
1797 require.Contains(t, result.Warnings[0].Details, "search preview models")
1798 })
1799
1800 t.Run("should send serviceTier flex processing setting", func(t *testing.T) {
1801 t.Parallel()
1802
1803 server := newMockServer()
1804 defer server.close()
1805
1806 server.prepareJSONResponse(map[string]any{
1807 "content": "",
1808 })
1809
1810 provider := NewOpenAIProvider(
1811 WithOpenAIApiKey("test-api-key"),
1812 WithOpenAIBaseURL(server.server.URL),
1813 )
1814 model := provider.LanguageModel("o3-mini")
1815
1816 _, err := model.Generate(context.Background(), ai.Call{
1817 Prompt: testPrompt,
1818 ProviderOptions: ai.ProviderOptions{
1819 "openai": map[string]any{
1820 "serviceTier": "flex",
1821 },
1822 },
1823 })
1824
1825 require.NoError(t, err)
1826 require.Len(t, server.calls, 1)
1827
1828 call := server.calls[0]
1829 require.Equal(t, "o3-mini", call.body["model"])
1830 require.Equal(t, "flex", call.body["service_tier"])
1831
1832 messages := call.body["messages"].([]any)
1833 require.Len(t, messages, 1)
1834
1835 message := messages[0].(map[string]any)
1836 require.Equal(t, "user", message["role"])
1837 require.Equal(t, "Hello", message["content"])
1838 })
1839
1840 t.Run("should show warning when using flex processing with unsupported model", func(t *testing.T) {
1841 t.Parallel()
1842
1843 server := newMockServer()
1844 defer server.close()
1845
1846 server.prepareJSONResponse(map[string]any{})
1847
1848 provider := NewOpenAIProvider(
1849 WithOpenAIApiKey("test-api-key"),
1850 WithOpenAIBaseURL(server.server.URL),
1851 )
1852 model := provider.LanguageModel("gpt-4o-mini")
1853
1854 result, err := model.Generate(context.Background(), ai.Call{
1855 Prompt: testPrompt,
1856 ProviderOptions: ai.ProviderOptions{
1857 "openai": map[string]any{
1858 "serviceTier": "flex",
1859 },
1860 },
1861 })
1862
1863 require.NoError(t, err)
1864 require.Len(t, server.calls, 1)
1865
1866 call := server.calls[0]
1867 require.Nil(t, call.body["service_tier"])
1868
1869 require.Len(t, result.Warnings, 1)
1870 require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1871 require.Equal(t, "serviceTier", result.Warnings[0].Setting)
1872 require.Contains(t, result.Warnings[0].Details, "flex processing is only available")
1873 })
1874
1875 t.Run("should send serviceTier priority processing setting", func(t *testing.T) {
1876 t.Parallel()
1877
1878 server := newMockServer()
1879 defer server.close()
1880
1881 server.prepareJSONResponse(map[string]any{})
1882
1883 provider := NewOpenAIProvider(
1884 WithOpenAIApiKey("test-api-key"),
1885 WithOpenAIBaseURL(server.server.URL),
1886 )
1887 model := provider.LanguageModel("gpt-4o-mini")
1888
1889 _, err := model.Generate(context.Background(), ai.Call{
1890 Prompt: testPrompt,
1891 ProviderOptions: ai.ProviderOptions{
1892 "openai": map[string]any{
1893 "serviceTier": "priority",
1894 },
1895 },
1896 })
1897
1898 require.NoError(t, err)
1899 require.Len(t, server.calls, 1)
1900
1901 call := server.calls[0]
1902 require.Equal(t, "gpt-4o-mini", call.body["model"])
1903 require.Equal(t, "priority", call.body["service_tier"])
1904
1905 messages := call.body["messages"].([]any)
1906 require.Len(t, messages, 1)
1907
1908 message := messages[0].(map[string]any)
1909 require.Equal(t, "user", message["role"])
1910 require.Equal(t, "Hello", message["content"])
1911 })
1912
1913 t.Run("should show warning when using priority processing with unsupported model", func(t *testing.T) {
1914 t.Parallel()
1915
1916 server := newMockServer()
1917 defer server.close()
1918
1919 server.prepareJSONResponse(map[string]any{})
1920
1921 provider := NewOpenAIProvider(
1922 WithOpenAIApiKey("test-api-key"),
1923 WithOpenAIBaseURL(server.server.URL),
1924 )
1925 model := provider.LanguageModel("gpt-3.5-turbo")
1926
1927 result, err := model.Generate(context.Background(), ai.Call{
1928 Prompt: testPrompt,
1929 ProviderOptions: ai.ProviderOptions{
1930 "openai": map[string]any{
1931 "serviceTier": "priority",
1932 },
1933 },
1934 })
1935
1936 require.NoError(t, err)
1937 require.Len(t, server.calls, 1)
1938
1939 call := server.calls[0]
1940 require.Nil(t, call.body["service_tier"])
1941
1942 require.Len(t, result.Warnings, 1)
1943 require.Equal(t, ai.CallWarningTypeUnsupportedSetting, result.Warnings[0].Type)
1944 require.Equal(t, "serviceTier", result.Warnings[0].Setting)
1945 require.Contains(t, result.Warnings[0].Details, "priority processing is only available")
1946 })
1947}
1948
1949type streamingMockServer struct {
1950 server *httptest.Server
1951 chunks []string
1952 calls []mockCall
1953}
1954
1955func newStreamingMockServer() *streamingMockServer {
1956 sms := &streamingMockServer{
1957 calls: make([]mockCall, 0),
1958 }
1959
1960 sms.server = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
1961 // Record the call
1962 call := mockCall{
1963 method: r.Method,
1964 path: r.URL.Path,
1965 headers: make(map[string]string),
1966 }
1967
1968 for k, v := range r.Header {
1969 if len(v) > 0 {
1970 call.headers[k] = v[0]
1971 }
1972 }
1973
1974 // Parse request body
1975 if r.Body != nil {
1976 var body map[string]any
1977 json.NewDecoder(r.Body).Decode(&body)
1978 call.body = body
1979 }
1980
1981 sms.calls = append(sms.calls, call)
1982
1983 // Set streaming headers
1984 w.Header().Set("Content-Type", "text/event-stream")
1985 w.Header().Set("Cache-Control", "no-cache")
1986 w.Header().Set("Connection", "keep-alive")
1987
1988 // Add custom headers if any
1989 for _, chunk := range sms.chunks {
1990 if strings.HasPrefix(chunk, "HEADER:") {
1991 parts := strings.SplitN(chunk[7:], ":", 2)
1992 if len(parts) == 2 {
1993 w.Header().Set(parts[0], parts[1])
1994 }
1995 continue
1996 }
1997 }
1998
1999 w.WriteHeader(http.StatusOK)
2000
2001 // Write chunks
2002 for _, chunk := range sms.chunks {
2003 if strings.HasPrefix(chunk, "HEADER:") {
2004 continue
2005 }
2006 w.Write([]byte(chunk))
2007 if f, ok := w.(http.Flusher); ok {
2008 f.Flush()
2009 }
2010 }
2011 }))
2012
2013 return sms
2014}
2015
2016func (sms *streamingMockServer) close() {
2017 sms.server.Close()
2018}
2019
2020func (sms *streamingMockServer) prepareStreamResponse(opts map[string]any) {
2021 content := []string{}
2022 if c, ok := opts["content"].([]string); ok {
2023 content = c
2024 }
2025
2026 usage := map[string]any{
2027 "prompt_tokens": 17,
2028 "total_tokens": 244,
2029 "completion_tokens": 227,
2030 }
2031 if u, ok := opts["usage"].(map[string]any); ok {
2032 usage = u
2033 }
2034
2035 logprobs := map[string]any{}
2036 if l, ok := opts["logprobs"].(map[string]any); ok {
2037 logprobs = l
2038 }
2039
2040 finishReason := "stop"
2041 if fr, ok := opts["finish_reason"].(string); ok {
2042 finishReason = fr
2043 }
2044
2045 model := "gpt-3.5-turbo-0613"
2046 if m, ok := opts["model"].(string); ok {
2047 model = m
2048 }
2049
2050 headers := map[string]string{}
2051 if h, ok := opts["headers"].(map[string]string); ok {
2052 headers = h
2053 }
2054
2055 chunks := []string{}
2056
2057 // Add custom headers
2058 for k, v := range headers {
2059 chunks = append(chunks, "HEADER:"+k+":"+v)
2060 }
2061
2062 // Initial chunk with role
2063 initialChunk := map[string]any{
2064 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2065 "object": "chat.completion.chunk",
2066 "created": 1702657020,
2067 "model": model,
2068 "system_fingerprint": nil,
2069 "choices": []map[string]any{
2070 {
2071 "index": 0,
2072 "delta": map[string]any{
2073 "role": "assistant",
2074 "content": "",
2075 },
2076 "finish_reason": nil,
2077 },
2078 },
2079 }
2080 initialData, _ := json.Marshal(initialChunk)
2081 chunks = append(chunks, "data: "+string(initialData)+"\n\n")
2082
2083 // Content chunks
2084 for i, text := range content {
2085 contentChunk := map[string]any{
2086 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2087 "object": "chat.completion.chunk",
2088 "created": 1702657020,
2089 "model": model,
2090 "system_fingerprint": nil,
2091 "choices": []map[string]any{
2092 {
2093 "index": 1,
2094 "delta": map[string]any{
2095 "content": text,
2096 },
2097 "finish_reason": nil,
2098 },
2099 },
2100 }
2101 contentData, _ := json.Marshal(contentChunk)
2102 chunks = append(chunks, "data: "+string(contentData)+"\n\n")
2103
2104 // Add annotations if this is the last content chunk and we have annotations
2105 if i == len(content)-1 {
2106 if annotations, ok := opts["annotations"].([]map[string]any); ok {
2107 annotationChunk := map[string]any{
2108 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2109 "object": "chat.completion.chunk",
2110 "created": 1702657020,
2111 "model": model,
2112 "system_fingerprint": nil,
2113 "choices": []map[string]any{
2114 {
2115 "index": 1,
2116 "delta": map[string]any{
2117 "annotations": annotations,
2118 },
2119 "finish_reason": nil,
2120 },
2121 },
2122 }
2123 annotationData, _ := json.Marshal(annotationChunk)
2124 chunks = append(chunks, "data: "+string(annotationData)+"\n\n")
2125 }
2126 }
2127 }
2128
2129 // Finish chunk
2130 finishChunk := map[string]any{
2131 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2132 "object": "chat.completion.chunk",
2133 "created": 1702657020,
2134 "model": model,
2135 "system_fingerprint": nil,
2136 "choices": []map[string]any{
2137 {
2138 "index": 0,
2139 "delta": map[string]any{},
2140 "finish_reason": finishReason,
2141 },
2142 },
2143 }
2144
2145 if len(logprobs) > 0 {
2146 finishChunk["choices"].([]map[string]any)[0]["logprobs"] = logprobs
2147 }
2148
2149 finishData, _ := json.Marshal(finishChunk)
2150 chunks = append(chunks, "data: "+string(finishData)+"\n\n")
2151
2152 // Usage chunk
2153 usageChunk := map[string]any{
2154 "id": "chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP",
2155 "object": "chat.completion.chunk",
2156 "created": 1702657020,
2157 "model": model,
2158 "system_fingerprint": "fp_3bc1b5746c",
2159 "choices": []map[string]any{},
2160 "usage": usage,
2161 }
2162 usageData, _ := json.Marshal(usageChunk)
2163 chunks = append(chunks, "data: "+string(usageData)+"\n\n")
2164
2165 // Done
2166 chunks = append(chunks, "data: [DONE]\n\n")
2167
2168 sms.chunks = chunks
2169}
2170
2171func (sms *streamingMockServer) prepareToolStreamResponse() {
2172 chunks := []string{
2173 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"call_O17Uplv4lJvD6DVdIvFFeRMw","type":"function","function":{"name":"test-tool","arguments":""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2174 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2175 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"value"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2176 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\":\""}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2177 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"Spark"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2178 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"le"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2179 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" Day"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2180 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"}"}}]},"logprobs":null,"finish_reason":null}]}` + "\n\n",
2181 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}` + "\n\n",
2182 `data: {"id":"chatcmpl-96aZqmeDpA9IPD6tACY8djkMsJCMP","object":"chat.completion.chunk","created":1711357598,"model":"gpt-3.5-turbo-0125","system_fingerprint":"fp_3bc1b5746c","choices":[],"usage":{"prompt_tokens":53,"completion_tokens":17,"total_tokens":70}}` + "\n\n",
2183 "data: [DONE]\n\n",
2184 }
2185 sms.chunks = chunks
2186}
2187
2188func (sms *streamingMockServer) prepareErrorStreamResponse() {
2189 chunks := []string{
2190 `data: {"error":{"message": "The server had an error processing your request. Sorry about that! You can retry your request, or contact us through our help center at help.openai.com if you keep seeing this error.","type":"server_error","param":null,"code":null}}` + "\n\n",
2191 "data: [DONE]\n\n",
2192 }
2193 sms.chunks = chunks
2194}
2195
2196func collectStreamParts(stream ai.StreamResponse) ([]ai.StreamPart, error) {
2197 var parts []ai.StreamPart
2198 for part := range stream {
2199 parts = append(parts, part)
2200 if part.Type == ai.StreamPartTypeError {
2201 break
2202 }
2203 if part.Type == ai.StreamPartTypeFinish {
2204 break
2205 }
2206 }
2207 return parts, nil
2208}
2209
2210func TestDoStream(t *testing.T) {
2211 t.Parallel()
2212
2213 t.Run("should stream text deltas", func(t *testing.T) {
2214 t.Parallel()
2215
2216 server := newStreamingMockServer()
2217 defer server.close()
2218
2219 server.prepareStreamResponse(map[string]any{
2220 "content": []string{"Hello", ", ", "World!"},
2221 "finish_reason": "stop",
2222 "usage": map[string]any{
2223 "prompt_tokens": 17,
2224 "total_tokens": 244,
2225 "completion_tokens": 227,
2226 },
2227 "logprobs": testLogprobs,
2228 })
2229
2230 provider := NewOpenAIProvider(
2231 WithOpenAIApiKey("test-api-key"),
2232 WithOpenAIBaseURL(server.server.URL),
2233 )
2234 model := provider.LanguageModel("gpt-3.5-turbo")
2235
2236 stream, err := model.Stream(context.Background(), ai.Call{
2237 Prompt: testPrompt,
2238 })
2239
2240 require.NoError(t, err)
2241
2242 parts, err := collectStreamParts(stream)
2243 require.NoError(t, err)
2244
2245 // Verify stream structure
2246 require.True(t, len(parts) >= 4) // text-start, deltas, text-end, finish
2247
2248 // Find text parts
2249 var textStart, textEnd, finish int = -1, -1, -1
2250 var deltas []string
2251
2252 for i, part := range parts {
2253 switch part.Type {
2254 case ai.StreamPartTypeTextStart:
2255 textStart = i
2256 case ai.StreamPartTypeTextDelta:
2257 deltas = append(deltas, part.Delta)
2258 case ai.StreamPartTypeTextEnd:
2259 textEnd = i
2260 case ai.StreamPartTypeFinish:
2261 finish = i
2262 }
2263 }
2264
2265 require.NotEqual(t, -1, textStart)
2266 require.NotEqual(t, -1, textEnd)
2267 require.NotEqual(t, -1, finish)
2268 require.Equal(t, []string{"Hello", ", ", "World!"}, deltas)
2269
2270 // Check finish part
2271 finishPart := parts[finish]
2272 require.Equal(t, ai.FinishReasonStop, finishPart.FinishReason)
2273 require.Equal(t, int64(17), finishPart.Usage.InputTokens)
2274 require.Equal(t, int64(227), finishPart.Usage.OutputTokens)
2275 require.Equal(t, int64(244), finishPart.Usage.TotalTokens)
2276 })
2277
2278 t.Run("should stream tool deltas", func(t *testing.T) {
2279 t.Parallel()
2280
2281 server := newStreamingMockServer()
2282 defer server.close()
2283
2284 server.prepareToolStreamResponse()
2285
2286 provider := NewOpenAIProvider(
2287 WithOpenAIApiKey("test-api-key"),
2288 WithOpenAIBaseURL(server.server.URL),
2289 )
2290 model := provider.LanguageModel("gpt-3.5-turbo")
2291
2292 stream, err := model.Stream(context.Background(), ai.Call{
2293 Prompt: testPrompt,
2294 Tools: []ai.Tool{
2295 ai.FunctionTool{
2296 Name: "test-tool",
2297 InputSchema: map[string]any{
2298 "type": "object",
2299 "properties": map[string]any{
2300 "value": map[string]any{
2301 "type": "string",
2302 },
2303 },
2304 "required": []string{"value"},
2305 "additionalProperties": false,
2306 "$schema": "http://json-schema.org/draft-07/schema#",
2307 },
2308 },
2309 },
2310 })
2311
2312 require.NoError(t, err)
2313
2314 parts, err := collectStreamParts(stream)
2315 require.NoError(t, err)
2316
2317 // Find tool-related parts
2318 toolInputStart, toolInputEnd, toolCall := -1, -1, -1
2319 var toolDeltas []string
2320
2321 for i, part := range parts {
2322 switch part.Type {
2323 case ai.StreamPartTypeToolInputStart:
2324 toolInputStart = i
2325 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2326 require.Equal(t, "test-tool", part.ToolCallName)
2327 case ai.StreamPartTypeToolInputDelta:
2328 toolDeltas = append(toolDeltas, part.Delta)
2329 case ai.StreamPartTypeToolInputEnd:
2330 toolInputEnd = i
2331 case ai.StreamPartTypeToolCall:
2332 toolCall = i
2333 require.Equal(t, "call_O17Uplv4lJvD6DVdIvFFeRMw", part.ID)
2334 require.Equal(t, "test-tool", part.ToolCallName)
2335 require.Equal(t, `{"value":"Sparkle Day"}`, part.ToolCallInput)
2336 }
2337 }
2338
2339 require.NotEqual(t, -1, toolInputStart)
2340 require.NotEqual(t, -1, toolInputEnd)
2341 require.NotEqual(t, -1, toolCall)
2342
2343 // Verify tool deltas combine to form the complete input
2344 fullInput := ""
2345 for _, delta := range toolDeltas {
2346 fullInput += delta
2347 }
2348 require.Equal(t, `{"value":"Sparkle Day"}`, fullInput)
2349 })
2350
2351 t.Run("should stream annotations/citations", func(t *testing.T) {
2352 t.Parallel()
2353
2354 server := newStreamingMockServer()
2355 defer server.close()
2356
2357 server.prepareStreamResponse(map[string]any{
2358 "content": []string{"Based on search results"},
2359 "annotations": []map[string]any{
2360 {
2361 "type": "url_citation",
2362 "url_citation": map[string]any{
2363 "start_index": 24,
2364 "end_index": 29,
2365 "url": "https://example.com/doc1.pdf",
2366 "title": "Document 1",
2367 },
2368 },
2369 },
2370 })
2371
2372 provider := NewOpenAIProvider(
2373 WithOpenAIApiKey("test-api-key"),
2374 WithOpenAIBaseURL(server.server.URL),
2375 )
2376 model := provider.LanguageModel("gpt-3.5-turbo")
2377
2378 stream, err := model.Stream(context.Background(), ai.Call{
2379 Prompt: testPrompt,
2380 })
2381
2382 require.NoError(t, err)
2383
2384 parts, err := collectStreamParts(stream)
2385 require.NoError(t, err)
2386
2387 // Find source part
2388 var sourcePart *ai.StreamPart
2389 for _, part := range parts {
2390 if part.Type == ai.StreamPartTypeSource {
2391 sourcePart = &part
2392 break
2393 }
2394 }
2395
2396 require.NotNil(t, sourcePart)
2397 require.Equal(t, ai.SourceTypeURL, sourcePart.SourceType)
2398 require.Equal(t, "https://example.com/doc1.pdf", sourcePart.URL)
2399 require.Equal(t, "Document 1", sourcePart.Title)
2400 require.NotEmpty(t, sourcePart.ID)
2401 })
2402
2403 t.Run("should handle error stream parts", func(t *testing.T) {
2404 t.Parallel()
2405
2406 server := newStreamingMockServer()
2407 defer server.close()
2408
2409 server.prepareErrorStreamResponse()
2410
2411 provider := NewOpenAIProvider(
2412 WithOpenAIApiKey("test-api-key"),
2413 WithOpenAIBaseURL(server.server.URL),
2414 )
2415 model := provider.LanguageModel("gpt-3.5-turbo")
2416
2417 stream, err := model.Stream(context.Background(), ai.Call{
2418 Prompt: testPrompt,
2419 })
2420
2421 require.NoError(t, err)
2422
2423 parts, err := collectStreamParts(stream)
2424 require.NoError(t, err)
2425
2426 // Should have error and finish parts
2427 require.True(t, len(parts) >= 1)
2428
2429 // Find error part
2430 var errorPart *ai.StreamPart
2431 for _, part := range parts {
2432 if part.Type == ai.StreamPartTypeError {
2433 errorPart = &part
2434 break
2435 }
2436 }
2437
2438 require.NotNil(t, errorPart)
2439 require.NotNil(t, errorPart.Error)
2440 })
2441
2442 t.Run("should send request body", func(t *testing.T) {
2443 t.Parallel()
2444
2445 server := newStreamingMockServer()
2446 defer server.close()
2447
2448 server.prepareStreamResponse(map[string]any{
2449 "content": []string{},
2450 })
2451
2452 provider := NewOpenAIProvider(
2453 WithOpenAIApiKey("test-api-key"),
2454 WithOpenAIBaseURL(server.server.URL),
2455 )
2456 model := provider.LanguageModel("gpt-3.5-turbo")
2457
2458 _, err := model.Stream(context.Background(), ai.Call{
2459 Prompt: testPrompt,
2460 })
2461
2462 require.NoError(t, err)
2463 require.Len(t, server.calls, 1)
2464
2465 call := server.calls[0]
2466 require.Equal(t, "POST", call.method)
2467 require.Equal(t, "/chat/completions", call.path)
2468 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2469 require.Equal(t, true, call.body["stream"])
2470
2471 streamOptions := call.body["stream_options"].(map[string]any)
2472 require.Equal(t, true, streamOptions["include_usage"])
2473
2474 messages := call.body["messages"].([]any)
2475 require.Len(t, messages, 1)
2476
2477 message := messages[0].(map[string]any)
2478 require.Equal(t, "user", message["role"])
2479 require.Equal(t, "Hello", message["content"])
2480 })
2481
2482 t.Run("should return cached tokens in providerMetadata", func(t *testing.T) {
2483 t.Parallel()
2484
2485 server := newStreamingMockServer()
2486 defer server.close()
2487
2488 server.prepareStreamResponse(map[string]any{
2489 "content": []string{},
2490 "usage": map[string]any{
2491 "prompt_tokens": 15,
2492 "completion_tokens": 20,
2493 "total_tokens": 35,
2494 "prompt_tokens_details": map[string]any{
2495 "cached_tokens": 1152,
2496 },
2497 },
2498 })
2499
2500 provider := NewOpenAIProvider(
2501 WithOpenAIApiKey("test-api-key"),
2502 WithOpenAIBaseURL(server.server.URL),
2503 )
2504 model := provider.LanguageModel("gpt-3.5-turbo")
2505
2506 stream, err := model.Stream(context.Background(), ai.Call{
2507 Prompt: testPrompt,
2508 })
2509
2510 require.NoError(t, err)
2511
2512 parts, err := collectStreamParts(stream)
2513 require.NoError(t, err)
2514
2515 // Find finish part
2516 var finishPart *ai.StreamPart
2517 for _, part := range parts {
2518 if part.Type == ai.StreamPartTypeFinish {
2519 finishPart = &part
2520 break
2521 }
2522 }
2523
2524 require.NotNil(t, finishPart)
2525 require.Equal(t, int64(1152), finishPart.Usage.CacheReadTokens)
2526 require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2527 require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2528 require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2529 })
2530
2531 t.Run("should return accepted_prediction_tokens and rejected_prediction_tokens", func(t *testing.T) {
2532 t.Parallel()
2533
2534 server := newStreamingMockServer()
2535 defer server.close()
2536
2537 server.prepareStreamResponse(map[string]any{
2538 "content": []string{},
2539 "usage": map[string]any{
2540 "prompt_tokens": 15,
2541 "completion_tokens": 20,
2542 "total_tokens": 35,
2543 "completion_tokens_details": map[string]any{
2544 "accepted_prediction_tokens": 123,
2545 "rejected_prediction_tokens": 456,
2546 },
2547 },
2548 })
2549
2550 provider := NewOpenAIProvider(
2551 WithOpenAIApiKey("test-api-key"),
2552 WithOpenAIBaseURL(server.server.URL),
2553 )
2554 model := provider.LanguageModel("gpt-3.5-turbo")
2555
2556 stream, err := model.Stream(context.Background(), ai.Call{
2557 Prompt: testPrompt,
2558 })
2559
2560 require.NoError(t, err)
2561
2562 parts, err := collectStreamParts(stream)
2563 require.NoError(t, err)
2564
2565 // Find finish part
2566 var finishPart *ai.StreamPart
2567 for _, part := range parts {
2568 if part.Type == ai.StreamPartTypeFinish {
2569 finishPart = &part
2570 break
2571 }
2572 }
2573
2574 require.NotNil(t, finishPart)
2575 require.NotNil(t, finishPart.ProviderMetadata)
2576
2577 openaiMeta, ok := finishPart.ProviderMetadata["openai"]
2578 require.True(t, ok)
2579 require.Equal(t, int64(123), openaiMeta["acceptedPredictionTokens"])
2580 require.Equal(t, int64(456), openaiMeta["rejectedPredictionTokens"])
2581 })
2582
2583 t.Run("should send store extension setting", func(t *testing.T) {
2584 t.Parallel()
2585
2586 server := newStreamingMockServer()
2587 defer server.close()
2588
2589 server.prepareStreamResponse(map[string]any{
2590 "content": []string{},
2591 })
2592
2593 provider := NewOpenAIProvider(
2594 WithOpenAIApiKey("test-api-key"),
2595 WithOpenAIBaseURL(server.server.URL),
2596 )
2597 model := provider.LanguageModel("gpt-3.5-turbo")
2598
2599 _, err := model.Stream(context.Background(), ai.Call{
2600 Prompt: testPrompt,
2601 ProviderOptions: ai.ProviderOptions{
2602 "openai": map[string]any{
2603 "store": true,
2604 },
2605 },
2606 })
2607
2608 require.NoError(t, err)
2609 require.Len(t, server.calls, 1)
2610
2611 call := server.calls[0]
2612 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2613 require.Equal(t, true, call.body["stream"])
2614 require.Equal(t, true, call.body["store"])
2615
2616 streamOptions := call.body["stream_options"].(map[string]any)
2617 require.Equal(t, true, streamOptions["include_usage"])
2618
2619 messages := call.body["messages"].([]any)
2620 require.Len(t, messages, 1)
2621
2622 message := messages[0].(map[string]any)
2623 require.Equal(t, "user", message["role"])
2624 require.Equal(t, "Hello", message["content"])
2625 })
2626
2627 t.Run("should send metadata extension values", func(t *testing.T) {
2628 t.Parallel()
2629
2630 server := newStreamingMockServer()
2631 defer server.close()
2632
2633 server.prepareStreamResponse(map[string]any{
2634 "content": []string{},
2635 })
2636
2637 provider := NewOpenAIProvider(
2638 WithOpenAIApiKey("test-api-key"),
2639 WithOpenAIBaseURL(server.server.URL),
2640 )
2641 model := provider.LanguageModel("gpt-3.5-turbo")
2642
2643 _, err := model.Stream(context.Background(), ai.Call{
2644 Prompt: testPrompt,
2645 ProviderOptions: ai.ProviderOptions{
2646 "openai": map[string]any{
2647 "metadata": map[string]any{
2648 "custom": "value",
2649 },
2650 },
2651 },
2652 })
2653
2654 require.NoError(t, err)
2655 require.Len(t, server.calls, 1)
2656
2657 call := server.calls[0]
2658 require.Equal(t, "gpt-3.5-turbo", call.body["model"])
2659 require.Equal(t, true, call.body["stream"])
2660
2661 metadata := call.body["metadata"].(map[string]any)
2662 require.Equal(t, "value", metadata["custom"])
2663
2664 streamOptions := call.body["stream_options"].(map[string]any)
2665 require.Equal(t, true, streamOptions["include_usage"])
2666
2667 messages := call.body["messages"].([]any)
2668 require.Len(t, messages, 1)
2669
2670 message := messages[0].(map[string]any)
2671 require.Equal(t, "user", message["role"])
2672 require.Equal(t, "Hello", message["content"])
2673 })
2674
2675 t.Run("should send serviceTier flex processing setting in streaming", func(t *testing.T) {
2676 t.Parallel()
2677
2678 server := newStreamingMockServer()
2679 defer server.close()
2680
2681 server.prepareStreamResponse(map[string]any{
2682 "content": []string{},
2683 })
2684
2685 provider := NewOpenAIProvider(
2686 WithOpenAIApiKey("test-api-key"),
2687 WithOpenAIBaseURL(server.server.URL),
2688 )
2689 model := provider.LanguageModel("o3-mini")
2690
2691 _, err := model.Stream(context.Background(), ai.Call{
2692 Prompt: testPrompt,
2693 ProviderOptions: ai.ProviderOptions{
2694 "openai": map[string]any{
2695 "serviceTier": "flex",
2696 },
2697 },
2698 })
2699
2700 require.NoError(t, err)
2701 require.Len(t, server.calls, 1)
2702
2703 call := server.calls[0]
2704 require.Equal(t, "o3-mini", call.body["model"])
2705 require.Equal(t, "flex", call.body["service_tier"])
2706 require.Equal(t, true, call.body["stream"])
2707
2708 streamOptions := call.body["stream_options"].(map[string]any)
2709 require.Equal(t, true, streamOptions["include_usage"])
2710
2711 messages := call.body["messages"].([]any)
2712 require.Len(t, messages, 1)
2713
2714 message := messages[0].(map[string]any)
2715 require.Equal(t, "user", message["role"])
2716 require.Equal(t, "Hello", message["content"])
2717 })
2718
2719 t.Run("should send serviceTier priority processing setting in streaming", func(t *testing.T) {
2720 t.Parallel()
2721
2722 server := newStreamingMockServer()
2723 defer server.close()
2724
2725 server.prepareStreamResponse(map[string]any{
2726 "content": []string{},
2727 })
2728
2729 provider := NewOpenAIProvider(
2730 WithOpenAIApiKey("test-api-key"),
2731 WithOpenAIBaseURL(server.server.URL),
2732 )
2733 model := provider.LanguageModel("gpt-4o-mini")
2734
2735 _, err := model.Stream(context.Background(), ai.Call{
2736 Prompt: testPrompt,
2737 ProviderOptions: ai.ProviderOptions{
2738 "openai": map[string]any{
2739 "serviceTier": "priority",
2740 },
2741 },
2742 })
2743
2744 require.NoError(t, err)
2745 require.Len(t, server.calls, 1)
2746
2747 call := server.calls[0]
2748 require.Equal(t, "gpt-4o-mini", call.body["model"])
2749 require.Equal(t, "priority", call.body["service_tier"])
2750 require.Equal(t, true, call.body["stream"])
2751
2752 streamOptions := call.body["stream_options"].(map[string]any)
2753 require.Equal(t, true, streamOptions["include_usage"])
2754
2755 messages := call.body["messages"].([]any)
2756 require.Len(t, messages, 1)
2757
2758 message := messages[0].(map[string]any)
2759 require.Equal(t, "user", message["role"])
2760 require.Equal(t, "Hello", message["content"])
2761 })
2762
2763 t.Run("should stream text delta for reasoning models", func(t *testing.T) {
2764 t.Parallel()
2765
2766 server := newStreamingMockServer()
2767 defer server.close()
2768
2769 server.prepareStreamResponse(map[string]any{
2770 "content": []string{"Hello, World!"},
2771 "model": "o1-preview",
2772 })
2773
2774 provider := NewOpenAIProvider(
2775 WithOpenAIApiKey("test-api-key"),
2776 WithOpenAIBaseURL(server.server.URL),
2777 )
2778 model := provider.LanguageModel("o1-preview")
2779
2780 stream, err := model.Stream(context.Background(), ai.Call{
2781 Prompt: testPrompt,
2782 })
2783
2784 require.NoError(t, err)
2785
2786 parts, err := collectStreamParts(stream)
2787 require.NoError(t, err)
2788
2789 // Find text parts
2790 var textDeltas []string
2791 for _, part := range parts {
2792 if part.Type == ai.StreamPartTypeTextDelta {
2793 textDeltas = append(textDeltas, part.Delta)
2794 }
2795 }
2796
2797 // Should contain the text content (without empty delta)
2798 require.Equal(t, []string{"Hello, World!"}, textDeltas)
2799 })
2800
2801 t.Run("should send reasoning tokens", func(t *testing.T) {
2802 t.Parallel()
2803
2804 server := newStreamingMockServer()
2805 defer server.close()
2806
2807 server.prepareStreamResponse(map[string]any{
2808 "content": []string{"Hello, World!"},
2809 "model": "o1-preview",
2810 "usage": map[string]any{
2811 "prompt_tokens": 15,
2812 "completion_tokens": 20,
2813 "total_tokens": 35,
2814 "completion_tokens_details": map[string]any{
2815 "reasoning_tokens": 10,
2816 },
2817 },
2818 })
2819
2820 provider := NewOpenAIProvider(
2821 WithOpenAIApiKey("test-api-key"),
2822 WithOpenAIBaseURL(server.server.URL),
2823 )
2824 model := provider.LanguageModel("o1-preview")
2825
2826 stream, err := model.Stream(context.Background(), ai.Call{
2827 Prompt: testPrompt,
2828 })
2829
2830 require.NoError(t, err)
2831
2832 parts, err := collectStreamParts(stream)
2833 require.NoError(t, err)
2834
2835 // Find finish part
2836 var finishPart *ai.StreamPart
2837 for _, part := range parts {
2838 if part.Type == ai.StreamPartTypeFinish {
2839 finishPart = &part
2840 break
2841 }
2842 }
2843
2844 require.NotNil(t, finishPart)
2845 require.Equal(t, int64(15), finishPart.Usage.InputTokens)
2846 require.Equal(t, int64(20), finishPart.Usage.OutputTokens)
2847 require.Equal(t, int64(35), finishPart.Usage.TotalTokens)
2848 require.Equal(t, int64(10), finishPart.Usage.ReasoningTokens)
2849 })
2850}